cubecl_cpp/shared/
mma.rs

1use cubecl_core::ir::{self as gpu};
2use cubecl_core::Feature;
3use cubecl_runtime::DeviceProperties;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7use std::{fmt::Debug, marker::PhantomData};
8
9use super::{Component, Dialect, Elem, Variable};
10
11pub type SupportedWmmaCombinations = Vec<(gpu::Elem, gpu::Elem, gpu::Elem, Vec<(u8, u8, u8)>)>;
12
13pub trait Architecture: FromStr<Err = String> {
14    fn warp_size(&self) -> u32;
15    fn is_wmma_capable(&self) -> bool;
16    fn is_mfma_capable(&self) -> bool;
17}
18
19pub trait WmmaCompiler<D: Dialect>:
20    Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
21{
22    type Architecture: Architecture;
23
24    fn wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
25    fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
26    fn local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
27
28    fn compile_fragment_ident(
29        ident: &FragmentIdent<D>,
30        f: &mut std::fmt::Formatter<'_>,
31    ) -> std::fmt::Result;
32
33    fn compile_fragment_layout(
34        layout: &FragmentLayout<D>,
35        f: &mut std::fmt::Formatter<'_>,
36    ) -> std::fmt::Result;
37
38    fn compile_fragment(
39        fragment: &Fragment<D>,
40        f: &mut std::fmt::Formatter<'_>,
41    ) -> std::fmt::Result;
42
43    fn compile_instruction(
44        instruction: &WmmaInstruction<D>,
45        f: &mut std::fmt::Formatter<'_>,
46    ) -> std::fmt::Result;
47
48    fn supported_wmma_combinations(arch: &Self::Architecture) -> SupportedWmmaCombinations;
49}
50
51pub fn register_wmma_features(
52    supported_combinations: SupportedWmmaCombinations,
53    properties: &mut DeviceProperties<Feature>,
54) {
55    for (i, o, c, tdims) in supported_combinations {
56        for (m, n, k) in tdims {
57            properties.register_feature(Feature::Cmma {
58                a: i,
59                b: o,
60                c,
61                m,
62                n,
63                k,
64            });
65        }
66    }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Copy)]
70pub enum FragmentIdent<D: Dialect> {
71    A,
72    B,
73    Accumulator,
74    _Dialect(PhantomData<D>),
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Copy)]
78pub enum FragmentLayout<D: Dialect> {
79    ColMajor,
80    RowMajor,
81    _Dialect(PhantomData<D>),
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, Copy)]
85pub struct Fragment<D: Dialect> {
86    pub ident: FragmentIdent<D>,
87    pub m: u8,
88    pub n: u8,
89    pub k: u8,
90    pub elem: Elem<D>,
91    pub layout: Option<FragmentLayout<D>>,
92}
93
94/// Warp Matrix-Multiply and Accumulate Instruction.
95#[derive(Debug, Clone, Copy)]
96pub enum WmmaInstruction<D: Dialect> {
97    /// Fill the fragment with the value.
98    Fill {
99        frag: Variable<D>,
100        value: Variable<D>,
101    },
102    /// Load the value into the fragment given the stride.
103    Load {
104        frag: Variable<D>,
105        value: Variable<D>,
106        stride: Variable<D>,
107        layout: Option<FragmentLayout<D>>,
108    },
109    /// Executes D=A*B+C;
110    ///
111    /// For implementing a matmul, `D=C` : `C+=A*B`
112    Execute {
113        frag_a: Variable<D>,
114        frag_b: Variable<D>,
115        frag_c: Variable<D>,
116        frag_d: Variable<D>,
117        warp_size: u32,
118    },
119    /// Store the fragment in an output variable following the stride and the layout.
120    Store {
121        output: Variable<D>,
122        frag: Variable<D>,
123        stride: Variable<D>,
124        layout: FragmentLayout<D>,
125    },
126    /// Cast
127    Cast {
128        input: Variable<D>,
129        output: Variable<D>,
130    },
131}
132
133impl<D: Dialect> Display for FragmentLayout<D> {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        D::compile_fragment_layout(self, f)
136    }
137}
138
139impl<D: Dialect> Display for FragmentIdent<D> {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        D::compile_fragment_ident(self, f)
142    }
143}
144
145impl<D: Dialect> Display for Fragment<D> {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        D::compile_fragment(self, f)
148    }
149}
150
151impl<D: Dialect> Display for WmmaInstruction<D> {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        D::compile_instruction(self, f)
154    }
155}
156
157pub mod wmma_api_base {
158    use super::*;
159
160    pub fn compile_fragment_ident<D: Dialect>(
161        namespace: &str,
162        ident: &FragmentIdent<D>,
163        f: &mut std::fmt::Formatter<'_>,
164    ) -> std::fmt::Result {
165        match ident {
166            FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
167            FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
168            FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
169            FragmentIdent::_Dialect(_) => Ok(()),
170        }
171    }
172
173    pub fn compile_fragment_layout<D: Dialect>(
174        namespace: &str,
175        layout: &FragmentLayout<D>,
176        f: &mut std::fmt::Formatter<'_>,
177    ) -> std::fmt::Result {
178        match layout {
179            FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
180            FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
181            FragmentLayout::_Dialect(_) => Ok(()),
182        }
183    }
184
185    pub fn compile_fragment<D: Dialect>(
186        namespace: &str,
187        fragment: &Fragment<D>,
188        f: &mut std::fmt::Formatter<'_>,
189    ) -> std::fmt::Result {
190        let elem = match fragment.elem {
191            Elem::TF32 => format!("{namespace}::precision::tf32"),
192            Elem::BF16 => {
193                if fragment.ident == FragmentIdent::Accumulator {
194                    format!("{}", Elem::<D>::F16) // Normally not supported except for cast.
195                } else {
196                    format!("{}", fragment.elem)
197                }
198            }
199            elem => format!("{elem}"),
200        };
201        match fragment.layout {
202            Some(layout) => write!(
203                f,
204                "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
205                fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
206            ),
207            None => write!(
208                f,
209                "{namespace}::fragment<{}, {}, {}, {}, {}>",
210                fragment.ident, fragment.m, fragment.n, fragment.k, elem,
211            ),
212        }
213    }
214
215    pub fn compile_instruction<D: Dialect>(
216        namespace: &str,
217        instruction: &WmmaInstruction<D>,
218        f: &mut std::fmt::Formatter<'_>,
219    ) -> std::fmt::Result {
220        match instruction {
221            WmmaInstruction::Fill { frag, value } => {
222                writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
223            }
224
225            WmmaInstruction::Load {
226                frag,
227                value,
228                stride,
229                layout: None,
230            } => {
231                let item = value.item();
232                if item.vectorization > 1 {
233                    let elem = item.elem;
234                    writeln!(f, "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value}), {stride});")
235                } else {
236                    writeln!(
237                        f,
238                        "{namespace}::load_matrix_sync({frag}, {value}, {stride});"
239                    )
240                }
241            }
242
243            WmmaInstruction::Load {
244                frag,
245                value,
246                stride,
247                layout: Some(layout),
248            } => {
249                let layout = match layout {
250                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
251                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
252                    FragmentLayout::_Dialect(_) => "".to_string(),
253                };
254                let item = value.item();
255                if item.vectorization > 1 {
256                    let elem = item.elem;
257                    writeln!(f, "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value}), {stride}, {layout});")
258                } else {
259                    writeln!(
260                        f,
261                        "{namespace}::load_matrix_sync({frag}, {value}, {stride}, {layout});"
262                    )
263                }
264            }
265
266            WmmaInstruction::Execute {
267                frag_a,
268                frag_b,
269                frag_c,
270                frag_d,
271                ..
272            } => writeln!(
273                f,
274                "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
275            ),
276
277            WmmaInstruction::Store {
278                output,
279                frag,
280                stride,
281                layout,
282            } => {
283                let layout = match layout {
284                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
285                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
286                    FragmentLayout::_Dialect(_) => "".to_string(),
287                };
288
289                let item = output.item();
290                let mut reinterpret_cast = item.vectorization > 1;
291                let elem = match item.elem {
292                    Elem::BF16 => {
293                        reinterpret_cast = true;
294                        Elem::F16
295                    }
296                    _ => item.elem,
297                };
298                if reinterpret_cast {
299                    writeln!(
300                        f,
301                        "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output}), {frag}, {stride}, {layout});"
302                    )
303                } else {
304                    writeln!(
305                        f,
306                        "{namespace}::store_matrix_sync({output}, {frag}, {stride}, {layout});"
307                    )
308                }
309            }
310            WmmaInstruction::Cast { input, output } => {
311                let ty = match output {
312                    Variable::WmmaFragment { frag, .. } => frag.elem,
313                    _ => panic!("Should be a fragment"),
314                };
315                match ty {
316                    Elem::BF16 => {
317                        let elem = Elem::<D>::F16;
318                        writeln!(
319                            f,
320                            "for(int t=0; t<{input}.num_elements; t++) {{
321                                {ty} elem = {ty}({input}.x[t]);
322                                {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
323                            }}"
324                        )
325                    }
326                    _ => {
327                        writeln!(
328                            f,
329                            "for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}"
330                        )
331                    }
332                }
333            }
334        }
335    }
336}