cubecl_matmul/components/tile/
tile_data.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::{MatrixLayout, stage::StageMemoryConfig};
5
6#[derive(CubeType, Clone, Copy)]
7/// Tile with a linear major dimension, and a strided minor dimension.
8/// Basic tile kind supported by all stage matmuls.
9pub struct StridedTile<ES: Numeric, IO: SliceVisibility = ReadOnly> {
10    /// Slice containing all data
11    pub slice: Slice<Line<ES>, IO>,
12    /// Stride between each row/col, depending on MatrixLayout (the other is assumed to be 1)
13    pub stride: u32,
14    #[cube(comptime)]
15    /// Layout of the tile (row-major or column-major).
16    pub layout: MatrixLayout,
17}
18
19#[cube]
20impl<ES: Numeric> StridedTile<ES> {
21    /// Creates a tile from a contiguous slice of data.
22    ///
23    /// The slice length must exactly match the tile size.
24    pub fn new_contiguous(
25        slice: Slice<Line<ES>>,
26        #[comptime] config: StageMemoryConfig,
27    ) -> StridedTile<ES> {
28        let layout = config.matrix_layout;
29        let stride = match layout {
30            MatrixLayout::RowMajor => config.elements_in_tile_col,
31            MatrixLayout::ColMajor => config.elements_in_tile_row,
32        };
33
34        let stride = comptime![stride / config.stage_line_size];
35
36        StridedTile::<ES> {
37            slice,
38            stride,
39            layout,
40        }
41    }
42
43    /// Creates a tile from a contiguous slice of data.
44    ///
45    /// The slice length must exactly match the tile size.
46    pub fn new_contiguous_mut(
47        slice: Slice<Line<ES>, ReadWrite>,
48        #[comptime] config: StageMemoryConfig,
49    ) -> StridedTile<ES, ReadWrite> {
50        let layout = config.matrix_layout;
51        let stride = match layout {
52            MatrixLayout::RowMajor => config.elements_in_tile_col,
53            MatrixLayout::ColMajor => config.elements_in_tile_row,
54        };
55
56        let stride = comptime![stride / config.stage_line_size];
57
58        StridedTile::<ES, ReadWrite> {
59            slice,
60            stride,
61            layout,
62        }
63    }
64
65    /// Creates a tile from a strided slice of data.
66    ///
67    /// The slice must include all elements of the tile, though it may include unused gaps.
68    pub fn new_strided(
69        slice: Slice<Line<ES>>,
70        stride: u32,
71        #[comptime] layout: MatrixLayout,
72    ) -> StridedTile<ES> {
73        StridedTile::<ES> {
74            slice,
75            stride,
76            layout,
77        }
78    }
79
80    /// Creates a tile from a strided slice of data.
81    ///
82    /// The slice must include all elements of the tile, though it may include unused gaps.
83    pub fn new_strided_mut(
84        slice: Slice<Line<ES>, ReadWrite>,
85        stride: u32,
86        #[comptime] layout: MatrixLayout,
87    ) -> StridedTile<ES, ReadWrite> {
88        StridedTile::<ES, ReadWrite> {
89            slice,
90            stride,
91            layout,
92        }
93    }
94}
95
96#[cube]
97impl<ES: Numeric, IO: SliceVisibility> StridedTile<ES, IO> {
98    /// Returns the tile as an unlined (scalar) slice.
99    ///
100    /// Returns:
101    /// - The unlined slice
102    /// - The updated stride to account for line width removal
103    pub fn as_unlined(&self) -> (Slice<ES, IO>, u32) {
104        let stage_line_size = comptime![self.slice.line_size()];
105        (
106            self.slice.try_cast_unchecked(),
107            self.stride * stage_line_size,
108        )
109    }
110
111    /// Returns a specific line from the tile based on coordinates.
112    pub fn get_line(&self, coor_strided: u32, coor_contiguous: u32) -> Line<ES> {
113        self.slice[coor_strided * self.stride + coor_contiguous]
114    }
115}