cubecl_linalg/matmul/components/
config.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use std::fmt::{Debug, Display};
4use std::hash::Hash;
5
6use crate::matmul::kernels::MatmulAvailabilityError;
7
8use super::{MatmulPrecision, MatmulProblem, MatmulSize};
9
10pub type InvalidConfigError = Box<dyn Display>;
11
12pub struct FormattedConfigError {
13    func: Box<dyn Fn() -> String>,
14}
15
16impl FormattedConfigError {
17    #[allow(clippy::new_ret_no_self)]
18    pub fn new<F: Fn() -> String + 'static>(func: F) -> Box<dyn Display> {
19        Box::new(Self {
20            func: Box::new(func),
21        })
22    }
23}
24
25impl Display for FormattedConfigError {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        let string = (self.func)();
28        write!(f, "{string}")
29    }
30}
31
32/// Provides configuration for a matmul kernel at any level
33pub trait MatmulConfigFactory: Send + Sync + 'static {
34    /// Configuration tailored to the matmul implementation
35    type Config: MatmulConfig;
36    type Input;
37
38    /// Asserts that the configuration for this matmul will lead to a valid computation
39    fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>;
40
41    /// Checks if the client can handle the features used in this computation
42    #[allow(clippy::result_large_err)]
43    fn check_availability<R: Runtime, MP: MatmulPrecision>(
44        _client: &ComputeClient<R::Server, R::Channel>,
45        _config: &Self::Config,
46    ) -> Result<(), MatmulAvailabilityError>;
47
48    /// Create config for this matmul, given launch information
49    fn make_config(
50        input: Self::Input,
51        problem: &MatmulProblem,
52        cube_dim: &CubeDim,
53        cube_count: &CubeCount,
54        quantized: bool,
55    ) -> Self::Config;
56}
57
58/// A config for a matmul
59///
60/// Useful to aggregate many trait bounds
61pub trait MatmulConfig:
62    Copy + Clone + Send + Sync + 'static + Eq + PartialEq + Hash + Debug
63{
64}
65
66#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
67/// Identifier for all three tensors in a matmul
68///
69/// Useful to specialize some functions depending on the tensor
70pub enum Ident {
71    Lhs,
72    Rhs,
73    Out,
74}
75
76impl Ident {
77    pub fn as_input_ident(&self) -> InputIdent {
78        match self {
79            Ident::Lhs => InputIdent::Lhs,
80            Ident::Rhs => InputIdent::Rhs,
81            Ident::Out => panic!("Out is not an input."),
82        }
83    }
84}
85
86#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
87/// Identifier for the two input tensors in a matmul.
88///
89/// Useful to specialize some functions depending on the tensor
90pub enum InputIdent {
91    Lhs,
92    Rhs,
93}
94
95impl InputIdent {
96    pub fn as_ident(&self) -> Ident {
97        match self {
98            InputIdent::Lhs => Ident::Lhs,
99            InputIdent::Rhs => Ident::Rhs,
100        }
101    }
102}
103
104impl From<InputIdent> for Ident {
105    fn from(value: InputIdent) -> Self {
106        value.as_ident()
107    }
108}
109
110#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug)]
111/// Layout of a 2D structure such as a tensor, shared memory or slice,
112/// used within any matmul kernel level
113pub enum MatrixLayout {
114    RowMajor,
115    ColMajor,
116}
117
118#[cube]
119/// Maps the matmul MatrixLayout to cmma's MatrixLayout, for use in Cmma API.
120pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
121    match layout {
122        MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
123        MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
124    }
125}
126
127#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
128/// Aggregation of [StageTiling]s for all components.
129pub struct CompleteStageTiling {
130    pub tile_shape: MatmulSize,
131    pub tile_count: MatmulSize,
132}
133
134impl CompleteStageTiling {
135    pub fn get(&self, ident: Ident) -> TilingDimensions {
136        match ident {
137            Ident::Lhs => TilingDimensions {
138                tile_shape_row: self.tile_shape.m,
139                tile_shape_col: self.tile_shape.k,
140                tile_count_row: self.tile_count.m,
141                tile_count_col: self.tile_count.k,
142            },
143            Ident::Rhs => TilingDimensions {
144                tile_shape_row: self.tile_shape.k,
145                tile_shape_col: self.tile_shape.n,
146                tile_count_row: self.tile_count.k,
147                tile_count_col: self.tile_count.n,
148            },
149            Ident::Out => TilingDimensions {
150                tile_shape_row: self.tile_shape.m,
151                tile_shape_col: self.tile_shape.n,
152                tile_count_row: self.tile_count.m,
153                tile_count_col: self.tile_count.n,
154            },
155        }
156    }
157
158    pub fn total_shape(&self) -> MatmulSize {
159        MatmulSize {
160            m: self.tile_shape.m * self.tile_count.m,
161            n: self.tile_shape.n * self.tile_count.n,
162            k: self.tile_shape.k * self.tile_count.k,
163        }
164    }
165}
166
167#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
168/// Dimensions for stage.
169pub struct TilingDimensions {
170    pub tile_shape_row: u32,
171    pub tile_shape_col: u32,
172    pub tile_count_row: u32,
173    pub tile_count_col: u32,
174}
175
176impl TilingDimensions {
177    /// Returns the total number of elements of the stage.
178    pub fn total_size(&self) -> u32 {
179        self.total_row() * self.total_col()
180    }
181
182    /// Returns the total number of rows of the stage.
183    pub fn total_row(&self) -> u32 {
184        self.tile_count_row() * self.tile_shape_row()
185    }
186
187    /// Returns the total number of columns of the stage.
188    pub fn total_col(&self) -> u32 {
189        self.tile_count_col() * self.tile_shape_col()
190    }
191
192    /// Returns the number of elements within one tile.
193    pub fn tile_size(&self) -> u32 {
194        self.tile_shape_row() * self.tile_shape_col()
195    }
196
197    /// Returns the size of the row axis of a tile.
198    pub fn tile_shape_row(&self) -> u32 {
199        self.tile_shape_row
200    }
201
202    /// Returns the size of the column axis of a tile.
203    pub fn tile_shape_col(&self) -> u32 {
204        self.tile_shape_col
205    }
206
207    /// Returns the number of tiles within the stage.
208    pub fn tile_count(&self) -> u32 {
209        self.tile_count_row() * self.tile_count_col()
210    }
211
212    /// Returns the number of tiles across the row axis of the stage.
213    pub fn tile_count_row(&self) -> u32 {
214        self.tile_count_row
215    }
216
217    /// Returns the number of tiles across the column axis of the stage.
218    pub fn tile_count_col(&self) -> u32 {
219        self.tile_count_col
220    }
221}
222
223pub trait TensorIdent:
224    Clone + Copy + Debug + Hash + PartialEq + Eq + Send + Sync + 'static
225{
226    const IDENT: Ident;
227}
228
229#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
230pub struct Lhs;
231#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
232pub struct Rhs;
233#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
234pub struct Out;
235
236impl TensorIdent for Lhs {
237    const IDENT: Ident = Ident::Lhs;
238}
239
240impl TensorIdent for Rhs {
241    const IDENT: Ident = Ident::Rhs;
242}
243
244impl TensorIdent for Out {
245    const IDENT: Ident = Ident::Out;
246}