1use cubecl::{intrinsic, prelude::*, std::Swizzle};
2
3use crate::{MatrixLayout, stage::StageMemoryConfig, stage::as_swizzle_object};
4
5#[derive(CubeType, Clone, Copy)]
6pub struct StridedTile<ES: Numeric, N: Size, IO: SliceVisibility = ReadOnly> {
9 pub container: Slice<Vector<ES, N>, IO>,
11 pub start: u32,
13 pub end: u32,
15 pub stride: u32,
17 pub swizzle: Swizzle,
19 #[cube(comptime)]
20 pub layout: MatrixLayout,
22}
23
24#[cube]
25impl<ES: Numeric, N: Size> StridedTile<ES, N> {
26 pub fn new_contiguous(
30 container: Slice<Vector<ES, N>>,
31 start: u32,
32 #[comptime] config: StageMemoryConfig,
33 ) -> StridedTile<ES, N> {
34 let len = config.elements_per_tile() / config.vector_size;
35 let layout = config.matrix_layout;
36 let stride = match layout {
37 MatrixLayout::RowMajor => config.elements_per_tile_along_col,
38 MatrixLayout::ColMajor => config.elements_per_tile_along_row,
39 };
40
41 let stride = stride / config.vector_size;
42
43 StridedTile::<ES, N> {
44 container,
45 start,
46 end: start + len,
47 stride,
48 swizzle: as_swizzle_object(config.swizzle),
49 layout,
50 }
51 }
52
53 pub fn new_contiguous_mut(
57 container: Slice<Vector<ES, N>, ReadWrite>,
58 start: u32,
59 #[comptime] config: StageMemoryConfig,
60 ) -> StridedTile<ES, N, ReadWrite> {
61 let len = config.elements_per_tile() / config.vector_size;
62 let layout = config.matrix_layout;
63 let stride = match layout {
64 MatrixLayout::RowMajor => config.elements_per_tile_along_col,
65 MatrixLayout::ColMajor => config.elements_per_tile_along_row,
66 };
67
68 let stride = stride / config.vector_size;
69
70 StridedTile::<ES, N, ReadWrite> {
71 container,
72 start,
73 end: start + len,
74 stride,
75 swizzle: as_swizzle_object(config.swizzle),
76 layout,
77 }
78 }
79
80 pub fn new_strided(
84 container: Slice<Vector<ES, N>>,
85 start: u32,
86 end: u32,
87 stride: u32,
88 swizzle: Swizzle,
89 #[comptime] layout: MatrixLayout,
90 ) -> StridedTile<ES, N> {
91 StridedTile::<ES, N> {
92 container,
93 start,
94 end,
95 stride,
96 swizzle,
97 layout,
98 }
99 }
100
101 pub fn new_strided_mut(
105 container: Slice<Vector<ES, N>, ReadWrite>,
106 start: u32,
107 end: u32,
108 stride: u32,
109 swizzle: Swizzle,
110 #[comptime] layout: MatrixLayout,
111 ) -> StridedTile<ES, N, ReadWrite> {
112 StridedTile::<ES, N, ReadWrite> {
113 container,
114 start,
115 end,
116 stride,
117 swizzle,
118 layout,
119 }
120 }
121}
122
123#[cube]
124impl<ES: Numeric, N: Size, IO: SliceVisibility> StridedTile<ES, N, IO> {
125 pub fn unvectorized_stride(&self) -> u32 {
126 let stage_vector_size = self.container.vector_size();
127 self.stride * stage_vector_size as u32
128 }
129}
130
131#[cube]
132impl<ES: Numeric, N: Size, IO: SliceVisibility> StridedTile<ES, N, IO> {
133 pub fn as_slice(&self) -> Slice<Vector<ES, N>, ReadOnly> {
136 self.container.slice(self.start as usize, self.end as usize)
137 }
138
139 pub fn to_read_only(&self) -> StridedTile<ES, N, ReadOnly> {
141 StridedTile::<ES, N, ReadOnly> {
142 container: self.container.to_slice(),
143 start: self.start,
144 end: self.end,
145 stride: self.stride,
146 swizzle: self.swizzle,
147 layout: self.layout,
148 }
149 }
150}
151
152#[cube]
153impl<ES: Numeric, N: Size> StridedTile<ES, N, ReadWrite> {
154 pub fn as_slice_mut(&self) -> Slice<Vector<ES, N>, ReadWrite> {
157 self.container
158 .slice(self.start as usize, self.end as usize)
159 .as_mut_unchecked()
160 }
161}
162
163#[cube]
164impl<ES: Numeric, N: Size, IO: SliceVisibility> StridedTile<ES, N, IO> {
165 pub fn get_vector(&self, coor_strided: u32, coor_contiguous: u32) -> Vector<ES, N> {
167 let offset = coor_strided * self.stride + coor_contiguous;
168 let offset_abs = self.start + offset;
169 let type_size = Vector::<ES, N>::type_size();
170 let offset_swizzled = self.swizzle.apply(offset_abs, type_size);
171 self.container[offset_swizzled as usize]
172 }
173
174 pub fn stage_offset(&self, relative_offset: u32) -> u32 {
175 let offset = self.start + relative_offset;
176 let type_size = Vector::<ES, N>::type_size();
177 self.swizzle.apply(offset, type_size)
178 }
179
180 #[allow(unused_variables)]
181 pub fn with_vector_size<N2: Size>(&self) -> StridedTile<ES, N2, IO> {
182 let vector_size = N2::value();
183 intrinsic!(|scope| {
184 let stage_vector_size = self.container.vector_size();
185
186 if vector_size == self.container.vector_size() {
187 return self.__expand_with_stage_vector_size_method(scope);
188 }
189
190 let current = stage_vector_size;
191 let mut out: StridedTileExpand<ES, N2, IO> =
192 self.clone().__expand_with_stage_vector_size_method(scope);
193
194 if current < vector_size {
195 let ratio = (vector_size / current) as u32;
196 let end = cubecl::frontend::div::expand(scope, self.end, ratio.into());
197 let start = cubecl::frontend::div::expand(scope, self.start, ratio.into());
198 let stride =
199 cubecl::frontend::div::expand(scope, self.stride, (ratio as u32).into());
200 out.start = start;
201 out.end = end;
202 out.stride = stride;
203 } else {
204 let ratio = (current / vector_size) as u32;
205 let start = cubecl::frontend::mul::expand(scope, self.start, ratio.into());
206 let end = cubecl::frontend::mul::expand(scope, self.end, ratio.into());
207 let stride = cubecl::frontend::mul::expand(scope, self.stride, ratio.into());
208 out.start = start;
209 out.end = end;
210 out.stride = stride;
211 }
212
213 out
214 })
215 }
216
217 #[allow(unused)]
222 unsafe fn with_stage_vector_size<N2: Size>(self) -> StridedTile<ES, N2, IO> {
223 StridedTile::<ES, N2, IO> {
224 container: self.container.with_vector_size::<N2>(),
225 start: self.start,
226 end: self.end,
227 stride: self.stride,
228 swizzle: self.swizzle,
229 layout: self.layout,
230 }
231 }
232}