cubecl_matmul/components/tile/
tile_data.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::{Ident, InputIdent, MatrixLayout, tile::TileConfig};
5
6#[derive(CubeType, Clone)]
7/// Data to be handed to the Tile Matmul
8pub struct Tile<ES: Numeric> {
9    /// Slice containing all data
10    pub slice: Slice<Line<ES>>,
11    /// Stride between each row/col, depending on MatrixLayout (the other is assumed to be 1)
12    pub stride: u32,
13    #[cube(comptime)]
14    /// Layout of the tile (row-major or column-major).
15    pub layout: MatrixLayout,
16}
17
18#[cube]
19impl<ES: Numeric> Tile<ES> {
20    /// Creates a tile from a contiguous slice of data.
21    ///
22    /// The slice length must exactly match the tile size.
23    pub fn new_contiguous<T: TileConfig>(
24        slice: Slice<Line<ES>>,
25        #[comptime] ident: Ident,
26        #[comptime] config: T,
27    ) -> Tile<ES> {
28        let layout = config.matrix_layout(ident);
29        let stride = comptime! {
30            (match ident.as_input_ident() {
31            InputIdent::Lhs => match layout {
32                MatrixLayout::RowMajor => config.tile_size().k(),
33                MatrixLayout::ColMajor => config.tile_size().m(),
34            },
35            InputIdent::Rhs => match layout {
36                MatrixLayout::RowMajor => config.tile_size().n(),
37                MatrixLayout::ColMajor => config.tile_size().k(),
38            },
39        }) / config.stage_line_size(ident)};
40
41        Tile::<ES> {
42            slice,
43            stride,
44            layout,
45        }
46    }
47
48    /// Creates a tile from a strided slice of data.
49    ///
50    /// The slice must include all elements of the tile, though it may include unused gaps.
51    pub fn new_strided(
52        slice: Slice<Line<ES>>,
53        stride: u32,
54        #[comptime] layout: MatrixLayout,
55    ) -> Tile<ES> {
56        Tile::<ES> {
57            slice,
58            stride,
59            layout,
60        }
61    }
62
63    /// Returns the tile as an unlined (scalar) slice.
64    ///
65    /// Returns:
66    /// - The unlined slice
67    /// - The updated stride to account for line width removal
68    pub fn as_unlined<T: TileConfig>(
69        &self,
70        #[comptime] ident: Ident,
71        #[comptime] config: T,
72    ) -> (Slice<ES>, u32) {
73        (
74            self.slice.try_cast_unchecked(),
75            self.stride * config.stage_line_size(ident),
76        )
77    }
78
79    /// Returns a specific line from the tile based on coordinates.
80    pub fn get_line(&self, coor_strided: u32, coor_contiguous: u32) -> Line<ES> {
81        self.slice[coor_strided * self.stride + coor_contiguous]
82    }
83}