cubecl_matmul/components/tile/
base.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::error::MatmulSetupError;
5use crate::components::{
6    AvailableLineSizes, Ident, InvalidConfigError, MatmulPrecision, MatmulProblem, MatrixLayout,
7    TileSize, resource::ComputeResources, tile::tile_data::Tile,
8};
9use crate::components::{MatmulLineSizes, MatmulSelection};
10use std::{fmt::Debug, hash::Hash};
11
12/// A family of [TileMatmul] implementations that operate with any [precision](MatmulPrecision).
13pub trait TileMatmulFamily: Send + Sync + 'static {
14    /// The specific [TileMatmul] implementation associated with this family.
15    type Matmul<MP: MatmulPrecision>: TileMatmul<MP, Config = Self::Config>;
16
17    /// The configuration type associated with this matmul family.
18    type Config: TileConfig;
19
20    /// Returns whether this tile matmul requires specialized hardware accelerators (e.g., tensor cores).
21    fn requires_accelerator() -> bool;
22
23    /// Returns the compute resources required to run this tile matmul.
24    fn computation_resources() -> Result<ComputeResources, InvalidConfigError>;
25
26    /// Constructs the configuration based on the matmul problem, selection, and line sizes.
27    ///
28    /// This function may return an error if the configuration cannot be supported on the current runtime.
29    fn setup<MP: MatmulPrecision, R: Runtime>(
30        client: &ComputeClient<R::Server, R::Channel>,
31        problem: &MatmulProblem,
32        selection: &MatmulSelection,
33        line_sizes: &MatmulLineSizes,
34    ) -> Result<Self::Config, MatmulSetupError>;
35
36    /// Filters out line sizes that are incompatible with this matmul family.
37    ///
38    /// By default, returns the input unchanged.
39    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
40        available_line_sizes
41    }
42}
43
44/// Provides matrix multiplication operations at the tile level.
45///
46/// At the tile level,
47///  - Dimensions M, N and K are fixed to an integer, and the
48///    matrix multiplication works only for size (M, K) ยท (K, N) = (M, N).
49///
50/// Assumptions:
51///  - Inputs must always be valid. If the actual matrix multiplication
52///    should be done on smaller sizes than M, N and K, padding with zeros must be done beforehand.
53///  - Enough units are present to perform the whole computation
54#[cube]
55pub trait TileMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
56    /// The configuration type associated with this Matmul.
57    type Config: TileConfig;
58    /// Contains Lhs data for computation
59    type Lhs: CubeType;
60    /// Contains Rhs data for computation
61    type Rhs: CubeType;
62    /// Contains and accumulates results of the Tile Matmul execution
63    type Accumulator: CubeType;
64
65    /// Executes the matrix multiplication of Lhs and Rhs, adding the result to the accumulator
66    fn execute(
67        lhs: &Self::Lhs,
68        rhs: &Self::Rhs,
69        out: &mut Self::Accumulator,
70        #[comptime] config: Self::Config,
71    );
72
73    /// Create the container for Lhs
74    ///
75    /// # Safety
76    ///
77    /// This may point towards uninitialized memory.
78    /// Make sure to call [fill_lhs](TileMatmul::fill_lhs) prior to [execute](TileMatmul::execute).
79    fn allocate_lhs(#[comptime] config: Self::Config) -> Self::Lhs;
80
81    /// Fill the container of Lhs with tile data
82    fn fill_lhs(tile: &Tile<MP::ES>, lhs: &mut Self::Lhs, #[comptime] config: Self::Config);
83
84    /// Create the container for Rhs
85    ///
86    /// # Safety
87    ///
88    /// This may point towards uninitialized memory.
89    /// Make sure to call [fill_rhs](TileMatmul::fill_rhs) prior to [execute](TileMatmul::execute).
90    fn allocate_rhs(#[comptime] config: Self::Config) -> Self::Rhs;
91
92    /// Fill the container of Rhs with tile data
93    fn fill_rhs(tile: &Tile<MP::ES>, rhs: &mut Self::Rhs, #[comptime] config: Self::Config);
94
95    /// Allocate the container to receive the execution output.
96    ///
97    /// # Safety
98    ///
99    /// The output container must be initialized to some value (typically 0),
100    /// because the execution adds to the already present value.
101    /// Make sure to call either [fill_accumulator](TileMatmul::fill_accumulator)
102    /// or [zero_accumulator](TileMatmul::zero_accumulator).
103    fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
104
105    /// Fill the accumulator with data
106    fn fill_accumulator(
107        tile: &Tile<MP::EA>,
108        acc: &mut Self::Accumulator,
109        #[comptime] config: Self::Config,
110    );
111
112    /// Fill the accumulator with zeros.
113    fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config);
114
115    /// Write the content of the output container to the given slice
116    fn write_results(
117        out: &Self::Accumulator,
118        slice: &mut SliceMut<Line<MP::EO>>,
119        #[comptime] config: Self::Config,
120    );
121}
122
123/// Configuration for the Tile Matmul level
124pub trait TileConfig: Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static {
125    /// Returns the number of units in a plane
126    fn plane_dim(&self) -> u32;
127
128    /// Returns the [MatrixLayout] for the given ident
129    fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout;
130
131    /// Returns the line size for the given ident
132    fn stage_line_size<I: Into<Ident>>(&self, ident: I) -> u32;
133
134    /// Returns the line size for the given ident
135    fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32;
136
137    /// Returns the (m,n,k) shape of the tiles
138    fn tile_size(&self) -> &TileSize;
139}