cubecl_cpp/shared/
mma.rs

1use cubecl_runtime::{DeviceProperties, MmaConfig, ScaledMmaConfig};
2use std::fmt::{Display, Formatter};
3use std::{fmt::Debug, marker::PhantomData};
4
5use super::{Component, Dialect, Elem, FmtLeft, Variable};
6
7pub type SupportedMmaCombinations = Vec<MmaConfig>;
8pub type SupportedScaledMmaCombinations = Vec<ScaledMmaConfig>;
9
10pub trait Architecture {
11    fn warp_size(&self) -> u32;
12    fn is_wmma_capable(&self) -> bool;
13    fn is_mfma_capable(&self) -> bool;
14    fn get_version(&self) -> u32 {
15        0
16    }
17}
18
19pub fn register_wmma_features(
20    supported_combinations: SupportedMmaCombinations,
21    properties: &mut DeviceProperties,
22) {
23    for config in supported_combinations {
24        properties.features.cmma.insert(config);
25    }
26}
27
28pub fn register_mma_features(
29    supported_combinations: SupportedMmaCombinations,
30    properties: &mut DeviceProperties,
31) {
32    for config in supported_combinations {
33        properties.features.mma.insert(config);
34    }
35}
36
37pub fn register_scaled_mma_features(
38    supported_combinations: SupportedScaledMmaCombinations,
39    properties: &mut DeviceProperties,
40) {
41    for config in supported_combinations {
42        properties.features.scaled_mma.insert(config);
43    }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Copy)]
47pub enum FragmentIdent<D: Dialect> {
48    A,
49    B,
50    Accumulator,
51    _Dialect(PhantomData<D>),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Copy)]
55pub enum FragmentLayout<D: Dialect> {
56    ColMajor,
57    RowMajor,
58    _Dialect(PhantomData<D>),
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Copy)]
62pub struct Fragment<D: Dialect> {
63    pub ident: FragmentIdent<D>,
64    pub m: u32,
65    pub n: u32,
66    pub k: u32,
67    pub elem: Elem<D>,
68    pub layout: Option<FragmentLayout<D>>,
69}
70
71#[derive(new, Debug, Clone, PartialEq, Eq, Copy)]
72pub struct MmaShape<D: Dialect> {
73    pub m: u32,
74    pub n: u32,
75    pub k: u32,
76    _d: PhantomData<D>,
77}
78
79impl<D: Dialect> MmaShape<D> {
80    pub fn num_elems(&self, ident: FragmentIdent<D>) -> u32 {
81        match ident {
82            FragmentIdent::A => self.m * self.k,
83            FragmentIdent::B => self.k * self.n,
84            FragmentIdent::Accumulator => self.m * self.n,
85            _ => unimplemented!(),
86        }
87    }
88}
89
90/// Warp Matrix-Multiply and Accumulate Instruction.
91#[derive(Debug, Clone, PartialEq)]
92pub enum WmmaInstruction<D: Dialect> {
93    /// Fill the fragment with the value.
94    Fill {
95        frag: Variable<D>,
96        value: Variable<D>,
97    },
98    /// Load the value into the fragment given the stride.
99    Load {
100        frag: Variable<D>,
101        value: Variable<D>,
102        offset: Variable<D>,
103        stride: Variable<D>,
104        layout: Option<FragmentLayout<D>>,
105    },
106    /// Executes D=A*B+C;
107    ///
108    /// For implementing a matmul, `D=C` : `C+=A*B`
109    Execute {
110        frag_a: Variable<D>,
111        frag_b: Variable<D>,
112        frag_c: Variable<D>,
113        frag_d: Variable<D>,
114        warp_size: u32,
115    },
116    /// Executes D=A*B+C using manually managed registers;
117    ///
118    /// For implementing a matmul, `D=C` : `C+=A*B`
119    /// Takes a sequence of registers for the inputs, and returns an array of registers for the
120    /// output. PTX requires output registers to be non-overlapping, so we use array to ensure that
121    /// and handle potentially destructuring it internally.
122    ExecuteManual {
123        shape: MmaShape<D>,
124        frag_a: Variable<D>,
125        frag_b: Variable<D>,
126        frag_c: Variable<D>,
127        frag_d: Variable<D>,
128    },
129    /// Executes D=A*B+C using manually managed registers;
130    ///
131    /// For implementing a matmul, `D=C` : `C+=A*B`
132    /// Takes a sequence of registers for the inputs, and returns an array of registers for the
133    /// output. PTX requires output registers to be non-overlapping, so we use array to ensure that
134    /// and handle potentially destructuring it internally.
135    ExecuteScaled {
136        shape: MmaShape<D>,
137        frag_a: Variable<D>,
138        frag_b: Variable<D>,
139        frag_c: Variable<D>,
140        frag_d: Variable<D>,
141
142        scales_a: Variable<D>,
143        scales_b: Variable<D>,
144        scales_factor: u32,
145    },
146    /// Store the fragment in an output variable following the stride and the layout.
147    Store {
148        output: Variable<D>,
149        frag: Variable<D>,
150        stride: Variable<D>,
151        offset: Variable<D>,
152        layout: FragmentLayout<D>,
153    },
154    /// Load a part of a fragment into registers, either 1, 2, or 4 at once.
155    LdMatrix {
156        output: Variable<D>,
157        buffer: Variable<D>,
158        offset: Variable<D>,
159        line_size: Option<u32>,
160        factor: u32,
161        transpose: bool,
162    },
163    /// Cast
164    Cast {
165        input: Variable<D>,
166        output: Variable<D>,
167    },
168}
169
170impl<D: Dialect> Display for FragmentLayout<D> {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        D::compile_wmma_fragment_layout(f, self)
173    }
174}
175
176impl<D: Dialect> Display for FragmentIdent<D> {
177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178        D::compile_wwma_fragment_ident(f, self)
179    }
180}
181
182impl<D: Dialect> Display for Fragment<D> {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        D::compile_wmma_fragment(f, self)
185    }
186}
187
188impl<D: Dialect> Display for WmmaInstruction<D> {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        D::compile_wmma_instruction(f, self)
191    }
192}
193
194pub mod wmma_api_base {
195    use crate::{cuda::ptx::ldmatrix_call, shared::ManualMma};
196
197    use super::*;
198
199    pub fn compile_fragment_declaration<D: Dialect>(
200        f: &mut std::fmt::Formatter<'_>,
201        var: &Variable<D>,
202    ) -> std::fmt::Result {
203        match var {
204            Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
205            _ => panic!("variable must be a fragment"),
206        }
207    }
208
209    pub fn compile_fragment_ident<D: Dialect>(
210        f: &mut std::fmt::Formatter<'_>,
211        namespace: &str,
212        ident: &FragmentIdent<D>,
213    ) -> std::fmt::Result {
214        match ident {
215            FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
216            FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
217            FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
218            FragmentIdent::_Dialect(_) => Ok(()),
219        }
220    }
221
222    pub fn compile_fragment_layout<D: Dialect>(
223        f: &mut std::fmt::Formatter<'_>,
224        namespace: &str,
225        layout: &FragmentLayout<D>,
226    ) -> std::fmt::Result {
227        match layout {
228            FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
229            FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
230            FragmentLayout::_Dialect(_) => Ok(()),
231        }
232    }
233
234    pub fn compile_fragment<D: Dialect>(
235        f: &mut std::fmt::Formatter<'_>,
236        namespace: &str,
237        fragment: &Fragment<D>,
238    ) -> std::fmt::Result {
239        let elem = match fragment.elem {
240            Elem::TF32 => format!("{namespace}::precision::tf32"),
241            Elem::BF16 => {
242                if fragment.ident == FragmentIdent::Accumulator {
243                    format!("{}", Elem::<D>::F16) // Normally not supported except for cast.
244                } else {
245                    format!("{}", fragment.elem)
246                }
247            }
248            elem => format!("{elem}"),
249        };
250        match fragment.layout {
251            Some(layout) => write!(
252                f,
253                "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
254                fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
255            ),
256            None => write!(
257                f,
258                "{namespace}::fragment<{}, {}, {}, {}, {}>",
259                fragment.ident, fragment.m, fragment.n, fragment.k, elem,
260            ),
261        }
262    }
263
264    pub fn compile_instruction<D: Dialect>(
265        f: &mut std::fmt::Formatter<'_>,
266        namespace: &str,
267        instruction: &WmmaInstruction<D>,
268    ) -> std::fmt::Result {
269        match instruction {
270            WmmaInstruction::Fill { frag, value } => {
271                writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
272            }
273            WmmaInstruction::Load {
274                frag,
275                value,
276                stride,
277                offset,
278                layout: None,
279            } => {
280                let item = value.item();
281                if item.vectorization > 1 {
282                    let elem = item.elem;
283                    let qualifier = value.const_qualifier();
284                    writeln!(
285                        f,
286                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
287                    )
288                } else {
289                    writeln!(
290                        f,
291                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
292                    )
293                }
294            }
295            WmmaInstruction::Load {
296                frag,
297                value,
298                offset,
299                stride,
300                layout: Some(layout),
301            } => {
302                let layout = match layout {
303                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
304                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
305                    FragmentLayout::_Dialect(_) => "".to_string(),
306                };
307                let item = value.item();
308                if item.vectorization > 1 {
309                    let elem = item.elem;
310                    writeln!(
311                        f,
312                        "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
313                    )
314                } else {
315                    writeln!(
316                        f,
317                        "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
318                    )
319                }
320            }
321            WmmaInstruction::LdMatrix {
322                output,
323                buffer,
324                offset,
325                line_size,
326                factor,
327                transpose,
328            } => f.write_str(&ldmatrix_call(
329                output, buffer, offset, line_size, factor, transpose,
330            )),
331            WmmaInstruction::Execute {
332                frag_a,
333                frag_b,
334                frag_c,
335                frag_d,
336                ..
337            } => writeln!(
338                f,
339                "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
340            ),
341            WmmaInstruction::Store {
342                output,
343                frag,
344                stride,
345                offset,
346                layout,
347            } => {
348                let layout = match layout {
349                    FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
350                    FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
351                    FragmentLayout::_Dialect(_) => "".to_string(),
352                };
353
354                let item = output.item();
355                let mut reinterpret_cast = item.vectorization > 1;
356                let elem = match item.elem {
357                    Elem::BF16 => {
358                        reinterpret_cast = true;
359                        Elem::F16
360                    }
361                    _ => item.elem,
362                };
363                if reinterpret_cast {
364                    writeln!(
365                        f,
366                        "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
367                    )
368                } else {
369                    writeln!(
370                        f,
371                        "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
372                    )
373                }
374            }
375            WmmaInstruction::Cast { input, output } => {
376                let ty = match output {
377                    Variable::WmmaFragment { frag, .. } => frag.elem,
378                    _ => panic!("Should be a fragment"),
379                };
380                match ty {
381                    Elem::BF16 => {
382                        let elem = Elem::<D>::F16;
383                        write!(
384                            f,
385                            "// cast
386for(int t=0; t<{input}.num_elements; t++) {{
387  {ty} elem = {ty}({input}.x[t]);
388  {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
389}}
390"
391                        )
392                    }
393                    _ => {
394                        write!(
395                            f,
396                            "// cast
397for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
398"
399                        )
400                    }
401                }
402            }
403            WmmaInstruction::ExecuteManual {
404                shape,
405                frag_a,
406                frag_b,
407                frag_c,
408                frag_d,
409            } => D::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d)),
410            WmmaInstruction::ExecuteScaled {
411                shape,
412                frag_a,
413                frag_b,
414                frag_c,
415                frag_d,
416                scales_a,
417                scales_b,
418                scales_factor,
419            } => D::compile_scaled_mma(
420                f,
421                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
422                *scales_a,
423                *scales_b,
424                *scales_factor,
425            ),
426        }
427    }
428}
429
430pub fn frag_as_ptr<D: Dialect>(
431    f: &mut Formatter<'_>,
432    frag: &Variable<D>,
433    offset: &Variable<D>,
434) -> Variable<D> {
435    let item = frag.item();
436    let mut frag_ptr = Variable::tmp_ptr(item);
437    if frag.is_const() {
438        frag_ptr.to_const();
439    }
440    let frag_ptr_out = frag_ptr.fmt_left();
441    writeln!(f, "{frag_ptr_out} = {frag} + {offset};").unwrap();
442
443    if item.vectorization > 1 {
444        let mut item_value = item;
445        item_value.vectorization = 1;
446        frag_ptr.reinterpret_ptr(f, item_value)
447    } else {
448        frag_ptr
449    }
450}
451
452pub fn frag_ident_str<D: Dialect>(frag: &FragmentIdent<D>) -> &str {
453    match frag {
454        FragmentIdent::A => "a",
455        FragmentIdent::B => "b",
456        FragmentIdent::Accumulator => "c",
457        FragmentIdent::_Dialect(_) => "d",
458    }
459}
460
461pub fn frag_layout_str<D: Dialect>(frag: &Option<FragmentLayout<D>>) -> &str {
462    match frag {
463        Some(layout) => match layout {
464            FragmentLayout::ColMajor => "col",
465            FragmentLayout::RowMajor => "row",
466            FragmentLayout::_Dialect(_) => "",
467        },
468        None => "",
469    }
470}
471
472pub fn variable_to_frag<D: Dialect>(frag: &Variable<D>) -> Fragment<D> {
473    match frag {
474        Variable::WmmaFragment { frag, .. } => *frag,
475        _ => panic!(),
476    }
477}