cubecl_cpp/shared/
mma.rs

1use cubecl_runtime::{DeviceProperties, MmaConfig, ScaledMmaConfig};
2use std::fmt::{Display, Formatter};
3use std::{fmt::Debug, marker::PhantomData};
4
5use super::{Component, Dialect, Elem, FmtLeft, Variable};
6
7pub type SupportedMmaCombinations = Vec<MmaConfig>;
8pub type SupportedScaledMmaCombinations = Vec<ScaledMmaConfig>;
9
10pub trait Architecture {
11    fn warp_size(&self) -> u32;
12    fn is_wmma_capable(&self) -> bool;
13    fn is_mfma_capable(&self) -> bool;
14    fn get_version(&self) -> u32 {
15        0
16    }
17}
18
19pub fn register_wmma_features(
20    supported_combinations: SupportedMmaCombinations,
21    properties: &mut DeviceProperties,
22) {
23    for config in supported_combinations {
24        properties.features.cmma.insert(config);
25    }
26}
27
28pub fn register_mma_features(
29    supported_combinations: SupportedMmaCombinations,
30    properties: &mut DeviceProperties,
31) {
32    for config in supported_combinations {
33        properties.features.mma.insert(config);
34    }
35}
36
37pub fn register_scaled_mma_features(
38    supported_combinations: SupportedScaledMmaCombinations,
39    properties: &mut DeviceProperties,
40) {
41    for config in supported_combinations {
42        properties.features.scaled_mma.insert(config);
43    }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Copy)]
47pub enum FragmentIdent<D: Dialect> {
48    A,
49    B,
50    Accumulator,
51    _Dialect(PhantomData<D>),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Copy)]
55pub enum FragmentLayout<D: Dialect> {
56    ColMajor,
57    RowMajor,
58    _Dialect(PhantomData<D>),
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Copy)]
62pub struct Fragment<D: Dialect> {
63    pub ident: FragmentIdent<D>,
64    pub m: u32,
65    pub n: u32,
66    pub k: u32,
67    pub elem: Elem<D>,
68    pub layout: Option<FragmentLayout<D>>,
69}
70
71#[derive(new, Debug, Clone, PartialEq, Eq, Copy)]
72pub struct MmaShape<D: Dialect> {
73    pub m: u32,
74    pub n: u32,
75    pub k: u32,
76    _d: PhantomData<D>,
77}
78
79impl<D: Dialect> MmaShape<D> {
80    pub fn num_elems(&self, ident: FragmentIdent<D>) -> u32 {
81        match ident {
82            FragmentIdent::A => self.m * self.k,
83            FragmentIdent::B => self.k * self.n,
84            FragmentIdent::Accumulator => self.m * self.n,
85            _ => unimplemented!(),
86        }
87    }
88}
89
90/// Warp Matrix-Multiply and Accumulate Instruction.
91#[derive(Debug, Clone, PartialEq)]
92pub enum WmmaInstruction<D: Dialect> {
93    /// Fill the fragment with the value.
94    Fill {
95        frag: Variable<D>,
96        value: Variable<D>,
97    },
98    /// Load the value into the fragment given the stride.
99    Load {
100        frag: Variable<D>,
101        value: Variable<D>,
102        offset: Variable<D>,
103        stride: Variable<D>,
104        layout: Option<FragmentLayout<D>>,
105    },
106    /// Executes D=A*B+C;
107    ///
108    /// For implementing a matmul, `D=C` : `C+=A*B`
109    Execute {
110        frag_a: Variable<D>,
111        frag_b: Variable<D>,
112        frag_c: Variable<D>,
113        frag_d: Variable<D>,
114        warp_size: u32,
115    },
116    /// Executes D=A*B+C using manually managed registers;
117    ///
118    /// For implementing a matmul, `D=C` : `C+=A*B`
119    /// Takes a sequence of registers for the inputs, and returns an array of registers for the
120    /// output. PTX requires output registers to be non-overlapping, so we use array to ensure that
121    /// and handle potentially destructuring it internally.
122    ExecuteManual {
123        shape: MmaShape<D>,
124        frag_a: Vec<Variable<D>>,
125        frag_b: Vec<Variable<D>>,
126        frag_c: Vec<Variable<D>>,
127        frag_d: Variable<D>,
128    },
129    /// Executes D=A*B+C using manually managed registers;
130    ///
131    /// For implementing a matmul, `D=C` : `C+=A*B`
132    /// Takes a sequence of registers for the inputs, and returns an array of registers for the
133    /// output. PTX requires output registers to be non-overlapping, so we use array to ensure that
134    /// and handle potentially destructuring it internally.
135    ExecuteScaled {
136        shape: MmaShape<D>,
137        frag_a: Vec<Variable<D>>,
138        frag_b: Vec<Variable<D>>,
139        frag_c: Vec<Variable<D>>,
140        frag_d: Variable<D>,
141
142        scales_a: Variable<D>,
143        scales_b: Variable<D>,
144        scales_factor: u32,
145    },
146    /// Store the fragment in an output variable following the stride and the layout.
147    Store {
148        output: Variable<D>,
149        frag: Variable<D>,
150        stride: Variable<D>,
151        offset: Variable<D>,
152        layout: FragmentLayout<D>,
153    },
154    /// Cast
155    Cast {
156        input: Variable<D>,
157        output: Variable<D>,
158    },
159}
160
161impl<D: Dialect> Display for FragmentLayout<D> {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        D::compile_wmma_fragment_layout(f, self)
164    }
165}
166
167impl<D: Dialect> Display for FragmentIdent<D> {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        D::compile_wwma_fragment_ident(f, self)
170    }
171}
172
173impl<D: Dialect> Display for Fragment<D> {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        D::compile_wmma_fragment(f, self)
176    }
177}
178
179impl<D: Dialect> Display for WmmaInstruction<D> {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        D::compile_wmma_instruction(f, self)
182    }
183}
184
185pub mod wmma_api_base {
186    use crate::shared::ManualMma;
187
188    use super::*;
189
190    pub fn compile_fragment_declaration<D: Dialect>(
191        f: &mut std::fmt::Formatter<'_>,
192        var: &Variable<D>,
193    ) -> std::fmt::Result {
194        match var {
195            Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
196            _ => panic!("variable must be a fragment"),
197        }
198    }
199
200    pub fn compile_fragment_ident<D: Dialect>(
201        f: &mut std::fmt::Formatter<'_>,
202        namespace: &str,
203        ident: &FragmentIdent<D>,
204    ) -> std::fmt::Result {
205        match ident {
206            FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
207            FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
208            FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
209            FragmentIdent::_Dialect(_) => Ok(()),
210        }
211    }
212
213    pub fn compile_fragment_layout<D: Dialect>(
214        f: &mut std::fmt::Formatter<'_>,
215        namespace: &str,
216        layout: &FragmentLayout<D>,
217    ) -> std::fmt::Result {
218        match layout {
219            FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
220            FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
221            FragmentLayout::_Dialect(_) => Ok(()),
222        }
223    }
224
225    pub fn compile_fragment<D: Dialect>(
226        f: &mut std::fmt::Formatter<'_>,
227        namespace: &str,
228        fragment: &Fragment<D>,
229    ) -> std::fmt::Result {
230        let elem = match fragment.elem {
231            Elem::TF32 => format!("{namespace}::precision::tf32"),
232            Elem::BF16 => {
233                if fragment.ident == FragmentIdent::Accumulator {
234                    format!("{}", Elem::<D>::F16) // Normally not supported except for cast.
235                } else {
236                    format!("{}", fragment.elem)
237                }
238            }
239            elem => format!("{elem}"),
240        };
241        match fragment.layout {
242            Some(layout) => write!(
243                f,
244                "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
245                fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
246            ),
247            None => write!(
248                f,
249                "{namespace}::fragment<{}, {}, {}, {}, {}>",
250                fragment.ident, fragment.m, fragment.n, fragment.k, elem,
251            ),
252        }
253    }
254
255    pub fn compile_instruction<D: Dialect>(
256        f: &mut std::fmt::Formatter<'_>,
257        namespace: &str,
258        instruction: &WmmaInstruction<D>,
259    ) -> std::fmt::Result {
260        match instruction {
261            WmmaInstruction::Fill { frag, value } => {
262                writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
263            }
264            WmmaInstruction::Load {
265                frag,
266                value,
267                stride,
268                offset,
269                layout: None,
270            } => {
271                let item = value.item();
272                if item.vectorization > 1 {
273                    let elem = item.elem;
274                    let qualifier = value.const_qualifier();
275                    writeln!(
276                        f,
277                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
278                    )
279                } else {
280                    writeln!(
281                        f,
282                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
283                    )
284                }
285            }
286            WmmaInstruction::Load {
287                frag,
288                value,
289                offset,
290                stride,
291                layout: Some(layout),
292            } => {
293                let layout = match layout {
294                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
295                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
296                    FragmentLayout::_Dialect(_) => "".to_string(),
297                };
298                let item = value.item();
299                if item.vectorization > 1 {
300                    let elem = item.elem;
301                    writeln!(
302                        f,
303                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
304                    )
305                } else {
306                    writeln!(
307                        f,
308                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
309                    )
310                }
311            }
312            WmmaInstruction::Execute {
313                frag_a,
314                frag_b,
315                frag_c,
316                frag_d,
317                ..
318            } => writeln!(
319                f,
320                "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
321            ),
322            WmmaInstruction::Store {
323                output,
324                frag,
325                stride,
326                offset,
327                layout,
328            } => {
329                let layout = match layout {
330                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
331                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
332                    FragmentLayout::_Dialect(_) => "".to_string(),
333                };
334
335                let item = output.item();
336                let mut reinterpret_cast = item.vectorization > 1;
337                let elem = match item.elem {
338                    Elem::BF16 => {
339                        reinterpret_cast = true;
340                        Elem::F16
341                    }
342                    _ => item.elem,
343                };
344                if reinterpret_cast {
345                    writeln!(
346                        f,
347                        "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
348                    )
349                } else {
350                    writeln!(
351                        f,
352                        "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
353                    )
354                }
355            }
356            WmmaInstruction::Cast { input, output } => {
357                let ty = match output {
358                    Variable::WmmaFragment { frag, .. } => frag.elem,
359                    _ => panic!("Should be a fragment"),
360                };
361                match ty {
362                    Elem::BF16 => {
363                        let elem = Elem::<D>::F16;
364                        write!(
365                            f,
366                            "// cast
367for(int t=0; t<{input}.num_elements; t++) {{
368  {ty} elem = {ty}({input}.x[t]);
369  {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
370}}
371"
372                        )
373                    }
374                    _ => {
375                        write!(
376                            f,
377                            "// cast
378for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
379"
380                        )
381                    }
382                }
383            }
384            WmmaInstruction::ExecuteManual {
385                shape,
386                frag_a,
387                frag_b,
388                frag_c,
389                frag_d,
390            } => D::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d)),
391            WmmaInstruction::ExecuteScaled {
392                shape,
393                frag_a,
394                frag_b,
395                frag_c,
396                frag_d,
397                scales_a,
398                scales_b,
399                scales_factor,
400            } => D::compile_scaled_mma(
401                f,
402                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
403                *scales_a,
404                *scales_b,
405                *scales_factor,
406            ),
407        }
408    }
409}
410
411pub fn frag_as_ptr<D: Dialect>(
412    f: &mut Formatter<'_>,
413    frag: &Variable<D>,
414    offset: &Variable<D>,
415) -> Variable<D> {
416    let item = frag.item();
417    let mut frag_ptr = Variable::tmp_ptr(item);
418    if frag.is_const() {
419        frag_ptr.to_const();
420    }
421    let frag_ptr_out = frag_ptr.fmt_left();
422    writeln!(f, "{frag_ptr_out} = {frag} + {offset};").unwrap();
423
424    if item.vectorization > 1 {
425        let mut item_value = item;
426        item_value.vectorization = 1;
427        frag_ptr.reinterpret_ptr(f, item_value)
428    } else {
429        frag_ptr
430    }
431}
432
433pub fn frag_ident_str<D: Dialect>(frag: &FragmentIdent<D>) -> &str {
434    match frag {
435        FragmentIdent::A => "a",
436        FragmentIdent::B => "b",
437        FragmentIdent::Accumulator => "c",
438        FragmentIdent::_Dialect(_) => "d",
439    }
440}
441
442pub fn frag_layout_str<D: Dialect>(frag: &Option<FragmentLayout<D>>) -> &str {
443    match frag {
444        Some(layout) => match layout {
445            FragmentLayout::ColMajor => "col",
446            FragmentLayout::RowMajor => "row",
447            FragmentLayout::_Dialect(_) => "",
448        },
449        None => "",
450    }
451}
452
453pub fn variable_to_frag<D: Dialect>(frag: &Variable<D>) -> Fragment<D> {
454    match frag {
455        Variable::WmmaFragment { frag, .. } => *frag,
456        _ => panic!(),
457    }
458}