Skip to main content

rlx_cpu/
tile.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! CPU `TileIO` impls (plans #23 + #27).
17//!
18//! Borrowed from MAX's
19//! `structured_kernels/{kernel_common, tile_types, smem_types}.mojo`
20//! and `layout/tile_io.mojo`. Lifts the "kernel-author standard
21//! library" pattern: typed primitives kernels compose, instead of
22//! re-deriving stride math and load/store loops per kernel.
23//!
24//! The vocabulary types (`Tile2`, `Coord2`, `Strides2`) live in
25//! `rlx_ir::layout` (plan #3 — shared layout IR) so Metal kernels
26//! can use the same names. CPU-specific `TileIO` impls live here.
27
28pub use rlx_ir::{Coord2, Strides2, Tile2};
29
30/// Tile I/O trait — load / store / prefetch parameterized over the
31/// physical layout. Two impls today: [`RowMajorTile`] (the standard
32/// flat layout) and [`StridedTile`] (when reading a non-contiguous
33/// view, e.g. last-axis Narrow into Attention).
34///
35/// Methods take pointers (not slices) so the abstraction works for
36/// both owned and aliased buffers.
37pub trait TileIO {
38    /// Compute the byte address for a coordinate. Used by
39    /// `load` / `store` / `prefetch` so impls only need to define
40    /// the address arithmetic once.
41    /// SAFETY: caller checks bounds.
42    unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32;
43
44    /// Load a tile element by `(row, col)`.
45    /// SAFETY: caller ensures the address is valid for read.
46    #[inline(always)]
47    unsafe fn load(&self, base: *const f32, c: Coord2) -> f32 {
48        unsafe { *self.address(base, c) }
49    }
50
51    /// Store an element by `(row, col)`.
52    /// SAFETY: caller ensures the address is valid for write.
53    #[inline(always)]
54    unsafe fn store(&self, base: *mut f32, c: Coord2, v: f32) {
55        unsafe {
56            *(self.address(base, c) as *mut f32) = v;
57        }
58    }
59
60    /// Hint to the prefetcher. On aarch64 issues a single
61    /// `prfm pldl1keep` (load into L1, retain). Elsewhere a no-op.
62    /// SAFETY: caller ensures the address is in a valid mapping.
63    #[inline(always)]
64    unsafe fn prefetch(&self, base: *const f32, c: Coord2) {
65        unsafe {
66            let addr = self.address(base, c);
67            #[cfg(target_arch = "aarch64")]
68            {
69                std::arch::asm!("prfm pldl1keep, [{0}]", in(reg) addr,
70                    options(nostack, readonly));
71            }
72            #[cfg(not(target_arch = "aarch64"))]
73            {
74                let _ = addr;
75            }
76        }
77    }
78}
79
80/// Row-major contiguous tile: `addr = base + row * cols + col`.
81#[derive(Debug, Clone, Copy)]
82pub struct RowMajorTile {
83    pub shape: Tile2,
84}
85
86impl TileIO for RowMajorTile {
87    #[inline(always)]
88    unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32 {
89        unsafe { base.add(c.row * self.shape.cols + c.col) }
90    }
91}
92
93/// Strided tile: each row stride is configurable. Lets a kernel
94/// read a non-contiguous view (e.g. last-axis Narrow output) with
95/// the same TileIO interface as a contiguous tile.
96#[derive(Debug, Clone, Copy)]
97pub struct StridedTile {
98    pub shape: Tile2,
99    pub strides: Strides2,
100}
101
102impl TileIO for StridedTile {
103    #[inline(always)]
104    unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32 {
105        unsafe { base.add(c.row * self.strides.row + c.col * self.strides.col) }
106    }
107}
108
109/// Walk every element of a tile in row-major order, calling `f`.
110/// Convenience for kernels that don't care about iteration order.
111#[inline(always)]
112pub fn for_each_coord(shape: Tile2, mut f: impl FnMut(Coord2)) {
113    for r in 0..shape.rows {
114        for c in 0..shape.cols {
115            f(Coord2 { row: r, col: c });
116        }
117    }
118}
119
120/// Tile copy via TileIO. Source and destination layouts can differ
121/// (the typical use: read strided source, write contiguous dst).
122///
123/// # Safety
124/// `src_base` and `dst_base` must point into allocations large enough
125/// for `shape`'s extents under the IO layouts in `src_io` / `dst_io`.
126/// The two ranges may not overlap.
127#[inline]
128pub unsafe fn copy_tile<S: TileIO, D: TileIO>(
129    src_io: &S,
130    src_base: *const f32,
131    dst_io: &D,
132    dst_base: *mut f32,
133    shape: Tile2,
134) {
135    for_each_coord(shape, |c| unsafe {
136        dst_io.store(dst_base, c, src_io.load(src_base, c));
137    });
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn row_major_round_trip() {
146        let mut buf = [0f32; 12]; // 3×4
147        let io = RowMajorTile {
148            shape: Tile2::new(3, 4),
149        };
150        unsafe {
151            io.store(buf.as_mut_ptr(), Coord2 { row: 1, col: 2 }, 42.0);
152            assert_eq!(io.load(buf.as_ptr(), Coord2 { row: 1, col: 2 }), 42.0);
153        }
154        assert_eq!(buf[4 + 2], 42.0);
155    }
156
157    #[test]
158    fn strided_reads_non_contig_view() {
159        // Source: 4-row tile inside a 4-row × 8-col parent.
160        // Pretending we narrowed cols 2..6 of each row; row stride = 8.
161        let parent: Vec<f32> = (0..32).map(|i| i as f32).collect();
162        let view = StridedTile {
163            shape: Tile2::new(4, 4),
164            strides: Strides2 { row: 8, col: 1 },
165        };
166        // base pointer offset to col=2 of row 0
167        let base = unsafe { parent.as_ptr().add(2) };
168        let v = unsafe { view.load(base, Coord2 { row: 1, col: 1 }) };
169        // expected: parent[1*8 + 2 + 1] = 11
170        assert_eq!(v, 11.0);
171    }
172
173    #[test]
174    fn prefetch_doesnt_panic() {
175        // Prefetch is a hint — it should not crash, and should
176        // accept any in-bounds address. We just verify the call
177        // sequence compiles + runs on the current target.
178        let buf = vec![0f32; 64];
179        let io = RowMajorTile {
180            shape: Tile2::new(8, 8),
181        };
182        unsafe {
183            io.prefetch(buf.as_ptr(), Coord2 { row: 0, col: 0 });
184            io.prefetch(buf.as_ptr(), Coord2 { row: 7, col: 7 });
185        }
186    }
187
188    #[test]
189    fn copy_tile_strided_to_contig() {
190        let parent: Vec<f32> = (0..32).map(|i| i as f32).collect();
191        let mut dst = vec![0f32; 16]; // 4×4 contiguous
192        let src_io = StridedTile {
193            shape: Tile2::new(4, 4),
194            strides: Strides2 { row: 8, col: 1 },
195        };
196        let dst_io = RowMajorTile {
197            shape: Tile2::new(4, 4),
198        };
199        let base = unsafe { parent.as_ptr().add(2) };
200        unsafe {
201            copy_tile(&src_io, base, &dst_io, dst.as_mut_ptr(), Tile2::new(4, 4));
202        }
203        // First row of dst should be parent[2..6] = [2,3,4,5].
204        assert_eq!(&dst[0..4], &[2.0, 3.0, 4.0, 5.0]);
205        assert_eq!(&dst[4..8], &[10.0, 11.0, 12.0, 13.0]);
206    }
207}