cubecl_linalg/matmul/components/
config.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use std::fmt::{Debug, Display};
4use std::hash::Hash;
5
6use crate::matmul::kernels::MatmulAvailabilityError;
7
8use super::{MatmulPrecision, MatmulProblem, MatmulSize};
9
10pub type InvalidConfigError = Box<dyn Display>;
11
12pub struct FormattedConfigError {
13 func: Box<dyn Fn() -> String>,
14}
15
16impl FormattedConfigError {
17 #[allow(clippy::new_ret_no_self)]
18 pub fn new<F: Fn() -> String + 'static>(func: F) -> Box<dyn Display> {
19 Box::new(Self {
20 func: Box::new(func),
21 })
22 }
23}
24
25impl Display for FormattedConfigError {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 let string = (self.func)();
28 write!(f, "{string}")
29 }
30}
31
32pub trait MatmulConfigFactory: Send + Sync + 'static {
34 type Config: MatmulConfig;
36 type Input;
37
38 fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>;
40
41 #[allow(clippy::result_large_err)]
43 fn check_availability<R: Runtime, MP: MatmulPrecision>(
44 _client: &ComputeClient<R::Server, R::Channel>,
45 _config: &Self::Config,
46 ) -> Result<(), MatmulAvailabilityError>;
47
48 fn make_config(
50 input: Self::Input,
51 problem: &MatmulProblem,
52 cube_dim: &CubeDim,
53 cube_count: &CubeCount,
54 quantized: bool,
55 ) -> Self::Config;
56}
57
58pub trait MatmulConfig:
62 Copy + Clone + Send + Sync + 'static + Eq + PartialEq + Hash + Debug
63{
64}
65
66#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
67pub enum Ident {
71 Lhs,
72 Rhs,
73 Out,
74}
75
76impl Ident {
77 pub fn as_input_ident(&self) -> InputIdent {
78 match self {
79 Ident::Lhs => InputIdent::Lhs,
80 Ident::Rhs => InputIdent::Rhs,
81 Ident::Out => panic!("Out is not an input."),
82 }
83 }
84}
85
86#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
87pub enum InputIdent {
91 Lhs,
92 Rhs,
93}
94
95impl InputIdent {
96 pub fn as_ident(&self) -> Ident {
97 match self {
98 InputIdent::Lhs => Ident::Lhs,
99 InputIdent::Rhs => Ident::Rhs,
100 }
101 }
102}
103
104impl From<InputIdent> for Ident {
105 fn from(value: InputIdent) -> Self {
106 value.as_ident()
107 }
108}
109
110#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug)]
111pub enum MatrixLayout {
114 RowMajor,
115 ColMajor,
116}
117
118#[cube]
119pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
121 match layout {
122 MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
123 MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
124 }
125}
126
127#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
128pub struct CompleteStageTiling {
130 pub tile_shape: MatmulSize,
131 pub tile_count: MatmulSize,
132}
133
134impl CompleteStageTiling {
135 pub fn get(&self, ident: Ident) -> TilingDimensions {
136 match ident {
137 Ident::Lhs => TilingDimensions {
138 tile_shape_row: self.tile_shape.m,
139 tile_shape_col: self.tile_shape.k,
140 tile_count_row: self.tile_count.m,
141 tile_count_col: self.tile_count.k,
142 },
143 Ident::Rhs => TilingDimensions {
144 tile_shape_row: self.tile_shape.k,
145 tile_shape_col: self.tile_shape.n,
146 tile_count_row: self.tile_count.k,
147 tile_count_col: self.tile_count.n,
148 },
149 Ident::Out => TilingDimensions {
150 tile_shape_row: self.tile_shape.m,
151 tile_shape_col: self.tile_shape.n,
152 tile_count_row: self.tile_count.m,
153 tile_count_col: self.tile_count.n,
154 },
155 }
156 }
157
158 pub fn total_shape(&self) -> MatmulSize {
159 MatmulSize {
160 m: self.tile_shape.m * self.tile_count.m,
161 n: self.tile_shape.n * self.tile_count.n,
162 k: self.tile_shape.k * self.tile_count.k,
163 }
164 }
165}
166
167#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
168pub struct TilingDimensions {
170 pub tile_shape_row: u32,
171 pub tile_shape_col: u32,
172 pub tile_count_row: u32,
173 pub tile_count_col: u32,
174}
175
176impl TilingDimensions {
177 pub fn total_size(&self) -> u32 {
179 self.total_row() * self.total_col()
180 }
181
182 pub fn total_row(&self) -> u32 {
184 self.tile_count_row() * self.tile_shape_row()
185 }
186
187 pub fn total_col(&self) -> u32 {
189 self.tile_count_col() * self.tile_shape_col()
190 }
191
192 pub fn tile_size(&self) -> u32 {
194 self.tile_shape_row() * self.tile_shape_col()
195 }
196
197 pub fn tile_shape_row(&self) -> u32 {
199 self.tile_shape_row
200 }
201
202 pub fn tile_shape_col(&self) -> u32 {
204 self.tile_shape_col
205 }
206
207 pub fn tile_count(&self) -> u32 {
209 self.tile_count_row() * self.tile_count_col()
210 }
211
212 pub fn tile_count_row(&self) -> u32 {
214 self.tile_count_row
215 }
216
217 pub fn tile_count_col(&self) -> u32 {
219 self.tile_count_col
220 }
221}
222
223pub trait TensorIdent:
224 Clone + Copy + Debug + Hash + PartialEq + Eq + Send + Sync + 'static
225{
226 const IDENT: Ident;
227}
228
229#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
230pub struct Lhs;
231#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
232pub struct Rhs;
233#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
234pub struct Out;
235
236impl TensorIdent for Lhs {
237 const IDENT: Ident = Ident::Lhs;
238}
239
240impl TensorIdent for Rhs {
241 const IDENT: Ident = Ident::Rhs;
242}
243
244impl TensorIdent for Out {
245 const IDENT: Ident = Ident::Out;
246}