Skip to main content

rlx_ir/
layout.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shared layout vocabulary (plan #3).
17//!
18//! Tile / coordinate / stride types used by every kernel-author
19//! crate. Lives in `rlx-ir` (the leaf) so CPU and Metal stop
20//! re-deriving stride math independently.
21//!
22//! Backend-specific I/O (CPU pointer reads, Metal threadgroup
23//! loads) lives in the backend's own crate behind a `TileIO` trait
24//! — only the *vocabulary* is shared here.
25
26/// 2-D row-major or strided tile shape (in elements).
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct Tile2 {
29    pub rows: usize,
30    pub cols: usize,
31}
32
33impl Tile2 {
34    pub const fn new(rows: usize, cols: usize) -> Self {
35        Self { rows, cols }
36    }
37    pub const fn area(self) -> usize {
38        self.rows * self.cols
39    }
40}
41
42/// 2-D coordinate within a tile.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub struct Coord2 {
45    pub row: usize,
46    pub col: usize,
47}
48
49/// Per-axis strides in **elements** (not bytes). `row` is the
50/// distance between consecutive rows; `col` between consecutive
51/// columns. For a contiguous row-major tile of shape (R, C):
52/// `Strides2 { row: C, col: 1 }`.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct Strides2 {
55    pub row: usize,
56    pub col: usize,
57}
58
59impl Strides2 {
60    pub const fn row_major(cols: usize) -> Self {
61        Self { row: cols, col: 1 }
62    }
63    pub const fn col_major(rows: usize) -> Self {
64        Self { row: 1, col: rows }
65    }
66}
67
68/// Hierarchical shape tuple (plan #38). Borrowed from MAX's
69/// `layout/int_tuple.mojo`: shapes nest, so a `((B, S), (H, D))`
70/// expression captures the "outer batch+seq, inner heads+head_dim"
71/// structure of a tiled layout. Useful for kernels that want to
72/// reason about block-tiled sweeps without re-deriving the
73/// implied stride math each time.
74///
75/// Stays alongside the existing flat [`crate::Shape`] (which is
76/// what every op carries today). New code that benefits from
77/// hierarchy uses [`ShapeTuple`]; we don't migrate Shape because
78/// the entire codebase is built around it and the win is
79/// concentrated in advanced layout / fusion code.
80#[derive(Debug, Clone, PartialEq, Eq)]
81pub enum ShapeTuple {
82    /// Single concrete dimension.
83    Leaf(usize),
84    /// Ordered list of sub-tuples. Nesting is unbounded.
85    Nested(Vec<ShapeTuple>),
86}
87
88impl ShapeTuple {
89    /// One-dim leaf. `ShapeTuple::leaf(8)`.
90    pub fn leaf(n: usize) -> Self {
91        Self::Leaf(n)
92    }
93
94    /// Wrapping constructor for nested layouts.
95    pub fn nested(parts: Vec<ShapeTuple>) -> Self {
96        Self::Nested(parts)
97    }
98
99    /// Convenience: build a flat tuple from `&[usize]`. Each
100    /// element becomes a `Leaf`. `flat(&[2, 3, 4])` is equivalent
101    /// to `Nested(vec![Leaf(2), Leaf(3), Leaf(4)])`.
102    pub fn flat(dims: &[usize]) -> Self {
103        Self::Nested(dims.iter().map(|&n| Self::Leaf(n)).collect())
104    }
105
106    pub fn is_leaf(&self) -> bool {
107        matches!(self, Self::Leaf(_))
108    }
109
110    /// Top-level rank. Leaves are rank 1; nested tuples are the
111    /// length of the outer list.
112    pub fn rank(&self) -> usize {
113        match self {
114            Self::Leaf(_) => 1,
115            Self::Nested(v) => v.len(),
116        }
117    }
118
119    /// Total element count, traversing the entire hierarchy.
120    pub fn product(&self) -> usize {
121        match self {
122            Self::Leaf(n) => *n,
123            Self::Nested(v) => v.iter().map(|p| p.product()).product(),
124        }
125    }
126
127    /// Flatten into a row-major sequence of leaves. Useful when
128    /// converting to the existing `Shape` type.
129    pub fn flatten(&self) -> Vec<usize> {
130        let mut out = Vec::new();
131        self.flatten_into(&mut out);
132        out
133    }
134
135    fn flatten_into(&self, out: &mut Vec<usize>) {
136        match self {
137            Self::Leaf(n) => out.push(*n),
138            Self::Nested(v) => v.iter().for_each(|p| p.flatten_into(out)),
139        }
140    }
141
142    /// Walk a path of indices through the hierarchy. Returns
143    /// the sub-tuple at `path` or `None` if the path goes out of
144    /// bounds at any level.
145    ///
146    /// `[]` returns `Some(self)`; `[0]` returns the first child.
147    pub fn get(&self, path: &[usize]) -> Option<&ShapeTuple> {
148        if path.is_empty() {
149            return Some(self);
150        }
151        match self {
152            Self::Leaf(_) => None, // can't descend into a leaf
153            Self::Nested(v) => v.get(path[0]).and_then(|c| c.get(&path[1..])),
154        }
155    }
156}
157
158/// Ragged-tensor descriptor (plan #4). Represents a tensor of
159/// variable-length sequences laid out without padding:
160///
161///   data:    [total_elems, trailing_dim]   flat
162///   offsets: [batch + 1]                    cumulative starts
163///
164/// `data[offsets[i]..offsets[i+1]]` is row `i`'s contents (each
165/// row has `(offsets[i+1] - offsets[i])` elements times trailing).
166///
167/// Borrowed from MAX's `nn/_ragged_utils.mojo`, `kv_cache_ragged.mojo`,
168/// and `gemv_partial_norm.mojo`. Essential for serving throughput when
169/// sequences in a batch have very different lengths — padding to max
170/// wastes most of the work; ragged + offset-driven kernels process each
171/// row at its actual length.
172///
173/// Today this is the type vocabulary; kernel paths come per-op as
174/// the ragged use-case lands (the cumsum primitive #44 already
175/// covers offset construction).
176#[derive(Debug, Clone, Copy, PartialEq, Eq)]
177pub struct Ragged {
178    /// Number of rows (= batch).
179    pub rows: usize,
180    /// Trailing per-element width. For BERT it's the hidden
181    /// dimension; for KV cache it's `num_heads * head_dim`. 1 if
182    /// the tensor is a flat sequence of scalars.
183    pub trailing: usize,
184    /// Total elements across all rows (sum of per-row lengths).
185    /// Equals `offsets[rows]` when offsets are materialized.
186    pub total: usize,
187}
188
189impl Ragged {
190    pub const fn new(rows: usize, trailing: usize, total: usize) -> Self {
191        Self {
192            rows,
193            trailing,
194            total,
195        }
196    }
197
198    /// Total f32 element count (data) — does not count the offsets
199    /// table.
200    pub const fn data_elements(self) -> usize {
201        self.total * self.trailing
202    }
203
204    /// Element count of the offsets table (`rows + 1`).
205    pub const fn offsets_elements(self) -> usize {
206        self.rows + 1
207    }
208}
209
210/// 3-D extension for `[batch, rows, cols]` tiles. Common for
211/// per-head attention sweeps.
212#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub struct Tile3 {
214    pub batch: usize,
215    pub rows: usize,
216    pub cols: usize,
217}
218
219impl Tile3 {
220    pub const fn new(batch: usize, rows: usize, cols: usize) -> Self {
221        Self { batch, rows, cols }
222    }
223    pub const fn area(self) -> usize {
224        self.batch * self.rows * self.cols
225    }
226}
227
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub struct Strides3 {
230    pub batch: usize,
231    pub row: usize,
232    pub col: usize,
233}
234
235impl Strides3 {
236    pub const fn row_major(rows: usize, cols: usize) -> Self {
237        Self {
238            batch: rows * cols,
239            row: cols,
240            col: 1,
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn tile2_area() {
251        assert_eq!(Tile2::new(3, 4).area(), 12);
252    }
253
254    #[test]
255    fn strides2_presets() {
256        assert_eq!(Strides2::row_major(8), Strides2 { row: 8, col: 1 });
257        assert_eq!(Strides2::col_major(8), Strides2 { row: 1, col: 8 });
258    }
259
260    #[test]
261    fn strides3_row_major() {
262        assert_eq!(
263            Strides3::row_major(3, 4),
264            Strides3 {
265                batch: 12,
266                row: 4,
267                col: 1
268            }
269        );
270    }
271
272    // Tuple tests live here so `tuple` test names cover the new
273    // hierarchical type (the runtime check covers the const fns).
274    #[test]
275    fn tuple_leaf_constructors() {
276        let a = ShapeTuple::leaf(8);
277        assert_eq!(a.flatten(), vec![8]);
278        assert_eq!(a.product(), 8);
279        assert!(a.is_leaf());
280    }
281
282    #[test]
283    fn tuple_flat_constructor() {
284        let s = ShapeTuple::flat(&[2, 3, 4]);
285        assert_eq!(s.flatten(), vec![2, 3, 4]);
286        assert_eq!(s.product(), 24);
287        assert_eq!(s.rank(), 3);
288    }
289
290    #[test]
291    fn tuple_nested_product_and_flatten() {
292        // BERT-shape: ((batch, seq), (heads, head_dim)).
293        let bs = ShapeTuple::nested(vec![ShapeTuple::leaf(8), ShapeTuple::leaf(15)]);
294        let nh = ShapeTuple::nested(vec![ShapeTuple::leaf(12), ShapeTuple::leaf(64)]);
295        let s = ShapeTuple::nested(vec![bs, nh]);
296        assert_eq!(s.product(), 8 * 15 * 12 * 64);
297        assert_eq!(s.flatten(), vec![8, 15, 12, 64]);
298        assert_eq!(s.rank(), 2); // top-level rank
299    }
300
301    #[test]
302    fn tuple_get_resolves_path() {
303        let inner = ShapeTuple::nested(vec![ShapeTuple::leaf(12), ShapeTuple::leaf(64)]);
304        let s = ShapeTuple::nested(vec![ShapeTuple::leaf(8), ShapeTuple::leaf(15), inner]);
305        assert_eq!(s.get(&[0]), Some(&ShapeTuple::Leaf(8)));
306        assert_eq!(s.get(&[2, 1]), Some(&ShapeTuple::Leaf(64)));
307        assert_eq!(s.get(&[2, 99]), None);
308    }
309
310    #[test]
311    fn ragged_element_counts() {
312        // 4 rows with total 30 elements; trailing = 8 (hidden dim).
313        let r = Ragged::new(4, 8, 30);
314        assert_eq!(r.data_elements(), 240); // 30 * 8 floats
315        assert_eq!(r.offsets_elements(), 5); // rows + 1
316    }
317}