1use cubecl::{ir::StorageType, prelude::*};
2use half::{bf16, f16};
3
4use crate::definition::MatmulIdent;
5
6pub trait MatmulPrecision: Send + Sync + Copy + 'static {
8 type Lhs: MatrixPrecision;
10 type Rhs: MatrixPrecision;
12 type Acc: MatrixPrecision;
14}
15
16pub trait MatrixPrecision: Send + Sync + Copy + 'static {
17 type Global: Numeric;
19 type Stage: Numeric;
21 type Register: Numeric;
23}
24
25impl<EG: Numeric, ES: Numeric> MatrixPrecision for (EG, ES) {
26 type Global = EG;
27 type Stage = ES;
28 type Register = ES;
29}
30
31impl<EG: Numeric, ES: Numeric, ER: Numeric> MatrixPrecision for (EG, ES, ER) {
32 type Global = EG;
33 type Stage = ES;
34 type Register = ER;
35}
36
37impl MatmulPrecision for f16 {
38 type Lhs = (f16, f16);
39 type Rhs = (f16, f16);
40 #[cfg(target_os = "macos")]
41 type Acc = (f16, f16);
42 #[cfg(not(target_os = "macos"))]
43 type Acc = (f16, f32);
44}
45
46impl MatmulPrecision for flex32 {
47 type Lhs = (f32, f16);
48 type Rhs = (f32, f16);
49 type Acc = (f32, f32);
50}
51
52impl MatmulPrecision for bf16 {
53 type Lhs = (bf16, bf16);
54 type Rhs = (bf16, bf16);
55 #[cfg(target_os = "macos")]
56 type Acc = (bf16, bf16);
57 #[cfg(not(target_os = "macos"))]
58 type Acc = (bf16, f32);
59}
60
61impl MatmulPrecision for f32 {
62 type Lhs = (f32, f32);
63 type Rhs = (f32, f32);
64 type Acc = (f32, f32);
65}
66
67impl MatmulPrecision for f64 {
68 type Lhs = (f64, f32);
69 type Rhs = (f64, f32);
70 type Acc = (f64, f32);
71}
72
73impl MatmulPrecision for u8 {
74 type Lhs = (u8, u8);
75 type Rhs = (u8, u8);
76 type Acc = (i32, i32);
77}
78
79impl MatmulPrecision for u16 {
80 type Lhs = (u16, u16);
81 type Rhs = (u16, u16);
82 type Acc = (i32, i32);
83}
84
85impl MatmulPrecision for u32 {
86 type Lhs = (u32, u32);
87 type Rhs = (u32, u32);
88 type Acc = (u32, u32);
89}
90
91impl MatmulPrecision for u64 {
92 type Lhs = (u64, u64);
93 type Rhs = (u64, u64);
94 type Acc = (u64, u64);
95}
96
97impl MatmulPrecision for i8 {
98 type Lhs = (i8, i8);
99 type Rhs = (i8, i8);
100 type Acc = (i32, i32);
101}
102
103impl MatmulPrecision for i16 {
104 type Lhs = (i16, i16);
105 type Rhs = (i16, i16);
106 type Acc = (i32, i32);
107}
108
109impl MatmulPrecision for i32 {
110 type Lhs = (i32, i32);
111 type Rhs = (i32, i32);
112 type Acc = (i32, i32);
113}
114
115impl MatmulPrecision for i64 {
116 type Lhs = (i64, i64);
117 type Rhs = (i64, i64);
118 type Acc = (i64, i64);
119}
120
121impl<Lhs: MatrixPrecision, Rhs: MatrixPrecision, Acc: MatrixPrecision> MatmulPrecision
122 for (Lhs, Rhs, Acc)
123{
124 type Lhs = Lhs;
125 type Rhs = Rhs;
126 type Acc = Acc;
127}
128
129pub type LhsG<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Global;
130pub type LhsS<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Stage;
131pub type LhsR<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Register;
132
133pub type RhsG<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Global;
134pub type RhsS<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Stage;
135pub type RhsR<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Register;
136
137pub type AccG<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Global;
138pub type AccS<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Stage;
139pub type AccR<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Register;
140
141#[derive(Debug, Clone, Eq, PartialEq, Hash)]
142pub struct MatmulElems {
143 pub lhs_global: StorageType,
144 pub rhs_global: StorageType,
145 pub acc_global: StorageType,
146 pub lhs_stage: StorageType,
147 pub rhs_stage: StorageType,
148 pub acc_stage: StorageType,
149 pub lhs_register: StorageType,
150 pub rhs_register: StorageType,
151 pub acc_register: StorageType,
152}
153
154#[derive(Clone, Debug)]
155pub struct MatmulGlobalElems {
156 pub lhs: StorageType,
157 pub rhs: StorageType,
158 pub out: StorageType,
159}
160
161impl MatmulElems {
162 pub fn new_deprecated<MP: MatmulPrecision>() -> Self {
163 Self {
164 lhs_global: <MP::Lhs as MatrixPrecision>::Global::as_type_native_unchecked(),
165 rhs_global: <MP::Rhs as MatrixPrecision>::Global::as_type_native_unchecked(),
166 acc_global: <MP::Acc as MatrixPrecision>::Global::as_type_native_unchecked(),
167 lhs_stage: <MP::Lhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
168 rhs_stage: <MP::Rhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
169 acc_stage: <MP::Acc as MatrixPrecision>::Stage::as_type_native_unchecked(),
170 lhs_register: <MP::Lhs as MatrixPrecision>::Register::as_type_native_unchecked(),
171 rhs_register: <MP::Rhs as MatrixPrecision>::Register::as_type_native_unchecked(),
172 acc_register: <MP::Acc as MatrixPrecision>::Register::as_type_native_unchecked(),
173 }
174 }
175
176 pub fn from_globals(global_elems: &MatmulGlobalElems) -> Self {
177 let acc_type = if global_elems.out == half::f16::as_type_native_unchecked()
178 || global_elems.out == half::bf16::as_type_native_unchecked()
179 {
180 f32::as_type_native_unchecked()
181 } else {
182 global_elems.out
183 };
184
185 Self {
186 lhs_global: global_elems.lhs,
187 rhs_global: global_elems.rhs,
188 acc_global: global_elems.out,
189 lhs_stage: global_elems.lhs,
190 rhs_stage: global_elems.rhs,
191 acc_stage: acc_type,
192 lhs_register: global_elems.lhs,
193 rhs_register: global_elems.rhs,
194 acc_register: acc_type,
195 }
196 }
197
198 pub fn from_single_dtype(dtype: StorageType) -> Self {
199 Self {
200 lhs_global: dtype,
201 rhs_global: dtype,
202 acc_global: dtype,
203 lhs_stage: dtype,
204 rhs_stage: dtype,
205 acc_stage: dtype,
206 lhs_register: dtype,
207 rhs_register: dtype,
208 acc_register: dtype,
209 }
210 }
211
212 pub fn from_define_arrays(
213 global: [StorageType; 3],
214 stage: [StorageType; 3],
215 register: [StorageType; 3],
216 ) -> Self {
217 Self {
218 lhs_global: global[0],
219 rhs_global: global[1],
220 acc_global: global[2],
221 lhs_stage: stage[0],
222 rhs_stage: stage[1],
223 acc_stage: stage[2],
224 lhs_register: register[0],
225 rhs_register: register[1],
226 acc_register: register[2],
227 }
228 }
229
230 pub fn global(&self, ident: MatmulIdent) -> StorageType {
231 match ident {
232 MatmulIdent::Lhs => self.lhs_global,
233 MatmulIdent::Rhs => self.rhs_global,
234 MatmulIdent::Out => self.acc_global,
235 }
236 }
237
238 pub fn stage(&self, ident: MatmulIdent) -> StorageType {
239 match ident {
240 MatmulIdent::Lhs => self.lhs_stage,
241 MatmulIdent::Rhs => self.rhs_stage,
242 MatmulIdent::Out => self.acc_stage,
243 }
244 }
245
246 pub fn register(&self, ident: MatmulIdent) -> StorageType {
247 match ident {
248 MatmulIdent::Lhs => self.lhs_register,
249 MatmulIdent::Rhs => self.rhs_register,
250 MatmulIdent::Out => self.acc_register,
251 }
252 }
253
254 pub fn as_global_elems(&self) -> MatmulGlobalElems {
255 MatmulGlobalElems {
256 lhs: self.lhs_global,
257 rhs: self.rhs_global,
258 out: self.acc_global,
259 }
260 }
261
262 pub fn adjust_stage_dtypes(&mut self) {
266 self.lhs_stage = self.lhs_global;
267 self.rhs_stage = self.rhs_global;
268 self.acc_stage = self.acc_global;
269 }
270}