cubecl_matmul/components/
spec.rs

1use cubecl_core::{ir::StorageType, prelude::*};
2use half::{bf16, f16};
3
4use crate::components::MatmulIdent;
5
6use super::global::args::MatmulArgs;
7
8/// Matrix multiplication precisions.
9pub trait MatmulPrecision: Send + Sync + Copy + 'static {
10    /// Element type of lhs input tensor of the kernel.
11    type Lhs: MatrixPrecision;
12    /// Element type of rhs input tensor of the kernel.
13    type Rhs: MatrixPrecision;
14    /// Element type of acc input tensor of the kernel.
15    type Acc: MatrixPrecision;
16}
17
18pub trait MatrixPrecision: Send + Sync + Copy + 'static {
19    /// Element type of input tensor in global memory
20    type Global: Numeric;
21    /// Element type once stored in shared memory
22    type Stage: Numeric;
23    /// Element type once in registers for computation
24    type Register: Numeric;
25}
26
27impl<EG: Numeric, ES: Numeric> MatrixPrecision for (EG, ES) {
28    type Global = EG;
29    type Stage = ES;
30    type Register = ES;
31}
32
33impl MatmulPrecision for f16 {
34    type Lhs = (f16, f16);
35    type Rhs = (f16, f16);
36    #[cfg(target_os = "macos")]
37    type Acc = (f16, f16);
38    #[cfg(not(target_os = "macos"))]
39    type Acc = (f16, f32);
40}
41
42impl MatmulPrecision for flex32 {
43    type Lhs = (f32, f16);
44    type Rhs = (f32, f16);
45    type Acc = (f32, f32);
46}
47
48impl MatmulPrecision for bf16 {
49    type Lhs = (bf16, bf16);
50    type Rhs = (bf16, bf16);
51    #[cfg(target_os = "macos")]
52    type Acc = (bf16, bf16);
53    #[cfg(not(target_os = "macos"))]
54    type Acc = (bf16, f32);
55}
56
57impl MatmulPrecision for f32 {
58    type Lhs = (f32, f32);
59    type Rhs = (f32, f32);
60    type Acc = (f32, f32);
61}
62
63impl MatmulPrecision for f64 {
64    type Lhs = (f64, f32);
65    type Rhs = (f64, f32);
66    type Acc = (f64, f32);
67}
68
69impl MatmulPrecision for u8 {
70    type Lhs = (u8, u8);
71    type Rhs = (u8, u8);
72    type Acc = (i32, i32);
73}
74
75impl MatmulPrecision for u16 {
76    type Lhs = (u16, u16);
77    type Rhs = (u16, u16);
78    type Acc = (i32, i32);
79}
80
81impl MatmulPrecision for u32 {
82    type Lhs = (u32, u32);
83    type Rhs = (u32, u32);
84    type Acc = (u32, u32);
85}
86
87impl MatmulPrecision for u64 {
88    type Lhs = (u64, u64);
89    type Rhs = (u64, u64);
90    type Acc = (u64, u64);
91}
92
93impl MatmulPrecision for i8 {
94    type Lhs = (i8, i8);
95    type Rhs = (i8, i8);
96    type Acc = (i32, i32);
97}
98
99impl MatmulPrecision for i16 {
100    type Lhs = (i16, i16);
101    type Rhs = (i16, i16);
102    type Acc = (i32, i32);
103}
104
105impl MatmulPrecision for i32 {
106    type Lhs = (i32, i32);
107    type Rhs = (i32, i32);
108    type Acc = (i32, i32);
109}
110
111impl MatmulPrecision for i64 {
112    type Lhs = (i64, i64);
113    type Rhs = (i64, i64);
114    type Acc = (i64, i64);
115}
116
117impl<LhsG: Numeric, RhsG: Numeric, AccG: Numeric, LhsS: Numeric, RhsS: Numeric, AccS: Numeric>
118    MatmulPrecision for (LhsG, RhsG, AccG, LhsS, RhsS, AccS)
119{
120    type Lhs = (LhsG, LhsS);
121    type Rhs = (RhsG, RhsS);
122    type Acc = (AccG, AccS);
123}
124
125pub type LhsG<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Global;
126pub type LhsS<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Stage;
127pub type LhsR<MP> = <<MP as MatmulPrecision>::Lhs as MatrixPrecision>::Register;
128
129pub type RhsG<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Global;
130pub type RhsS<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Stage;
131pub type RhsR<MP> = <<MP as MatmulPrecision>::Rhs as MatrixPrecision>::Register;
132
133pub type AccG<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Global;
134pub type AccS<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Stage;
135pub type AccR<MP> = <<MP as MatmulPrecision>::Acc as MatrixPrecision>::Register;
136
137/// Input argument
138pub type InputArg<MA> =
139    <MA as MatmulArgs>::Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>;
140
141/// Output argument
142pub type OutputArg<MA> = <MA as MatmulArgs>::Output<NumericExpand<2>>;
143
144/// Input runtime argument
145pub type InputRuntimeArg<'a, MA, R> = <InputArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
146
147/// Output runtime argument
148pub type OutputRuntimeArg<'a, MA, R> = <OutputArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
149
150#[derive(Clone, Debug)]
151pub struct MatmulElems {
152    pub lhs_global: StorageType,
153    pub rhs_global: StorageType,
154    pub acc_global: StorageType,
155    pub lhs_stage: StorageType,
156    pub rhs_stage: StorageType,
157    pub acc_stage: StorageType,
158    pub lhs_register: StorageType,
159    pub rhs_register: StorageType,
160    pub acc_register: StorageType,
161}
162
163impl MatmulElems {
164    pub fn new<MP: MatmulPrecision>() -> Self {
165        Self {
166            lhs_global: <MP::Lhs as MatrixPrecision>::Global::as_type_native_unchecked(),
167            rhs_global: <MP::Rhs as MatrixPrecision>::Global::as_type_native_unchecked(),
168            acc_global: <MP::Acc as MatrixPrecision>::Global::as_type_native_unchecked(),
169            lhs_stage: <MP::Lhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
170            rhs_stage: <MP::Rhs as MatrixPrecision>::Stage::as_type_native_unchecked(),
171            acc_stage: <MP::Acc as MatrixPrecision>::Stage::as_type_native_unchecked(),
172            lhs_register: <MP::Lhs as MatrixPrecision>::Register::as_type_native_unchecked(),
173            rhs_register: <MP::Rhs as MatrixPrecision>::Register::as_type_native_unchecked(),
174            acc_register: <MP::Acc as MatrixPrecision>::Register::as_type_native_unchecked(),
175        }
176    }
177
178    pub fn from_globals(lhs: StorageType, rhs: StorageType, out: StorageType) -> Self {
179        let acc_type = |dtype: StorageType| {
180            if dtype == half::f16::as_type_native_unchecked()
181                || dtype == half::bf16::as_type_native_unchecked()
182            {
183                return f32::as_type_native_unchecked();
184            }
185
186            dtype
187        };
188
189        Self {
190            lhs_global: lhs,
191            rhs_global: rhs,
192            acc_global: out,
193            lhs_stage: lhs,
194            rhs_stage: rhs,
195            acc_stage: acc_type(out),
196            lhs_register: lhs,
197            rhs_register: rhs,
198            acc_register: acc_type(out),
199        }
200    }
201
202    pub fn global(&self, ident: MatmulIdent) -> StorageType {
203        match ident {
204            MatmulIdent::Lhs => self.lhs_global,
205            MatmulIdent::Rhs => self.rhs_global,
206            MatmulIdent::Out => self.acc_global,
207        }
208    }
209
210    pub fn stage(&self, ident: MatmulIdent) -> StorageType {
211        match ident {
212            MatmulIdent::Lhs => self.lhs_stage,
213            MatmulIdent::Rhs => self.rhs_stage,
214            MatmulIdent::Out => self.acc_stage,
215        }
216    }
217
218    pub fn register(&self, ident: MatmulIdent) -> StorageType {
219        match ident {
220            MatmulIdent::Lhs => self.lhs_register,
221            MatmulIdent::Rhs => self.rhs_register,
222            MatmulIdent::Out => self.acc_register,
223        }
224    }
225}