cubek_matmul/definition/
spec.rs

1use cubecl::{ir::StorageType, prelude::*};
2use half::{bf16, f16};
3
4use crate::definition::MatmulIdent;
5
6/// Matrix multiplication precisions.
7pub trait MatmulPrecision: Send + Sync + Copy + 'static {
8    /// Element type of lhs input tensor of the kernel.
9    type Lhs: MatrixPrecision;
10    /// Element type of rhs input tensor of the kernel.
11    type Rhs: MatrixPrecision;
12    /// Element type of acc input tensor of the kernel.
13    type Acc: MatrixPrecision;
14}
15
16pub trait MatrixPrecision: Send + Sync + Copy + 'static {
17    /// Element type of input tensor in global memory
18    type Global: Numeric;
19    /// Element type once stored in shared memory
20    type Stage: Numeric;
21    /// Element type once in registers for computation
22    type Register: Numeric;
23}
24
25impl<EG: Numeric, ES: Numeric> MatrixPrecision for (EG, ES) {
26    type Global = EG;
27    type Stage = ES;
28    type Register = ES;
29}
30
31impl<EG: Numeric, ES: Numeric, ER: Numeric> MatrixPrecision for (EG, ES, ER) {
32    type Global = EG;
33    type Stage = ES;
34    type Register = ER;
35}
36
37impl MatmulPrecision for f16 {
38    type Lhs = (f16, f16);
39    type Rhs = (f16, f16);
40    #[cfg(target_os = "macos")]
41    type Acc = (f16, f16);
42    #[cfg(not(target_os = "macos"))]
43    type Acc = (f16, f32);
44}
45
46impl MatmulPrecision for flex32 {
47    type Lhs = (f32, f16);
48    type Rhs = (f32, f16);
49    type Acc = (f32, f32);
50}
51
52impl MatmulPrecision for bf16 {
53    type Lhs = (bf16, bf16);
54    type Rhs = (bf16, bf16);
55    #[cfg(target_os = "macos")]
56    type Acc = (bf16, bf16);
57    #[cfg(not(target_os = "macos"))]
58    type Acc = (bf16, f32);
59}
60
61impl MatmulPrecision for f32 {
62    type Lhs = (f32, f32);
63    type Rhs = (f32, f32);
64    type Acc = (f32, f32);
65}
66
67impl MatmulPrecision for f64 {
68    type Lhs = (f64, f32);
69    type Rhs = (f64, f32);
70    type Acc = (f64, f32);
71}
72
73impl MatmulPrecision for u8 {
74    type Lhs = (u8, u8);
75    type Rhs = (u8, u8);
76    type Acc = (i32, i32);
77}
78
79impl MatmulPrecision for u16 {
80    type Lhs = (u16, u16);
81    type Rhs = (u16, u16);
82    type Acc = (i32, i32);
83}
84
85impl MatmulPrecision for u32 {
86    type Lhs = (u32, u32);
87    type Rhs = (u32, u32);
88    type Acc = (u32, u32);
89}
90
91impl MatmulPrecision for u64 {
92    type Lhs = (u64, u64);
93    type Rhs = (u64, u64);
94    type Acc = (u64, u64);
95}
96
97impl MatmulPrecision for i8 {
98    type Lhs = (i8, i8);
99    type Rhs = (i8, i8);
100    type Acc = (i32, i32);
101}
102
103impl MatmulPrecision for i16 {
104    type Lhs = (i16, i16);
105    type Rhs = (i16, i16);
106    type Acc = (i32, i32);
107}
108
109impl MatmulPrecision for i32 {
110    type Lhs = (i32, i32);
111    type Rhs = (i32, i32);
112    type Acc = (i32, i32);
113}
114
115impl MatmulPrecision for i64 {
116    type Lhs = (i64, i64);
117    type Rhs = (i64, i64);
118    type Acc = (i64, i64);
119}
120
121impl<Lhs: MatrixPrecision, Rhs: MatrixPrecision, Acc: MatrixPrecision> MatmulPrecision
122    for (Lhs, Rhs, Acc)
123{
124    type Lhs = Lhs;
125    type Rhs = Rhs;
126    type Acc = Acc;
127}
128
129pub type LhsG<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Global;
130pub type LhsS<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Stage;
131pub type LhsR<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Register;
132
133pub type RhsG<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Global;
134pub type RhsS<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Stage;
135pub type RhsR<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Register;
136
137pub type AccG<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Global;
138pub type AccS<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Stage;
139pub type AccR<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Register;
140
141#[derive(Debug, Clone, Eq, PartialEq, Hash)]
142pub struct MatmulElems {
143    pub lhs_global: StorageType,
144    pub rhs_global: StorageType,
145    pub acc_global: StorageType,
146    pub lhs_stage: StorageType,
147    pub rhs_stage: StorageType,
148    pub acc_stage: StorageType,
149    pub lhs_register: StorageType,
150    pub rhs_register: StorageType,
151    pub acc_register: StorageType,
152}
153
154#[derive(Clone, Debug)]
155pub struct MatmulGlobalElems {
156    pub lhs: StorageType,
157    pub rhs: StorageType,
158    pub out: StorageType,
159}
160
161impl MatmulElems {
162    pub fn new_deprecated<MP: MatmulPrecision>() -> Self {
163        Self {
164            lhs_global: <MP::Lhs as MatrixPrecision>::Global::as_type_native_unchecked(),
165            rhs_global: <MP::Rhs as MatrixPrecision>::Global::as_type_native_unchecked(),
166            acc_global: <MP::Acc as MatrixPrecision>::Global::as_type_native_unchecked(),
167            lhs_stage: <MP::Lhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
168            rhs_stage: <MP::Rhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
169            acc_stage: <MP::Acc as MatrixPrecision>::Stage::as_type_native_unchecked(),
170            lhs_register: <MP::Lhs as MatrixPrecision>::Register::as_type_native_unchecked(),
171            rhs_register: <MP::Rhs as MatrixPrecision>::Register::as_type_native_unchecked(),
172            acc_register: <MP::Acc as MatrixPrecision>::Register::as_type_native_unchecked(),
173        }
174    }
175
176    pub fn from_globals(global_elems: &MatmulGlobalElems) -> Self {
177        let acc_type = if global_elems.out == half::f16::as_type_native_unchecked()
178            || global_elems.out == half::bf16::as_type_native_unchecked()
179        {
180            f32::as_type_native_unchecked()
181        } else {
182            global_elems.out
183        };
184
185        Self {
186            lhs_global: global_elems.lhs,
187            rhs_global: global_elems.rhs,
188            acc_global: global_elems.out,
189            lhs_stage: global_elems.lhs,
190            rhs_stage: global_elems.rhs,
191            acc_stage: acc_type,
192            lhs_register: global_elems.lhs,
193            rhs_register: global_elems.rhs,
194            acc_register: acc_type,
195        }
196    }
197
198    pub fn from_single_dtype(dtype: StorageType) -> Self {
199        Self {
200            lhs_global: dtype,
201            rhs_global: dtype,
202            acc_global: dtype,
203            lhs_stage: dtype,
204            rhs_stage: dtype,
205            acc_stage: dtype,
206            lhs_register: dtype,
207            rhs_register: dtype,
208            acc_register: dtype,
209        }
210    }
211
212    pub fn from_define_arrays(
213        global: [StorageType; 3],
214        stage: [StorageType; 3],
215        register: [StorageType; 3],
216    ) -> Self {
217        Self {
218            lhs_global: global[0],
219            rhs_global: global[1],
220            acc_global: global[2],
221            lhs_stage: stage[0],
222            rhs_stage: stage[1],
223            acc_stage: stage[2],
224            lhs_register: register[0],
225            rhs_register: register[1],
226            acc_register: register[2],
227        }
228    }
229
230    pub fn global(&self, ident: MatmulIdent) -> StorageType {
231        match ident {
232            MatmulIdent::Lhs => self.lhs_global,
233            MatmulIdent::Rhs => self.rhs_global,
234            MatmulIdent::Out => self.acc_global,
235        }
236    }
237
238    pub fn stage(&self, ident: MatmulIdent) -> StorageType {
239        match ident {
240            MatmulIdent::Lhs => self.lhs_stage,
241            MatmulIdent::Rhs => self.rhs_stage,
242            MatmulIdent::Out => self.acc_stage,
243        }
244    }
245
246    pub fn register(&self, ident: MatmulIdent) -> StorageType {
247        match ident {
248            MatmulIdent::Lhs => self.lhs_register,
249            MatmulIdent::Rhs => self.rhs_register,
250            MatmulIdent::Out => self.acc_register,
251        }
252    }
253
254    pub fn as_global_elems(&self) -> MatmulGlobalElems {
255        MatmulGlobalElems {
256            lhs: self.lhs_global,
257            rhs: self.rhs_global,
258            out: self.acc_global,
259        }
260    }
261
262    /// Prefer output type for stage because it's the same size at best, but often smaller.
263    /// Having stage == global also enables things like TMA, and an f16 stage for output enables
264    /// using `stmatrix` on the registers after casting.
265    pub fn adjust_stage_dtypes(&mut self) {
266        self.lhs_stage = self.lhs_global;
267        self.rhs_stage = self.rhs_global;
268        self.acc_stage = self.acc_global;
269    }
270}