Skip to main content

cubek_std/tile/data/
strided.rs

1use cubecl::{intrinsic, prelude::*, std::Swizzle};
2
3use crate::{MatrixLayout, stage::StageMemoryConfig, stage::as_swizzle_object};
4
5#[derive(CubeType, Clone, Copy)]
6/// Tile with a linear major dimension, and a strided minor dimension.
7/// Basic tile kind supported by all stage matmuls.
8pub struct StridedTile<ES: Numeric, N: Size, IO: SliceVisibility = ReadOnly> {
9    /// Slice containing all data for the stage
10    pub container: Slice<Vector<ES, N>, IO>,
11    /// Offset of the tile in the stage
12    pub start: u32,
13    /// End of the tile in the stage, may be wrong with swizzle
14    pub end: u32,
15    /// Stride between each row/col, depending on MatrixLayout (the other is assumed to be 1)
16    pub stride: u32,
17    /// Swizzle object to transform the index
18    pub swizzle: Swizzle,
19    #[cube(comptime)]
20    /// Layout of the tile (row-major or column-major).
21    pub layout: MatrixLayout,
22}
23
24#[cube]
25impl<ES: Numeric, N: Size> StridedTile<ES, N> {
26    /// Creates a tile from a contiguous slice of data.
27    ///
28    /// The slice length must exactly match the tile size.
29    pub fn new_contiguous(
30        container: Slice<Vector<ES, N>>,
31        start: u32,
32        #[comptime] config: StageMemoryConfig,
33    ) -> StridedTile<ES, N> {
34        let len = config.elements_per_tile() / config.vector_size;
35        let layout = config.matrix_layout;
36        let stride = match layout {
37            MatrixLayout::RowMajor => config.elements_per_tile_along_col,
38            MatrixLayout::ColMajor => config.elements_per_tile_along_row,
39        };
40
41        let stride = stride / config.vector_size;
42
43        StridedTile::<ES, N> {
44            container,
45            start,
46            end: start + len,
47            stride,
48            swizzle: as_swizzle_object(config.swizzle),
49            layout,
50        }
51    }
52
53    /// Creates a tile from a contiguous slice of data.
54    ///
55    /// The slice length must exactly match the tile size.
56    pub fn new_contiguous_mut(
57        container: Slice<Vector<ES, N>, ReadWrite>,
58        start: u32,
59        #[comptime] config: StageMemoryConfig,
60    ) -> StridedTile<ES, N, ReadWrite> {
61        let len = config.elements_per_tile() / config.vector_size;
62        let layout = config.matrix_layout;
63        let stride = match layout {
64            MatrixLayout::RowMajor => config.elements_per_tile_along_col,
65            MatrixLayout::ColMajor => config.elements_per_tile_along_row,
66        };
67
68        let stride = stride / config.vector_size;
69
70        StridedTile::<ES, N, ReadWrite> {
71            container,
72            start,
73            end: start + len,
74            stride,
75            swizzle: as_swizzle_object(config.swizzle),
76            layout,
77        }
78    }
79
80    /// Creates a tile from a strided slice of data.
81    ///
82    /// The slice must include all elements of the tile, though it may include unused gaps.
83    pub fn new_strided(
84        container: Slice<Vector<ES, N>>,
85        start: u32,
86        end: u32,
87        stride: u32,
88        swizzle: Swizzle,
89        #[comptime] layout: MatrixLayout,
90    ) -> StridedTile<ES, N> {
91        StridedTile::<ES, N> {
92            container,
93            start,
94            end,
95            stride,
96            swizzle,
97            layout,
98        }
99    }
100
101    /// Creates a tile from a strided slice of data.
102    ///
103    /// The slice must include all elements of the tile, though it may include unused gaps.
104    pub fn new_strided_mut(
105        container: Slice<Vector<ES, N>, ReadWrite>,
106        start: u32,
107        end: u32,
108        stride: u32,
109        swizzle: Swizzle,
110        #[comptime] layout: MatrixLayout,
111    ) -> StridedTile<ES, N, ReadWrite> {
112        StridedTile::<ES, N, ReadWrite> {
113            container,
114            start,
115            end,
116            stride,
117            swizzle,
118            layout,
119        }
120    }
121}
122
123#[cube]
124impl<ES: Numeric, N: Size, IO: SliceVisibility> StridedTile<ES, N, IO> {
125    pub fn unvectorized_stride(&self) -> u32 {
126        let stage_vector_size = self.container.vector_size();
127        self.stride * stage_vector_size as u32
128    }
129}
130
131#[cube]
132impl<ES: Numeric, N: Size, IO: SliceVisibility> StridedTile<ES, N, IO> {
133    /// Returns the tile as an offset read-only slice. Should only be used when swizzling is
134    /// definitely not applicable.
135    pub fn as_slice(&self) -> Slice<Vector<ES, N>, ReadOnly> {
136        self.container.slice(self.start as usize, self.end as usize)
137    }
138
139    /// Returns a read-only view of this tile, dropping write permission on the container.
140    pub fn to_read_only(&self) -> StridedTile<ES, N, ReadOnly> {
141        StridedTile::<ES, N, ReadOnly> {
142            container: self.container.to_slice(),
143            start: self.start,
144            end: self.end,
145            stride: self.stride,
146            swizzle: self.swizzle,
147            layout: self.layout,
148        }
149    }
150}
151
152#[cube]
153impl<ES: Numeric, N: Size> StridedTile<ES, N, ReadWrite> {
154    /// Returns the tile as an offset slice. Should only be used when swizzling is definitely not
155    /// applicable.
156    pub fn as_slice_mut(&self) -> Slice<Vector<ES, N>, ReadWrite> {
157        self.container
158            .slice(self.start as usize, self.end as usize)
159            .as_mut_unchecked()
160    }
161}
162
163#[cube]
164impl<ES: Numeric, N: Size, IO: SliceVisibility> StridedTile<ES, N, IO> {
165    /// Returns a specific vector from the tile based on coordinates.
166    pub fn get_vector(&self, coor_strided: u32, coor_contiguous: u32) -> Vector<ES, N> {
167        let offset = coor_strided * self.stride + coor_contiguous;
168        let offset_abs = self.start + offset;
169        let type_size = Vector::<ES, N>::type_size();
170        let offset_swizzled = self.swizzle.apply(offset_abs, type_size);
171        self.container[offset_swizzled as usize]
172    }
173
174    pub fn stage_offset(&self, relative_offset: u32) -> u32 {
175        let offset = self.start + relative_offset;
176        let type_size = Vector::<ES, N>::type_size();
177        self.swizzle.apply(offset, type_size)
178    }
179
180    #[allow(unused_variables)]
181    pub fn with_vector_size<N2: Size>(&self) -> StridedTile<ES, N2, IO> {
182        let vector_size = N2::value();
183        intrinsic!(|scope| {
184            let stage_vector_size = self.container.vector_size();
185
186            if vector_size == self.container.vector_size() {
187                return self.__expand_with_stage_vector_size_method(scope);
188            }
189
190            let current = stage_vector_size;
191            let mut out: StridedTileExpand<ES, N2, IO> =
192                self.clone().__expand_with_stage_vector_size_method(scope);
193
194            if current < vector_size {
195                let ratio = (vector_size / current) as u32;
196                let end = cubecl::frontend::div::expand(scope, self.end, ratio.into());
197                let start = cubecl::frontend::div::expand(scope, self.start, ratio.into());
198                let stride =
199                    cubecl::frontend::div::expand(scope, self.stride, (ratio as u32).into());
200                out.start = start;
201                out.end = end;
202                out.stride = stride;
203            } else {
204                let ratio = (current / vector_size) as u32;
205                let start = cubecl::frontend::mul::expand(scope, self.start, ratio.into());
206                let end = cubecl::frontend::mul::expand(scope, self.end, ratio.into());
207                let stride = cubecl::frontend::mul::expand(scope, self.stride, ratio.into());
208                out.start = start;
209                out.end = end;
210                out.stride = stride;
211            }
212
213            out
214        })
215    }
216
217    /// Cast only the stage vector size. This leaves the tile in an invalid state - start, end and
218    /// stride must be adjusted accordingly.
219    /// # Safety
220    /// Must not be used without further metadata adjustments
221    #[allow(unused)]
222    unsafe fn with_stage_vector_size<N2: Size>(self) -> StridedTile<ES, N2, IO> {
223        StridedTile::<ES, N2, IO> {
224            container: self.container.with_vector_size::<N2>(),
225            start: self.start,
226            end: self.end,
227            stride: self.stride,
228            swizzle: self.swizzle,
229            layout: self.layout,
230        }
231    }
232}
233
234/// V-erased view over a shared-memory tile. Wraps a `StridedTile` and hides
235/// its vectorization from the type system. The underlying slice is downcast
236/// to a scalar `Slice<E, IO>` only at the Rust type level — the runtime
237/// vector_size on the cubecl slice is preserved, so projecting back via
238/// `view::<V>()` is a pure retype with no metadata change. `V` must match
239/// the original `V` the tile was wrapped with.
240#[derive(CubeType, Clone, Copy)]
241pub struct SharedTile<E: Numeric, IO: SliceVisibility = ReadOnly> {
242    container: Slice<E, IO>,
243    start: u32,
244    end: u32,
245    stride: u32,
246    swizzle: Swizzle,
247    #[cube(comptime)]
248    layout: MatrixLayout,
249}
250
251#[cube]
252impl<E: Numeric, IO: SliceVisibility> SharedTile<E, IO> {
253    /// Wrap a `StridedTile` whose vectorization is `V`. The slice is type-erased
254    /// to scalar `Slice<E, IO>` while preserving the runtime vector_size set at
255    /// allocation time. No metadata scaling is performed.
256    pub fn wrap<V: Size>(tile: StridedTile<E, V, IO>) -> SharedTile<E, IO> {
257        let container: Slice<E, IO> = unsafe { tile.container.downcast_unchecked::<E>() };
258        SharedTile::<E, IO> {
259            container,
260            start: tile.start,
261            end: tile.end,
262            stride: tile.stride,
263            swizzle: tile.swizzle,
264            layout: tile.layout,
265        }
266    }
267
268    /// Project the wrapped tile back to a typed `StridedTile<E, V, IO>`.
269    /// `V` must match the original `V` the tile was wrapped with — only
270    /// the Rust type changes, the runtime layout is unchanged.
271    pub fn view<V: Size>(&self) -> StridedTile<E, V, IO> {
272        let container: Slice<Vector<E, V>, IO> =
273            unsafe { self.container.downcast_unchecked::<Vector<E, V>>() };
274        StridedTile::<E, V, IO> {
275            container,
276            start: self.start,
277            end: self.end,
278            stride: self.stride,
279            swizzle: self.swizzle,
280            layout: self.layout,
281        }
282    }
283}