cubecl_matmul/components/tile/base.rs
1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::error::MatmulSetupError;
5use crate::components::{
6 AvailableLineSizes, InvalidConfigError, MatmulProblem, MatrixLayout, TileSize,
7 resource::ComputeResources,
8 tile::io::{Tile, TileKind},
9};
10use crate::components::{MatmulLineSizes, MatmulSelection};
11use crate::components::{StageIdent, tile::io::TileMut};
12use std::{fmt::Debug, hash::Hash};
13
14/// A family of [TileMatmul] implementations that operate with any [precision](MatmulPrecision).
15pub trait TileMatmulFamily: Send + Sync + 'static {
16 /// The specific [TileMatmul] implementation associated with this family.
17 type Matmul<L: Numeric, R: Numeric, A: Numeric>: TileMatmul<
18 L,
19 R,
20 A,
21 Config = Self::Config,
22 LhsTile = Self::LhsTile,
23 RhsTile = Self::RhsTile,
24 AccTile = Self::AccTile,
25 OutTile = Self::OutTile,
26 >;
27
28 /// Tile kind for Lhs
29 type LhsTile: TileKind;
30 /// Tile kind for Rhs
31 type RhsTile: TileKind;
32 /// Tile kind for Acc
33 type AccTile: TileKind;
34 /// Tile kind for Out
35 type OutTile: TileKind<ReadWrite>;
36
37 /// The configuration type associated with this matmul family.
38 type Config: TileConfig;
39
40 /// Returns whether this tile matmul requires specialized hardware accelerators (e.g., tensor cores).
41 fn requires_accelerator() -> bool;
42
43 /// Returns the compute resources required to run this tile matmul.
44 fn computation_resources() -> Result<ComputeResources, InvalidConfigError>;
45
46 /// Constructs the configuration based on the matmul problem, selection, and line sizes.
47 ///
48 /// This function may return an error if the configuration cannot be supported on the current runtime.
49 fn setup<Lhs: Numeric, Rhs: Numeric, Acc: Numeric, R: Runtime>(
50 client: &ComputeClient<R::Server>,
51 problem: &MatmulProblem,
52 selection: &MatmulSelection,
53 matmul_line_sizes: &MatmulLineSizes,
54 ) -> Result<Self::Config, MatmulSetupError>;
55
56 /// Filters out line sizes that are incompatible with this matmul family.
57 ///
58 /// By default, returns the input unchanged.
59 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
60 available_line_sizes
61 }
62}
63
64/// Provides matrix multiplication operations at the tile level.
65///
66/// At the tile level,
67/// - Dimensions M, N and K are fixed to an integer, and the
68/// matrix multiplication works only for size (M, K) ยท (K, N) = (M, N).
69///
70/// Assumptions:
71/// - Inputs must always be valid. If the actual matrix multiplication
72/// should be done on smaller sizes than M, N and K, padding with zeros must be done beforehand.
73/// - Enough units are present to perform the whole computation
74#[cube]
75pub trait TileMatmul<L: Numeric, R: Numeric, A: Numeric>: 'static + Send + Sync {
76 /// The configuration type associated with this Matmul.
77 type Config: TileConfig;
78
79 /// Contains Lhs data for computation
80 type LhsFragment: CubeType;
81 /// Contains Rhs data for computation
82 type RhsFragment: CubeType;
83 /// Contains and accumulates results of the Tile Matmul execution
84 type AccFragment: CubeType;
85
86 /// Tile for the lhs data
87 type LhsTile: TileKind;
88 /// Tile for the rhs data
89 type RhsTile: TileKind;
90 /// Tile for the accumulator data
91 type AccTile: TileKind;
92 /// Tile for the output data
93 type OutTile: TileKind<ReadWrite>;
94
95 /// Executes the matrix multiplication of Lhs and Rhs, adding the result to the accumulator
96 fn execute(
97 lhs: &Self::LhsFragment,
98 rhs: &Self::RhsFragment,
99 out: &mut Self::AccFragment,
100 #[comptime] config: Self::Config,
101 );
102
103 /// Create the container for Lhs
104 ///
105 /// # Safety
106 ///
107 /// This may point towards uninitialized memory.
108 /// Make sure to call [load_lhs](TileMatmul::load_lhs) prior to [execute](TileMatmul::execute).
109 fn allocate_lhs(#[comptime] config: Self::Config) -> Self::LhsFragment;
110
111 /// Load the container of Lhs from tile data
112 fn load_lhs<E: Numeric>(
113 tile: &Tile<Self::LhsTile, E>,
114 lhs: &mut Self::LhsFragment,
115 #[comptime] config: Self::Config,
116 );
117
118 /// Create the container for Rhs
119 ///
120 /// # Safety
121 ///
122 /// This may point towards uninitialized memory.
123 /// Make sure to call [load_rhs](TileMatmul::load_rhs) prior to [execute](TileMatmul::execute).
124 fn allocate_rhs(#[comptime] config: Self::Config) -> Self::RhsFragment;
125
126 /// Load the container of Rhs from tile data
127 fn load_rhs<E: Numeric>(
128 tile: &Tile<Self::RhsTile, E>,
129 rhs: &mut Self::RhsFragment,
130 #[comptime] config: Self::Config,
131 );
132
133 /// Allocate the container to receive the execution output.
134 ///
135 /// # Safety
136 ///
137 /// The output container must be initialized to some value (typically 0),
138 /// because the execution adds to the already present value.
139 /// Make sure to call [load_acc](TileMatmul::load_acc) prior to [execute](TileMatmul::execute).
140 fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment;
141
142 /// Load the container of Acc from tile data
143 fn load_acc<E: Numeric>(
144 tile: &Tile<Self::AccTile, E>,
145 acc: &mut Self::AccFragment,
146 #[comptime] config: Self::Config,
147 );
148
149 /// Write the content of the output container to the given slice
150 fn write_results<E: Numeric>(
151 tile: &mut TileMut<Self::OutTile, E>,
152 out: &Self::AccFragment,
153 #[comptime] config: Self::Config,
154 );
155}
156
157/// Configuration for the Tile Matmul level
158pub trait TileConfig: Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static {
159 /// Returns the number of units in a plane
160 fn plane_dim(&self) -> u32;
161
162 /// Returns the [MatrixLayout] for the given ident
163 fn matrix_layout(&self, ident: StageIdent) -> MatrixLayout;
164
165 /// Returns the line size for the given ident
166 fn stage_line_size(&self, ident: StageIdent) -> u32;
167
168 /// Returns the line size for the given ident
169 fn global_line_size(&self, ident: StageIdent) -> u32;
170
171 /// Returns the (m,n,k) shape of the tiles
172 fn tile_size(&self) -> &TileSize;
173}