cubecl_linalg/matmul/components/
spec.rs

1use core::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_std::SymQ8;
5use half::{bf16, f16};
6
7use super::global::args::{MatmulArgs, TensorArgs};
8
9/// Matrix multiplication spec definiting each element types used in the computation as well as
10/// how the arguments are passed to the kernel.
11pub trait MatmulSpec: Send + Sync + Clone + 'static {
12    type Precision: MatmulPrecision;
13    /// How the input and output tensors are passed as arguments.
14    type Args: MatmulArgs;
15}
16
17impl<MP: MatmulPrecision, Args: MatmulArgs> MatmulSpec for (MP, Args) {
18    type Precision = MP;
19    type Args = Args;
20}
21
22// A simple default for TensorArgs
23impl<MP: MatmulPrecision> MatmulSpec for MP {
24    type Precision = MP;
25    type Args = TensorArgs;
26}
27
28/// Matrix multiplication precisions.
29pub trait MatmulPrecision: Send + Sync + Copy + 'static {
30    const QUANTIZED: bool;
31
32    /// Element type of each input tensors of the kernel.
33    type EI: Numeric;
34    /// Element type for the shared memories used to read inputs.
35    type ES: Numeric;
36    /// Element type for the shared memories or fragments used to accumulate
37    /// smaller matmul results before writing to the output tensor.
38    type EA: Numeric;
39    /// Element type of the output tensor of the kernel.
40    type EO: Numeric;
41}
42
43impl MatmulPrecision for f16 {
44    const QUANTIZED: bool = false;
45    type EI = f16;
46    type ES = f16;
47    #[cfg(target_os = "macos")]
48    type EA = f16;
49    #[cfg(not(target_os = "macos"))]
50    type EA = f32;
51    type EO = f16;
52}
53
54impl MatmulPrecision for flex32 {
55    const QUANTIZED: bool = false;
56    type EI = f32;
57    type ES = f16;
58    type EA = f32;
59    type EO = f32;
60}
61
62impl MatmulPrecision for bf16 {
63    const QUANTIZED: bool = false;
64    type EI = bf16;
65    type ES = bf16;
66    #[cfg(target_os = "macos")]
67    type EA = bf16;
68    #[cfg(not(target_os = "macos"))]
69    type EA = f32;
70    type EO = bf16;
71}
72
73impl MatmulPrecision for f32 {
74    const QUANTIZED: bool = false;
75    type EI = f32;
76    type ES = f32;
77    type EA = f32;
78    type EO = f32;
79}
80
81impl MatmulPrecision for f64 {
82    const QUANTIZED: bool = false;
83    type EI = f64;
84    type ES = f32;
85    type EA = f32;
86    type EO = f64;
87}
88
89#[derive(Clone, Copy)]
90pub struct ReplaceES<MP: MatmulPrecision, ES: Numeric> {
91    _phantom: PhantomData<(ES, MP)>,
92}
93
94impl<MP: MatmulPrecision, ES: Numeric> MatmulPrecision for ReplaceES<MP, ES> {
95    const QUANTIZED: bool = MP::QUANTIZED;
96    type EI = MP::EI;
97    type ES = ES;
98    type EA = MP::EA;
99    type EO = MP::EO;
100}
101
102impl<EI: Numeric, ES: Numeric, EA: Numeric, EO: Numeric> MatmulPrecision for (EI, ES, EA, EO) {
103    const QUANTIZED: bool = false;
104    type EI = EI;
105    type ES = ES;
106    type EA = EA;
107    type EO = EO;
108}
109
110#[derive(Clone, Copy)]
111pub struct Quantized;
112
113impl<EI: Numeric, ES: Numeric, EA: Numeric, EO: Numeric> MatmulPrecision
114    for (EI, ES, EA, EO, Quantized)
115{
116    const QUANTIZED: bool = true;
117    type EI = EI;
118    type ES = ES;
119    type EA = EA;
120    type EO = EO;
121}
122
123impl MatmulPrecision for SymQ8 {
124    const QUANTIZED: bool = true;
125    type EI = i8;
126    type ES = f16;
127    type EA = f16;
128    type EO = f16;
129}
130
131/// Input argument
132pub type InputArg<MS> = <Args<MS> as MatmulArgs>::Input<EI<MS>>;
133
134/// Output argument
135pub type OutputArg<MS> = <Args<MS> as MatmulArgs>::Output<EO<MS>>;
136
137/// Input runtime argument
138pub type InputRuntimeArg<'a, MS, R> = <InputArg<MS> as LaunchArg>::RuntimeArg<'a, R>;
139
140/// Output runtime argument
141pub type OutputRuntimeArg<'a, MS, R> = <OutputArg<MS> as LaunchArg>::RuntimeArg<'a, R>;
142
143pub type EI<MS> = <<MS as MatmulSpec>::Precision as MatmulPrecision>::EI;
144pub type ES<MS> = <<MS as MatmulSpec>::Precision as MatmulPrecision>::ES;
145pub type EA<MS> = <<MS as MatmulSpec>::Precision as MatmulPrecision>::EA;
146pub type EO<MS> = <<MS as MatmulSpec>::Precision as MatmulPrecision>::EO;
147
148pub type Args<MS> = <MS as MatmulSpec>::Args;