cubecl_matmul/components/tile/
tile_data.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, intrinsic};
3use cubecl_std::{Swizzle, type_size};
4
5use crate::components::{
6 MatrixLayout,
7 stage::{StageMemoryConfig, as_swizzle_object},
8};
9
10#[derive(CubeType, Clone, Copy)]
11pub struct StridedTile<ES: Numeric, IO: SliceVisibility = ReadOnly> {
14 pub stage: Slice<Line<ES>, IO>,
16 pub start: u32,
18 pub end: u32,
20 pub stride: u32,
22 pub swizzle: Swizzle,
24 #[cube(comptime)]
25 pub layout: MatrixLayout,
27 #[cube(comptime)]
28 pub line_size: u32,
30}
31
32#[cube]
33impl<ES: Numeric> StridedTile<ES> {
34 pub fn new_contiguous(
38 stage: Slice<Line<ES>>,
39 start: u32,
40 #[comptime] config: StageMemoryConfig,
41 ) -> StridedTile<ES> {
42 let len = config.elements_per_tile() / config.line_size;
43 let layout = config.matrix_layout;
44 let stride = match layout {
45 MatrixLayout::RowMajor => config.elements_per_tile_along_col,
46 MatrixLayout::ColMajor => config.elements_per_tile_along_row,
47 };
48
49 let stride = comptime![stride / config.line_size];
50
51 StridedTile::<ES> {
52 stage,
53 start,
54 end: start + len,
55 stride,
56 swizzle: as_swizzle_object(config.swizzle),
57 layout,
58 line_size: config.line_size,
59 }
60 }
61
62 pub fn new_contiguous_mut(
66 stage: Slice<Line<ES>, ReadWrite>,
67 start: u32,
68 #[comptime] config: StageMemoryConfig,
69 ) -> StridedTile<ES, ReadWrite> {
70 let len = config.elements_per_tile() / config.line_size;
71 let layout = config.matrix_layout;
72 let stride = match layout {
73 MatrixLayout::RowMajor => config.elements_per_tile_along_col,
74 MatrixLayout::ColMajor => config.elements_per_tile_along_row,
75 };
76
77 let stride = comptime![stride / config.line_size];
78
79 StridedTile::<ES, ReadWrite> {
80 stage,
81 start,
82 end: start + len,
83 stride,
84 swizzle: as_swizzle_object(config.swizzle),
85 layout,
86 line_size: config.line_size,
87 }
88 }
89
90 pub fn new_strided(
94 stage: Slice<Line<ES>>,
95 start: u32,
96 end: u32,
97 stride: u32,
98 swizzle: Swizzle,
99 #[comptime] layout: MatrixLayout,
100 #[comptime] line_size: u32,
101 ) -> StridedTile<ES> {
102 StridedTile::<ES> {
103 stage,
104 start,
105 end,
106 stride,
107 swizzle,
108 layout,
109 line_size,
110 }
111 }
112
113 pub fn new_strided_mut(
117 stage: Slice<Line<ES>, ReadWrite>,
118 start: u32,
119 end: u32,
120 stride: u32,
121 swizzle: Swizzle,
122 #[comptime] layout: MatrixLayout,
123 #[comptime] line_size: u32,
124 ) -> StridedTile<ES, ReadWrite> {
125 StridedTile::<ES, ReadWrite> {
126 stage,
127 start,
128 end,
129 stride,
130 swizzle,
131 layout,
132 line_size,
133 }
134 }
135}
136
137#[cube]
138impl<ES: Numeric> StridedTile<ES, ReadOnly> {
139 pub fn as_unlined(&self) -> (Slice<ES, ReadOnly>, u32) {
145 let stage_line_size = comptime![self.stage.line_size()];
146 (
147 self.stage.slice(self.start, self.end).try_cast_unchecked(),
148 self.stride * stage_line_size,
149 )
150 }
151}
152
153#[cube]
154impl<ES: Numeric> StridedTile<ES, ReadWrite> {
155 pub fn as_unlined_mut(&self) -> (Slice<ES, ReadWrite>, u32) {
161 let stage_line_size = comptime![self.stage.line_size()];
162 (
163 self.stage
164 .slice(self.start, self.end)
165 .as_mut_unchecked()
166 .try_cast_unchecked(),
167 self.stride * stage_line_size,
168 )
169 }
170
171 pub fn as_slice_mut(&self) -> Slice<Line<ES>, ReadWrite> {
174 self.stage.slice(self.start, self.end).as_mut_unchecked()
175 }
176}
177
178#[cube]
179impl<ES: Numeric, IO: SliceVisibility> StridedTile<ES, IO> {
180 pub fn get_line(&self, coor_strided: u32, coor_contiguous: u32) -> Line<ES> {
182 let offset = coor_strided * self.stride + coor_contiguous;
183 let offset_abs = self.start + offset;
184 let type_size = type_size::<ES>(self.stage.line_size());
185 let offset_swizzled = self.swizzle.apply(offset_abs, type_size);
186 self.stage[offset_swizzled]
187 }
188
189 pub fn stage_offset(&self, relative_offset: u32) -> u32 {
190 let offset = self.start + relative_offset;
191 let type_size = type_size::<ES>(self.stage.line_size());
192 self.swizzle.apply(offset, type_size)
193 }
194
195 #[allow(unused_variables)]
196 pub fn with_line_size(&self, #[comptime] line_size: u32) -> Self {
197 intrinsic!(|scope| {
198 let stage_line_size = self.stage.line_size();
199
200 if line_size == self.stage.line_size() {
201 return self;
202 }
203
204 let current = stage_line_size;
205 let mut out = self.clone();
206
207 if current < line_size {
208 let ratio = line_size / current;
209 let end = cubecl::frontend::div::expand(scope, self.end, ratio.into());
210 let start = cubecl::frontend::div::expand(scope, self.start, ratio.into());
211 let stride = cubecl::frontend::div::expand(scope, self.stride, ratio.into());
212 out.start = start;
213 out.end = end;
214 out.stride = stride;
215 } else {
216 let ratio = current / line_size;
217 let start = cubecl::frontend::mul::expand(scope, self.start, ratio.into());
218 let end = cubecl::frontend::mul::expand(scope, self.end, ratio.into());
219 let stride = cubecl::frontend::mul::expand(scope, self.stride, ratio.into());
220 out.start = start;
221 out.end = end;
222 out.stride = stride;
223 }
224
225 out.stage = out.stage.__expand_with_line_size_method(scope, line_size);
226 out
227 })
228 }
229}