cubecl_cpp/cuda/mma/
ptx_wmma_compiler.rs

1use super::WMMA_MINIMUM_VERSION;
2use crate::{
3    Dialect,
4    cuda::{
5        CudaDialect,
6        arch::CudaArchitecture,
7        ptx::{comma_separated, ldmatrix_call, stmatrix_call},
8    },
9    shared::{
10        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
11        FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
12        SupportedScaledMmaCombinations, Variable, WmmaInstruction,
13    },
14};
15use cubecl_core::ir::{
16    self as gpu, ConstantValue, Matrix, MatrixIdent,
17    features::{MmaConfig, ScaledMmaConfig},
18};
19use itertools::Itertools;
20use std::fmt::Display;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26    fn compile_wmma_includes(
27        f: &mut std::fmt::Formatter<'_>,
28        flags: &Flags<CudaDialect<Self>>,
29    ) -> std::fmt::Result {
30        // We need mma header for conversion
31        if flags.elem_tf32 {
32            f.write_str("#include <mma.h>\n")?;
33        }
34        Ok(())
35    }
36
37    fn compile_wmma_fragment_declaration(
38        f: &mut std::fmt::Formatter<'_>,
39        var: &Variable<CudaDialect<Self>>,
40    ) -> std::fmt::Result {
41        let frag = match var {
42            Variable::WmmaFragment { frag, .. } => *frag,
43            _ => panic!("load instruction expects a WmmaFragment"),
44        };
45        let reg_count = get_fragment_register_total_count(&frag);
46        let ty = match frag.elem {
47            Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
48            Elem::F32 => "float",
49            Elem::F64 => "double",
50            _ => panic!("unsupported type"),
51        };
52        writeln!(f, "{ty} {var}[{reg_count}];")
53    }
54
55    fn compile_wmma_instruction(
56        f: &mut std::fmt::Formatter<'_>,
57        instruction: &WmmaInstruction<CudaDialect<Self>>,
58    ) -> std::fmt::Result {
59        match instruction {
60            WmmaInstruction::Fill { frag: var, value } => {
61                let frag = match var {
62                    Variable::WmmaFragment { frag, .. } => *frag,
63                    _ => panic!("variable should be WmmaFragment"),
64                };
65                let reg_count = get_fragment_register_total_count(&frag);
66                write!(
67                    f,
68                    "// fill
69for (uint i = 0; i < uint({reg_count}); ++i) {{
70  {var}[i] = {value};
71}}
72 "
73                )
74            }
75            WmmaInstruction::Load {
76                frag: var,
77                value,
78                offset,
79                stride,
80                layout,
81            } => {
82                let frag = match var {
83                    Variable::WmmaFragment { frag, .. } => *frag,
84                    _ => panic!("load instruction expects a WmmaFragment"),
85                };
86                // Important note: the current frontend has been designed around
87                // CUDA wmma which is not optimal in the case of PTX wmma and mma
88                // We choose here to use the layout defined in the fragment first,
89                // if it is unknown and we look into the layout passed to the instruction.
90                let layout = if frag.layout.is_some() {
91                    get_fragment_layout_qualifier(var)
92                } else if let Some(layout) = layout {
93                    get_qualifier_from_layout(layout)
94                } else {
95                    panic!("unknown matrix layout for wmma load instruction");
96                };
97                // instruction qualifiers
98                let ty = get_type_qualifier(value);
99                let matrix = match frag.ident {
100                    FragmentIdent::A => "a",
101                    FragmentIdent::B => "b",
102                    FragmentIdent::Accumulator => "c",
103                    FragmentIdent::_Dialect(_) => unreachable!(),
104                };
105                let value_ty = value.item();
106                let opcode = match frag.elem {
107                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
108                        format!(
109                            "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
110                            frag.m, frag.n, frag.k,
111                        )
112                    }
113                    other => panic!("{other} fragment type not supported"),
114                };
115                // constraints
116                let mut reg_count = 0;
117                let (regs_decl, out_constraints) =
118                    get_variable_regs_decl_constraints(var, true, &mut reg_count);
119                let buffer_reg = format_reg_and_inc(&mut reg_count);
120                let (stride_reg, stride_constraint) =
121                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
122                let tmp_ptr = Variable::tmp_ptr(value.item());
123                let tmp_ptr_left = tmp_ptr.fmt_left();
124                write!(
125                    f,
126                    r#"// load
127{tmp_ptr_left} = ({value_ty}*){value} + {offset};
128asm volatile(
129    "{opcode} "
130    "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
131    : {out_constraints}
132    : "l"({tmp_ptr}){stride_constraint}
133);
134"#
135                )
136            }
137            WmmaInstruction::LdMatrix {
138                output,
139                buffer,
140                offset,
141                line_size,
142                factor,
143                transpose,
144            } => f.write_str(&ldmatrix_call(
145                output, buffer, offset, line_size, factor, transpose,
146            )),
147            WmmaInstruction::StMatrix {
148                registers,
149                buffer,
150                offset,
151                line_size,
152                factor,
153                transpose,
154            } => f.write_str(&stmatrix_call(
155                registers, buffer, offset, line_size, factor, transpose,
156            )),
157            WmmaInstruction::Execute {
158                frag_a: var_a,
159                frag_b: var_b,
160                frag_c: var_c,
161                frag_d: var_d,
162                ..
163            } => {
164                let frag_a = match var_a {
165                    Variable::WmmaFragment { frag, .. } => *frag,
166                    _ => panic!("variable should be WmmaFragment"),
167                };
168                let layout_a = get_fragment_layout_qualifier(var_a);
169                let layout_b = get_fragment_layout_qualifier(var_b);
170                let type_c = get_type_qualifier(var_c);
171                let type_d = get_type_qualifier(var_d);
172                let opcode = match var_a.elem() {
173                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
174                        "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
175                        frag_a.m, frag_a.n, frag_a.k,
176                    ),
177                    Elem::BF16 => format!(
178                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
179                        frag_a.m, frag_a.n, frag_a.k,
180                    ),
181                    Elem::TF32 => format!(
182                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
183                        frag_a.m, frag_a.n, frag_a.k,
184                    ),
185                    other => panic!("{other} fragment type not supported"),
186                };
187                let mut reg_count = 0;
188                // order matters, declare the registers in the same order as the intrinsic
189                let (regs_decl_d, out_constraints_d) =
190                    get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
191                let (regs_decl_a, in_constraints_a) =
192                    get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
193                let (regs_decl_b, in_constraints_b) =
194                    get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
195                let (regs_decl_c, in_constraints_c) =
196                    get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
197                write!(
198                    f,
199                    r#"// execute
200asm volatile(
201    "{opcode} "
202    "{{{regs_decl_d}}}, "
203    "{{{regs_decl_a}}}, "
204    "{{{regs_decl_b}}}, "
205    "{{{regs_decl_c}}};\n"
206    : {out_constraints_d}
207    : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
208);
209"#
210                )
211            }
212            WmmaInstruction::ExecuteManual {
213                shape,
214                frag_a,
215                frag_b,
216                frag_c,
217                frag_d,
218            } => {
219                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
220            }
221            WmmaInstruction::ExecuteScaled {
222                shape,
223                frag_a,
224                frag_b,
225                frag_c,
226                frag_d,
227
228                scales_a,
229                scales_b,
230                scales_factor,
231            } => Self::compile_scaled_mma(
232                f,
233                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
234                *scales_a,
235                *scales_b,
236                *scales_factor,
237            ),
238            WmmaInstruction::Store {
239                output,
240                frag: var,
241                stride,
242                offset,
243                layout,
244            } => {
245                let frag_acc = match var {
246                    Variable::WmmaFragment { frag, .. } => *frag,
247                    _ => panic!("variable should be WmmaFragment"),
248                };
249                // instruction qualifiers
250                let layout = match layout {
251                    FragmentLayout::ColMajor => "col",
252                    FragmentLayout::RowMajor => "row",
253                    FragmentLayout::_Dialect(..) => unreachable!(),
254                };
255                let opcode = match var.elem() {
256                    Elem::F16 | Elem::BF16 => format!(
257                        // hack because wmma.store does not support bf16
258                        // f16 should still work correctly for bf16 as long
259                        // as the input registers are in correct format
260                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
261                        frag_acc.m, frag_acc.n, frag_acc.k,
262                    ),
263                    Elem::TF32 | Elem::F32 => format!(
264                        // same hack for tf32
265                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
266                        frag_acc.m, frag_acc.n, frag_acc.k,
267                    ),
268                    Elem::I32 => format!(
269                        // same hack for tf32
270                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
271                        frag_acc.m, frag_acc.n, frag_acc.k,
272                    ),
273                    other => panic!("{other} fragment type not supported"),
274                };
275                // constraints
276                let mut reg_count = 0;
277                let buffer_reg = format_reg_and_inc(&mut reg_count);
278                // offset and stride can be passed as local const or as const scalar
279                // we need to handle both cases correctly in the asm.
280                let (stride_reg, stride_constraint) =
281                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
282                // we start at 2 because of the buffer address calculation
283                let (regs_decl, in_constraints) =
284                    get_variable_regs_decl_constraints(var, false, &mut reg_count);
285                let tmp_ptr = Variable::tmp_ptr(output.item());
286                let tmp_ptr_left = tmp_ptr.fmt_left();
287                write!(
288                    f,
289                    r#"// store
290{tmp_ptr_left} = {output} + {offset};
291asm volatile(
292    "{opcode} "
293    "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
294    :
295    : "l"({tmp_ptr}),
296      {in_constraints}{stride_constraint}
297);
298"#
299                )
300            }
301            WmmaInstruction::Cast { input, output } => {
302                let frag = match input {
303                    Variable::WmmaFragment { frag, .. } => *frag,
304                    _ => panic!("variable should be WmmaFragment"),
305                };
306                let reg_count = get_fragment_register_total_count(&frag);
307                match output.elem() {
308                    Elem::F16 => {
309                        write!(
310                            f,
311                            "// cast
312for (int i = 0; i < {reg_count}; ++i) {{
313    __half h_lo = __float2half_rn({input}[2*i + 0]);
314    __half h_hi = __float2half_rn({input}[2*i + 1]);
315    __half2 h2 = __halves2half2(h_lo, h_hi);
316    {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
317}}
318"
319                        )
320                    }
321                    Elem::BF16 => {
322                        write!(
323                            f,
324                            "// cast
325for (int i = 0; i < {reg_count}; ++i) {{
326    __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
327    __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
328    __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
329    {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
330}}
331"
332                        )
333                    }
334                    other => panic!("casting fragment to {other} not supported"),
335                }
336            }
337        }
338    }
339
340    fn compile_manual_mma(
341        f: &mut std::fmt::Formatter<'_>,
342        mma: ManualMma<CudaDialect<Self>>,
343    ) -> std::fmt::Result {
344        compile_manual_mma(f, mma)
345    }
346
347    fn compile_scaled_mma(
348        f: &mut std::fmt::Formatter<'_>,
349        mma: ManualMma<CudaDialect<Self>>,
350        scales_a: Variable<CudaDialect<Self>>,
351        scales_b: Variable<CudaDialect<Self>>,
352        scales_factor: u32,
353    ) -> std::fmt::Result {
354        compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
355    }
356
357    fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
358        let mut result: SupportedMmaCombinations = vec![];
359        if arch.get_version() >= WMMA_MINIMUM_VERSION {
360            // Types fully supported.
361            let types = vec![
362                (
363                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
364                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
365                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
366                ),
367                (
368                    gpu::ElemType::Float(gpu::FloatKind::F16),
369                    gpu::ElemType::Float(gpu::FloatKind::F16),
370                    gpu::ElemType::Float(gpu::FloatKind::F32),
371                ),
372                (
373                    gpu::ElemType::Float(gpu::FloatKind::BF16),
374                    gpu::ElemType::Float(gpu::FloatKind::BF16),
375                    gpu::ElemType::Float(gpu::FloatKind::F32),
376                ),
377            ];
378            let combinations: SupportedMmaCombinations = types
379                .into_iter()
380                .map(|(a, b, cd)| MmaConfig {
381                    a_type: a.into(),
382                    b_type: b.into(),
383                    cd_type: cd.into(),
384                    m: 16,
385                    n: 16,
386                    k: 16,
387                })
388                .collect();
389            result.extend(combinations);
390            if arch.get_version() >= 72 {
391                result.extend([
392                    MmaConfig {
393                        a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
394                        b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
395                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
396                        m: 16,
397                        n: 16,
398                        k: 16,
399                    },
400                    MmaConfig {
401                        a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
402                        b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
403                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
404                        m: 16,
405                        n: 16,
406                        k: 16,
407                    },
408                ]);
409            }
410            if arch.get_version() >= 80 {
411                result.push(MmaConfig {
412                    a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
413                    b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
414                    cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
415                    m: 16,
416                    n: 16,
417                    k: 8,
418                });
419            }
420        }
421        result
422    }
423
424    fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
425        supported_mma_combinations(arch)
426    }
427
428    fn supported_scaled_mma_combinations(
429        arch: &CudaArchitecture,
430    ) -> SupportedScaledMmaCombinations {
431        supported_scaled_mma_combinations(arch)
432    }
433}
434
435fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
436    let Fragment {
437        ident,
438        m,
439        n,
440        k,
441        elem,
442        ..
443    } = frag;
444    let elements = match ident {
445        FragmentIdent::A => m * k,
446        FragmentIdent::B => k * n,
447        FragmentIdent::Accumulator => m * n,
448        _ => unreachable!(),
449    };
450    let bits_per_elem = elem.size_bits() as u32;
451    // TODO: retrieve the warp size from the compiler CompilationOptions
452    let lanes_per_reg = 32 / bits_per_elem;
453    // choose threads-per-frag:
454    // - accumulators always use 32 lanes
455    // - A/B use 16 lanes _except_ TF32 (k=8) which also uses 32 lanes
456    let threads_per_frag = match ident {
457        FragmentIdent::Accumulator => 32,
458        FragmentIdent::A | FragmentIdent::B => {
459            if frag.elem == Elem::TF32 {
460                32
461            } else {
462                16
463            }
464        }
465        _ => unreachable!(),
466    };
467
468    elements / (lanes_per_reg * threads_per_frag)
469}
470
471fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
472    match var.elem() {
473        Elem::U8 => "u8",
474        Elem::I8 => "s8",
475        Elem::F16 => "f16",
476        Elem::BF16 => "bf16",
477        Elem::F32 => "f32",
478        Elem::TF32 => "tf32",
479        Elem::I32 => "s32",
480        Elem::F64 => "f64",
481        _ => panic!("unsupported WMMA fragment type"),
482    }
483    .to_string()
484}
485
486fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
487    let frag = match var {
488        Variable::WmmaFragment { frag, .. } => *frag,
489        _ => panic!("variable should be WmmaFragment"),
490    };
491    match frag.layout {
492        Some(layout) => get_qualifier_from_layout(&layout),
493        None => "".to_string(),
494    }
495}
496
497fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
498    match layout {
499        FragmentLayout::ColMajor => "col",
500        FragmentLayout::RowMajor => "row",
501        FragmentLayout::_Dialect(..) => unreachable!(),
502    }
503    .to_string()
504}
505
506fn get_variable_regs_decl_constraints(
507    var: &Variable<CudaDialect<PtxWmmaCompiler>>,
508    output: bool,
509    reg_count: &mut u8,
510) -> (String, String) {
511    match var {
512        Variable::WmmaFragment { frag, .. } => {
513            let reg_total_count = get_fragment_register_total_count(frag);
514            let reg_decl = (0..reg_total_count)
515                .map(|_| format_reg_and_inc(reg_count))
516                .collect::<Vec<_>>()
517                .join(",");
518            let frag_elem = frag.elem;
519            let modifier = format!(
520                "{}{}",
521                if output { "=" } else { "" },
522                match frag_elem {
523                    Elem::F32 => "f",
524                    Elem::F64 => "d",
525                    _ => "r",
526                },
527            );
528            let constraints = (0..reg_total_count)
529                .map(|i| format!("\"{modifier}\"({var}[{i}])"))
530                .collect::<Vec<_>>()
531                .join(", ");
532            (reg_decl, constraints)
533        }
534        Variable::Constant(number, ..) => match number {
535            ConstantValue::UInt(val, ..) => (val.to_string(), "".to_string()),
536            _ => panic!("variable should be an unsigned integer"),
537        },
538        _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
539    }
540}
541
542fn format_reg_and_inc(count: &mut u8) -> String {
543    let res = format!("%{count}");
544    *count += 1;
545    res
546}
547
548fn as_ty(var: impl Display, ty: impl Display) -> String {
549    format!("reinterpret_cast<{ty}&>({var})")
550}
551
552fn as_const_ty(var: impl Display, ty: impl Display) -> String {
553    format!("reinterpret_cast<const {ty}&>({var})")
554}
555
556pub(super) fn compile_manual_mma<D: Dialect>(
557    f: &mut core::fmt::Formatter<'_>,
558    mma: ManualMma<D>,
559) -> std::fmt::Result {
560    let ManualMma {
561        shape,
562        frag_a,
563        frag_b,
564        frag_c,
565        frag_d,
566    } = mma;
567
568    let a_elem = frag_a.elem().unpacked();
569    let b_elem = frag_b.elem().unpacked();
570    let cd_elem = frag_c.elem().unpacked();
571
572    let ab_ty = match a_elem {
573        Elem::F32 => &format!("{}", Elem::<D>::F32),
574        _ => &format!("{}", Elem::<D>::U32),
575    };
576    let cd_ty = match cd_elem {
577        Elem::F32 => &format!("{}", Elem::<D>::F32),
578        _ => &format!("{}", Elem::<D>::U32),
579    };
580
581    let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
582    let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
583    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
584
585    let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
586    let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
587    let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
588
589    let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
590    let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
591
592    // C and D fragments are always vectorized for optimal stores and casts, but are taken as separate
593    // registers for f32 so we need to unpack it. f16 is taken in packed registers, so use as is.
594    let frag_c = match cd_elem.size() {
595        4 | 8 => (0..cd_regs)
596            .map(|i| as_ty(format!("{frag_c}[{}].i_{}", i / 2, i % 2), cd_ty))
597            .collect::<Vec<_>>(),
598        2 => (0..cd_regs)
599            .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
600            .collect::<Vec<_>>(),
601        other => panic!("Found unhandled accumulator elem size {other}"),
602    };
603    let frag_d = match cd_elem.size() {
604        4 | 8 => (0..cd_regs)
605            .map(|i| as_ty(format!("{frag_d}[{}].i_{}", i / 2, i % 2), cd_ty))
606            .collect::<Vec<_>>(),
607        2 => (0..cd_regs)
608            .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
609            .collect::<Vec<_>>(),
610        other => panic!("Found unhandled accumulator elem size {other}"),
611    };
612    let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
613    write!(
614        f,
615        "__mma_m16n8k{}_{}_{}_{}({args});",
616        shape.k, a_elem, b_elem, cd_elem
617    )
618}
619
620pub(super) fn compile_scaled_mma<D: Dialect>(
621    f: &mut core::fmt::Formatter<'_>,
622    mma: ManualMma<D>,
623    scales_a: Variable<D>,
624    scales_b: Variable<D>,
625    scales_factor: u32,
626) -> std::fmt::Result {
627    let ManualMma {
628        shape,
629        frag_a,
630        frag_b,
631        frag_c,
632        frag_d,
633    } = mma;
634
635    let a_elem = frag_a.elem().unpacked();
636    let b_elem = frag_b.elem().unpacked();
637    let cd_elem = frag_c.elem().unpacked();
638
639    let ab_ty = &format!("{}", Elem::<D>::U32);
640    let cd_ty = &format!("{}", Elem::<D>::F32);
641
642    let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
643    let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
644    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
645
646    let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
647    let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
648    let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
649
650    let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
651    let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
652    let frag_c = (0..cd_regs).map(|i| as_const_ty(format!("{frag_c}[{i}]"), cd_ty));
653    let frag_d = (0..cd_regs).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
654    let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
655    write!(
656        f,
657        "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
658        shape.k, a_elem, b_elem, cd_elem
659    )
660}
661
662pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
663    let mut result: SupportedMmaCombinations = vec![];
664    // Higher than WMMA because we only support the newest shapes. Other shapes would make things
665    // very complicated.
666    // Also only use f32 accumulators for now
667    if arch.get_version() >= 80 {
668        result.extend([
669            MmaConfig {
670                a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // a
671                b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // b
672                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), // cd
673                m: 16,
674                n: 8,
675                k: 16,
676            },
677            MmaConfig {
678                a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
679                b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
680                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
681                m: 16,
682                n: 8,
683                k: 16,
684            },
685            MmaConfig {
686                a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
687                b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
688                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
689                m: 16,
690                n: 8,
691                k: 8,
692            },
693            MmaConfig {
694                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
695                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
696                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
697                m: 16,
698                n: 8,
699                k: 32,
700            },
701            MmaConfig {
702                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
703                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
704                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
705                m: 16,
706                n: 8,
707                k: 32,
708            },
709            MmaConfig {
710                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
711                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
712                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
713                m: 16,
714                n: 8,
715                k: 32,
716            },
717            MmaConfig {
718                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
719                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
720                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
721                m: 16,
722                n: 8,
723                k: 32,
724            },
725            // TODO: u4/i4/b1, there's no types for them yet
726        ]);
727    }
728    if arch.get_version() >= 89 {
729        let f8f6f4_types = [
730            gpu::FloatKind::E4M3,
731            gpu::FloatKind::E5M2,
732            gpu::FloatKind::E3M2,
733            gpu::FloatKind::E2M3,
734            gpu::FloatKind::E2M1,
735        ];
736        let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
737        result.extend(combinations.map(|(t1, t2)| MmaConfig {
738            a_type: gpu::ElemType::Float(*t1).into(),
739            b_type: gpu::ElemType::Float(*t2).into(),
740            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
741            m: 16,
742            n: 8,
743            k: 32,
744        }));
745    }
746    result
747}
748
749pub(super) fn supported_scaled_mma_combinations(
750    arch: &CudaArchitecture,
751) -> SupportedScaledMmaCombinations {
752    let mut result: SupportedScaledMmaCombinations = vec![];
753    // sm_120f
754    if arch.get_version() >= 120 && arch.get_version() < 130 {
755        let f8f6f4_types = [
756            gpu::FloatKind::E4M3,
757            gpu::FloatKind::E5M2,
758            gpu::FloatKind::E3M2,
759            gpu::FloatKind::E2M3,
760            gpu::FloatKind::E2M1,
761        ];
762        let combinations = f8f6f4_types
763            .iter()
764            .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
765
766        result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
767            a_type: gpu::ElemType::Float(*t1).into(),
768            b_type: gpu::ElemType::Float(*t2).into(),
769            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
770            scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
771            m: 16,
772            n: 8,
773            k: 32,
774            scales_factor: 1,
775        }));
776
777        result.extend([
778            ScaledMmaConfig {
779                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
780                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
781                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
782                scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
783                m: 16,
784                n: 8,
785                k: 64,
786                scales_factor: 2,
787            },
788            // Sign of scales is ignored
789            ScaledMmaConfig {
790                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
791                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
792                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
793                scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
794                m: 16,
795                n: 8,
796                k: 64,
797                scales_factor: 4,
798            },
799        ]);
800    }
801    result
802}
803
804pub fn contiguous_elements_cuda(ident: MatrixIdent, matrix: Matrix) -> usize {
805    match ident {
806        MatrixIdent::A | MatrixIdent::B => 32 / matrix.storage.size_bits(),
807        MatrixIdent::Accumulator => 2,
808    }
809}