rlx_cpu/tile.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! CPU `TileIO` impls (plans #23 + #27).
17//!
18//! Borrowed from MAX's
19//! `structured_kernels/{kernel_common, tile_types, smem_types}.mojo`
20//! and `layout/tile_io.mojo`. Lifts the "kernel-author standard
21//! library" pattern: typed primitives kernels compose, instead of
22//! re-deriving stride math and load/store loops per kernel.
23//!
24//! The vocabulary types (`Tile2`, `Coord2`, `Strides2`) live in
25//! `rlx_ir::layout` (plan #3 — shared layout IR) so Metal kernels
26//! can use the same names. CPU-specific `TileIO` impls live here.
27
28pub use rlx_ir::{Coord2, Strides2, Tile2};
29
30/// Tile I/O trait — load / store / prefetch parameterized over the
31/// physical layout. Two impls today: [`RowMajorTile`] (the standard
32/// flat layout) and [`StridedTile`] (when reading a non-contiguous
33/// view, e.g. last-axis Narrow into Attention).
34///
35/// Methods take pointers (not slices) so the abstraction works for
36/// both owned and aliased buffers.
37pub trait TileIO {
38 /// Compute the byte address for a coordinate. Used by
39 /// `load` / `store` / `prefetch` so impls only need to define
40 /// the address arithmetic once.
41 /// SAFETY: caller checks bounds.
42 unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32;
43
44 /// Load a tile element by `(row, col)`.
45 /// SAFETY: caller ensures the address is valid for read.
46 #[inline(always)]
47 unsafe fn load(&self, base: *const f32, c: Coord2) -> f32 {
48 unsafe { *self.address(base, c) }
49 }
50
51 /// Store an element by `(row, col)`.
52 /// SAFETY: caller ensures the address is valid for write.
53 #[inline(always)]
54 unsafe fn store(&self, base: *mut f32, c: Coord2, v: f32) {
55 unsafe {
56 *(self.address(base, c) as *mut f32) = v;
57 }
58 }
59
60 /// Hint to the prefetcher. On aarch64 issues a single
61 /// `prfm pldl1keep` (load into L1, retain). Elsewhere a no-op.
62 /// SAFETY: caller ensures the address is in a valid mapping.
63 #[inline(always)]
64 unsafe fn prefetch(&self, base: *const f32, c: Coord2) {
65 unsafe {
66 let addr = self.address(base, c);
67 #[cfg(target_arch = "aarch64")]
68 {
69 std::arch::asm!("prfm pldl1keep, [{0}]", in(reg) addr,
70 options(nostack, readonly));
71 }
72 #[cfg(not(target_arch = "aarch64"))]
73 {
74 let _ = addr;
75 }
76 }
77 }
78}
79
80/// Row-major contiguous tile: `addr = base + row * cols + col`.
81#[derive(Debug, Clone, Copy)]
82pub struct RowMajorTile {
83 pub shape: Tile2,
84}
85
86impl TileIO for RowMajorTile {
87 #[inline(always)]
88 unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32 {
89 unsafe { base.add(c.row * self.shape.cols + c.col) }
90 }
91}
92
93/// Strided tile: each row stride is configurable. Lets a kernel
94/// read a non-contiguous view (e.g. last-axis Narrow output) with
95/// the same TileIO interface as a contiguous tile.
96#[derive(Debug, Clone, Copy)]
97pub struct StridedTile {
98 pub shape: Tile2,
99 pub strides: Strides2,
100}
101
102impl TileIO for StridedTile {
103 #[inline(always)]
104 unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32 {
105 unsafe { base.add(c.row * self.strides.row + c.col * self.strides.col) }
106 }
107}
108
109/// Walk every element of a tile in row-major order, calling `f`.
110/// Convenience for kernels that don't care about iteration order.
111#[inline(always)]
112pub fn for_each_coord(shape: Tile2, mut f: impl FnMut(Coord2)) {
113 for r in 0..shape.rows {
114 for c in 0..shape.cols {
115 f(Coord2 { row: r, col: c });
116 }
117 }
118}
119
120/// Tile copy via TileIO. Source and destination layouts can differ
121/// (the typical use: read strided source, write contiguous dst).
122///
123/// # Safety
124/// `src_base` and `dst_base` must point into allocations large enough
125/// for `shape`'s extents under the IO layouts in `src_io` / `dst_io`.
126/// The two ranges may not overlap.
127#[inline]
128pub unsafe fn copy_tile<S: TileIO, D: TileIO>(
129 src_io: &S,
130 src_base: *const f32,
131 dst_io: &D,
132 dst_base: *mut f32,
133 shape: Tile2,
134) {
135 for_each_coord(shape, |c| unsafe {
136 dst_io.store(dst_base, c, src_io.load(src_base, c));
137 });
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn row_major_round_trip() {
146 let mut buf = [0f32; 12]; // 3×4
147 let io = RowMajorTile {
148 shape: Tile2::new(3, 4),
149 };
150 unsafe {
151 io.store(buf.as_mut_ptr(), Coord2 { row: 1, col: 2 }, 42.0);
152 assert_eq!(io.load(buf.as_ptr(), Coord2 { row: 1, col: 2 }), 42.0);
153 }
154 assert_eq!(buf[4 + 2], 42.0);
155 }
156
157 #[test]
158 fn strided_reads_non_contig_view() {
159 // Source: 4-row tile inside a 4-row × 8-col parent.
160 // Pretending we narrowed cols 2..6 of each row; row stride = 8.
161 let parent: Vec<f32> = (0..32).map(|i| i as f32).collect();
162 let view = StridedTile {
163 shape: Tile2::new(4, 4),
164 strides: Strides2 { row: 8, col: 1 },
165 };
166 // base pointer offset to col=2 of row 0
167 let base = unsafe { parent.as_ptr().add(2) };
168 let v = unsafe { view.load(base, Coord2 { row: 1, col: 1 }) };
169 // expected: parent[1*8 + 2 + 1] = 11
170 assert_eq!(v, 11.0);
171 }
172
173 #[test]
174 fn prefetch_doesnt_panic() {
175 // Prefetch is a hint — it should not crash, and should
176 // accept any in-bounds address. We just verify the call
177 // sequence compiles + runs on the current target.
178 let buf = vec![0f32; 64];
179 let io = RowMajorTile {
180 shape: Tile2::new(8, 8),
181 };
182 unsafe {
183 io.prefetch(buf.as_ptr(), Coord2 { row: 0, col: 0 });
184 io.prefetch(buf.as_ptr(), Coord2 { row: 7, col: 7 });
185 }
186 }
187
188 #[test]
189 fn copy_tile_strided_to_contig() {
190 let parent: Vec<f32> = (0..32).map(|i| i as f32).collect();
191 let mut dst = vec![0f32; 16]; // 4×4 contiguous
192 let src_io = StridedTile {
193 shape: Tile2::new(4, 4),
194 strides: Strides2 { row: 8, col: 1 },
195 };
196 let dst_io = RowMajorTile {
197 shape: Tile2::new(4, 4),
198 };
199 let base = unsafe { parent.as_ptr().add(2) };
200 unsafe {
201 copy_tile(&src_io, base, &dst_io, dst.as_mut_ptr(), Tile2::new(4, 4));
202 }
203 // First row of dst should be parent[2..6] = [2,3,4,5].
204 assert_eq!(&dst[0..4], &[2.0, 3.0, 4.0, 5.0]);
205 assert_eq!(&dst[4..8], &[10.0, 11.0, 12.0, 13.0]);
206 }
207}