cubecl_matmul/components/
spec.rs1use cubecl_core::{ir::StorageType, prelude::*};
2use half::{bf16, f16};
3
4use crate::components::MatmulIdent;
5
6use super::global::args::MatmulArgs;
7
8pub trait MatmulPrecision: Send + Sync + Copy + 'static {
10 type Lhs: MatrixPrecision;
12 type Rhs: MatrixPrecision;
14 type Acc: MatrixPrecision;
16}
17
18pub trait MatrixPrecision: Send + Sync + Copy + 'static {
19 type Global: Numeric;
21 type Stage: Numeric;
23 type Register: Numeric;
25}
26
27impl<EG: Numeric, ES: Numeric> MatrixPrecision for (EG, ES) {
28 type Global = EG;
29 type Stage = ES;
30 type Register = ES;
31}
32
33impl MatmulPrecision for f16 {
34 type Lhs = (f16, f16);
35 type Rhs = (f16, f16);
36 #[cfg(target_os = "macos")]
37 type Acc = (f16, f16);
38 #[cfg(not(target_os = "macos"))]
39 type Acc = (f16, f32);
40}
41
42impl MatmulPrecision for flex32 {
43 type Lhs = (f32, f16);
44 type Rhs = (f32, f16);
45 type Acc = (f32, f32);
46}
47
48impl MatmulPrecision for bf16 {
49 type Lhs = (bf16, bf16);
50 type Rhs = (bf16, bf16);
51 #[cfg(target_os = "macos")]
52 type Acc = (bf16, bf16);
53 #[cfg(not(target_os = "macos"))]
54 type Acc = (bf16, f32);
55}
56
57impl MatmulPrecision for f32 {
58 type Lhs = (f32, f32);
59 type Rhs = (f32, f32);
60 type Acc = (f32, f32);
61}
62
63impl MatmulPrecision for f64 {
64 type Lhs = (f64, f32);
65 type Rhs = (f64, f32);
66 type Acc = (f64, f32);
67}
68
69impl MatmulPrecision for u8 {
70 type Lhs = (u8, u8);
71 type Rhs = (u8, u8);
72 type Acc = (i32, i32);
73}
74
75impl MatmulPrecision for u16 {
76 type Lhs = (u16, u16);
77 type Rhs = (u16, u16);
78 type Acc = (i32, i32);
79}
80
81impl MatmulPrecision for u32 {
82 type Lhs = (u32, u32);
83 type Rhs = (u32, u32);
84 type Acc = (u32, u32);
85}
86
87impl MatmulPrecision for u64 {
88 type Lhs = (u64, u64);
89 type Rhs = (u64, u64);
90 type Acc = (u64, u64);
91}
92
93impl MatmulPrecision for i8 {
94 type Lhs = (i8, i8);
95 type Rhs = (i8, i8);
96 type Acc = (i32, i32);
97}
98
99impl MatmulPrecision for i16 {
100 type Lhs = (i16, i16);
101 type Rhs = (i16, i16);
102 type Acc = (i32, i32);
103}
104
105impl MatmulPrecision for i32 {
106 type Lhs = (i32, i32);
107 type Rhs = (i32, i32);
108 type Acc = (i32, i32);
109}
110
111impl MatmulPrecision for i64 {
112 type Lhs = (i64, i64);
113 type Rhs = (i64, i64);
114 type Acc = (i64, i64);
115}
116
117impl<LhsG: Numeric, RhsG: Numeric, AccG: Numeric, LhsS: Numeric, RhsS: Numeric, AccS: Numeric>
118 MatmulPrecision for (LhsG, RhsG, AccG, LhsS, RhsS, AccS)
119{
120 type Lhs = (LhsG, LhsS);
121 type Rhs = (RhsG, RhsS);
122 type Acc = (AccG, AccS);
123}
124
125pub type LhsG<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Global;
126pub type LhsS<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Stage;
127pub type LhsR<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Register;
128
129pub type RhsG<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Global;
130pub type RhsS<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Stage;
131pub type RhsR<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Register;
132
133pub type AccG<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Global;
134pub type AccS<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Stage;
135pub type AccR<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Register;
136
137pub type InputArg<MA> =
139 <MA as MatmulArgs>::Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>;
140
141pub type OutputArg<MA> = <MA as MatmulArgs>::Output<NumericExpand<2>>;
143
144pub type InputRuntimeArg<'a, MA, R> = <InputArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
146
147pub type OutputRuntimeArg<'a, MA, R> = <OutputArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
149
150#[derive(Clone, Debug)]
151pub struct MatmulElems {
152 pub lhs_global: StorageType,
153 pub rhs_global: StorageType,
154 pub acc_global: StorageType,
155 pub lhs_stage: StorageType,
156 pub rhs_stage: StorageType,
157 pub acc_stage: StorageType,
158 pub lhs_register: StorageType,
159 pub rhs_register: StorageType,
160 pub acc_register: StorageType,
161}
162
163impl MatmulElems {
164 pub fn new<MP: MatmulPrecision>() -> Self {
165 Self {
166 lhs_global: <MP::Lhs as MatrixPrecision>::Global::as_type_native_unchecked(),
167 rhs_global: <MP::Rhs as MatrixPrecision>::Global::as_type_native_unchecked(),
168 acc_global: <MP::Acc as MatrixPrecision>::Global::as_type_native_unchecked(),
169 lhs_stage: <MP::Lhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
170 rhs_stage: <MP::Rhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
171 acc_stage: <MP::Acc as MatrixPrecision>::Stage::as_type_native_unchecked(),
172 lhs_register: <MP::Lhs as MatrixPrecision>::Register::as_type_native_unchecked(),
173 rhs_register: <MP::Rhs as MatrixPrecision>::Register::as_type_native_unchecked(),
174 acc_register: <MP::Acc as MatrixPrecision>::Register::as_type_native_unchecked(),
175 }
176 }
177
178 pub fn from_globals(lhs: StorageType, rhs: StorageType, out: StorageType) -> Self {
179 let acc_type = |dtype: StorageType| {
180 if dtype == half::f16::as_type_native_unchecked()
181 || dtype == half::bf16::as_type_native_unchecked()
182 {
183 return f32::as_type_native_unchecked();
184 }
185
186 dtype
187 };
188
189 Self {
190 lhs_global: lhs,
191 rhs_global: rhs,
192 acc_global: out,
193 lhs_stage: lhs,
194 rhs_stage: rhs,
195 acc_stage: acc_type(out),
196 lhs_register: lhs,
197 rhs_register: rhs,
198 acc_register: acc_type(out),
199 }
200 }
201
202 pub fn global(&self, ident: MatmulIdent) -> StorageType {
203 match ident {
204 MatmulIdent::Lhs => self.lhs_global,
205 MatmulIdent::Rhs => self.rhs_global,
206 MatmulIdent::Out => self.acc_global,
207 }
208 }
209
210 pub fn stage(&self, ident: MatmulIdent) -> StorageType {
211 match ident {
212 MatmulIdent::Lhs => self.lhs_stage,
213 MatmulIdent::Rhs => self.rhs_stage,
214 MatmulIdent::Out => self.acc_stage,
215 }
216 }
217
218 pub fn register(&self, ident: MatmulIdent) -> StorageType {
219 match ident {
220 MatmulIdent::Lhs => self.lhs_register,
221 MatmulIdent::Rhs => self.rhs_register,
222 MatmulIdent::Out => self.acc_register,
223 }
224 }
225}