Skip to main content

cubek_matmul/definition/
spec.rs

1use cubecl::{ir::StorageType, prelude::*};
2use half::{bf16, f16};
3
4use crate::definition::MatmulIdent;
5
6/// Matrix multiplication precisions.
7pub trait MatmulPrecision: Send + Sync + Copy + 'static {
8    /// Element type of lhs input tensor of the kernel.
9    type Lhs: MatrixPrecision;
10    /// Element type of rhs input tensor of the kernel.
11    type Rhs: MatrixPrecision;
12    /// Element type of acc input tensor of the kernel.
13    type Acc: MatrixPrecision;
14}
15
16pub trait MatrixPrecision: Send + Sync + Copy + 'static {
17    /// Element type of input tensor in global memory
18    type Global: Numeric;
19    /// Element type once stored in shared memory
20    type Stage: Numeric;
21    /// Element type once in registers for computation
22    type Register: Numeric;
23}
24
25impl<EG: Numeric, ES: Numeric> MatrixPrecision for (EG, ES) {
26    type Global = EG;
27    type Stage = ES;
28    type Register = ES;
29}
30
31impl<EG: Numeric, ES: Numeric, ER: Numeric> MatrixPrecision for (EG, ES, ER) {
32    type Global = EG;
33    type Stage = ES;
34    type Register = ER;
35}
36
37/// Matrix multiplication precisions.
38pub trait MatmulTypes: Send + Sync + Copy + 'static {
39    /// Element type of lhs input tensor of the kernel.
40    type Lhs: MatrixTypes;
41    /// Element type of rhs input tensor of the kernel.
42    type Rhs: MatrixTypes;
43    /// Element type of acc input tensor of the kernel.
44    type Acc: MatrixTypes;
45}
46
47pub trait MatrixTypes: Send + Sync + Copy + 'static {
48    /// Element type of input tensor in global memory
49    type Global: Numeric;
50    type GlobalSize: Size;
51    /// Element type once stored in shared memory
52    type Stage: Numeric;
53    type StageSize: Size;
54    /// Element type once in registers for computation
55    type Register: Numeric;
56    type RegisterSize: Size;
57}
58
59impl<EG: Numeric, SG: Size, ES: Numeric, SS: Size, ER: Numeric, SR: Size> MatrixTypes
60    for (EG, SG, ES, SS, ER, SR)
61{
62    type Global = EG;
63    type GlobalSize = SG;
64    type Stage = ES;
65    type StageSize = SS;
66    type Register = ER;
67    type RegisterSize = SR;
68}
69
70impl MatmulPrecision for f16 {
71    type Lhs = (f16, f16);
72    type Rhs = (f16, f16);
73    #[cfg(target_os = "macos")]
74    type Acc = (f16, f16);
75    #[cfg(not(target_os = "macos"))]
76    type Acc = (f16, f32);
77}
78
79impl MatmulPrecision for flex32 {
80    type Lhs = (f32, f16);
81    type Rhs = (f32, f16);
82    type Acc = (f32, f32);
83}
84
85impl MatmulPrecision for bf16 {
86    type Lhs = (bf16, bf16);
87    type Rhs = (bf16, bf16);
88    #[cfg(target_os = "macos")]
89    type Acc = (bf16, bf16);
90    #[cfg(not(target_os = "macos"))]
91    type Acc = (bf16, f32);
92}
93
94impl MatmulPrecision for f32 {
95    type Lhs = (f32, f32);
96    type Rhs = (f32, f32);
97    type Acc = (f32, f32);
98}
99
100impl MatmulPrecision for f64 {
101    type Lhs = (f64, f32);
102    type Rhs = (f64, f32);
103    type Acc = (f64, f32);
104}
105
106impl MatmulPrecision for u8 {
107    type Lhs = (u8, u8);
108    type Rhs = (u8, u8);
109    type Acc = (i32, i32);
110}
111
112impl MatmulPrecision for u16 {
113    type Lhs = (u16, u16);
114    type Rhs = (u16, u16);
115    type Acc = (i32, i32);
116}
117
118impl MatmulPrecision for u32 {
119    type Lhs = (u32, u32);
120    type Rhs = (u32, u32);
121    type Acc = (u32, u32);
122}
123
124impl MatmulPrecision for u64 {
125    type Lhs = (u64, u64);
126    type Rhs = (u64, u64);
127    type Acc = (u64, u64);
128}
129
130impl MatmulPrecision for i8 {
131    type Lhs = (i8, i8);
132    type Rhs = (i8, i8);
133    type Acc = (i32, i32);
134}
135
136impl MatmulPrecision for i16 {
137    type Lhs = (i16, i16);
138    type Rhs = (i16, i16);
139    type Acc = (i32, i32);
140}
141
142impl MatmulPrecision for i32 {
143    type Lhs = (i32, i32);
144    type Rhs = (i32, i32);
145    type Acc = (i32, i32);
146}
147
148impl MatmulPrecision for i64 {
149    type Lhs = (i64, i64);
150    type Rhs = (i64, i64);
151    type Acc = (i64, i64);
152}
153
154impl<Lhs: MatrixPrecision, Rhs: MatrixPrecision, Acc: MatrixPrecision> MatmulPrecision
155    for (Lhs, Rhs, Acc)
156{
157    type Lhs = Lhs;
158    type Rhs = Rhs;
159    type Acc = Acc;
160}
161
162impl<Lhs: MatrixTypes, Rhs: MatrixTypes, Acc: MatrixTypes> MatmulTypes for (Lhs, Rhs, Acc) {
163    type Lhs = Lhs;
164    type Rhs = Rhs;
165    type Acc = Acc;
166}
167
168pub type Lhs<MT> = <MT as MatmulTypes>::Lhs;
169pub type Rhs<MT> = <MT as MatmulTypes>::Rhs;
170pub type Acc<MT> = <MT as MatmulTypes>::Acc;
171
172pub type Global<MT> = <MT as MatrixTypes>::Global;
173pub type GlobalSize<MT> = <MT as MatrixTypes>::GlobalSize;
174
175pub type Stage<MT> = <MT as MatrixTypes>::Stage;
176pub type StageSize<MT> = <MT as MatrixTypes>::StageSize;
177
178pub type Register<MT> = <MT as MatrixTypes>::Register;
179pub type RegisterSize<MT> = <MT as MatrixTypes>::RegisterSize;
180
181// ==================== LHS ====================
182
183// Vector forms
184pub type LhsG<MT> = Vector<Global<Lhs<MT>>, GlobalSize<Lhs<MT>>>;
185pub type LhsS<MT> = Vector<Stage<Lhs<MT>>, StageSize<Lhs<MT>>>;
186pub type LhsR<MT> = Vector<Register<Lhs<MT>>, RegisterSize<Lhs<MT>>>;
187
188// Element / Size splits
189pub type LhsGE<MT> = <Lhs<MT> as MatrixTypes>::Global;
190pub type LhsGS<MT> = <Lhs<MT> as MatrixTypes>::GlobalSize;
191
192pub type LhsSE<MT> = <Lhs<MT> as MatrixTypes>::Stage;
193pub type LhsSS<MT> = <Lhs<MT> as MatrixTypes>::StageSize;
194
195pub type LhsRE<MT> = <Lhs<MT> as MatrixTypes>::Register;
196pub type LhsRS<MT> = <Lhs<MT> as MatrixTypes>::RegisterSize;
197
198// ==================== RHS ====================
199
200// Vector forms
201pub type RhsG<MT> = Vector<Global<Rhs<MT>>, GlobalSize<Rhs<MT>>>;
202pub type RhsS<MT> = Vector<Stage<Rhs<MT>>, StageSize<Rhs<MT>>>;
203pub type RhsR<MT> = Vector<Register<Rhs<MT>>, RegisterSize<Rhs<MT>>>;
204
205// Element / Size splits
206pub type RhsGE<MT> = <Rhs<MT> as MatrixTypes>::Global;
207pub type RhsGS<MT> = <Rhs<MT> as MatrixTypes>::GlobalSize;
208
209pub type RhsSE<MT> = <Rhs<MT> as MatrixTypes>::Stage;
210pub type RhsSS<MT> = <Rhs<MT> as MatrixTypes>::StageSize;
211
212pub type RhsRE<MT> = <Rhs<MT> as MatrixTypes>::Register;
213pub type RhsRS<MT> = <Rhs<MT> as MatrixTypes>::RegisterSize;
214
215// ==================== ACC ====================
216
217// Vector forms
218pub type AccG<MT> = Vector<Global<Acc<MT>>, GlobalSize<Acc<MT>>>;
219pub type AccS<MT> = Vector<Stage<Acc<MT>>, StageSize<Acc<MT>>>;
220pub type AccR<MT> = Vector<Register<Acc<MT>>, RegisterSize<Acc<MT>>>;
221
222// Element / Size splits
223pub type AccGE<MT> = <Acc<MT> as MatrixTypes>::Global;
224pub type AccGS<MT> = <Acc<MT> as MatrixTypes>::GlobalSize;
225
226pub type AccSE<MT> = <Acc<MT> as MatrixTypes>::Stage;
227pub type AccSS<MT> = <Acc<MT> as MatrixTypes>::StageSize;
228
229pub type AccRE<MT> = <Acc<MT> as MatrixTypes>::Register;
230pub type AccRS<MT> = <Acc<MT> as MatrixTypes>::RegisterSize;
231
232#[derive(Debug, Clone, Eq, PartialEq, Hash)]
233pub struct MatmulElems {
234    pub lhs_global: StorageType,
235    pub rhs_global: StorageType,
236    pub acc_global: StorageType,
237    pub lhs_stage: StorageType,
238    pub rhs_stage: StorageType,
239    pub acc_stage: StorageType,
240    pub lhs_register: StorageType,
241    pub rhs_register: StorageType,
242    pub acc_register: StorageType,
243}
244
245#[derive(Clone, Debug)]
246pub struct MatmulGlobalElems {
247    pub lhs: StorageType,
248    pub rhs: StorageType,
249    pub out: StorageType,
250}
251
252impl MatmulElems {
253    pub fn new_deprecated<MP: MatmulPrecision>() -> Self {
254        Self {
255            lhs_global: <MP::Lhs as MatrixPrecision>::Global::as_type_native_unchecked()
256                .storage_type(),
257            rhs_global: <MP::Rhs as MatrixPrecision>::Global::as_type_native_unchecked()
258                .storage_type(),
259            acc_global: <MP::Acc as MatrixPrecision>::Global::as_type_native_unchecked()
260                .storage_type(),
261            lhs_stage: <MP::Lhs as MatrixPrecision>::Stage::as_type_native_unchecked()
262                .storage_type(),
263            rhs_stage: <MP::Rhs as MatrixPrecision>::Stage::as_type_native_unchecked()
264                .storage_type(),
265            acc_stage: <MP::Acc as MatrixPrecision>::Stage::as_type_native_unchecked()
266                .storage_type(),
267            lhs_register: <MP::Lhs as MatrixPrecision>::Register::as_type_native_unchecked()
268                .storage_type(),
269            rhs_register: <MP::Rhs as MatrixPrecision>::Register::as_type_native_unchecked()
270                .storage_type(),
271            acc_register: <MP::Acc as MatrixPrecision>::Register::as_type_native_unchecked()
272                .storage_type(),
273        }
274    }
275
276    pub fn from_globals(global_elems: &MatmulGlobalElems) -> Self {
277        let acc_type = if global_elems.out == half::f16::as_type_native_unchecked().storage_type()
278            || global_elems.out == half::bf16::as_type_native_unchecked().storage_type()
279        {
280            f32::as_type_native_unchecked().storage_type()
281        } else {
282            global_elems.out
283        };
284
285        Self {
286            lhs_global: global_elems.lhs,
287            rhs_global: global_elems.rhs,
288            acc_global: global_elems.out,
289            lhs_stage: global_elems.lhs,
290            rhs_stage: global_elems.rhs,
291            acc_stage: acc_type,
292            lhs_register: global_elems.lhs,
293            rhs_register: global_elems.rhs,
294            acc_register: acc_type,
295        }
296    }
297
298    pub fn from_single_dtype(dtype: Type) -> Self {
299        let dtype = dtype.storage_type();
300        Self {
301            lhs_global: dtype,
302            rhs_global: dtype,
303            acc_global: dtype,
304            lhs_stage: dtype,
305            rhs_stage: dtype,
306            acc_stage: dtype,
307            lhs_register: dtype,
308            rhs_register: dtype,
309            acc_register: dtype,
310        }
311    }
312
313    pub fn global(&self, ident: MatmulIdent) -> StorageType {
314        match ident {
315            MatmulIdent::Lhs => self.lhs_global,
316            MatmulIdent::Rhs => self.rhs_global,
317            MatmulIdent::Out => self.acc_global,
318        }
319    }
320
321    pub fn stage(&self, ident: MatmulIdent) -> StorageType {
322        match ident {
323            MatmulIdent::Lhs => self.lhs_stage,
324            MatmulIdent::Rhs => self.rhs_stage,
325            MatmulIdent::Out => self.acc_stage,
326        }
327    }
328
329    pub fn register(&self, ident: MatmulIdent) -> StorageType {
330        match ident {
331            MatmulIdent::Lhs => self.lhs_register,
332            MatmulIdent::Rhs => self.rhs_register,
333            MatmulIdent::Out => self.acc_register,
334        }
335    }
336
337    pub fn as_global_elems(&self) -> MatmulGlobalElems {
338        MatmulGlobalElems {
339            lhs: self.lhs_global,
340            rhs: self.rhs_global,
341            out: self.acc_global,
342        }
343    }
344
345    /// Prefer output type for stage because it's the same size at best, but often smaller.
346    /// Having stage == global also enables things like TMA, and an f16 stage for output enables
347    /// using `stmatrix` on the registers after casting.
348    pub fn adjust_stage_dtypes(&mut self) {
349        self.lhs_stage = self.lhs_global;
350        self.rhs_stage = self.rhs_global;
351        self.acc_stage = self.acc_global;
352    }
353}