cubecl_matmul/components/tile/
tile_data.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, intrinsic};
3use cubecl_std::{Swizzle, type_size};
4
5use crate::components::{
6    MatrixLayout,
7    stage::{StageMemoryConfig, as_swizzle_object},
8};
9
10#[derive(CubeType, Clone, Copy)]
11/// Tile with a linear major dimension, and a strided minor dimension.
12/// Basic tile kind supported by all stage matmuls.
13pub struct StridedTile<ES: Numeric, IO: SliceVisibility = ReadOnly> {
14    /// Slice containing all data for the stage
15    pub stage: Slice<Line<ES>, IO>,
16    /// Offset of the tile in the stage
17    pub start: u32,
18    /// End of the tile in the stage, may be wrong with swizzle
19    pub end: u32,
20    /// Stride between each row/col, depending on MatrixLayout (the other is assumed to be 1)
21    pub stride: u32,
22    /// Swizzle object to transform the index
23    pub swizzle: Swizzle,
24    #[cube(comptime)]
25    /// Layout of the tile (row-major or column-major).
26    pub layout: MatrixLayout,
27    #[cube(comptime)]
28    /// Line size of the slice
29    pub line_size: u32,
30}
31
32#[cube]
33impl<ES: Numeric> StridedTile<ES> {
34    /// Creates a tile from a contiguous slice of data.
35    ///
36    /// The slice length must exactly match the tile size.
37    pub fn new_contiguous(
38        stage: Slice<Line<ES>>,
39        start: u32,
40        #[comptime] config: StageMemoryConfig,
41    ) -> StridedTile<ES> {
42        let len = config.elements_per_tile() / config.line_size;
43        let layout = config.matrix_layout;
44        let stride = match layout {
45            MatrixLayout::RowMajor => config.elements_per_tile_along_col,
46            MatrixLayout::ColMajor => config.elements_per_tile_along_row,
47        };
48
49        let stride = comptime![stride / config.line_size];
50
51        StridedTile::<ES> {
52            stage,
53            start,
54            end: start + len,
55            stride,
56            swizzle: as_swizzle_object(config.swizzle),
57            layout,
58            line_size: config.line_size,
59        }
60    }
61
62    /// Creates a tile from a contiguous slice of data.
63    ///
64    /// The slice length must exactly match the tile size.
65    pub fn new_contiguous_mut(
66        stage: Slice<Line<ES>, ReadWrite>,
67        start: u32,
68        #[comptime] config: StageMemoryConfig,
69    ) -> StridedTile<ES, ReadWrite> {
70        let len = config.elements_per_tile() / config.line_size;
71        let layout = config.matrix_layout;
72        let stride = match layout {
73            MatrixLayout::RowMajor => config.elements_per_tile_along_col,
74            MatrixLayout::ColMajor => config.elements_per_tile_along_row,
75        };
76
77        let stride = comptime![stride / config.line_size];
78
79        StridedTile::<ES, ReadWrite> {
80            stage,
81            start,
82            end: start + len,
83            stride,
84            swizzle: as_swizzle_object(config.swizzle),
85            layout,
86            line_size: config.line_size,
87        }
88    }
89
90    /// Creates a tile from a strided slice of data.
91    ///
92    /// The slice must include all elements of the tile, though it may include unused gaps.
93    pub fn new_strided(
94        stage: Slice<Line<ES>>,
95        start: u32,
96        end: u32,
97        stride: u32,
98        swizzle: Swizzle,
99        #[comptime] layout: MatrixLayout,
100        #[comptime] line_size: u32,
101    ) -> StridedTile<ES> {
102        StridedTile::<ES> {
103            stage,
104            start,
105            end,
106            stride,
107            swizzle,
108            layout,
109            line_size,
110        }
111    }
112
113    /// Creates a tile from a strided slice of data.
114    ///
115    /// The slice must include all elements of the tile, though it may include unused gaps.
116    pub fn new_strided_mut(
117        stage: Slice<Line<ES>, ReadWrite>,
118        start: u32,
119        end: u32,
120        stride: u32,
121        swizzle: Swizzle,
122        #[comptime] layout: MatrixLayout,
123        #[comptime] line_size: u32,
124    ) -> StridedTile<ES, ReadWrite> {
125        StridedTile::<ES, ReadWrite> {
126            stage,
127            start,
128            end,
129            stride,
130            swizzle,
131            layout,
132            line_size,
133        }
134    }
135}
136
137#[cube]
138impl<ES: Numeric> StridedTile<ES, ReadOnly> {
139    /// Returns the tile as an unlined (scalar) slice.
140    ///
141    /// Returns:
142    /// - The unlined slice
143    /// - The updated stride to account for line width removal
144    pub fn as_unlined(&self) -> (Slice<ES, ReadOnly>, u32) {
145        let stage_line_size = comptime![self.stage.line_size()];
146        (
147            self.stage.slice(self.start, self.end).try_cast_unchecked(),
148            self.stride * stage_line_size,
149        )
150    }
151}
152
153#[cube]
154impl<ES: Numeric> StridedTile<ES, ReadWrite> {
155    /// Returns the tile as an unlined (scalar) slice.
156    ///
157    /// Returns:
158    /// - The unlined slice
159    /// - The updated stride to account for line width removal
160    pub fn as_unlined_mut(&self) -> (Slice<ES, ReadWrite>, u32) {
161        let stage_line_size = comptime![self.stage.line_size()];
162        (
163            self.stage
164                .slice(self.start, self.end)
165                .as_mut_unchecked()
166                .try_cast_unchecked(),
167            self.stride * stage_line_size,
168        )
169    }
170
171    /// Returns the tile as an offset slice. Should only be used when swizzling is definitely not
172    /// applicable.
173    pub fn as_slice_mut(&self) -> Slice<Line<ES>, ReadWrite> {
174        self.stage.slice(self.start, self.end).as_mut_unchecked()
175    }
176}
177
178#[cube]
179impl<ES: Numeric, IO: SliceVisibility> StridedTile<ES, IO> {
180    /// Returns a specific line from the tile based on coordinates.
181    pub fn get_line(&self, coor_strided: u32, coor_contiguous: u32) -> Line<ES> {
182        let offset = coor_strided * self.stride + coor_contiguous;
183        let offset_abs = self.start + offset;
184        let type_size = type_size::<ES>(self.stage.line_size());
185        let offset_swizzled = self.swizzle.apply(offset_abs, type_size);
186        self.stage[offset_swizzled]
187    }
188
189    pub fn stage_offset(&self, relative_offset: u32) -> u32 {
190        let offset = self.start + relative_offset;
191        let type_size = type_size::<ES>(self.stage.line_size());
192        self.swizzle.apply(offset, type_size)
193    }
194
195    #[allow(unused_variables)]
196    pub fn with_line_size(&self, #[comptime] line_size: u32) -> Self {
197        intrinsic!(|scope| {
198            let stage_line_size = self.stage.line_size();
199
200            if line_size == self.stage.line_size() {
201                return self;
202            }
203
204            let current = stage_line_size;
205            let mut out = self.clone();
206
207            if current < line_size {
208                let ratio = line_size / current;
209                let end = cubecl::frontend::div::expand(scope, self.end, ratio.into());
210                let start = cubecl::frontend::div::expand(scope, self.start, ratio.into());
211                let stride = cubecl::frontend::div::expand(scope, self.stride, ratio.into());
212                out.start = start;
213                out.end = end;
214                out.stride = stride;
215            } else {
216                let ratio = current / line_size;
217                let start = cubecl::frontend::mul::expand(scope, self.start, ratio.into());
218                let end = cubecl::frontend::mul::expand(scope, self.end, ratio.into());
219                let stride = cubecl::frontend::mul::expand(scope, self.stride, ratio.into());
220                out.start = start;
221                out.end = end;
222                out.stride = stride;
223            }
224
225            out.stage = out.stage.__expand_with_line_size_method(scope, line_size);
226            out
227        })
228    }
229}