cubecl_matmul/components/tile/
tile_data.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::{Ident, InputIdent, MatrixLayout, tile::TileConfig};
5
6#[derive(CubeType, Clone)]
7pub struct Tile<ES: Numeric> {
9 pub slice: Slice<Line<ES>>,
11 pub stride: u32,
13 #[cube(comptime)]
14 pub layout: MatrixLayout,
16}
17
18#[cube]
19impl<ES: Numeric> Tile<ES> {
20 pub fn new_contiguous<T: TileConfig>(
24 slice: Slice<Line<ES>>,
25 #[comptime] ident: Ident,
26 #[comptime] config: T,
27 ) -> Tile<ES> {
28 let layout = config.matrix_layout(ident);
29 let stride = comptime! {
30 (match ident.as_input_ident() {
31 InputIdent::Lhs => match layout {
32 MatrixLayout::RowMajor => config.tile_size().k(),
33 MatrixLayout::ColMajor => config.tile_size().m(),
34 },
35 InputIdent::Rhs => match layout {
36 MatrixLayout::RowMajor => config.tile_size().n(),
37 MatrixLayout::ColMajor => config.tile_size().k(),
38 },
39 }) / config.stage_line_size(ident)};
40
41 Tile::<ES> {
42 slice,
43 stride,
44 layout,
45 }
46 }
47
48 pub fn new_strided(
52 slice: Slice<Line<ES>>,
53 stride: u32,
54 #[comptime] layout: MatrixLayout,
55 ) -> Tile<ES> {
56 Tile::<ES> {
57 slice,
58 stride,
59 layout,
60 }
61 }
62
63 pub fn as_unlined<T: TileConfig>(
69 &self,
70 #[comptime] ident: Ident,
71 #[comptime] config: T,
72 ) -> (Slice<ES>, u32) {
73 (
74 self.slice.try_cast_unchecked(),
75 self.stride * config.stage_line_size(ident),
76 )
77 }
78
79 pub fn get_line(&self, coor_strided: u32, coor_contiguous: u32) -> Line<ES> {
81 self.slice[coor_strided * self.stride + coor_contiguous]
82 }
83}