1use cubecl::{ir::StorageType, prelude::*};
2use half::{bf16, f16};
3
4use crate::definition::MatmulIdent;
5
6pub trait MatmulPrecision: Send + Sync + Copy + 'static {
8 type Lhs: MatrixPrecision;
10 type Rhs: MatrixPrecision;
12 type Acc: MatrixPrecision;
14}
15
16pub trait MatrixPrecision: Send + Sync + Copy + 'static {
17 type Global: Numeric;
19 type Stage: Numeric;
21 type Register: Numeric;
23}
24
25impl<EG: Numeric, ES: Numeric> MatrixPrecision for (EG, ES) {
26 type Global = EG;
27 type Stage = ES;
28 type Register = ES;
29}
30
31impl<EG: Numeric, ES: Numeric, ER: Numeric> MatrixPrecision for (EG, ES, ER) {
32 type Global = EG;
33 type Stage = ES;
34 type Register = ER;
35}
36
37pub trait MatmulTypes: Send + Sync + Copy + 'static {
39 type Lhs: MatrixTypes;
41 type Rhs: MatrixTypes;
43 type Acc: MatrixTypes;
45}
46
47pub trait MatrixTypes: Send + Sync + Copy + 'static {
48 type Global: Numeric;
50 type GlobalSize: Size;
51 type Stage: Numeric;
53 type StageSize: Size;
54 type Register: Numeric;
56 type RegisterSize: Size;
57}
58
59impl<EG: Numeric, SG: Size, ES: Numeric, SS: Size, ER: Numeric, SR: Size> MatrixTypes
60 for (EG, SG, ES, SS, ER, SR)
61{
62 type Global = EG;
63 type GlobalSize = SG;
64 type Stage = ES;
65 type StageSize = SS;
66 type Register = ER;
67 type RegisterSize = SR;
68}
69
70impl MatmulPrecision for f16 {
71 type Lhs = (f16, f16);
72 type Rhs = (f16, f16);
73 #[cfg(target_os = "macos")]
74 type Acc = (f16, f16);
75 #[cfg(not(target_os = "macos"))]
76 type Acc = (f16, f32);
77}
78
79impl MatmulPrecision for flex32 {
80 type Lhs = (f32, f16);
81 type Rhs = (f32, f16);
82 type Acc = (f32, f32);
83}
84
85impl MatmulPrecision for bf16 {
86 type Lhs = (bf16, bf16);
87 type Rhs = (bf16, bf16);
88 #[cfg(target_os = "macos")]
89 type Acc = (bf16, bf16);
90 #[cfg(not(target_os = "macos"))]
91 type Acc = (bf16, f32);
92}
93
94impl MatmulPrecision for f32 {
95 type Lhs = (f32, f32);
96 type Rhs = (f32, f32);
97 type Acc = (f32, f32);
98}
99
100impl MatmulPrecision for f64 {
101 type Lhs = (f64, f32);
102 type Rhs = (f64, f32);
103 type Acc = (f64, f32);
104}
105
106impl MatmulPrecision for u8 {
107 type Lhs = (u8, u8);
108 type Rhs = (u8, u8);
109 type Acc = (i32, i32);
110}
111
112impl MatmulPrecision for u16 {
113 type Lhs = (u16, u16);
114 type Rhs = (u16, u16);
115 type Acc = (i32, i32);
116}
117
118impl MatmulPrecision for u32 {
119 type Lhs = (u32, u32);
120 type Rhs = (u32, u32);
121 type Acc = (u32, u32);
122}
123
124impl MatmulPrecision for u64 {
125 type Lhs = (u64, u64);
126 type Rhs = (u64, u64);
127 type Acc = (u64, u64);
128}
129
130impl MatmulPrecision for i8 {
131 type Lhs = (i8, i8);
132 type Rhs = (i8, i8);
133 type Acc = (i32, i32);
134}
135
136impl MatmulPrecision for i16 {
137 type Lhs = (i16, i16);
138 type Rhs = (i16, i16);
139 type Acc = (i32, i32);
140}
141
142impl MatmulPrecision for i32 {
143 type Lhs = (i32, i32);
144 type Rhs = (i32, i32);
145 type Acc = (i32, i32);
146}
147
148impl MatmulPrecision for i64 {
149 type Lhs = (i64, i64);
150 type Rhs = (i64, i64);
151 type Acc = (i64, i64);
152}
153
154impl<Lhs: MatrixPrecision, Rhs: MatrixPrecision, Acc: MatrixPrecision> MatmulPrecision
155 for (Lhs, Rhs, Acc)
156{
157 type Lhs = Lhs;
158 type Rhs = Rhs;
159 type Acc = Acc;
160}
161
162impl<Lhs: MatrixTypes, Rhs: MatrixTypes, Acc: MatrixTypes> MatmulTypes for (Lhs, Rhs, Acc) {
163 type Lhs = Lhs;
164 type Rhs = Rhs;
165 type Acc = Acc;
166}
167
168pub type Lhs<MT> = <MT as MatmulTypes>::Lhs;
169pub type Rhs<MT> = <MT as MatmulTypes>::Rhs;
170pub type Acc<MT> = <MT as MatmulTypes>::Acc;
171
172pub type Global<MT> = <MT as MatrixTypes>::Global;
173pub type GlobalSize<MT> = <MT as MatrixTypes>::GlobalSize;
174
175pub type Stage<MT> = <MT as MatrixTypes>::Stage;
176pub type StageSize<MT> = <MT as MatrixTypes>::StageSize;
177
178pub type Register<MT> = <MT as MatrixTypes>::Register;
179pub type RegisterSize<MT> = <MT as MatrixTypes>::RegisterSize;
180
181pub type LhsG<MT> = Vector<Global<Lhs<MT>>, GlobalSize<Lhs<MT>>>;
185pub type LhsS<MT> = Vector<Stage<Lhs<MT>>, StageSize<Lhs<MT>>>;
186pub type LhsR<MT> = Vector<Register<Lhs<MT>>, RegisterSize<Lhs<MT>>>;
187
188pub type LhsGE<MT> = <Lhs<MT> as MatrixTypes>::Global;
190pub type LhsGS<MT> = <Lhs<MT> as MatrixTypes>::GlobalSize;
191
192pub type LhsSE<MT> = <Lhs<MT> as MatrixTypes>::Stage;
193pub type LhsSS<MT> = <Lhs<MT> as MatrixTypes>::StageSize;
194
195pub type LhsRE<MT> = <Lhs<MT> as MatrixTypes>::Register;
196pub type LhsRS<MT> = <Lhs<MT> as MatrixTypes>::RegisterSize;
197
198pub type RhsG<MT> = Vector<Global<Rhs<MT>>, GlobalSize<Rhs<MT>>>;
202pub type RhsS<MT> = Vector<Stage<Rhs<MT>>, StageSize<Rhs<MT>>>;
203pub type RhsR<MT> = Vector<Register<Rhs<MT>>, RegisterSize<Rhs<MT>>>;
204
205pub type RhsGE<MT> = <Rhs<MT> as MatrixTypes>::Global;
207pub type RhsGS<MT> = <Rhs<MT> as MatrixTypes>::GlobalSize;
208
209pub type RhsSE<MT> = <Rhs<MT> as MatrixTypes>::Stage;
210pub type RhsSS<MT> = <Rhs<MT> as MatrixTypes>::StageSize;
211
212pub type RhsRE<MT> = <Rhs<MT> as MatrixTypes>::Register;
213pub type RhsRS<MT> = <Rhs<MT> as MatrixTypes>::RegisterSize;
214
215pub type AccG<MT> = Vector<Global<Acc<MT>>, GlobalSize<Acc<MT>>>;
219pub type AccS<MT> = Vector<Stage<Acc<MT>>, StageSize<Acc<MT>>>;
220pub type AccR<MT> = Vector<Register<Acc<MT>>, RegisterSize<Acc<MT>>>;
221
222pub type AccGE<MT> = <Acc<MT> as MatrixTypes>::Global;
224pub type AccGS<MT> = <Acc<MT> as MatrixTypes>::GlobalSize;
225
226pub type AccSE<MT> = <Acc<MT> as MatrixTypes>::Stage;
227pub type AccSS<MT> = <Acc<MT> as MatrixTypes>::StageSize;
228
229pub type AccRE<MT> = <Acc<MT> as MatrixTypes>::Register;
230pub type AccRS<MT> = <Acc<MT> as MatrixTypes>::RegisterSize;
231
232#[derive(Debug, Clone, Eq, PartialEq, Hash)]
233pub struct MatmulElems {
234 pub lhs_global: StorageType,
235 pub rhs_global: StorageType,
236 pub acc_global: StorageType,
237 pub lhs_stage: StorageType,
238 pub rhs_stage: StorageType,
239 pub acc_stage: StorageType,
240 pub lhs_register: StorageType,
241 pub rhs_register: StorageType,
242 pub acc_register: StorageType,
243}
244
245#[derive(Clone, Debug)]
246pub struct MatmulGlobalElems {
247 pub lhs: StorageType,
248 pub rhs: StorageType,
249 pub out: StorageType,
250}
251
252impl MatmulElems {
253 pub fn new_deprecated<MP: MatmulPrecision>() -> Self {
254 Self {
255 lhs_global: <MP::Lhs as MatrixPrecision>::Global::as_type_native_unchecked()
256 .storage_type(),
257 rhs_global: <MP::Rhs as MatrixPrecision>::Global::as_type_native_unchecked()
258 .storage_type(),
259 acc_global: <MP::Acc as MatrixPrecision>::Global::as_type_native_unchecked()
260 .storage_type(),
261 lhs_stage: <MP::Lhs as MatrixPrecision>::Stage::as_type_native_unchecked()
262 .storage_type(),
263 rhs_stage: <MP::Rhs as MatrixPrecision>::Stage::as_type_native_unchecked()
264 .storage_type(),
265 acc_stage: <MP::Acc as MatrixPrecision>::Stage::as_type_native_unchecked()
266 .storage_type(),
267 lhs_register: <MP::Lhs as MatrixPrecision>::Register::as_type_native_unchecked()
268 .storage_type(),
269 rhs_register: <MP::Rhs as MatrixPrecision>::Register::as_type_native_unchecked()
270 .storage_type(),
271 acc_register: <MP::Acc as MatrixPrecision>::Register::as_type_native_unchecked()
272 .storage_type(),
273 }
274 }
275
276 pub fn from_globals(global_elems: &MatmulGlobalElems) -> Self {
277 let acc_type = if global_elems.out == half::f16::as_type_native_unchecked().storage_type()
278 || global_elems.out == half::bf16::as_type_native_unchecked().storage_type()
279 {
280 f32::as_type_native_unchecked().storage_type()
281 } else {
282 global_elems.out
283 };
284
285 Self {
286 lhs_global: global_elems.lhs,
287 rhs_global: global_elems.rhs,
288 acc_global: global_elems.out,
289 lhs_stage: global_elems.lhs,
290 rhs_stage: global_elems.rhs,
291 acc_stage: acc_type,
292 lhs_register: global_elems.lhs,
293 rhs_register: global_elems.rhs,
294 acc_register: acc_type,
295 }
296 }
297
298 pub fn from_single_dtype(dtype: Type) -> Self {
299 let dtype = dtype.storage_type();
300 Self {
301 lhs_global: dtype,
302 rhs_global: dtype,
303 acc_global: dtype,
304 lhs_stage: dtype,
305 rhs_stage: dtype,
306 acc_stage: dtype,
307 lhs_register: dtype,
308 rhs_register: dtype,
309 acc_register: dtype,
310 }
311 }
312
313 pub fn global(&self, ident: MatmulIdent) -> StorageType {
314 match ident {
315 MatmulIdent::Lhs => self.lhs_global,
316 MatmulIdent::Rhs => self.rhs_global,
317 MatmulIdent::Out => self.acc_global,
318 }
319 }
320
321 pub fn stage(&self, ident: MatmulIdent) -> StorageType {
322 match ident {
323 MatmulIdent::Lhs => self.lhs_stage,
324 MatmulIdent::Rhs => self.rhs_stage,
325 MatmulIdent::Out => self.acc_stage,
326 }
327 }
328
329 pub fn register(&self, ident: MatmulIdent) -> StorageType {
330 match ident {
331 MatmulIdent::Lhs => self.lhs_register,
332 MatmulIdent::Rhs => self.rhs_register,
333 MatmulIdent::Out => self.acc_register,
334 }
335 }
336
337 pub fn as_global_elems(&self) -> MatmulGlobalElems {
338 MatmulGlobalElems {
339 lhs: self.lhs_global,
340 rhs: self.rhs_global,
341 out: self.acc_global,
342 }
343 }
344
345 pub fn adjust_stage_dtypes(&mut self) {
349 self.lhs_stage = self.lhs_global;
350 self.rhs_stage = self.rhs_global;
351 self.acc_stage = self.acc_global;
352 }
353}