cubecl_matmul/components/
spec.rs1use cubecl_core::{ir::StorageType, prelude::*};
2use half::{bf16, f16};
3
4use super::global::args::{MatmulArgs, TensorArgs};
5
6pub trait MatmulSpec: Send + Sync + Clone + 'static {
9 type Precision: MatmulPrecision;
10 type Args: MatmulArgs;
12}
13
14impl<MP: MatmulPrecision, Args: MatmulArgs> MatmulSpec for (MP, Args) {
15 type Precision = MP;
16 type Args = Args;
17}
18
19impl<MP: MatmulPrecision> MatmulSpec for MP {
21 type Precision = MP;
22 type Args = TensorArgs;
23}
24
25pub trait MatmulPrecision: Send + Sync + Copy + 'static {
27 type Lhs: MatrixPrecision;
29 type Rhs: MatrixPrecision;
31 type Acc: MatrixPrecision;
33}
34
35pub trait MatrixPrecision: Send + Sync + Copy + 'static {
36 type Global: Numeric;
38 type Stage: Numeric;
40 type Register: Numeric;
42}
43
44impl<EG: Numeric, ES: Numeric> MatrixPrecision for (EG, ES) {
45 type Global = EG;
46 type Stage = ES;
47 type Register = ES;
48}
49
50impl MatmulPrecision for f16 {
51 type Lhs = (f16, f16);
52 type Rhs = (f16, f16);
53 #[cfg(target_os = "macos")]
54 type Acc = (f16, f16);
55 #[cfg(not(target_os = "macos"))]
56 type Acc = (f16, f32);
57}
58
59impl MatmulPrecision for flex32 {
60 type Lhs = (f32, f16);
61 type Rhs = (f32, f16);
62 type Acc = (f32, f32);
63}
64
65impl MatmulPrecision for bf16 {
66 type Lhs = (bf16, bf16);
67 type Rhs = (bf16, bf16);
68 #[cfg(target_os = "macos")]
69 type Acc = (bf16, bf16);
70 #[cfg(not(target_os = "macos"))]
71 type Acc = (bf16, f32);
72}
73
74impl MatmulPrecision for f32 {
75 type Lhs = (f32, f32);
76 type Rhs = (f32, f32);
77 type Acc = (f32, f32);
78}
79
80impl MatmulPrecision for f64 {
81 type Lhs = (f64, f32);
82 type Rhs = (f64, f32);
83 type Acc = (f64, f32);
84}
85
86impl MatmulPrecision for u8 {
87 type Lhs = (u8, u8);
88 type Rhs = (u8, u8);
89 type Acc = (i32, i32);
90}
91
92impl MatmulPrecision for u16 {
93 type Lhs = (u16, u16);
94 type Rhs = (u16, u16);
95 type Acc = (i32, i32);
96}
97
98impl MatmulPrecision for u32 {
99 type Lhs = (u32, u32);
100 type Rhs = (u32, u32);
101 type Acc = (u32, u32);
102}
103
104impl MatmulPrecision for u64 {
105 type Lhs = (u64, u64);
106 type Rhs = (u64, u64);
107 type Acc = (u64, u64);
108}
109
110impl MatmulPrecision for i8 {
111 type Lhs = (i8, i8);
112 type Rhs = (i8, i8);
113 type Acc = (i32, i32);
114}
115
116impl MatmulPrecision for i16 {
117 type Lhs = (i16, i16);
118 type Rhs = (i16, i16);
119 type Acc = (i32, i32);
120}
121
122impl MatmulPrecision for i32 {
123 type Lhs = (i32, i32);
124 type Rhs = (i32, i32);
125 type Acc = (i32, i32);
126}
127
128impl MatmulPrecision for i64 {
129 type Lhs = (i64, i64);
130 type Rhs = (i64, i64);
131 type Acc = (i64, i64);
132}
133
134impl<LhsG: Numeric, RhsG: Numeric, AccG: Numeric, LhsS: Numeric, RhsS: Numeric, AccS: Numeric>
135 MatmulPrecision for (LhsG, RhsG, AccG, LhsS, RhsS, AccS)
136{
137 type Lhs = (LhsG, LhsS);
138 type Rhs = (RhsG, RhsS);
139 type Acc = (AccG, AccS);
140}
141
142pub type InputArg<MS> = <Args<MS> as MatmulArgs>::Input<LhsG<MS>, RhsG<MS>, AccG<MS>>;
144
145pub type OutputArg<MS> = <Args<MS> as MatmulArgs>::Output<AccG<MS>>;
147
148pub type InputRuntimeArg<'a, MS, R> = <InputArg<MS> as LaunchArg>::RuntimeArg<'a, R>;
150
151pub type OutputRuntimeArg<'a, MS, R> = <OutputArg<MS> as LaunchArg>::RuntimeArg<'a, R>;
153
154pub type LhsG<MS> =
155 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Lhs as MatrixPrecision>::Global;
156pub type LhsS<MS> =
157 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Lhs as MatrixPrecision>::Stage;
158pub type LhsR<MS> =
159 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Lhs as MatrixPrecision>::Register;
160pub type RhsG<MS> =
161 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Rhs as MatrixPrecision>::Global;
162pub type RhsS<MS> =
163 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Rhs as MatrixPrecision>::Stage;
164pub type RhsR<MS> =
165 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Rhs as MatrixPrecision>::Register;
166pub type AccG<MS> =
167 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Acc as MatrixPrecision>::Global;
168pub type AccS<MS> =
169 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Acc as MatrixPrecision>::Stage;
170pub type AccR<MS> =
171 <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Acc as MatrixPrecision>::Register;
172
173pub type Args<MS> = <MS as MatmulSpec>::Args;
174
175pub struct MatmulElems {
176 pub lhs_global: StorageType,
177 pub rhs_global: StorageType,
178 pub acc_global: StorageType,
179 pub lhs_stage: StorageType,
180 pub rhs_stage: StorageType,
181 pub acc_stage: StorageType,
182 pub lhs_register: StorageType,
183 pub rhs_register: StorageType,
184 pub acc_register: StorageType,
185}
186
187impl MatmulElems {
188 pub fn new<MP: MatmulPrecision>() -> Self {
189 Self {
190 lhs_global: <MP::Lhs as MatrixPrecision>::Global::as_type_native_unchecked(),
191 rhs_global: <MP::Rhs as MatrixPrecision>::Global::as_type_native_unchecked(),
192 acc_global: <MP::Acc as MatrixPrecision>::Global::as_type_native_unchecked(),
193 lhs_stage: <MP::Lhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
194 rhs_stage: <MP::Rhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
195 acc_stage: <MP::Acc as MatrixPrecision>::Stage::as_type_native_unchecked(),
196 lhs_register: <MP::Lhs as MatrixPrecision>::Register::as_type_native_unchecked(),
197 rhs_register: <MP::Rhs as MatrixPrecision>::Register::as_type_native_unchecked(),
198 acc_register: <MP::Acc as MatrixPrecision>::Register::as_type_native_unchecked(),
199 }
200 }
201}