1use crate::matmul::components::config::InputIdent;
2use crate::matmul::components::global;
3use crate::matmul::components::{Ident, MatrixLayout};
4use cubecl_core as cubecl;
5use cubecl_core::io::read_masked;
6use cubecl_core::prelude::*;
7use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
8
9#[derive(Clone, CubeType)]
10pub struct TensorReader<EI: Numeric> {
14 pub tensor: VirtualTensor<EI>,
15 pub x_offset: u32,
16 pub y_offset: u32,
17 pub stride_x: u32,
18 pub stride_y: u32,
19 pub shape_x: u32,
20 pub shape_y: u32,
21 pub batch_offset: u32,
22}
23
24#[derive(CubeType)]
25pub struct MappedTensorReader<EG: Numeric> {
28 pub tensor: TensorMap<EG>,
29 pub tile_x: u32,
30 pub tile_y: u32,
31 pub batch: u32,
32}
33
34#[derive(CubeType)]
35pub struct TensorWriter<EO: Numeric> {
39 pub tensor: VirtualTensor<EO, ReadWrite>,
40 pub x_offset: u32,
41 pub y_offset: u32,
42 pub stride_x: u32,
43 pub stride_y: u32,
44 pub shape_x: u32,
45 pub shape_y: u32,
46 pub batch_offset: u32,
47}
48
49unsafe impl<EG: Numeric> Sync for TensorReader<EG> {}
50unsafe impl<EG: Numeric> Send for TensorReader<EG> {}
51unsafe impl<EG: Numeric> Sync for MappedTensorReader<EG> {}
52unsafe impl<EG: Numeric> Send for MappedTensorReader<EG> {}
53unsafe impl<EG: Numeric> Sync for TensorWriter<EG> {}
54unsafe impl<EG: Numeric> Send for TensorWriter<EG> {}
55
56#[derive(CubeType)]
57pub struct Window<EG: Numeric> {
59 pub slice: Slice<Line<EG>>,
61 pub size: u32,
63}
64
65#[cube]
66impl<EG: Numeric> MappedTensorReader<EG> {
67 pub fn new(tensor: TensorMap<EG>, tile_x: u32, tile_y: u32, batch: u32) -> Self {
69 MappedTensorReader::<EG> {
70 tensor,
71 tile_x,
72 tile_y,
73 batch,
74 }
75 }
76
77 pub fn update_view(&mut self, k_offset: u32, #[comptime] ident: Ident) {
79 match ident.as_input_ident() {
80 InputIdent::Lhs => {
81 self.tile_y += k_offset;
82 }
83 InputIdent::Rhs => {
84 self.tile_x += k_offset;
85 }
86 }
87 }
88}
89
90#[cube]
91impl<EG: Numeric> TensorReader<EG> {
92 pub fn new(tensor: VirtualTensor<EG>, x_offset: u32, y_offset: u32, batch_offset: u32) -> Self {
94 let rank = tensor.rank();
95 let stride_x = tensor.stride(rank - 2);
96 let stride_y = tensor.stride(rank - 1);
97 let shape_x = tensor.shape(rank - 2);
98 let shape_y = tensor.shape(rank - 1);
99
100 TensorReader::<EG> {
101 tensor,
102 x_offset,
103 y_offset,
104 stride_x,
105 stride_y,
106 shape_x,
107 shape_y,
108 batch_offset,
109 }
110 }
111
112 pub fn update_view(&mut self, k_offset: u32, #[comptime] ident: InputIdent) {
114 match ident {
115 InputIdent::Lhs => {
116 self.y_offset += k_offset;
117 }
118 InputIdent::Rhs => {
119 self.x_offset += k_offset;
120 }
121 }
122 }
123
124 pub fn load_window_in_tile<G: global::GlobalConfig>(
134 &self,
135 tile: (u32, u32),
136 nth_window: u32,
137 #[comptime] input_ident: InputIdent,
138 #[comptime] config: G,
139 ) -> Window<EG> {
140 let line_size = config.global_line_size(input_ident);
141 let tiling_dimensions = config.tiling_dimensions(input_ident);
142 let matrix_layout = config.matrix_layout(input_ident);
143
144 let tile_size_x = tiling_dimensions.tile_shape_row();
145 let tile_size_y = tiling_dimensions.tile_shape_col();
146
147 let num_lines_in_window = comptime! {match matrix_layout {
148 MatrixLayout::RowMajor => tile_size_y / line_size,
149 MatrixLayout::ColMajor => tile_size_x / line_size,
150 }};
151
152 self.load_window::<G>(
153 nth_window,
154 (tile.0 * tile_size_x, tile.1 * tile_size_y),
155 num_lines_in_window,
156 input_ident,
157 config,
158 )
159 }
160
161 pub fn load_window_in_stage<G: global::GlobalConfig>(
170 &self,
171 nth_window: u32,
172 #[comptime] input_ident: InputIdent,
173 #[comptime] config: G,
174 ) -> Window<EG> {
175 let line_size = config.global_line_size(input_ident);
176 let tiling_dimensions = config.tiling_dimensions(input_ident);
177 let matrix_layout = config.matrix_layout(input_ident);
178
179 let num_lines_in_window = comptime! {match matrix_layout {
180 MatrixLayout::RowMajor =>
181 tiling_dimensions.total_col() / line_size
182 ,
183 MatrixLayout::ColMajor =>
184 tiling_dimensions.total_row() / line_size
185 ,
186 }};
187
188 self.load_window::<G>(
189 nth_window,
190 (0u32, 0u32).runtime(),
191 num_lines_in_window,
192 input_ident,
193 config,
194 )
195 }
196
197 fn load_window<G: global::GlobalConfig>(
198 &self,
199 nth_window: u32,
200 tile_offsets: (u32, u32),
201 #[comptime] num_lines_in_window: u32,
202 #[comptime] ident: InputIdent,
203 #[comptime] config: G,
204 ) -> Window<EG> {
205 let line_size = config.global_line_size(ident);
206 let matrix_layout = config.matrix_layout(ident);
207
208 let (load_x, load_y) = match matrix_layout {
209 MatrixLayout::RowMajor => (nth_window, 0),
210 MatrixLayout::ColMajor => (0, nth_window),
211 };
212
213 let view_tile_x = tile_offsets.0 + self.x_offset;
214 let view_tile_y = tile_offsets.1 + self.y_offset;
215
216 let view_x = view_tile_x + load_x;
217 let view_y = view_tile_y + load_y;
218
219 let read_pos =
220 (view_x * self.stride_x + view_y * self.stride_y + self.batch_offset) / line_size;
221
222 let (check_h_bounds, view_h, shape_h, check_w_bounds, view_w, shape_w) =
223 match config.matrix_layout(ident) {
224 MatrixLayout::RowMajor => (
225 config.check_row_bounds(ident),
226 view_x,
227 self.shape_x,
228 config.check_col_bounds(ident),
229 view_y,
230 self.shape_y,
231 ),
232 MatrixLayout::ColMajor => (
233 config.check_col_bounds(ident),
234 view_y,
235 self.shape_y,
236 config.check_row_bounds(ident),
237 view_x,
238 self.shape_x,
239 ),
240 };
241
242 let max_lines_in_window = if comptime!(check_h_bounds) {
244 num_lines_in_window * u32::cast_from(view_h < shape_h)
245 } else {
246 num_lines_in_window.runtime()
247 };
248
249 let size = if comptime!(check_w_bounds) {
251 slice_length_clamp(shape_w / line_size, view_w / line_size, max_lines_in_window)
252 } else {
253 max_lines_in_window
254 };
255
256 Window::<EG> {
257 slice: self.tensor.as_slice(read_pos, read_pos + size),
258 size,
259 }
260 }
261
262 pub fn load_coalesced_in_tile<G: global::GlobalConfig>(
272 &self,
273 tile_x: u32,
274 tile_y: u32,
275 position: u32,
276 #[comptime] input_ident: InputIdent,
277 #[comptime] config: G,
278 ) -> Line<EG> {
279 let tile_shape_x = config.tiling_dimensions(input_ident).tile_shape_row();
280 let tile_shape_y = config.tiling_dimensions(input_ident).tile_shape_col();
281
282 let view_tile_x = tile_x * tile_shape_x;
283 let view_tile_y = tile_y * tile_shape_y;
284
285 let (load_x, load_y) = match config.matrix_layout(input_ident) {
286 MatrixLayout::RowMajor => (position / tile_shape_y, position % tile_shape_y),
287 MatrixLayout::ColMajor => (position % tile_shape_x, position / tile_shape_x),
288 };
289
290 self.load_coalesced::<G>(
291 (load_x + view_tile_x, load_y + view_tile_y),
292 input_ident,
293 config,
294 )
295 }
296
297 pub fn load_coalesced_in_stage<G: global::GlobalConfig>(
308 &self,
309 position: u32,
310 #[comptime] input_ident: InputIdent,
311 #[comptime] config: G,
312 ) -> Line<EG> {
313 let stage_shape_x = config.tiling_dimensions(input_ident).total_row();
314 let stage_shape_y = config.tiling_dimensions(input_ident).total_col();
315
316 let load_offsets = match config.matrix_layout(input_ident) {
317 MatrixLayout::RowMajor => (position / stage_shape_y, position % stage_shape_y),
318 MatrixLayout::ColMajor => (position % stage_shape_x, position / stage_shape_x),
319 };
320
321 self.load_coalesced::<G>(load_offsets, input_ident, config)
322 }
323
324 fn load_coalesced<G: global::GlobalConfig>(
325 &self,
326 load_offsets: (u32, u32),
327 #[comptime] input_ident: InputIdent,
328 #[comptime] config: G,
329 ) -> Line<EG> {
330 let line_size = config.global_line_size(input_ident);
331
332 let view_x = load_offsets.0 + self.x_offset;
333 let view_y = load_offsets.1 + self.y_offset;
334
335 let read_pos =
336 (view_x * self.stride_x + view_y * self.stride_y + self.batch_offset) / line_size;
337
338 match comptime!((
339 config.check_row_bounds(input_ident),
340 config.check_col_bounds(input_ident)
341 )) {
342 (true, true) => read_masked::<Line<EG>>(
343 view_x < self.shape_x && view_y < self.shape_y,
344 self.tensor.as_slice(0, self.tensor.len()),
345 read_pos,
346 Line::cast_from(0),
347 ),
348 (true, false) => read_masked::<Line<EG>>(
349 view_x < self.shape_x,
350 self.tensor.as_slice(0, self.tensor.len()),
351 read_pos,
352 Line::cast_from(0),
353 ),
354 (false, true) => read_masked::<Line<EG>>(
355 view_y < self.shape_y,
356 self.tensor.as_slice(0, self.tensor.len()),
357 read_pos,
358 Line::cast_from(0),
359 ),
360 (false, false) => self.tensor.read(read_pos),
361 }
362 }
363}
364
365#[cube]
366impl<EG: Numeric> TensorWriter<EG> {
367 pub fn new(
369 tensor: VirtualTensor<EG, ReadWrite>,
370 x_offset: u32,
371 y_offset: u32,
372 batch_offset: u32,
373 ) -> Self {
374 let rank = tensor.rank();
375 let stride_x = tensor.stride(rank - 2);
376 let stride_y = tensor.stride(rank - 1);
377 let shape_x = tensor.shape(rank - 2);
378 let shape_y = tensor.shape(rank - 1);
379
380 TensorWriter::<EG> {
381 tensor,
382 x_offset,
383 y_offset,
384 stride_x,
385 stride_y,
386 shape_x,
387 shape_y,
388 batch_offset,
389 }
390 }
391
392 pub fn write_coalesced<ES: Numeric, G: global::GlobalConfig>(
396 &mut self,
397 tile_x: u32,
398 tile_y: u32,
399 unit_id: u32,
400 value: Line<ES>,
401 #[comptime] config: G,
402 ) {
403 let tiling = config.tiling_dimensions(Ident::Out);
404
405 let view_x =
406 tile_x * tiling.tile_shape_row() + unit_id / tiling.tile_shape_col() + self.x_offset;
407 let view_y =
408 tile_y * tiling.tile_shape_col() + unit_id % tiling.tile_shape_col() + self.y_offset;
409
410 let write_position = (view_x * self.stride_x + view_y * self.stride_y + self.batch_offset)
411 / config.global_line_size(Ident::Out);
412
413 match comptime!((
414 config.check_row_bounds(Ident::Out),
415 config.check_col_bounds(Ident::Out)
416 )) {
417 (true, true) => {
418 if view_x < self.shape_x && view_y < self.shape_y {
419 self.write(write_position, Line::cast_from(value));
420 }
421 }
422 (true, false) => {
423 if view_x < self.shape_x {
424 self.write(write_position, Line::cast_from(value));
425 }
426 }
427 (false, true) => {
428 if view_y < self.shape_y {
429 self.write(write_position, Line::cast_from(value));
430 }
431 }
432 (false, false) => {
433 self.write(write_position, Line::cast_from(value));
434 }
435 }
436 }
437
438 fn write(&mut self, position: u32, value: Line<EG>) {
439 self.tensor.write(position, value)
440 }
441}
442
443#[cube]
444fn slice_length_clamp(shape: u32, offset: u32, max_length: u32) -> u32 {
446 Min::min(select(shape > offset, shape - offset, 0), max_length)
447}