cubecl_cpp/shared/
mma.rs

1use cubecl_runtime::{DeviceProperties, MmaConfig, ScaledMmaConfig};
2use std::fmt::{Display, Formatter};
3use std::{fmt::Debug, marker::PhantomData};
4
5use super::{Component, Dialect, Elem, FmtLeft, Variable};
6
7pub type SupportedMmaCombinations = Vec<MmaConfig>;
8pub type SupportedScaledMmaCombinations = Vec<ScaledMmaConfig>;
9
10pub trait Architecture {
11    fn warp_size(&self) -> u32;
12    fn is_wmma_capable(&self) -> bool;
13    fn is_mfma_capable(&self) -> bool;
14    fn get_version(&self) -> u32 {
15        0
16    }
17}
18
19pub fn register_wmma_features(
20    supported_combinations: SupportedMmaCombinations,
21    properties: &mut DeviceProperties,
22) {
23    for config in supported_combinations {
24        properties.features.cmma.insert(config);
25    }
26}
27
28pub fn register_mma_features(
29    supported_combinations: SupportedMmaCombinations,
30    properties: &mut DeviceProperties,
31) {
32    for config in supported_combinations {
33        properties.features.mma.insert(config);
34    }
35}
36
37pub fn register_scaled_mma_features(
38    supported_combinations: SupportedScaledMmaCombinations,
39    properties: &mut DeviceProperties,
40) {
41    for config in supported_combinations {
42        properties.features.scaled_mma.insert(config);
43    }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Copy)]
47pub enum FragmentIdent<D: Dialect> {
48    A,
49    B,
50    Accumulator,
51    _Dialect(PhantomData<D>),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Copy)]
55pub enum FragmentLayout<D: Dialect> {
56    ColMajor,
57    RowMajor,
58    _Dialect(PhantomData<D>),
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Copy)]
62pub struct Fragment<D: Dialect> {
63    pub ident: FragmentIdent<D>,
64    pub m: u32,
65    pub n: u32,
66    pub k: u32,
67    pub elem: Elem<D>,
68    pub layout: Option<FragmentLayout<D>>,
69}
70
71#[derive(new, Debug, Clone, PartialEq, Eq, Copy)]
72pub struct MmaShape<D: Dialect> {
73    pub m: u32,
74    pub n: u32,
75    pub k: u32,
76    _d: PhantomData<D>,
77}
78
79impl<D: Dialect> MmaShape<D> {
80    pub fn num_elems(&self, ident: FragmentIdent<D>) -> u32 {
81        match ident {
82            FragmentIdent::A => self.m * self.k,
83            FragmentIdent::B => self.k * self.n,
84            FragmentIdent::Accumulator => self.m * self.n,
85            _ => unimplemented!(),
86        }
87    }
88}
89
90/// Warp Matrix-Multiply and Accumulate Instruction.
91#[derive(Debug, Clone, PartialEq)]
92pub enum WmmaInstruction<D: Dialect> {
93    /// Fill the fragment with the value.
94    Fill {
95        frag: Variable<D>,
96        value: Variable<D>,
97    },
98    /// Load the value into the fragment given the stride.
99    Load {
100        frag: Variable<D>,
101        value: Variable<D>,
102        offset: Variable<D>,
103        stride: Variable<D>,
104        layout: Option<FragmentLayout<D>>,
105    },
106    /// Executes D=A*B+C;
107    ///
108    /// For implementing a matmul, `D=C` : `C+=A*B`
109    Execute {
110        frag_a: Variable<D>,
111        frag_b: Variable<D>,
112        frag_c: Variable<D>,
113        frag_d: Variable<D>,
114        warp_size: u32,
115    },
116    /// Executes D=A*B+C using manually managed registers;
117    ///
118    /// For implementing a matmul, `D=C` : `C+=A*B`
119    /// Takes a sequence of registers for the inputs, and returns an array of registers for the
120    /// output. PTX requires output registers to be non-overlapping, so we use array to ensure that
121    /// and handle potentially destructuring it internally.
122    ExecuteManual {
123        shape: MmaShape<D>,
124        frag_a: Variable<D>,
125        frag_b: Variable<D>,
126        frag_c: Variable<D>,
127        frag_d: Variable<D>,
128    },
129    /// Executes D=A*B+C using manually managed registers;
130    ///
131    /// For implementing a matmul, `D=C` : `C+=A*B`
132    /// Takes a sequence of registers for the inputs, and returns an array of registers for the
133    /// output. PTX requires output registers to be non-overlapping, so we use array to ensure that
134    /// and handle potentially destructuring it internally.
135    ExecuteScaled {
136        shape: MmaShape<D>,
137        frag_a: Variable<D>,
138        frag_b: Variable<D>,
139        frag_c: Variable<D>,
140        frag_d: Variable<D>,
141
142        scales_a: Variable<D>,
143        scales_b: Variable<D>,
144        scales_factor: u32,
145    },
146    /// Store the fragment in an output variable following the stride and the layout.
147    Store {
148        output: Variable<D>,
149        frag: Variable<D>,
150        stride: Variable<D>,
151        offset: Variable<D>,
152        layout: FragmentLayout<D>,
153    },
154    /// Load a part of a fragment into registers, either 1, 2, or 4 at once.
155    LdMatrix {
156        output: Variable<D>,
157        buffer: Variable<D>,
158        offset: Variable<D>,
159        line_size: Option<u32>,
160        factor: u32,
161        transpose: bool,
162    },
163    /// Store a part of a fragment into smem, either 1, 2, or 4 at once.
164    StMatrix {
165        registers: Variable<D>,
166        buffer: Variable<D>,
167        offset: Variable<D>,
168        line_size: Option<u32>,
169        factor: u32,
170        transpose: bool,
171    },
172    /// Cast
173    Cast {
174        input: Variable<D>,
175        output: Variable<D>,
176    },
177}
178
179impl<D: Dialect> Display for FragmentLayout<D> {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        D::compile_wmma_fragment_layout(f, self)
182    }
183}
184
185impl<D: Dialect> Display for FragmentIdent<D> {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        D::compile_wwma_fragment_ident(f, self)
188    }
189}
190
191impl<D: Dialect> Display for Fragment<D> {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        D::compile_wmma_fragment(f, self)
194    }
195}
196
197impl<D: Dialect> Display for WmmaInstruction<D> {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        D::compile_wmma_instruction(f, self)
200    }
201}
202
203pub mod wmma_api_base {
204    use crate::{
205        cuda::ptx::{ldmatrix_call, stmatrix_call},
206        shared::ManualMma,
207    };
208
209    use super::*;
210
211    pub fn compile_fragment_declaration<D: Dialect>(
212        f: &mut std::fmt::Formatter<'_>,
213        var: &Variable<D>,
214    ) -> std::fmt::Result {
215        match var {
216            Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
217            _ => panic!("variable must be a fragment"),
218        }
219    }
220
221    pub fn compile_fragment_ident<D: Dialect>(
222        f: &mut std::fmt::Formatter<'_>,
223        namespace: &str,
224        ident: &FragmentIdent<D>,
225    ) -> std::fmt::Result {
226        match ident {
227            FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
228            FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
229            FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
230            FragmentIdent::_Dialect(_) => Ok(()),
231        }
232    }
233
234    pub fn compile_fragment_layout<D: Dialect>(
235        f: &mut std::fmt::Formatter<'_>,
236        namespace: &str,
237        layout: &FragmentLayout<D>,
238    ) -> std::fmt::Result {
239        match layout {
240            FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
241            FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
242            FragmentLayout::_Dialect(_) => Ok(()),
243        }
244    }
245
246    pub fn compile_fragment<D: Dialect>(
247        f: &mut std::fmt::Formatter<'_>,
248        namespace: &str,
249        fragment: &Fragment<D>,
250    ) -> std::fmt::Result {
251        let elem = match fragment.elem {
252            Elem::TF32 => format!("{namespace}::precision::tf32"),
253            Elem::BF16 => {
254                if fragment.ident == FragmentIdent::Accumulator {
255                    format!("{}", Elem::<D>::F16) // Normally not supported except for cast.
256                } else {
257                    format!("{}", fragment.elem)
258                }
259            }
260            elem => format!("{elem}"),
261        };
262        match fragment.layout {
263            Some(layout) => write!(
264                f,
265                "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
266                fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
267            ),
268            None => write!(
269                f,
270                "{namespace}::fragment<{}, {}, {}, {}, {}>",
271                fragment.ident, fragment.m, fragment.n, fragment.k, elem,
272            ),
273        }
274    }
275
276    pub fn compile_instruction<D: Dialect>(
277        f: &mut std::fmt::Formatter<'_>,
278        namespace: &str,
279        instruction: &WmmaInstruction<D>,
280    ) -> std::fmt::Result {
281        match instruction {
282            WmmaInstruction::Fill { frag, value } => {
283                writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
284            }
285            WmmaInstruction::Load {
286                frag,
287                value,
288                stride,
289                offset,
290                layout: None,
291            } => {
292                let item = value.item();
293                if item.vectorization > 1 {
294                    let elem = item.elem;
295                    let qualifier = value.const_qualifier();
296                    writeln!(
297                        f,
298                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
299                    )
300                } else {
301                    writeln!(
302                        f,
303                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
304                    )
305                }
306            }
307            WmmaInstruction::Load {
308                frag,
309                value,
310                offset,
311                stride,
312                layout: Some(layout),
313            } => {
314                let layout = match layout {
315                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
316                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
317                    FragmentLayout::_Dialect(_) => "".to_string(),
318                };
319                let item = value.item();
320                if item.vectorization > 1 {
321                    let elem = item.elem;
322                    writeln!(
323                        f,
324                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
325                    )
326                } else {
327                    writeln!(
328                        f,
329                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
330                    )
331                }
332            }
333            WmmaInstruction::LdMatrix {
334                output,
335                buffer,
336                offset,
337                line_size,
338                factor,
339                transpose,
340            } => f.write_str(&ldmatrix_call(
341                output, buffer, offset, line_size, factor, transpose,
342            )),
343            WmmaInstruction::StMatrix {
344                registers,
345                buffer,
346                offset,
347                line_size,
348                factor,
349                transpose,
350            } => f.write_str(&stmatrix_call(
351                registers, buffer, offset, line_size, factor, transpose,
352            )),
353            WmmaInstruction::Execute {
354                frag_a,
355                frag_b,
356                frag_c,
357                frag_d,
358                ..
359            } => writeln!(
360                f,
361                "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
362            ),
363            WmmaInstruction::Store {
364                output,
365                frag,
366                stride,
367                offset,
368                layout,
369            } => {
370                let layout = match layout {
371                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
372                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
373                    FragmentLayout::_Dialect(_) => "".to_string(),
374                };
375
376                let item = output.item();
377                let mut reinterpret_cast = item.vectorization > 1;
378                let elem = match item.elem {
379                    Elem::BF16 => {
380                        reinterpret_cast = true;
381                        Elem::F16
382                    }
383                    _ => item.elem,
384                };
385                if reinterpret_cast {
386                    writeln!(
387                        f,
388                        "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
389                    )
390                } else {
391                    writeln!(
392                        f,
393                        "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
394                    )
395                }
396            }
397            WmmaInstruction::Cast { input, output } => {
398                let ty = match output {
399                    Variable::WmmaFragment { frag, .. } => frag.elem,
400                    _ => panic!("Should be a fragment"),
401                };
402                match ty {
403                    Elem::BF16 => {
404                        let elem = Elem::<D>::F16;
405                        write!(
406                            f,
407                            "// cast
408for(int t=0; t<{input}.num_elements; t++) {{
409  {ty} elem = {ty}({input}.x[t]);
410  {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
411}}
412"
413                        )
414                    }
415                    _ => {
416                        write!(
417                            f,
418                            "// cast
419for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
420"
421                        )
422                    }
423                }
424            }
425            WmmaInstruction::ExecuteManual {
426                shape,
427                frag_a,
428                frag_b,
429                frag_c,
430                frag_d,
431            } => D::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d)),
432            WmmaInstruction::ExecuteScaled {
433                shape,
434                frag_a,
435                frag_b,
436                frag_c,
437                frag_d,
438                scales_a,
439                scales_b,
440                scales_factor,
441            } => D::compile_scaled_mma(
442                f,
443                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
444                *scales_a,
445                *scales_b,
446                *scales_factor,
447            ),
448        }
449    }
450}
451
452pub fn frag_as_ptr<D: Dialect>(
453    f: &mut Formatter<'_>,
454    frag: &Variable<D>,
455    offset: &Variable<D>,
456) -> Variable<D> {
457    let item = frag.item();
458    let mut frag_ptr = Variable::tmp_ptr(item);
459    if frag.is_const() {
460        frag_ptr.to_const();
461    }
462    let frag_ptr_out = frag_ptr.fmt_left();
463    writeln!(f, "{frag_ptr_out} = {frag} + {offset};").unwrap();
464
465    if item.vectorization > 1 {
466        let mut item_value = item;
467        item_value.vectorization = 1;
468        frag_ptr.reinterpret_ptr(f, item_value)
469    } else {
470        frag_ptr
471    }
472}
473
474pub fn frag_ident_str<D: Dialect>(frag: &FragmentIdent<D>) -> &str {
475    match frag {
476        FragmentIdent::A => "a",
477        FragmentIdent::B => "b",
478        FragmentIdent::Accumulator => "c",
479        FragmentIdent::_Dialect(_) => "d",
480    }
481}
482
483pub fn frag_layout_str<D: Dialect>(frag: &Option<FragmentLayout<D>>) -> &str {
484    match frag {
485        Some(layout) => match layout {
486            FragmentLayout::ColMajor => "col",
487            FragmentLayout::RowMajor => "row",
488            FragmentLayout::_Dialect(_) => "",
489        },
490        None => "",
491    }
492}
493
494pub fn variable_to_frag<D: Dialect>(frag: &Variable<D>) -> Fragment<D> {
495    match frag {
496        Variable::WmmaFragment { frag, .. } => *frag,
497        _ => panic!(),
498    }
499}