cubecl_matmul/components/tile/
tile_data.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::{MatrixLayout, stage::StageMemoryConfig};
5
6#[derive(CubeType, Clone, Copy)]
7pub struct StridedTile<ES: Numeric, IO: SliceVisibility = ReadOnly> {
10 pub slice: Slice<Line<ES>, IO>,
12 pub stride: u32,
14 #[cube(comptime)]
15 pub layout: MatrixLayout,
17}
18
19#[cube]
20impl<ES: Numeric> StridedTile<ES> {
21 pub fn new_contiguous(
25 slice: Slice<Line<ES>>,
26 #[comptime] config: StageMemoryConfig,
27 ) -> StridedTile<ES> {
28 let layout = config.matrix_layout;
29 let stride = match layout {
30 MatrixLayout::RowMajor => config.elements_in_tile_col,
31 MatrixLayout::ColMajor => config.elements_in_tile_row,
32 };
33
34 let stride = comptime![stride / config.stage_line_size];
35
36 StridedTile::<ES> {
37 slice,
38 stride,
39 layout,
40 }
41 }
42
43 pub fn new_contiguous_mut(
47 slice: Slice<Line<ES>, ReadWrite>,
48 #[comptime] config: StageMemoryConfig,
49 ) -> StridedTile<ES, ReadWrite> {
50 let layout = config.matrix_layout;
51 let stride = match layout {
52 MatrixLayout::RowMajor => config.elements_in_tile_col,
53 MatrixLayout::ColMajor => config.elements_in_tile_row,
54 };
55
56 let stride = comptime![stride / config.stage_line_size];
57
58 StridedTile::<ES, ReadWrite> {
59 slice,
60 stride,
61 layout,
62 }
63 }
64
65 pub fn new_strided(
69 slice: Slice<Line<ES>>,
70 stride: u32,
71 #[comptime] layout: MatrixLayout,
72 ) -> StridedTile<ES> {
73 StridedTile::<ES> {
74 slice,
75 stride,
76 layout,
77 }
78 }
79
80 pub fn new_strided_mut(
84 slice: Slice<Line<ES>, ReadWrite>,
85 stride: u32,
86 #[comptime] layout: MatrixLayout,
87 ) -> StridedTile<ES, ReadWrite> {
88 StridedTile::<ES, ReadWrite> {
89 slice,
90 stride,
91 layout,
92 }
93 }
94}
95
96#[cube]
97impl<ES: Numeric, IO: SliceVisibility> StridedTile<ES, IO> {
98 pub fn as_unlined(&self) -> (Slice<ES, IO>, u32) {
104 let stage_line_size = comptime![self.slice.line_size()];
105 (
106 self.slice.try_cast_unchecked(),
107 self.stride * stage_line_size,
108 )
109 }
110
111 pub fn get_line(&self, coor_strided: u32, coor_contiguous: u32) -> Line<ES> {
113 self.slice[coor_strided * self.stride + coor_contiguous]
114 }
115}