cubecl_linalg/matmul/components/tile/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::matmul::components::{
5    Ident, InputIdent, MatmulConfigFactory, MatmulPrecision, MatmulSize, MatrixLayout,
6    config::MatmulConfig, stage::shared::StageVectorization,
7};
8
9#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
10pub struct TileMatmulConfigInput {
11    pub vectorization: StageVectorization,
12    pub size: MatmulSize,
13}
14
15pub trait TileMatmulFamily:
16    MatmulConfigFactory<Input = TileMatmulConfigInput, Config: TileConfig>
17{
18    fn tile_shape(config: &Self::Config) -> MatmulSize;
19    fn requires_tensor_cores() -> bool;
20
21    type Matmul<MP: MatmulPrecision>: TileMatmul<MP, Config = Self::Config>;
22}
23
24/// Provides matrix multiplication operations at the tile level.
25///
26/// At the tile level,
27///  - Inputs are raw slices of data, called tiles.
28///  - units within one plane can collaborate to solve the problem
29///  - dimensions M, N and K are fixed to an integer, and the
30///    matrix multiplication works only for size (M, K) ยท (K, N) = (M, N).
31///
32/// Assumptions:
33///  - Slices given as inputs must always be valid. If the actual matrix multiplication
34///    should be done on smaller sizes than M, N and K, padding with zeros must be done beforehand.
35///  - Enough units are present to perform the whole computation
36#[cube]
37pub trait TileMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
38    type Config: TileConfig;
39    /// Contains LHS data that can be split across the units
40    type Lhs: CubeType;
41    /// Contains RHS data that can be split across the units
42    type Rhs: CubeType;
43    /// Contains output data that can be split across the units
44    type Accumulator: CubeType + Copy + Clone;
45
46    /// Executes the matrix multiplication of LHS and RHS, adding the result to the output
47    fn execute(
48        lhs: &Self::Lhs,
49        rhs: &Self::Rhs,
50        out: &mut Self::Accumulator,
51        #[comptime] config: Self::Config,
52    );
53
54    /// Create the container for LHS data
55    ///
56    /// # Safety
57    ///
58    /// This may point towards uninitialized memory.
59    /// Make sure to call [fill_lhs](TileMatmul::fill_lhs) prior to [execute](TileMatmul::execute).
60    fn allocate_lhs(#[comptime] config: Self::Config) -> Self::Lhs;
61
62    /// Create the container for RHS data
63    ///
64    /// # Safety
65    ///
66    /// This may point towards uninitialized memory.
67    /// Make sure to call [fill_rhs](TileMatmul::fill_lhs) prior to [execute](TileMatmul::execute).
68    fn allocate_rhs(#[comptime] config: Self::Config) -> Self::Rhs;
69
70    /// Fill the container of LHS with data
71    fn fill_lhs(slice: &Tile<MP::ES>, lhs: &mut Self::Lhs, #[comptime] config: Self::Config);
72
73    /// Fill the container of RHS with data
74    fn fill_rhs(slice: &Tile<MP::ES>, rhs: &mut Self::Rhs, #[comptime] config: Self::Config);
75
76    /// Fill the accumulator with data
77    fn fill_accumulator(
78        tile: &Tile<MP::EA>,
79        acc: &mut Self::Accumulator,
80        #[comptime] config: Self::Config,
81    );
82
83    /// Write the content of the output container to the given slice
84    fn read_accumulator<C: Numeric>(
85        // TODO is this always MP::EG?
86        out: &Self::Accumulator,
87        slice: &mut SliceMut<Line<C>>,
88        #[comptime] config: Self::Config,
89    );
90
91    /// Allocate the container to receive the execution output.
92    ///
93    /// # Safety
94    ///
95    /// The output container must be initialized to some value (typically 0),
96    /// because the execution adds to the already present value.
97    /// Make sure to call either [fill_accumulator](TileMatmul::fill_accumulator)
98    /// or [zero_accumulator](TileMatmul::zero_accumulator).
99    fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
100
101    /// Fill the accumulator with zeros.
102    fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config);
103}
104
105/// Configuration for the Tile matmul (TMM) level
106pub trait TileConfig: MatmulConfig {
107    /// Returns the size of the plane dimension
108    fn plane_dim(&self) -> u32;
109
110    /// Returns the [MatrixLayout] for the given ident
111    fn matrix_layout(&self, ident: Ident) -> MatrixLayout;
112
113    /// Returns the line size for the given ident
114    fn stage_line_size(&self, ident: Ident) -> u32;
115
116    /// Returns the shape of the tiles in the three axes m, k and n.
117    fn tile_shape(&self) -> &MatmulSize;
118}
119
120#[derive(CubeType)]
121/// Data to be handed to the tile matmul
122pub struct Tile<ES: Numeric> {
123    /// Slice containing all data
124    pub slice: Slice<Line<ES>>,
125    /// Stride between each row/col, depending on MatrixLayout (the other is assumed to be 1)
126    pub stride: u32,
127}
128
129#[cube]
130impl<ES: Numeric> Tile<ES> {
131    pub fn new_contiguous<T: TileConfig>(
132        slice: Slice<Line<ES>>,
133        #[comptime] ident: Ident,
134        #[comptime] config: T,
135    ) -> Tile<ES> {
136        let stride = comptime! {
137            (match ident.as_input_ident() {
138            InputIdent::Lhs => match config.matrix_layout(ident) {
139                MatrixLayout::RowMajor => config.tile_shape().k,
140                MatrixLayout::ColMajor => config.tile_shape().m,
141            },
142            InputIdent::Rhs => match config.matrix_layout(ident) {
143                MatrixLayout::RowMajor => config.tile_shape().n,
144                MatrixLayout::ColMajor => config.tile_shape().k,
145            },
146        }) / config.stage_line_size(ident)};
147
148        Tile::<ES> { slice, stride }
149    }
150
151    pub fn new_strided(slice: Slice<Line<ES>>, stride: u32) -> Tile<ES> {
152        Tile::<ES> { slice, stride }
153    }
154
155    pub fn as_unlined<T: TileConfig>(
156        &self,
157        #[comptime] ident: Ident,
158        #[comptime] config: T,
159    ) -> (Slice<ES>, u32) {
160        (
161            self.slice.try_cast_unchecked(),
162            self.stride * config.stage_line_size(ident),
163        )
164    }
165}