cubecl_matmul/components/tile/base.rs
1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::error::MatmulSetupError;
5use crate::components::{
6 AvailableLineSizes, Ident, InvalidConfigError, MatmulPrecision, MatmulProblem, MatrixLayout,
7 TileSize, resource::ComputeResources, tile::tile_data::Tile,
8};
9use crate::components::{MatmulLineSizes, MatmulSelection};
10use std::{fmt::Debug, hash::Hash};
11
12/// A family of [TileMatmul] implementations that operate with any [precision](MatmulPrecision).
13pub trait TileMatmulFamily: Send + Sync + 'static {
14 /// The specific [TileMatmul] implementation associated with this family.
15 type Matmul<MP: MatmulPrecision>: TileMatmul<MP, Config = Self::Config>;
16
17 /// The configuration type associated with this matmul family.
18 type Config: TileConfig;
19
20 /// Returns whether this tile matmul requires specialized hardware accelerators (e.g., tensor cores).
21 fn requires_accelerator() -> bool;
22
23 /// Returns the compute resources required to run this tile matmul.
24 fn computation_resources() -> Result<ComputeResources, InvalidConfigError>;
25
26 /// Constructs the configuration based on the matmul problem, selection, and line sizes.
27 ///
28 /// This function may return an error if the configuration cannot be supported on the current runtime.
29 fn setup<MP: MatmulPrecision, R: Runtime>(
30 client: &ComputeClient<R::Server, R::Channel>,
31 problem: &MatmulProblem,
32 selection: &MatmulSelection,
33 line_sizes: &MatmulLineSizes,
34 ) -> Result<Self::Config, MatmulSetupError>;
35
36 /// Filters out line sizes that are incompatible with this matmul family.
37 ///
38 /// By default, returns the input unchanged.
39 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
40 available_line_sizes
41 }
42}
43
44/// Provides matrix multiplication operations at the tile level.
45///
46/// At the tile level,
47/// - Dimensions M, N and K are fixed to an integer, and the
48/// matrix multiplication works only for size (M, K) ยท (K, N) = (M, N).
49///
50/// Assumptions:
51/// - Inputs must always be valid. If the actual matrix multiplication
52/// should be done on smaller sizes than M, N and K, padding with zeros must be done beforehand.
53/// - Enough units are present to perform the whole computation
54#[cube]
55pub trait TileMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
56 /// The configuration type associated with this Matmul.
57 type Config: TileConfig;
58 /// Contains Lhs data for computation
59 type Lhs: CubeType;
60 /// Contains Rhs data for computation
61 type Rhs: CubeType;
62 /// Contains and accumulates results of the Tile Matmul execution
63 type Accumulator: CubeType;
64
65 /// Executes the matrix multiplication of Lhs and Rhs, adding the result to the accumulator
66 fn execute(
67 lhs: &Self::Lhs,
68 rhs: &Self::Rhs,
69 out: &mut Self::Accumulator,
70 #[comptime] config: Self::Config,
71 );
72
73 /// Create the container for Lhs
74 ///
75 /// # Safety
76 ///
77 /// This may point towards uninitialized memory.
78 /// Make sure to call [fill_lhs](TileMatmul::fill_lhs) prior to [execute](TileMatmul::execute).
79 fn allocate_lhs(#[comptime] config: Self::Config) -> Self::Lhs;
80
81 /// Fill the container of Lhs with tile data
82 fn fill_lhs(tile: &Tile<MP::ES>, lhs: &mut Self::Lhs, #[comptime] config: Self::Config);
83
84 /// Create the container for Rhs
85 ///
86 /// # Safety
87 ///
88 /// This may point towards uninitialized memory.
89 /// Make sure to call [fill_rhs](TileMatmul::fill_rhs) prior to [execute](TileMatmul::execute).
90 fn allocate_rhs(#[comptime] config: Self::Config) -> Self::Rhs;
91
92 /// Fill the container of Rhs with tile data
93 fn fill_rhs(tile: &Tile<MP::ES>, rhs: &mut Self::Rhs, #[comptime] config: Self::Config);
94
95 /// Allocate the container to receive the execution output.
96 ///
97 /// # Safety
98 ///
99 /// The output container must be initialized to some value (typically 0),
100 /// because the execution adds to the already present value.
101 /// Make sure to call either [fill_accumulator](TileMatmul::fill_accumulator)
102 /// or [zero_accumulator](TileMatmul::zero_accumulator).
103 fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
104
105 /// Fill the accumulator with data
106 fn fill_accumulator(
107 tile: &Tile<MP::EA>,
108 acc: &mut Self::Accumulator,
109 #[comptime] config: Self::Config,
110 );
111
112 /// Fill the accumulator with zeros.
113 fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config);
114
115 /// Write the content of the output container to the given slice
116 fn write_results(
117 out: &Self::Accumulator,
118 slice: &mut SliceMut<Line<MP::EO>>,
119 #[comptime] config: Self::Config,
120 );
121}
122
123/// Configuration for the Tile Matmul level
124pub trait TileConfig: Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static {
125 /// Returns the number of units in a plane
126 fn plane_dim(&self) -> u32;
127
128 /// Returns the [MatrixLayout] for the given ident
129 fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout;
130
131 /// Returns the line size for the given ident
132 fn stage_line_size<I: Into<Ident>>(&self, ident: I) -> u32;
133
134 /// Returns the line size for the given ident
135 fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32;
136
137 /// Returns the (m,n,k) shape of the tiles
138 fn tile_size(&self) -> &TileSize;
139}