cubecl_linalg/matmul/components/
spec.rs1use core::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_std::SymQ8;
5use half::{bf16, f16};
6
7use super::global::args::{MatmulArgs, TensorArgs};
8
9pub trait MatmulSpec: Send + Sync + Clone + 'static {
12 type Precision: MatmulPrecision;
13 type Args: MatmulArgs;
15}
16
17impl<MP: MatmulPrecision, Args: MatmulArgs> MatmulSpec for (MP, Args) {
18 type Precision = MP;
19 type Args = Args;
20}
21
22impl<MP: MatmulPrecision> MatmulSpec for MP {
24 type Precision = MP;
25 type Args = TensorArgs;
26}
27
28pub trait MatmulPrecision: Send + Sync + Copy + 'static {
30 const QUANTIZED: bool;
31
32 type EI: Numeric;
34 type ES: Numeric;
36 type EA: Numeric;
39 type EO: Numeric;
41}
42
43impl MatmulPrecision for f16 {
44 const QUANTIZED: bool = false;
45 type EI = f16;
46 type ES = f16;
47 #[cfg(target_os = "macos")]
48 type EA = f16;
49 #[cfg(not(target_os = "macos"))]
50 type EA = f32;
51 type EO = f16;
52}
53
54impl MatmulPrecision for flex32 {
55 const QUANTIZED: bool = false;
56 type EI = f32;
57 type ES = f16;
58 type EA = f32;
59 type EO = f32;
60}
61
62impl MatmulPrecision for bf16 {
63 const QUANTIZED: bool = false;
64 type EI = bf16;
65 type ES = bf16;
66 #[cfg(target_os = "macos")]
67 type EA = bf16;
68 #[cfg(not(target_os = "macos"))]
69 type EA = f32;
70 type EO = bf16;
71}
72
73impl MatmulPrecision for f32 {
74 const QUANTIZED: bool = false;
75 type EI = f32;
76 type ES = f32;
77 type EA = f32;
78 type EO = f32;
79}
80
81impl MatmulPrecision for f64 {
82 const QUANTIZED: bool = false;
83 type EI = f64;
84 type ES = f32;
85 type EA = f32;
86 type EO = f64;
87}
88
89#[derive(Clone, Copy)]
90pub struct ReplaceES<MP: MatmulPrecision, ES: Numeric> {
91 _phantom: PhantomData<(ES, MP)>,
92}
93
94impl<MP: MatmulPrecision, ES: Numeric> MatmulPrecision for ReplaceES<MP, ES> {
95 const QUANTIZED: bool = MP::QUANTIZED;
96 type EI = MP::EI;
97 type ES = ES;
98 type EA = MP::EA;
99 type EO = MP::EO;
100}
101
102impl<EI: Numeric, ES: Numeric, EA: Numeric, EO: Numeric> MatmulPrecision for (EI, ES, EA, EO) {
103 const QUANTIZED: bool = false;
104 type EI = EI;
105 type ES = ES;
106 type EA = EA;
107 type EO = EO;
108}
109
110#[derive(Clone, Copy)]
111pub struct Quantized;
112
113impl<EI: Numeric, ES: Numeric, EA: Numeric, EO: Numeric> MatmulPrecision
114 for (EI, ES, EA, EO, Quantized)
115{
116 const QUANTIZED: bool = true;
117 type EI = EI;
118 type ES = ES;
119 type EA = EA;
120 type EO = EO;
121}
122
123impl MatmulPrecision for SymQ8 {
124 const QUANTIZED: bool = true;
125 type EI = i8;
126 type ES = f16;
127 type EA = f16;
128 type EO = f16;
129}
130
131pub type InputArg<MS> = <Args<MS> as MatmulArgs>::Input<EI<MS>>;
133
134pub type OutputArg<MS> = <Args<MS> as MatmulArgs>::Output<EO<MS>>;
136
137pub type InputRuntimeArg<'a, MS, R> = <InputArg<MS> as LaunchArg>::RuntimeArg<'a, R>;
139
140pub type OutputRuntimeArg<'a, MS, R> = <OutputArg<MS> as LaunchArg>::RuntimeArg<'a, R>;
142
143pub type EI<MS> = <<MS as MatmulSpec>::Precision as MatmulPrecision>::EI;
144pub type ES<MS> = <<MS as MatmulSpec>::Precision as MatmulPrecision>::ES;
145pub type EA<MS> = <<MS as MatmulSpec>::Precision as MatmulPrecision>::EA;
146pub type EO<MS> = <<MS as MatmulSpec>::Precision as MatmulPrecision>::EO;
147
148pub type Args<MS> = <MS as MatmulSpec>::Args;