1use cubecl::{intrinsic, prelude::*, std::Swizzle};
2
3use crate::{MatrixLayout, stage::StageMemoryConfig, stage::as_swizzle_object};
4
5#[derive(CubeType, Clone, Copy)]
6pub struct StridedTile<ES: Numeric, N: Size, IO: SliceVisibility = ReadOnly> {
9 pub stage: Slice<Vector<ES, N>, IO>,
11 pub start: u32,
13 pub end: u32,
15 pub stride: u32,
17 pub swizzle: Swizzle,
19 #[cube(comptime)]
20 pub layout: MatrixLayout,
22}
23
24#[cube]
25impl<ES: Numeric, N: Size> StridedTile<ES, N> {
26 pub fn new_contiguous(
30 stage: Slice<Vector<ES, N>>,
31 start: u32,
32 #[comptime] config: StageMemoryConfig,
33 ) -> StridedTile<ES, N> {
34 let len = config.elements_per_tile() / config.vector_size;
35 let layout = config.matrix_layout;
36 let stride = match layout {
37 MatrixLayout::RowMajor => config.elements_per_tile_along_col,
38 MatrixLayout::ColMajor => config.elements_per_tile_along_row,
39 };
40
41 let stride = stride / config.vector_size;
42
43 StridedTile::<ES, N> {
44 stage,
45 start,
46 end: start + len,
47 stride,
48 swizzle: as_swizzle_object(config.swizzle),
49 layout,
50 }
51 }
52
53 pub fn new_contiguous_mut(
57 stage: Slice<Vector<ES, N>, ReadWrite>,
58 start: u32,
59 #[comptime] config: StageMemoryConfig,
60 ) -> StridedTile<ES, N, ReadWrite> {
61 let len = config.elements_per_tile() / config.vector_size;
62 let layout = config.matrix_layout;
63 let stride = match layout {
64 MatrixLayout::RowMajor => config.elements_per_tile_along_col,
65 MatrixLayout::ColMajor => config.elements_per_tile_along_row,
66 };
67
68 let stride = stride / config.vector_size;
69
70 StridedTile::<ES, N, ReadWrite> {
71 stage,
72 start,
73 end: start + len,
74 stride,
75 swizzle: as_swizzle_object(config.swizzle),
76 layout,
77 }
78 }
79
80 pub fn new_strided(
84 stage: Slice<Vector<ES, N>>,
85 start: u32,
86 end: u32,
87 stride: u32,
88 swizzle: Swizzle,
89 #[comptime] layout: MatrixLayout,
90 ) -> StridedTile<ES, N> {
91 StridedTile::<ES, N> {
92 stage,
93 start,
94 end,
95 stride,
96 swizzle,
97 layout,
98 }
99 }
100
101 pub fn new_strided_mut(
105 stage: Slice<Vector<ES, N>, ReadWrite>,
106 start: u32,
107 end: u32,
108 stride: u32,
109 swizzle: Swizzle,
110 #[comptime] layout: MatrixLayout,
111 ) -> StridedTile<ES, N, ReadWrite> {
112 StridedTile::<ES, N, ReadWrite> {
113 stage,
114 start,
115 end,
116 stride,
117 swizzle,
118 layout,
119 }
120 }
121}
122
123#[cube]
124impl<ES: Numeric, N: Size, IO: SliceVisibility> StridedTile<ES, N, IO> {
125 pub fn unvectorized_stride(&self) -> u32 {
126 let stage_vector_size = self.stage.vector_size();
127 self.stride * stage_vector_size as u32
128 }
129}
130
131#[cube]
132impl<ES: Numeric, N: Size> StridedTile<ES, N, ReadOnly> {
133 pub fn as_slice(&self) -> Slice<Vector<ES, N>, ReadOnly> {
136 self.stage.slice(self.start as usize, self.end as usize)
137 }
138}
139
140#[cube]
141impl<ES: Numeric, N: Size> StridedTile<ES, N, ReadWrite> {
142 pub fn as_slice_mut(&self) -> Slice<Vector<ES, N>, ReadWrite> {
145 self.stage
146 .slice(self.start as usize, self.end as usize)
147 .as_mut_unchecked()
148 }
149}
150
151#[cube]
152impl<ES: Numeric, N: Size, IO: SliceVisibility> StridedTile<ES, N, IO> {
153 pub fn get_vector(&self, coor_strided: u32, coor_contiguous: u32) -> Vector<ES, N> {
155 let offset = coor_strided * self.stride + coor_contiguous;
156 let offset_abs = self.start + offset;
157 let type_size = Vector::<ES, N>::type_size();
158 let offset_swizzled = self.swizzle.apply(offset_abs, type_size);
159 self.stage[offset_swizzled as usize]
160 }
161
162 pub fn stage_offset(&self, relative_offset: u32) -> u32 {
163 let offset = self.start + relative_offset;
164 let type_size = Vector::<ES, N>::type_size();
165 self.swizzle.apply(offset, type_size)
166 }
167
168 #[allow(unused_variables)]
169 pub fn with_vector_size<N2: Size>(&self) -> StridedTile<ES, N2, IO> {
170 let vector_size = N2::value();
171 intrinsic!(|scope| {
172 let stage_vector_size = self.stage.vector_size();
173
174 if vector_size == self.stage.vector_size() {
175 return self.__expand_with_stage_vector_size_method(scope);
176 }
177
178 let current = stage_vector_size;
179 let mut out: StridedTileExpand<ES, N2, IO> =
180 self.clone().__expand_with_stage_vector_size_method(scope);
181
182 if current < vector_size {
183 let ratio = (vector_size / current) as u32;
184 let end = cubecl::frontend::div::expand(scope, self.end, ratio.into());
185 let start = cubecl::frontend::div::expand(scope, self.start, ratio.into());
186 let stride =
187 cubecl::frontend::div::expand(scope, self.stride, (ratio as u32).into());
188 out.start = start;
189 out.end = end;
190 out.stride = stride;
191 } else {
192 let ratio = (current / vector_size) as u32;
193 let start = cubecl::frontend::mul::expand(scope, self.start, ratio.into());
194 let end = cubecl::frontend::mul::expand(scope, self.end, ratio.into());
195 let stride = cubecl::frontend::mul::expand(scope, self.stride, ratio.into());
196 out.start = start;
197 out.end = end;
198 out.stride = stride;
199 }
200
201 out
202 })
203 }
204
205 #[allow(unused)]
210 unsafe fn with_stage_vector_size<N2: Size>(self) -> StridedTile<ES, N2, IO> {
211 StridedTile::<ES, N2, IO> {
212 stage: self.stage.with_vector_size::<N2>(),
213 start: self.start,
214 end: self.end,
215 stride: self.stride,
216 swizzle: self.swizzle,
217 layout: self.layout,
218 }
219 }
220}