cubecl_matmul/components/
spec.rs

1use cubecl_core::{ir::StorageType, prelude::*};
2use half::{bf16, f16};
3
4use super::global::args::{MatmulArgs, TensorArgs};
5
6/// Matrix multiplication spec defining each element types used in the computation as well as
7/// how the arguments are passed to the kernel.
8pub trait MatmulSpec: Send + Sync + Clone + 'static {
9    type Precision: MatmulPrecision;
10    /// How the input and output tensors are passed as arguments.
11    type Args: MatmulArgs;
12}
13
14impl<MP: MatmulPrecision, Args: MatmulArgs> MatmulSpec for (MP, Args) {
15    type Precision = MP;
16    type Args = Args;
17}
18
19// A simple default for TensorArgs
20impl<MP: MatmulPrecision> MatmulSpec for MP {
21    type Precision = MP;
22    type Args = TensorArgs;
23}
24
25/// Matrix multiplication precisions.
26pub trait MatmulPrecision: Send + Sync + Copy + 'static {
27    /// Element type of lhs input tensor of the kernel.
28    type Lhs: MatrixPrecision;
29    /// Element type of rhs input tensor of the kernel.
30    type Rhs: MatrixPrecision;
31    /// Element type of acc input tensor of the kernel.
32    type Acc: MatrixPrecision;
33}
34
35pub trait MatrixPrecision: Send + Sync + Copy + 'static {
36    /// Element type of input tensor in global memory
37    type Global: Numeric;
38    /// Element type once stored in shared memory
39    type Stage: Numeric;
40    /// Element type once in registers for computation
41    type Register: Numeric;
42}
43
44impl<EG: Numeric, ES: Numeric> MatrixPrecision for (EG, ES) {
45    type Global = EG;
46    type Stage = ES;
47    type Register = ES;
48}
49
50impl MatmulPrecision for f16 {
51    type Lhs = (f16, f16);
52    type Rhs = (f16, f16);
53    #[cfg(target_os = "macos")]
54    type Acc = (f16, f16);
55    #[cfg(not(target_os = "macos"))]
56    type Acc = (f16, f32);
57}
58
59impl MatmulPrecision for flex32 {
60    type Lhs = (f32, f16);
61    type Rhs = (f32, f16);
62    type Acc = (f32, f32);
63}
64
65impl MatmulPrecision for bf16 {
66    type Lhs = (bf16, bf16);
67    type Rhs = (bf16, bf16);
68    #[cfg(target_os = "macos")]
69    type Acc = (bf16, bf16);
70    #[cfg(not(target_os = "macos"))]
71    type Acc = (bf16, f32);
72}
73
74impl MatmulPrecision for f32 {
75    type Lhs = (f32, f32);
76    type Rhs = (f32, f32);
77    type Acc = (f32, f32);
78}
79
80impl MatmulPrecision for f64 {
81    type Lhs = (f64, f32);
82    type Rhs = (f64, f32);
83    type Acc = (f64, f32);
84}
85
86impl MatmulPrecision for u8 {
87    type Lhs = (u8, u8);
88    type Rhs = (u8, u8);
89    type Acc = (i32, i32);
90}
91
92impl MatmulPrecision for u16 {
93    type Lhs = (u16, u16);
94    type Rhs = (u16, u16);
95    type Acc = (i32, i32);
96}
97
98impl MatmulPrecision for u32 {
99    type Lhs = (u32, u32);
100    type Rhs = (u32, u32);
101    type Acc = (u32, u32);
102}
103
104impl MatmulPrecision for u64 {
105    type Lhs = (u64, u64);
106    type Rhs = (u64, u64);
107    type Acc = (u64, u64);
108}
109
110impl MatmulPrecision for i8 {
111    type Lhs = (i8, i8);
112    type Rhs = (i8, i8);
113    type Acc = (i32, i32);
114}
115
116impl MatmulPrecision for i16 {
117    type Lhs = (i16, i16);
118    type Rhs = (i16, i16);
119    type Acc = (i32, i32);
120}
121
122impl MatmulPrecision for i32 {
123    type Lhs = (i32, i32);
124    type Rhs = (i32, i32);
125    type Acc = (i32, i32);
126}
127
128impl MatmulPrecision for i64 {
129    type Lhs = (i64, i64);
130    type Rhs = (i64, i64);
131    type Acc = (i64, i64);
132}
133
134impl<LhsG: Numeric, RhsG: Numeric, AccG: Numeric, LhsS: Numeric, RhsS: Numeric, AccS: Numeric>
135    MatmulPrecision for (LhsG, RhsG, AccG, LhsS, RhsS, AccS)
136{
137    type Lhs = (LhsG, LhsS);
138    type Rhs = (RhsG, RhsS);
139    type Acc = (AccG, AccS);
140}
141
142/// Input argument
143pub type InputArg<MS> = <Args<MS> as MatmulArgs>::Input<LhsG<MS>, RhsG<MS>, AccG<MS>>;
144
145/// Output argument
146pub type OutputArg<MS> = <Args<MS> as MatmulArgs>::Output<AccG<MS>>;
147
148/// Input runtime argument
149pub type InputRuntimeArg<'a, MS, R> = <InputArg<MS> as LaunchArg>::RuntimeArg<'a, R>;
150
151/// Output runtime argument
152pub type OutputRuntimeArg<'a, MS, R> = <OutputArg<MS> as LaunchArg>::RuntimeArg<'a, R>;
153
154pub type LhsG<MS> =
155    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Lhs as MatrixPrecision>::Global;
156pub type LhsS<MS> =
157    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Lhs as MatrixPrecision>::Stage;
158pub type LhsR<MS> =
159    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Lhs as MatrixPrecision>::Register;
160pub type RhsG<MS> =
161    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Rhs as MatrixPrecision>::Global;
162pub type RhsS<MS> =
163    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Rhs as MatrixPrecision>::Stage;
164pub type RhsR<MS> =
165    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Rhs as MatrixPrecision>::Register;
166pub type AccG<MS> =
167    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Acc as MatrixPrecision>::Global;
168pub type AccS<MS> =
169    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Acc as MatrixPrecision>::Stage;
170pub type AccR<MS> =
171    <<<MS as MatmulSpec>::Precision as MatmulPrecision>::Acc as MatrixPrecision>::Register;
172
173pub type Args<MS> = <MS as MatmulSpec>::Args;
174
175pub struct MatmulElems {
176    pub lhs_global: StorageType,
177    pub rhs_global: StorageType,
178    pub acc_global: StorageType,
179    pub lhs_stage: StorageType,
180    pub rhs_stage: StorageType,
181    pub acc_stage: StorageType,
182    pub lhs_register: StorageType,
183    pub rhs_register: StorageType,
184    pub acc_register: StorageType,
185}
186
187impl MatmulElems {
188    pub fn new<MP: MatmulPrecision>() -> Self {
189        Self {
190            lhs_global: <MP::Lhs as MatrixPrecision>::Global::as_type_native_unchecked(),
191            rhs_global: <MP::Rhs as MatrixPrecision>::Global::as_type_native_unchecked(),
192            acc_global: <MP::Acc as MatrixPrecision>::Global::as_type_native_unchecked(),
193            lhs_stage: <MP::Lhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
194            rhs_stage: <MP::Rhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
195            acc_stage: <MP::Acc as MatrixPrecision>::Stage::as_type_native_unchecked(),
196            lhs_register: <MP::Lhs as MatrixPrecision>::Register::as_type_native_unchecked(),
197            rhs_register: <MP::Rhs as MatrixPrecision>::Register::as_type_native_unchecked(),
198            acc_register: <MP::Acc as MatrixPrecision>::Register::as_type_native_unchecked(),
199        }
200    }
201}