cubecl_cpp/shared/
mma.rs

1use cubecl_core::Feature;
2use cubecl_core::ir::{self as gpu};
3use cubecl_runtime::DeviceProperties;
4use std::fmt::Display;
5use std::{fmt::Debug, marker::PhantomData};
6
7use super::{Component, Dialect, Elem, Variable};
8
9pub type SupportedWmmaCombinations = Vec<(gpu::Elem, gpu::Elem, gpu::Elem, Vec<(u8, u8, u8)>)>;
10
11pub trait Architecture {
12    fn warp_size(&self) -> u32;
13    fn is_wmma_capable(&self) -> bool;
14    fn is_mfma_capable(&self) -> bool;
15    fn get_version(&self) -> u32 {
16        0
17    }
18}
19
20pub fn register_wmma_features(
21    supported_combinations: SupportedWmmaCombinations,
22    properties: &mut DeviceProperties<Feature>,
23) {
24    for (i, o, c, tdims) in supported_combinations {
25        for (m, n, k) in tdims {
26            properties.register_feature(Feature::Cmma {
27                a: i,
28                b: o,
29                c,
30                m,
31                n,
32                k,
33            });
34        }
35    }
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Copy)]
39pub enum FragmentIdent<D: Dialect> {
40    A,
41    B,
42    Accumulator,
43    _Dialect(PhantomData<D>),
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Copy)]
47pub enum FragmentLayout<D: Dialect> {
48    ColMajor,
49    RowMajor,
50    _Dialect(PhantomData<D>),
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Copy)]
54pub struct Fragment<D: Dialect> {
55    pub ident: FragmentIdent<D>,
56    pub m: u8,
57    pub n: u8,
58    pub k: u8,
59    pub elem: Elem<D>,
60    pub layout: Option<FragmentLayout<D>>,
61}
62
63/// Warp Matrix-Multiply and Accumulate Instruction.
64#[derive(Debug, Clone, Copy)]
65pub enum WmmaInstruction<D: Dialect> {
66    /// Fill the fragment with the value.
67    Fill {
68        frag: Variable<D>,
69        value: Variable<D>,
70    },
71    /// Load the value into the fragment given the stride.
72    Load {
73        frag: Variable<D>,
74        value: Variable<D>,
75        offset: Variable<D>,
76        stride: Variable<D>,
77        layout: Option<FragmentLayout<D>>,
78    },
79    /// Executes D=A*B+C;
80    ///
81    /// For implementing a matmul, `D=C` : `C+=A*B`
82    Execute {
83        frag_a: Variable<D>,
84        frag_b: Variable<D>,
85        frag_c: Variable<D>,
86        frag_d: Variable<D>,
87        warp_size: u32,
88    },
89    /// Store the fragment in an output variable following the stride and the layout.
90    Store {
91        output: Variable<D>,
92        frag: Variable<D>,
93        stride: Variable<D>,
94        offset: Variable<D>,
95        layout: FragmentLayout<D>,
96    },
97    /// Cast
98    Cast {
99        input: Variable<D>,
100        output: Variable<D>,
101    },
102}
103
104impl<D: Dialect> Display for FragmentLayout<D> {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        D::compile_wmma_fragment_layout(f, self)
107    }
108}
109
110impl<D: Dialect> Display for FragmentIdent<D> {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        D::compile_wwma_fragment_ident(f, self)
113    }
114}
115
116impl<D: Dialect> Display for Fragment<D> {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        D::compile_wmma_fragment(f, self)
119    }
120}
121
122impl<D: Dialect> Display for WmmaInstruction<D> {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        D::compile_wmma_instruction(f, self)
125    }
126}
127
128pub mod wmma_api_base {
129    use super::*;
130
131    pub fn compile_fragment_declaration<D: Dialect>(
132        f: &mut std::fmt::Formatter<'_>,
133        var: &Variable<D>,
134    ) -> std::fmt::Result {
135        match var {
136            Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
137            _ => panic!("variable must be a fragment"),
138        }
139    }
140
141    pub fn compile_fragment_ident<D: Dialect>(
142        f: &mut std::fmt::Formatter<'_>,
143        namespace: &str,
144        ident: &FragmentIdent<D>,
145    ) -> std::fmt::Result {
146        match ident {
147            FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
148            FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
149            FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
150            FragmentIdent::_Dialect(_) => Ok(()),
151        }
152    }
153
154    pub fn compile_fragment_layout<D: Dialect>(
155        f: &mut std::fmt::Formatter<'_>,
156        namespace: &str,
157        layout: &FragmentLayout<D>,
158    ) -> std::fmt::Result {
159        match layout {
160            FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
161            FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
162            FragmentLayout::_Dialect(_) => Ok(()),
163        }
164    }
165
166    pub fn compile_fragment<D: Dialect>(
167        f: &mut std::fmt::Formatter<'_>,
168        namespace: &str,
169        fragment: &Fragment<D>,
170    ) -> std::fmt::Result {
171        let elem = match fragment.elem {
172            Elem::TF32 => format!("{namespace}::precision::tf32"),
173            Elem::BF16 => {
174                if fragment.ident == FragmentIdent::Accumulator {
175                    format!("{}", Elem::<D>::F16) // Normally not supported except for cast.
176                } else {
177                    format!("{}", fragment.elem)
178                }
179            }
180            elem => format!("{elem}"),
181        };
182        match fragment.layout {
183            Some(layout) => write!(
184                f,
185                "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
186                fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
187            ),
188            None => write!(
189                f,
190                "{namespace}::fragment<{}, {}, {}, {}, {}>",
191                fragment.ident, fragment.m, fragment.n, fragment.k, elem,
192            ),
193        }
194    }
195
196    pub fn compile_instruction<D: Dialect>(
197        f: &mut std::fmt::Formatter<'_>,
198        namespace: &str,
199        instruction: &WmmaInstruction<D>,
200    ) -> std::fmt::Result {
201        match instruction {
202            WmmaInstruction::Fill { frag, value } => {
203                writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
204            }
205
206            WmmaInstruction::Load {
207                frag,
208                value,
209                stride,
210                offset,
211                layout: None,
212            } => {
213                let item = value.item();
214                if item.vectorization > 1 {
215                    let elem = item.elem;
216                    let qualifier = value.const_qualifier();
217                    writeln!(
218                        f,
219                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
220                    )
221                } else {
222                    writeln!(
223                        f,
224                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
225                    )
226                }
227            }
228
229            WmmaInstruction::Load {
230                frag,
231                value,
232                offset,
233                stride,
234                layout: Some(layout),
235            } => {
236                let layout = match layout {
237                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
238                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
239                    FragmentLayout::_Dialect(_) => "".to_string(),
240                };
241                let item = value.item();
242                if item.vectorization > 1 {
243                    let elem = item.elem;
244                    writeln!(
245                        f,
246                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
247                    )
248                } else {
249                    writeln!(
250                        f,
251                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
252                    )
253                }
254            }
255
256            WmmaInstruction::Execute {
257                frag_a,
258                frag_b,
259                frag_c,
260                frag_d,
261                ..
262            } => writeln!(
263                f,
264                "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
265            ),
266
267            WmmaInstruction::Store {
268                output,
269                frag,
270                stride,
271                offset,
272                layout,
273            } => {
274                let layout = match layout {
275                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
276                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
277                    FragmentLayout::_Dialect(_) => "".to_string(),
278                };
279
280                let item = output.item();
281                let mut reinterpret_cast = item.vectorization > 1;
282                let elem = match item.elem {
283                    Elem::BF16 => {
284                        reinterpret_cast = true;
285                        Elem::F16
286                    }
287                    _ => item.elem,
288                };
289                if reinterpret_cast {
290                    writeln!(
291                        f,
292                        "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
293                    )
294                } else {
295                    writeln!(
296                        f,
297                        "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
298                    )
299                }
300            }
301            WmmaInstruction::Cast { input, output } => {
302                let ty = match output {
303                    Variable::WmmaFragment { frag, .. } => frag.elem,
304                    _ => panic!("Should be a fragment"),
305                };
306                match ty {
307                    Elem::BF16 => {
308                        let elem = Elem::<D>::F16;
309                        write!(
310                            f,
311                            "// cast
312for(int t=0; t<{input}.num_elements; t++) {{
313  {ty} elem = {ty}({input}.x[t]);
314  {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
315}}
316"
317                        )
318                    }
319                    _ => {
320                        write!(
321                            f,
322                            "// cast
323for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
324"
325                        )
326                    }
327                }
328            }
329        }
330    }
331}