cubecl_matmul/components/tile/
base.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::error::MatmulSetupError;
5use crate::components::{
6    AvailableLineSizes, InvalidConfigError, MatmulProblem, MatrixLayout, TileSize,
7    resource::ComputeResources,
8    tile::io::{Tile, TileKind},
9};
10use crate::components::{MatmulLineSizes, MatmulSelection};
11use crate::components::{StageIdent, tile::io::TileMut};
12use std::{fmt::Debug, hash::Hash};
13
14/// A family of [TileMatmul] implementations that operate with any [precision](MatmulPrecision).
15pub trait TileMatmulFamily: Send + Sync + 'static {
16    /// The specific [TileMatmul] implementation associated with this family.
17    type Matmul<L: Numeric, R: Numeric, A: Numeric>: TileMatmul<
18            L,
19            R,
20            A,
21            Config = Self::Config,
22            LhsTile = Self::LhsTile,
23            RhsTile = Self::RhsTile,
24            AccTile = Self::AccTile,
25            OutTile = Self::OutTile,
26        >;
27
28    /// Tile kind for Lhs
29    type LhsTile: TileKind;
30    /// Tile kind for Rhs
31    type RhsTile: TileKind;
32    /// Tile kind for Acc
33    type AccTile: TileKind;
34    /// Tile kind for Out
35    type OutTile: TileKind<ReadWrite>;
36
37    /// The configuration type associated with this matmul family.
38    type Config: TileConfig;
39
40    /// Returns whether this tile matmul requires specialized hardware accelerators (e.g., tensor cores).
41    fn requires_accelerator() -> bool;
42
43    /// Returns the compute resources required to run this tile matmul.
44    fn computation_resources() -> Result<ComputeResources, InvalidConfigError>;
45
46    /// Constructs the configuration based on the matmul problem, selection, and line sizes.
47    ///
48    /// This function may return an error if the configuration cannot be supported on the current runtime.
49    fn setup<Lhs: Numeric, Rhs: Numeric, Acc: Numeric, R: Runtime>(
50        client: &ComputeClient<R::Server>,
51        problem: &MatmulProblem,
52        selection: &MatmulSelection,
53        matmul_line_sizes: &MatmulLineSizes,
54    ) -> Result<Self::Config, MatmulSetupError>;
55
56    /// Filters out line sizes that are incompatible with this matmul family.
57    ///
58    /// By default, returns the input unchanged.
59    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
60        available_line_sizes
61    }
62}
63
64/// Provides matrix multiplication operations at the tile level.
65///
66/// At the tile level,
67///  - Dimensions M, N and K are fixed to an integer, and the
68///    matrix multiplication works only for size (M, K) ยท (K, N) = (M, N).
69///
70/// Assumptions:
71///  - Inputs must always be valid. If the actual matrix multiplication
72///    should be done on smaller sizes than M, N and K, padding with zeros must be done beforehand.
73///  - Enough units are present to perform the whole computation
74#[cube]
75pub trait TileMatmul<L: Numeric, R: Numeric, A: Numeric>: 'static + Send + Sync {
76    /// The configuration type associated with this Matmul.
77    type Config: TileConfig;
78
79    /// Contains Lhs data for computation
80    type LhsFragment: CubeType;
81    /// Contains Rhs data for computation
82    type RhsFragment: CubeType;
83    /// Contains and accumulates results of the Tile Matmul execution
84    type AccFragment: CubeType;
85
86    /// Tile for the lhs data
87    type LhsTile: TileKind;
88    /// Tile for the rhs data
89    type RhsTile: TileKind;
90    /// Tile for the accumulator data
91    type AccTile: TileKind;
92    /// Tile for the output data
93    type OutTile: TileKind<ReadWrite>;
94
95    /// Executes the matrix multiplication of Lhs and Rhs, adding the result to the accumulator
96    fn execute(
97        lhs: &Self::LhsFragment,
98        rhs: &Self::RhsFragment,
99        out: &mut Self::AccFragment,
100        #[comptime] config: Self::Config,
101    );
102
103    /// Create the container for Lhs
104    ///
105    /// # Safety
106    ///
107    /// This may point towards uninitialized memory.
108    /// Make sure to call [load_lhs](TileMatmul::load_lhs) prior to [execute](TileMatmul::execute).
109    fn allocate_lhs(#[comptime] config: Self::Config) -> Self::LhsFragment;
110
111    /// Load the container of Lhs from tile data
112    fn load_lhs<E: Numeric>(
113        tile: &Tile<Self::LhsTile, E>,
114        lhs: &mut Self::LhsFragment,
115        #[comptime] config: Self::Config,
116    );
117
118    /// Create the container for Rhs
119    ///
120    /// # Safety
121    ///
122    /// This may point towards uninitialized memory.
123    /// Make sure to call [load_rhs](TileMatmul::load_rhs) prior to [execute](TileMatmul::execute).
124    fn allocate_rhs(#[comptime] config: Self::Config) -> Self::RhsFragment;
125
126    /// Load the container of Rhs from tile data
127    fn load_rhs<E: Numeric>(
128        tile: &Tile<Self::RhsTile, E>,
129        rhs: &mut Self::RhsFragment,
130        #[comptime] config: Self::Config,
131    );
132
133    /// Allocate the container to receive the execution output.
134    ///
135    /// # Safety
136    ///
137    /// The output container must be initialized to some value (typically 0),
138    /// because the execution adds to the already present value.
139    /// Make sure to call [load_acc](TileMatmul::load_acc) prior to [execute](TileMatmul::execute).
140    fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment;
141
142    /// Load the container of Acc from tile data
143    fn load_acc<E: Numeric>(
144        tile: &Tile<Self::AccTile, E>,
145        acc: &mut Self::AccFragment,
146        #[comptime] config: Self::Config,
147    );
148
149    /// Write the content of the output container to the given slice
150    fn write_results<E: Numeric>(
151        tile: &mut TileMut<Self::OutTile, E>,
152        out: &Self::AccFragment,
153        #[comptime] config: Self::Config,
154    );
155}
156
157/// Configuration for the Tile Matmul level
158pub trait TileConfig: Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static {
159    /// Returns the number of units in a plane
160    fn plane_dim(&self) -> u32;
161
162    /// Returns the [MatrixLayout] for the given ident
163    fn matrix_layout(&self, ident: StageIdent) -> MatrixLayout;
164
165    /// Returns the line size for the given ident
166    fn stage_line_size(&self, ident: StageIdent) -> u32;
167
168    /// Returns the line size for the given ident
169    fn global_line_size(&self, ident: StageIdent) -> u32;
170
171    /// Returns the (m,n,k) shape of the tiles
172    fn tile_size(&self) -> &TileSize;
173}