cubecl_matmul/components/stage/
layout.rs

1use std::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl};
5
6use crate::components::tile::Tile;
7use crate::components::{Ident, InputIdent, MatrixLayout};
8
9use super::{StageConfig, StageMemory};
10
11#[cube]
12/// Determines the order in which tiles are stored in shared memory,
13/// if [TilingLayout] is contiguous
14pub trait TilingOrder: 'static + Send + Sync + Clone + Copy {
15    /// Returns the coordinates (row, col) of the tile
16    fn to_row_col<C: StageConfig>(
17        nth: u32,
18        #[comptime] tile_count_rows: u32,
19        #[comptime] tile_count_cols: u32,
20        #[comptime] ident: Ident,
21        #[comptime] config: C,
22    ) -> (u32, u32);
23
24    /// Given the coordinates (row, col) of the tile,
25    /// returns its index in shared memory
26    fn to_nth_tile<C: StageConfig>(
27        row: u32,
28        col: u32,
29        #[comptime] tile_count_rows: u32,
30        #[comptime] tile_count_cols: u32,
31        #[comptime] ident: Ident,
32        #[comptime] config: C,
33    ) -> u32;
34
35    /// Return the trait value as enum
36    fn to_enum() -> comptime_type!(TilingOrderEnum);
37}
38
39/// Enum for the available traits
40pub enum TilingOrderEnum {
41    /// Tiles of the same row are side by side
42    RowMajor,
43    /// Tiles of the column are side by side
44    ColMajor,
45    /// Tiles are laid out in column-major order across a fixed number of rows,
46    /// with all tiles from those rows placed contiguously side by side.
47    Ordered,
48    /// If the matrix data layout is row-major, the tiling order is col-major
49    /// If the matrix data layout is col-major, the tiling order is row-major
50    Tma,
51}
52
53#[derive(CubeType, Clone, Copy)]
54/// Tiles laid out in row-major order.
55///
56/// Each tile is contiguous, and tiles are placed side by side,
57/// row by row (left to right, top to bottom).
58/// Example tile indices:
59///
60/// ```text
61/// ┌───┬───┐
62/// │ 0 │ 1 │
63/// ├───┼───┤
64/// │ 2 │ 3 │
65/// ├───┼───┤
66/// │ 4 │ 5 │
67/// ├───┼───┤
68/// │ 6 │ 7 │
69/// └───┴───┘
70/// ```
71pub struct RowMajorTilingOrder {}
72
73#[derive(CubeType, Clone, Copy)]
74/// Tiles laid out in column-major order.
75///
76/// Each tile is contiguous, and tiles are placed top to bottom,
77/// column by column (like reading columns left to right).
78///
79/// Example tile indices:
80///
81/// ```text
82/// ┌───┬───┐
83/// │ 0 │ 4 │
84/// ├───┼───┤
85/// │ 1 │ 5 │
86/// ├───┼───┤
87/// │ 2 │ 6 │
88/// ├───┼───┤
89/// │ 3 │ 7 │
90/// └───┴───┘
91/// ```
92pub struct ColMajorTilingOrder {}
93
94#[derive(CubeType, Clone, Copy)]
95/// Tiles are laid out in column-major order across a fixed number of rows,
96/// with all tiles from those rows placed contiguously side by side.
97///
98/// The grouping should match the set of tiles processed by a warp,
99/// so warp-local tile memory remains contiguous.
100///
101/// This layout ensures that for Lhs data, all tiles needed for a given
102/// `k` iteration are stored contiguously, before moving to the next iteration.
103///
104/// Note: use only for Lhs
105///
106/// Example tile indices for 4 rows grouped 2 at a time:
107///
108/// ```text
109/// ┌───┬───┐
110/// │ 0 │ 2 │
111/// ├───┼───┤
112/// │ 1 │ 3 │
113/// ├───┼───┤
114/// │ 4 │ 6 │
115/// ├───┼───┤
116/// │ 5 │ 7 │
117/// └───┴───┘
118/// ```
119pub struct OrderedTilingOrder {}
120
121#[cube]
122impl TilingOrder for RowMajorTilingOrder {
123    fn to_row_col<C: StageConfig>(
124        nth: u32,
125        #[comptime] _tile_count_rows: u32,
126        #[comptime] tile_count_cols: u32,
127        #[comptime] _ident: Ident,
128        #[comptime] _config: C,
129    ) -> (u32, u32) {
130        (nth / tile_count_cols, nth % tile_count_cols)
131    }
132    fn to_nth_tile<C: StageConfig>(
133        row: u32,
134        col: u32,
135        #[comptime] _tile_count_rows: u32,
136        #[comptime] tile_count_cols: u32,
137        #[comptime] _ident: Ident,
138        #[comptime] _config: C,
139    ) -> u32 {
140        row * tile_count_cols + col
141    }
142
143    fn to_enum() -> comptime_type!(TilingOrderEnum) {
144        TilingOrderEnum::RowMajor
145    }
146}
147
148#[cube]
149impl TilingOrder for ColMajorTilingOrder {
150    fn to_row_col<C: StageConfig>(
151        nth: u32,
152        #[comptime] num_rows: u32,
153        #[comptime] _num_cols: u32,
154        #[comptime] _ident: Ident,
155        #[comptime] _config: C,
156    ) -> (u32, u32) {
157        (nth % num_rows, nth / num_rows)
158    }
159    fn to_nth_tile<C: StageConfig>(
160        row: u32,
161        col: u32,
162        #[comptime] tile_count_rows: u32,
163        #[comptime] _tile_count_cols: u32,
164        #[comptime] _ident: Ident,
165        #[comptime] _config: C,
166    ) -> u32 {
167        col * tile_count_rows + row
168    }
169
170    fn to_enum() -> comptime_type!(TilingOrderEnum) {
171        TilingOrderEnum::ColMajor
172    }
173}
174
175#[cube]
176impl TilingOrder for OrderedTilingOrder {
177    fn to_row_col<C: StageConfig>(
178        nth: u32,
179        #[comptime] tile_count_rows: u32,
180        #[comptime] tile_count_cols: u32,
181        #[comptime] ident: Ident,
182        #[comptime] config: C,
183    ) -> (u32, u32) {
184        if Ident::Lhs != ident {
185            panic!("Ordered tiling order should be used only on Lhs")
186        }
187
188        let group_rows = tile_count_rows / config.num_main_flow_planes();
189        let tiles_per_group = group_rows * tile_count_cols;
190
191        let group = nth / tiles_per_group;
192        let pos_within_group = nth % tiles_per_group;
193
194        let local_row = pos_within_group % group_rows;
195        let row = group * group_rows + local_row;
196        let col = pos_within_group / group_rows;
197
198        (row, col)
199    }
200
201    fn to_nth_tile<C: StageConfig>(
202        row: u32,
203        col: u32,
204        #[comptime] tile_count_rows: u32,
205        #[comptime] tile_count_cols: u32,
206        #[comptime] ident: Ident,
207        #[comptime] config: C,
208    ) -> u32 {
209        if Ident::Lhs != ident {
210            panic!("Ordered tiling order should be used only on Lhs")
211        }
212
213        let group_rows = tile_count_rows / config.num_main_flow_planes();
214        let group = row / group_rows;
215
216        let local_row = row % group_rows;
217        let tiles_per_group = group_rows * tile_count_cols;
218        let pos_within_group = col * group_rows + local_row;
219
220        group * tiles_per_group + pos_within_group
221    }
222
223    fn to_enum() -> comptime_type!(TilingOrderEnum) {
224        TilingOrderEnum::Ordered
225    }
226}
227
228#[cube]
229/// Describes how tiles are arranged in shared memory.
230pub trait TilingLayout: 'static + Send + Sync + Clone + Copy {
231    /// Returns the tile at shared memory coordinates
232    fn get_tile<ES: Numeric, S: StageConfig>(
233        stage: &StageMemory<ES, Self>,
234        row: u32,
235        col: u32,
236        #[comptime] buffer_index: u32,
237        #[comptime] ident: Ident,
238        #[comptime] config: S,
239    ) -> Tile<ES>;
240}
241
242#[derive(Clone, Copy)]
243/// Each tile is stored contiguously in shared memory.
244/// Global memory loads may require remapping to match this layout.
245pub struct ContiguousTilingLayout<T: TilingOrder> {
246    tiling_order: PhantomData<T>,
247}
248
249#[derive(Clone, Copy)]
250/// Tiles follow a strided layout that often mirrors global memory layout.
251/// Not all tiles are contiguous in shared memory, but mapping is more direct.
252pub struct StridedTilingLayout {}
253
254#[cube]
255impl<T: TilingOrder> ContiguousTilingLayout<T> {
256    /// Converts a tile index in the stage to its (x,y) position
257    pub fn to_x_y<S: StageConfig>(
258        nth: u32,
259        #[comptime] ident: Ident,
260        #[comptime] config: S,
261    ) -> (u32, u32) {
262        let num_x = config.tiling_scheme().tiles_in_stage_row(ident);
263        let num_y = config.tiling_scheme().tiles_in_stage_col(ident);
264
265        T::to_row_col::<S>(nth, num_x, num_y, ident, config)
266    }
267}
268
269#[cube]
270impl<TO: TilingOrder> TilingLayout for ContiguousTilingLayout<TO> {
271    fn get_tile<ES: Numeric, S: StageConfig>(
272        stage_memory: &StageMemory<ES, Self>,
273        row: u32,
274        col: u32,
275        #[comptime] buffer_index: u32,
276        #[comptime] ident: Ident,
277        #[comptime] config: S,
278    ) -> Tile<ES> {
279        let stage_line_size = config.stage_line_size(ident);
280        let tiling_scheme = config.tiling_scheme();
281        let matrix_layout = config.matrix_layout(ident);
282
283        let (row_buffer_offset, col_buffer_offset, total_tile_count_row, total_tile_count_col) =
284            match ident.as_input_ident() {
285                InputIdent::Lhs => {
286                    let x_tile_offset = 0;
287                    let y_tile_offset = tiling_scheme.tiles_in_stage_col(ident) * buffer_index;
288                    let total_tile_count_x = tiling_scheme.tiles_in_stage_row(ident);
289                    let total_tile_count_y = tiling_scheme.tiles_in_stage_col(ident)
290                        * config.num_stages(InputIdent::Lhs);
291                    (
292                        x_tile_offset,
293                        y_tile_offset,
294                        total_tile_count_x,
295                        total_tile_count_y,
296                    )
297                }
298                InputIdent::Rhs => {
299                    let x_tile_offset = tiling_scheme.tiles_in_stage_row(ident) * buffer_index;
300                    let y_tile_offset = 0;
301                    let total_tile_count_x = tiling_scheme.tiles_in_stage_row(ident)
302                        * config.num_stages(InputIdent::Rhs);
303                    let total_tile_count_y = tiling_scheme.tiles_in_stage_col(ident);
304                    (
305                        x_tile_offset,
306                        y_tile_offset,
307                        total_tile_count_x,
308                        total_tile_count_y,
309                    )
310                }
311            };
312
313        let (tile_size_x, tile_size_y, tile_slice_length) = match matrix_layout {
314            MatrixLayout::RowMajor => {
315                let tile_size_x = tiling_scheme.elements_in_tile_row(ident);
316                let tile_size_y = tiling_scheme.elements_in_tile_col(ident) / stage_line_size;
317                let stride_x = comptime!(tile_size_y * total_tile_count_col);
318                let length = (tile_size_x - 1) * stride_x + tile_size_y;
319
320                (tile_size_x, tile_size_y, length)
321            }
322            MatrixLayout::ColMajor => {
323                let tile_size_x = tiling_scheme.elements_in_tile_row(ident) / stage_line_size;
324                let tile_size_y = tiling_scheme.elements_in_tile_col(ident);
325                let stride_y = comptime!(tile_size_x * total_tile_count_row);
326                let length = (tile_size_y - 1) * stride_y + tile_size_x;
327
328                (tile_size_x, tile_size_y, length)
329            }
330        };
331
332        let start = tile_size_x
333            * tile_size_y
334            * TO::to_nth_tile::<S>(
335                row + row_buffer_offset,
336                col + col_buffer_offset,
337                total_tile_count_row,
338                total_tile_count_col,
339                ident,
340                config,
341            );
342
343        Tile::new_contiguous::<S::TileConfig>(
344            stage_memory
345                .as_slice(stage_line_size)
346                .slice(start, start + tile_slice_length),
347            ident,
348            config.tile_config(),
349        )
350    }
351}
352
353#[cube]
354impl StridedTilingLayout {
355    /// Returns the nth slice of the stage
356    pub fn nth_slice<ES: Numeric, S: StageConfig>(
357        stage: &mut StageMemory<ES, Self>,
358        nth: u32,
359        #[comptime] ident: Ident,
360        #[comptime] config: S,
361    ) -> SliceMut<Line<ES>> {
362        let matrix_layout = config.matrix_layout(ident);
363        let stage_line_size = config.stage_line_size(ident);
364
365        let slice_length = match comptime!(matrix_layout) {
366            MatrixLayout::RowMajor => config.tiling_scheme().elements_in_stage_col(ident),
367            MatrixLayout::ColMajor => config.tiling_scheme().elements_in_stage_row(ident),
368        } / stage_line_size;
369
370        let start = slice_length * nth;
371        stage
372            .as_slice_mut(stage_line_size)
373            .slice_mut(start, start + slice_length)
374    }
375}
376
377#[cube]
378impl TilingLayout for StridedTilingLayout {
379    fn get_tile<ES: Numeric, S: StageConfig>(
380        stage: &StageMemory<ES, Self>,
381        x: u32,
382        y: u32,
383        #[comptime] _buffer_index: u32,
384        #[comptime] ident: Ident,
385        #[comptime] config: S,
386    ) -> Tile<ES> {
387        if comptime!(config.num_stages(ident.as_input_ident()) > 1) {
388            unimplemented!()
389        }
390
391        let stage_line_size = config.stage_line_size(ident);
392        let tiling_scheme = config.tiling_scheme();
393        let matrix_layout = config.matrix_layout(ident);
394
395        let tile_count_x = tiling_scheme.tiles_in_stage_row(ident);
396        let tile_count_y = tiling_scheme.tiles_in_stage_col(ident);
397
398        match matrix_layout {
399            MatrixLayout::RowMajor => {
400                let tile_size_x = tiling_scheme.elements_in_tile_row(ident);
401                let tile_size_y = tiling_scheme.elements_in_tile_col(ident) / stage_line_size;
402
403                let stride = tile_count_y * tile_size_y;
404                let length = (tile_size_x - 1) * stride + tile_size_y;
405                let start = x * tile_size_x * stride + y * tile_size_y;
406
407                Tile::new_strided(
408                    stage.as_slice(stage_line_size).slice(start, start + length),
409                    stride,
410                    matrix_layout,
411                )
412            }
413            MatrixLayout::ColMajor => {
414                let tile_size_x = tiling_scheme.elements_in_tile_row(ident) / stage_line_size;
415                let tile_size_y = tiling_scheme.elements_in_tile_col(ident);
416
417                let stride = tile_count_x * tile_size_x;
418                let length = (tile_size_y - 1) * stride + tile_size_x;
419                let start = x * tile_size_x + y * tile_size_y * stride;
420
421                Tile::new_strided(
422                    stage.as_slice(stage_line_size).slice(start, start + length),
423                    stride,
424                    matrix_layout,
425                )
426            }
427        }
428    }
429}