rlx_ir/layout.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shared layout vocabulary (plan #3).
17//!
18//! Tile / coordinate / stride types used by every kernel-author
19//! crate. Lives in `rlx-ir` (the leaf) so CPU and Metal stop
20//! re-deriving stride math independently.
21//!
22//! Backend-specific I/O (CPU pointer reads, Metal threadgroup
23//! loads) lives in the backend's own crate behind a `TileIO` trait
24//! — only the *vocabulary* is shared here.
25
26/// 2-D row-major or strided tile shape (in elements).
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct Tile2 {
29 pub rows: usize,
30 pub cols: usize,
31}
32
33impl Tile2 {
34 pub const fn new(rows: usize, cols: usize) -> Self {
35 Self { rows, cols }
36 }
37 pub const fn area(self) -> usize {
38 self.rows * self.cols
39 }
40}
41
42/// 2-D coordinate within a tile.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub struct Coord2 {
45 pub row: usize,
46 pub col: usize,
47}
48
49/// Per-axis strides in **elements** (not bytes). `row` is the
50/// distance between consecutive rows; `col` between consecutive
51/// columns. For a contiguous row-major tile of shape (R, C):
52/// `Strides2 { row: C, col: 1 }`.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct Strides2 {
55 pub row: usize,
56 pub col: usize,
57}
58
59impl Strides2 {
60 pub const fn row_major(cols: usize) -> Self {
61 Self { row: cols, col: 1 }
62 }
63 pub const fn col_major(rows: usize) -> Self {
64 Self { row: 1, col: rows }
65 }
66}
67
68/// Hierarchical shape tuple (plan #38). Borrowed from MAX's
69/// `layout/int_tuple.mojo`: shapes nest, so a `((B, S), (H, D))`
70/// expression captures the "outer batch+seq, inner heads+head_dim"
71/// structure of a tiled layout. Useful for kernels that want to
72/// reason about block-tiled sweeps without re-deriving the
73/// implied stride math each time.
74///
75/// Stays alongside the existing flat [`crate::Shape`] (which is
76/// what every op carries today). New code that benefits from
77/// hierarchy uses [`ShapeTuple`]; we don't migrate Shape because
78/// the entire codebase is built around it and the win is
79/// concentrated in advanced layout / fusion code.
80#[derive(Debug, Clone, PartialEq, Eq)]
81pub enum ShapeTuple {
82 /// Single concrete dimension.
83 Leaf(usize),
84 /// Ordered list of sub-tuples. Nesting is unbounded.
85 Nested(Vec<ShapeTuple>),
86}
87
88impl ShapeTuple {
89 /// One-dim leaf. `ShapeTuple::leaf(8)`.
90 pub fn leaf(n: usize) -> Self {
91 Self::Leaf(n)
92 }
93
94 /// Wrapping constructor for nested layouts.
95 pub fn nested(parts: Vec<ShapeTuple>) -> Self {
96 Self::Nested(parts)
97 }
98
99 /// Convenience: build a flat tuple from `&[usize]`. Each
100 /// element becomes a `Leaf`. `flat(&[2, 3, 4])` is equivalent
101 /// to `Nested(vec![Leaf(2), Leaf(3), Leaf(4)])`.
102 pub fn flat(dims: &[usize]) -> Self {
103 Self::Nested(dims.iter().map(|&n| Self::Leaf(n)).collect())
104 }
105
106 pub fn is_leaf(&self) -> bool {
107 matches!(self, Self::Leaf(_))
108 }
109
110 /// Top-level rank. Leaves are rank 1; nested tuples are the
111 /// length of the outer list.
112 pub fn rank(&self) -> usize {
113 match self {
114 Self::Leaf(_) => 1,
115 Self::Nested(v) => v.len(),
116 }
117 }
118
119 /// Total element count, traversing the entire hierarchy.
120 pub fn product(&self) -> usize {
121 match self {
122 Self::Leaf(n) => *n,
123 Self::Nested(v) => v.iter().map(|p| p.product()).product(),
124 }
125 }
126
127 /// Flatten into a row-major sequence of leaves. Useful when
128 /// converting to the existing `Shape` type.
129 pub fn flatten(&self) -> Vec<usize> {
130 let mut out = Vec::new();
131 self.flatten_into(&mut out);
132 out
133 }
134
135 fn flatten_into(&self, out: &mut Vec<usize>) {
136 match self {
137 Self::Leaf(n) => out.push(*n),
138 Self::Nested(v) => v.iter().for_each(|p| p.flatten_into(out)),
139 }
140 }
141
142 /// Walk a path of indices through the hierarchy. Returns
143 /// the sub-tuple at `path` or `None` if the path goes out of
144 /// bounds at any level.
145 ///
146 /// `[]` returns `Some(self)`; `[0]` returns the first child.
147 pub fn get(&self, path: &[usize]) -> Option<&ShapeTuple> {
148 if path.is_empty() {
149 return Some(self);
150 }
151 match self {
152 Self::Leaf(_) => None, // can't descend into a leaf
153 Self::Nested(v) => v.get(path[0]).and_then(|c| c.get(&path[1..])),
154 }
155 }
156}
157
158/// Ragged-tensor descriptor (plan #4). Represents a tensor of
159/// variable-length sequences laid out without padding:
160///
161/// data: [total_elems, trailing_dim] flat
162/// offsets: [batch + 1] cumulative starts
163///
164/// `data[offsets[i]..offsets[i+1]]` is row `i`'s contents (each
165/// row has `(offsets[i+1] - offsets[i])` elements times trailing).
166///
167/// Borrowed from MAX's `nn/_ragged_utils.mojo`, `kv_cache_ragged.mojo`,
168/// and `gemv_partial_norm.mojo`. Essential for serving throughput when
169/// sequences in a batch have very different lengths — padding to max
170/// wastes most of the work; ragged + offset-driven kernels process each
171/// row at its actual length.
172///
173/// Today this is the type vocabulary; kernel paths come per-op as
174/// the ragged use-case lands (the cumsum primitive #44 already
175/// covers offset construction).
176#[derive(Debug, Clone, Copy, PartialEq, Eq)]
177pub struct Ragged {
178 /// Number of rows (= batch).
179 pub rows: usize,
180 /// Trailing per-element width. For BERT it's the hidden
181 /// dimension; for KV cache it's `num_heads * head_dim`. 1 if
182 /// the tensor is a flat sequence of scalars.
183 pub trailing: usize,
184 /// Total elements across all rows (sum of per-row lengths).
185 /// Equals `offsets[rows]` when offsets are materialized.
186 pub total: usize,
187}
188
189impl Ragged {
190 pub const fn new(rows: usize, trailing: usize, total: usize) -> Self {
191 Self {
192 rows,
193 trailing,
194 total,
195 }
196 }
197
198 /// Total f32 element count (data) — does not count the offsets
199 /// table.
200 pub const fn data_elements(self) -> usize {
201 self.total * self.trailing
202 }
203
204 /// Element count of the offsets table (`rows + 1`).
205 pub const fn offsets_elements(self) -> usize {
206 self.rows + 1
207 }
208}
209
210/// 3-D extension for `[batch, rows, cols]` tiles. Common for
211/// per-head attention sweeps.
212#[derive(Debug, Clone, Copy, PartialEq, Eq)]
213pub struct Tile3 {
214 pub batch: usize,
215 pub rows: usize,
216 pub cols: usize,
217}
218
219impl Tile3 {
220 pub const fn new(batch: usize, rows: usize, cols: usize) -> Self {
221 Self { batch, rows, cols }
222 }
223 pub const fn area(self) -> usize {
224 self.batch * self.rows * self.cols
225 }
226}
227
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub struct Strides3 {
230 pub batch: usize,
231 pub row: usize,
232 pub col: usize,
233}
234
235impl Strides3 {
236 pub const fn row_major(rows: usize, cols: usize) -> Self {
237 Self {
238 batch: rows * cols,
239 row: cols,
240 col: 1,
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn tile2_area() {
251 assert_eq!(Tile2::new(3, 4).area(), 12);
252 }
253
254 #[test]
255 fn strides2_presets() {
256 assert_eq!(Strides2::row_major(8), Strides2 { row: 8, col: 1 });
257 assert_eq!(Strides2::col_major(8), Strides2 { row: 1, col: 8 });
258 }
259
260 #[test]
261 fn strides3_row_major() {
262 assert_eq!(
263 Strides3::row_major(3, 4),
264 Strides3 {
265 batch: 12,
266 row: 4,
267 col: 1
268 }
269 );
270 }
271
272 // Tuple tests live here so `tuple` test names cover the new
273 // hierarchical type (the runtime check covers the const fns).
274 #[test]
275 fn tuple_leaf_constructors() {
276 let a = ShapeTuple::leaf(8);
277 assert_eq!(a.flatten(), vec![8]);
278 assert_eq!(a.product(), 8);
279 assert!(a.is_leaf());
280 }
281
282 #[test]
283 fn tuple_flat_constructor() {
284 let s = ShapeTuple::flat(&[2, 3, 4]);
285 assert_eq!(s.flatten(), vec![2, 3, 4]);
286 assert_eq!(s.product(), 24);
287 assert_eq!(s.rank(), 3);
288 }
289
290 #[test]
291 fn tuple_nested_product_and_flatten() {
292 // BERT-shape: ((batch, seq), (heads, head_dim)).
293 let bs = ShapeTuple::nested(vec![ShapeTuple::leaf(8), ShapeTuple::leaf(15)]);
294 let nh = ShapeTuple::nested(vec![ShapeTuple::leaf(12), ShapeTuple::leaf(64)]);
295 let s = ShapeTuple::nested(vec![bs, nh]);
296 assert_eq!(s.product(), 8 * 15 * 12 * 64);
297 assert_eq!(s.flatten(), vec![8, 15, 12, 64]);
298 assert_eq!(s.rank(), 2); // top-level rank
299 }
300
301 #[test]
302 fn tuple_get_resolves_path() {
303 let inner = ShapeTuple::nested(vec![ShapeTuple::leaf(12), ShapeTuple::leaf(64)]);
304 let s = ShapeTuple::nested(vec![ShapeTuple::leaf(8), ShapeTuple::leaf(15), inner]);
305 assert_eq!(s.get(&[0]), Some(&ShapeTuple::Leaf(8)));
306 assert_eq!(s.get(&[2, 1]), Some(&ShapeTuple::Leaf(64)));
307 assert_eq!(s.get(&[2, 99]), None);
308 }
309
310 #[test]
311 fn ragged_element_counts() {
312 // 4 rows with total 30 elements; trailing = 8 (hidden dim).
313 let r = Ragged::new(4, 8, 30);
314 assert_eq!(r.data_elements(), 240); // 30 * 8 floats
315 assert_eq!(r.offsets_elements(), 5); // rows + 1
316 }
317}