cubecl_cpp/shared/
mma.rs

1use super::{Component, Dialect, Elem, FmtLeft, Variable};
2use cubecl_core::ir::{
3    DeviceProperties,
4    features::{MmaConfig, ScaledMmaConfig},
5};
6use std::{
7    fmt::{Debug, Display, Formatter},
8    marker::PhantomData,
9};
10
11pub type SupportedMmaCombinations = Vec<MmaConfig>;
12pub type SupportedScaledMmaCombinations = Vec<ScaledMmaConfig>;
13
14pub trait Architecture {
15    fn warp_size(&self) -> u32;
16    fn is_wmma_capable(&self) -> bool;
17    fn is_mfma_capable(&self) -> bool;
18    fn get_version(&self) -> u32 {
19        0
20    }
21}
22
23pub fn register_wmma_features(
24    supported_combinations: SupportedMmaCombinations,
25    properties: &mut DeviceProperties,
26) {
27    for config in supported_combinations {
28        properties.features.cmma.insert(config);
29    }
30}
31
32pub fn register_mma_features(
33    supported_combinations: SupportedMmaCombinations,
34    properties: &mut DeviceProperties,
35) {
36    for config in supported_combinations {
37        properties.features.mma.insert(config);
38    }
39}
40
41pub fn register_scaled_mma_features(
42    supported_combinations: SupportedScaledMmaCombinations,
43    properties: &mut DeviceProperties,
44) {
45    for config in supported_combinations {
46        properties.features.scaled_mma.insert(config);
47    }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Copy)]
51pub enum FragmentIdent<D: Dialect> {
52    A,
53    B,
54    Accumulator,
55    _Dialect(PhantomData<D>),
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Copy)]
59pub enum FragmentLayout<D: Dialect> {
60    ColMajor,
61    RowMajor,
62    _Dialect(PhantomData<D>),
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Copy)]
66pub struct Fragment<D: Dialect> {
67    pub ident: FragmentIdent<D>,
68    pub m: u32,
69    pub n: u32,
70    pub k: u32,
71    pub elem: Elem<D>,
72    pub layout: Option<FragmentLayout<D>>,
73}
74
75#[derive(new, Debug, Clone, PartialEq, Eq, Copy)]
76pub struct MmaShape<D: Dialect> {
77    pub m: u32,
78    pub n: u32,
79    pub k: u32,
80    _d: PhantomData<D>,
81}
82
83impl<D: Dialect> MmaShape<D> {
84    pub fn num_elems(&self, ident: FragmentIdent<D>) -> u32 {
85        match ident {
86            FragmentIdent::A => self.m * self.k,
87            FragmentIdent::B => self.k * self.n,
88            FragmentIdent::Accumulator => self.m * self.n,
89            _ => unimplemented!(),
90        }
91    }
92}
93
94/// Warp Matrix-Multiply and Accumulate Instruction.
95#[derive(Debug, Clone, PartialEq)]
96pub enum WmmaInstruction<D: Dialect> {
97    /// Fill the fragment with the value.
98    Fill {
99        frag: Variable<D>,
100        value: Variable<D>,
101    },
102    /// Load the value into the fragment given the stride.
103    Load {
104        frag: Variable<D>,
105        value: Variable<D>,
106        offset: Variable<D>,
107        stride: Variable<D>,
108        layout: Option<FragmentLayout<D>>,
109    },
110    /// Executes D=A*B+C;
111    ///
112    /// For implementing a matmul, `D=C` : `C+=A*B`
113    Execute {
114        frag_a: Variable<D>,
115        frag_b: Variable<D>,
116        frag_c: Variable<D>,
117        frag_d: Variable<D>,
118        warp_size: u32,
119    },
120    /// Executes D=A*B+C using manually managed registers;
121    ///
122    /// For implementing a matmul, `D=C` : `C+=A*B`
123    /// Takes a sequence of registers for the inputs, and returns an array of registers for the
124    /// output. PTX requires output registers to be non-overlapping, so we use array to ensure that
125    /// and handle potentially destructuring it internally.
126    ExecuteManual {
127        shape: MmaShape<D>,
128        frag_a: Variable<D>,
129        frag_b: Variable<D>,
130        frag_c: Variable<D>,
131        frag_d: Variable<D>,
132    },
133    /// Executes D=A*B+C using manually managed registers;
134    ///
135    /// For implementing a matmul, `D=C` : `C+=A*B`
136    /// Takes a sequence of registers for the inputs, and returns an array of registers for the
137    /// output. PTX requires output registers to be non-overlapping, so we use array to ensure that
138    /// and handle potentially destructuring it internally.
139    ExecuteScaled {
140        shape: MmaShape<D>,
141        frag_a: Variable<D>,
142        frag_b: Variable<D>,
143        frag_c: Variable<D>,
144        frag_d: Variable<D>,
145
146        scales_a: Variable<D>,
147        scales_b: Variable<D>,
148        scales_factor: u32,
149    },
150    /// Store the fragment in an output variable following the stride and the layout.
151    Store {
152        output: Variable<D>,
153        frag: Variable<D>,
154        stride: Variable<D>,
155        offset: Variable<D>,
156        layout: FragmentLayout<D>,
157    },
158    /// Load a part of a fragment into registers, either 1, 2, or 4 at once.
159    LdMatrix {
160        output: Variable<D>,
161        buffer: Variable<D>,
162        offset: Variable<D>,
163        line_size: Option<usize>,
164        factor: u32,
165        transpose: bool,
166    },
167    /// Store a part of a fragment into smem, either 1, 2, or 4 at once.
168    StMatrix {
169        registers: Variable<D>,
170        buffer: Variable<D>,
171        offset: Variable<D>,
172        line_size: Option<usize>,
173        factor: u32,
174        transpose: bool,
175    },
176    /// Cast
177    Cast {
178        input: Variable<D>,
179        output: Variable<D>,
180    },
181}
182
183impl<D: Dialect> Display for FragmentLayout<D> {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        D::compile_wmma_fragment_layout(f, self)
186    }
187}
188
189impl<D: Dialect> Display for FragmentIdent<D> {
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        D::compile_wwma_fragment_ident(f, self)
192    }
193}
194
195impl<D: Dialect> Display for Fragment<D> {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        D::compile_wmma_fragment(f, self)
198    }
199}
200
201impl<D: Dialect> Display for WmmaInstruction<D> {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        D::compile_wmma_instruction(f, self)
204    }
205}
206
207pub mod wmma_api_base {
208    use crate::{
209        cuda::ptx::{ldmatrix_call, stmatrix_call},
210        shared::ManualMma,
211    };
212
213    use super::*;
214
215    pub fn compile_fragment_declaration<D: Dialect>(
216        f: &mut std::fmt::Formatter<'_>,
217        var: &Variable<D>,
218    ) -> std::fmt::Result {
219        match var {
220            Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
221            _ => panic!("variable must be a fragment"),
222        }
223    }
224
225    pub fn compile_fragment_ident<D: Dialect>(
226        f: &mut std::fmt::Formatter<'_>,
227        namespace: &str,
228        ident: &FragmentIdent<D>,
229    ) -> std::fmt::Result {
230        match ident {
231            FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
232            FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
233            FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
234            FragmentIdent::_Dialect(_) => Ok(()),
235        }
236    }
237
238    pub fn compile_fragment_layout<D: Dialect>(
239        f: &mut std::fmt::Formatter<'_>,
240        namespace: &str,
241        layout: &FragmentLayout<D>,
242    ) -> std::fmt::Result {
243        match layout {
244            FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
245            FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
246            FragmentLayout::_Dialect(_) => Ok(()),
247        }
248    }
249
250    pub fn compile_fragment<D: Dialect>(
251        f: &mut std::fmt::Formatter<'_>,
252        namespace: &str,
253        fragment: &Fragment<D>,
254    ) -> std::fmt::Result {
255        let elem = match fragment.elem {
256            Elem::TF32 => format!("{namespace}::precision::tf32"),
257            Elem::BF16 => {
258                if fragment.ident == FragmentIdent::Accumulator {
259                    format!("{}", Elem::<D>::F16) // Normally not supported except for cast.
260                } else {
261                    format!("{}", fragment.elem)
262                }
263            }
264            elem => format!("{elem}"),
265        };
266        match fragment.layout {
267            Some(layout) => write!(
268                f,
269                "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
270                fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
271            ),
272            None => write!(
273                f,
274                "{namespace}::fragment<{}, {}, {}, {}, {}>",
275                fragment.ident, fragment.m, fragment.n, fragment.k, elem,
276            ),
277        }
278    }
279
280    pub fn compile_instruction<D: Dialect>(
281        f: &mut std::fmt::Formatter<'_>,
282        namespace: &str,
283        instruction: &WmmaInstruction<D>,
284    ) -> std::fmt::Result {
285        match instruction {
286            WmmaInstruction::Fill { frag, value } => {
287                writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
288            }
289            WmmaInstruction::Load {
290                frag,
291                value,
292                stride,
293                offset,
294                layout: None,
295            } => {
296                let item = value.item();
297                if item.vectorization > 1 {
298                    let elem = item.elem;
299                    let qualifier = value.const_qualifier();
300                    writeln!(
301                        f,
302                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
303                    )
304                } else {
305                    writeln!(
306                        f,
307                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
308                    )
309                }
310            }
311            WmmaInstruction::Load {
312                frag,
313                value,
314                offset,
315                stride,
316                layout: Some(layout),
317            } => {
318                let layout = match layout {
319                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
320                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
321                    FragmentLayout::_Dialect(_) => "".to_string(),
322                };
323                let item = value.item();
324                if item.vectorization > 1 {
325                    let elem = item.elem;
326                    writeln!(
327                        f,
328                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
329                    )
330                } else {
331                    writeln!(
332                        f,
333                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
334                    )
335                }
336            }
337            WmmaInstruction::LdMatrix {
338                output,
339                buffer,
340                offset,
341                line_size,
342                factor,
343                transpose,
344            } => f.write_str(&ldmatrix_call(
345                output, buffer, offset, line_size, factor, transpose,
346            )),
347            WmmaInstruction::StMatrix {
348                registers,
349                buffer,
350                offset,
351                line_size,
352                factor,
353                transpose,
354            } => f.write_str(&stmatrix_call(
355                registers, buffer, offset, line_size, factor, transpose,
356            )),
357            WmmaInstruction::Execute {
358                frag_a,
359                frag_b,
360                frag_c,
361                frag_d,
362                ..
363            } => writeln!(
364                f,
365                "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
366            ),
367            WmmaInstruction::Store {
368                output,
369                frag,
370                stride,
371                offset,
372                layout,
373            } => {
374                let layout = match layout {
375                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
376                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
377                    FragmentLayout::_Dialect(_) => "".to_string(),
378                };
379
380                let item = output.item();
381                let mut reinterpret_cast = item.vectorization > 1;
382                let elem = match item.elem {
383                    Elem::BF16 => {
384                        reinterpret_cast = true;
385                        Elem::F16
386                    }
387                    _ => item.elem,
388                };
389                if reinterpret_cast {
390                    writeln!(
391                        f,
392                        "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
393                    )
394                } else {
395                    writeln!(
396                        f,
397                        "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
398                    )
399                }
400            }
401            WmmaInstruction::Cast { input, output } => {
402                let ty = match output {
403                    Variable::WmmaFragment { frag, .. } => frag.elem,
404                    _ => panic!("Should be a fragment"),
405                };
406                match ty {
407                    Elem::BF16 => {
408                        let elem = Elem::<D>::F16;
409                        write!(
410                            f,
411                            "// cast
412for(int t=0; t<{input}.num_elements; t++) {{
413  {ty} elem = {ty}({input}.x[t]);
414  {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
415}}
416"
417                        )
418                    }
419                    _ => {
420                        write!(
421                            f,
422                            "// cast
423for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
424"
425                        )
426                    }
427                }
428            }
429            WmmaInstruction::ExecuteManual {
430                shape,
431                frag_a,
432                frag_b,
433                frag_c,
434                frag_d,
435            } => D::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d)),
436            WmmaInstruction::ExecuteScaled {
437                shape,
438                frag_a,
439                frag_b,
440                frag_c,
441                frag_d,
442                scales_a,
443                scales_b,
444                scales_factor,
445            } => D::compile_scaled_mma(
446                f,
447                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
448                *scales_a,
449                *scales_b,
450                *scales_factor,
451            ),
452        }
453    }
454}
455
456pub fn frag_as_ptr<D: Dialect>(
457    f: &mut Formatter<'_>,
458    frag: &Variable<D>,
459    offset: &Variable<D>,
460) -> Variable<D> {
461    let item = frag.item();
462    let mut frag_ptr = Variable::tmp_ptr(item);
463    if frag.is_const() {
464        frag_ptr.to_const();
465    }
466    let frag_ptr_out = frag_ptr.fmt_left();
467    writeln!(f, "{frag_ptr_out} = {frag} + {offset};").unwrap();
468
469    if item.vectorization > 1 {
470        let mut item_value = item;
471        item_value.vectorization = 1;
472        frag_ptr.reinterpret_ptr(f, item_value)
473    } else {
474        frag_ptr
475    }
476}
477
478pub fn frag_ident_str<D: Dialect>(frag: &FragmentIdent<D>) -> &str {
479    match frag {
480        FragmentIdent::A => "a",
481        FragmentIdent::B => "b",
482        FragmentIdent::Accumulator => "c",
483        FragmentIdent::_Dialect(_) => "d",
484    }
485}
486
487pub fn frag_layout_str<D: Dialect>(frag: &Option<FragmentLayout<D>>) -> &str {
488    match frag {
489        Some(layout) => match layout {
490            FragmentLayout::ColMajor => "col",
491            FragmentLayout::RowMajor => "row",
492            FragmentLayout::_Dialect(_) => "",
493        },
494        None => "",
495    }
496}
497
498pub fn variable_to_frag<D: Dialect>(frag: &Variable<D>) -> Fragment<D> {
499    match frag {
500        Variable::WmmaFragment { frag, .. } => *frag,
501        _ => panic!(),
502    }
503}