Skip to main content

cubecl_cpp/cuda/mma/
ptx_wmma_compiler.rs

1use super::WMMA_MINIMUM_VERSION;
2use crate::{
3    Dialect,
4    cuda::{
5        CudaDialect,
6        arch::CudaArchitecture,
7        ptx::{comma_separated, ldmatrix_call, stmatrix_call},
8    },
9    shared::{
10        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
11        FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
12        SupportedScaledMmaCombinations, Variable, WmmaInstruction,
13    },
14};
15use cubecl_core::ir::{
16    self as gpu, ConstantValue, Matrix, MatrixIdent,
17    features::{MmaConfig, ScaledMmaConfig},
18};
19use itertools::Itertools;
20use std::fmt::Display;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26    fn compile_wmma_includes(
27        f: &mut std::fmt::Formatter<'_>,
28        flags: &Flags<CudaDialect<Self>>,
29    ) -> std::fmt::Result {
30        // We need mma header for conversion
31        if flags.elem_tf32 {
32            f.write_str("#include <mma.h>\n")?;
33        }
34        Ok(())
35    }
36
37    fn compile_wmma_fragment_declaration(
38        f: &mut std::fmt::Formatter<'_>,
39        var: &Variable<CudaDialect<Self>>,
40    ) -> std::fmt::Result {
41        let frag = match var {
42            Variable::WmmaFragment { frag, .. } => *frag,
43            _ => panic!("load instruction expects a WmmaFragment"),
44        };
45        let reg_count = get_fragment_register_total_count(&frag);
46        let ty = match frag.elem {
47            Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
48            Elem::F32 => "float",
49            Elem::F64 => "double",
50            _ => panic!("unsupported type"),
51        };
52        writeln!(f, "{ty} {var}[{reg_count}];")
53    }
54
55    fn compile_wmma_instruction(
56        f: &mut std::fmt::Formatter<'_>,
57        instruction: &WmmaInstruction<CudaDialect<Self>>,
58    ) -> std::fmt::Result {
59        match instruction {
60            WmmaInstruction::Fill { frag: var, value } => {
61                let frag = match var {
62                    Variable::WmmaFragment { frag, .. } => *frag,
63                    _ => panic!("variable should be WmmaFragment"),
64                };
65                let reg_count = get_fragment_register_total_count(&frag);
66                write!(
67                    f,
68                    "// fill
69for (uint i = 0; i < uint({reg_count}); ++i) {{
70  {var}[i] = {value};
71}}
72 "
73                )
74            }
75            WmmaInstruction::Load {
76                frag: var,
77                value,
78                offset,
79                stride,
80                layout,
81            } => {
82                let frag = match var {
83                    Variable::WmmaFragment { frag, .. } => *frag,
84                    _ => panic!("load instruction expects a WmmaFragment"),
85                };
86                // Important note: the current frontend has been designed around
87                // CUDA wmma which is not optimal in the case of PTX wmma and mma
88                // We choose here to use the layout defined in the fragment first,
89                // if it is unknown and we look into the layout passed to the instruction.
90                let layout = if frag.layout.is_some() {
91                    get_fragment_layout_qualifier(var)
92                } else if let Some(layout) = layout {
93                    get_qualifier_from_layout(layout)
94                } else {
95                    panic!("unknown matrix layout for wmma load instruction");
96                };
97                // instruction qualifiers
98                let ty = get_type_qualifier(value);
99                let matrix = match frag.ident {
100                    FragmentIdent::A => "a",
101                    FragmentIdent::B => "b",
102                    FragmentIdent::Accumulator => "c",
103                    FragmentIdent::_Dialect(_) => unreachable!(),
104                };
105                let value_ty = value.item();
106                let opcode = match frag.elem {
107                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
108                        format!(
109                            "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
110                            frag.m, frag.n, frag.k,
111                        )
112                    }
113                    other => panic!("{other} fragment type not supported"),
114                };
115                // constraints
116                let mut reg_count = 0;
117                let (regs_decl, out_constraints) =
118                    get_variable_regs_decl_constraints(var, true, &mut reg_count);
119                let buffer_reg = format_reg_and_inc(&mut reg_count);
120                let (stride_reg, stride_constraint) =
121                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
122                let tmp_ptr = Variable::tmp_ptr(value.item());
123                let tmp_ptr_left = tmp_ptr.fmt_left();
124                write!(
125                    f,
126                    r#"// load
127{tmp_ptr_left} = ({value_ty}*){value} + {offset};
128asm volatile(
129    "{opcode} "
130    "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
131    : {out_constraints}
132    : "l"({tmp_ptr}){stride_constraint}
133);
134"#
135                )
136            }
137            WmmaInstruction::LdMatrix {
138                output,
139                buffer,
140                offset,
141                vector_size,
142                factor,
143                transpose,
144            } => f.write_str(&ldmatrix_call(
145                output,
146                buffer,
147                offset,
148                vector_size,
149                factor,
150                transpose,
151            )),
152            WmmaInstruction::StMatrix {
153                registers,
154                buffer,
155                offset,
156                vector_size,
157                factor,
158                transpose,
159            } => f.write_str(&stmatrix_call(
160                registers,
161                buffer,
162                offset,
163                vector_size,
164                factor,
165                transpose,
166            )),
167            WmmaInstruction::Execute {
168                frag_a: var_a,
169                frag_b: var_b,
170                frag_c: var_c,
171                frag_d: var_d,
172                ..
173            } => {
174                let frag_a = match var_a {
175                    Variable::WmmaFragment { frag, .. } => *frag,
176                    _ => panic!("variable should be WmmaFragment"),
177                };
178                let layout_a = get_fragment_layout_qualifier(var_a);
179                let layout_b = get_fragment_layout_qualifier(var_b);
180                let type_c = get_type_qualifier(var_c);
181                let type_d = get_type_qualifier(var_d);
182                let opcode = match var_a.elem() {
183                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
184                        "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
185                        frag_a.m, frag_a.n, frag_a.k,
186                    ),
187                    Elem::BF16 => format!(
188                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
189                        frag_a.m, frag_a.n, frag_a.k,
190                    ),
191                    Elem::TF32 => format!(
192                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
193                        frag_a.m, frag_a.n, frag_a.k,
194                    ),
195                    other => panic!("{other} fragment type not supported"),
196                };
197                let mut reg_count = 0;
198                // order matters, declare the registers in the same order as the intrinsic
199                let (regs_decl_d, out_constraints_d) =
200                    get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
201                let (regs_decl_a, in_constraints_a) =
202                    get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
203                let (regs_decl_b, in_constraints_b) =
204                    get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
205                let (regs_decl_c, in_constraints_c) =
206                    get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
207                write!(
208                    f,
209                    r#"// execute
210asm volatile(
211    "{opcode} "
212    "{{{regs_decl_d}}}, "
213    "{{{regs_decl_a}}}, "
214    "{{{regs_decl_b}}}, "
215    "{{{regs_decl_c}}};\n"
216    : {out_constraints_d}
217    : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
218);
219"#
220                )
221            }
222            WmmaInstruction::ExecuteManual {
223                shape,
224                frag_a,
225                frag_b,
226                frag_c,
227                frag_d,
228            } => {
229                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
230            }
231            WmmaInstruction::ExecuteScaled {
232                shape,
233                frag_a,
234                frag_b,
235                frag_c,
236                frag_d,
237
238                scales_a,
239                scales_b,
240                scales_factor,
241            } => Self::compile_scaled_mma(
242                f,
243                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
244                *scales_a,
245                *scales_b,
246                *scales_factor,
247            ),
248            WmmaInstruction::Store {
249                output,
250                frag: var,
251                stride,
252                offset,
253                layout,
254            } => {
255                let frag_acc = match var {
256                    Variable::WmmaFragment { frag, .. } => *frag,
257                    _ => panic!("variable should be WmmaFragment"),
258                };
259                // instruction qualifiers
260                let layout = match layout {
261                    FragmentLayout::ColMajor => "col",
262                    FragmentLayout::RowMajor => "row",
263                    FragmentLayout::_Dialect(..) => unreachable!(),
264                };
265                let opcode = match var.elem() {
266                    Elem::F16 | Elem::BF16 => format!(
267                        // hack because wmma.store does not support bf16
268                        // f16 should still work correctly for bf16 as long
269                        // as the input registers are in correct format
270                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
271                        frag_acc.m, frag_acc.n, frag_acc.k,
272                    ),
273                    Elem::TF32 | Elem::F32 => format!(
274                        // same hack for tf32
275                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
276                        frag_acc.m, frag_acc.n, frag_acc.k,
277                    ),
278                    Elem::I32 => format!(
279                        // same hack for tf32
280                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
281                        frag_acc.m, frag_acc.n, frag_acc.k,
282                    ),
283                    other => panic!("{other} fragment type not supported"),
284                };
285                // constraints
286                let mut reg_count = 0;
287                let buffer_reg = format_reg_and_inc(&mut reg_count);
288                // offset and stride can be passed as local const or as const scalar
289                // we need to handle both cases correctly in the asm.
290                let (stride_reg, stride_constraint) =
291                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
292                // we start at 2 because of the buffer address calculation
293                let (regs_decl, in_constraints) =
294                    get_variable_regs_decl_constraints(var, false, &mut reg_count);
295                let tmp_ptr = Variable::tmp_ptr(output.item());
296                let tmp_ptr_left = tmp_ptr.fmt_left();
297                write!(
298                    f,
299                    r#"// store
300{tmp_ptr_left} = {output} + {offset};
301asm volatile(
302    "{opcode} "
303    "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
304    :
305    : "l"({tmp_ptr}),
306      {in_constraints}{stride_constraint}
307);
308"#
309                )
310            }
311            WmmaInstruction::Cast { input, output } => {
312                let frag = match input {
313                    Variable::WmmaFragment { frag, .. } => *frag,
314                    _ => panic!("variable should be WmmaFragment"),
315                };
316                let reg_count = get_fragment_register_total_count(&frag);
317                match output.elem() {
318                    Elem::F16 => {
319                        write!(
320                            f,
321                            "// cast
322for (int i = 0; i < {reg_count}; ++i) {{
323    __half h_lo = __float2half_rn({input}[2*i + 0]);
324    __half h_hi = __float2half_rn({input}[2*i + 1]);
325    __half2 h2 = __halves2half2(h_lo, h_hi);
326    {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
327}}
328"
329                        )
330                    }
331                    Elem::BF16 => {
332                        write!(
333                            f,
334                            "// cast
335for (int i = 0; i < {reg_count}; ++i) {{
336    __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
337    __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
338    __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
339    {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
340}}
341"
342                        )
343                    }
344                    other => panic!("casting fragment to {other} not supported"),
345                }
346            }
347        }
348    }
349
350    fn compile_manual_mma(
351        f: &mut std::fmt::Formatter<'_>,
352        mma: ManualMma<CudaDialect<Self>>,
353    ) -> std::fmt::Result {
354        compile_manual_mma(f, mma)
355    }
356
357    fn compile_scaled_mma(
358        f: &mut std::fmt::Formatter<'_>,
359        mma: ManualMma<CudaDialect<Self>>,
360        scales_a: Variable<CudaDialect<Self>>,
361        scales_b: Variable<CudaDialect<Self>>,
362        scales_factor: u32,
363    ) -> std::fmt::Result {
364        compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
365    }
366
367    fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
368        let mut result: SupportedMmaCombinations = vec![];
369        if arch.get_version() >= WMMA_MINIMUM_VERSION {
370            // Types fully supported.
371            let types = vec![
372                (
373                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
374                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
375                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
376                ),
377                (
378                    gpu::ElemType::Float(gpu::FloatKind::F16),
379                    gpu::ElemType::Float(gpu::FloatKind::F16),
380                    gpu::ElemType::Float(gpu::FloatKind::F32),
381                ),
382                (
383                    gpu::ElemType::Float(gpu::FloatKind::BF16),
384                    gpu::ElemType::Float(gpu::FloatKind::BF16),
385                    gpu::ElemType::Float(gpu::FloatKind::F32),
386                ),
387            ];
388            let combinations: SupportedMmaCombinations = types
389                .into_iter()
390                .map(|(a, b, cd)| MmaConfig {
391                    a_type: a.into(),
392                    b_type: b.into(),
393                    cd_type: cd.into(),
394                    m: 16,
395                    n: 16,
396                    k: 16,
397                })
398                .collect();
399            result.extend(combinations);
400            if arch.get_version() >= 72 {
401                result.extend([
402                    MmaConfig {
403                        a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
404                        b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
405                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
406                        m: 16,
407                        n: 16,
408                        k: 16,
409                    },
410                    MmaConfig {
411                        a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
412                        b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
413                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
414                        m: 16,
415                        n: 16,
416                        k: 16,
417                    },
418                ]);
419            }
420            if arch.get_version() >= 80 {
421                result.push(MmaConfig {
422                    a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
423                    b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
424                    cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
425                    m: 16,
426                    n: 16,
427                    k: 8,
428                });
429            }
430        }
431        result
432    }
433
434    fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
435        supported_mma_combinations(arch)
436    }
437
438    fn supported_scaled_mma_combinations(
439        arch: &CudaArchitecture,
440    ) -> SupportedScaledMmaCombinations {
441        supported_scaled_mma_combinations(arch)
442    }
443}
444
445fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
446    let Fragment {
447        ident,
448        m,
449        n,
450        k,
451        elem,
452        ..
453    } = frag;
454    let elements = match ident {
455        FragmentIdent::A => m * k,
456        FragmentIdent::B => k * n,
457        FragmentIdent::Accumulator => m * n,
458        _ => unreachable!(),
459    };
460    let bits_per_elem = elem.size_bits() as u32;
461    // TODO: retrieve the warp size from the compiler CompilationOptions
462    let lanes_per_reg = 32 / bits_per_elem;
463    // choose threads-per-frag:
464    // - accumulators always use 32 lanes
465    // - A/B use 16 lanes _except_ TF32 (k=8) which also uses 32 lanes
466    let threads_per_frag = match ident {
467        FragmentIdent::Accumulator => 32,
468        FragmentIdent::A | FragmentIdent::B => {
469            if frag.elem == Elem::TF32 {
470                32
471            } else {
472                16
473            }
474        }
475        _ => unreachable!(),
476    };
477
478    elements / (lanes_per_reg * threads_per_frag)
479}
480
481fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
482    match var.elem() {
483        Elem::U8 => "u8",
484        Elem::I8 => "s8",
485        Elem::F16 => "f16",
486        Elem::BF16 => "bf16",
487        Elem::F32 => "f32",
488        Elem::TF32 => "tf32",
489        Elem::I32 => "s32",
490        Elem::F64 => "f64",
491        _ => panic!("unsupported WMMA fragment type"),
492    }
493    .to_string()
494}
495
496fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
497    let frag = match var {
498        Variable::WmmaFragment { frag, .. } => *frag,
499        _ => panic!("variable should be WmmaFragment"),
500    };
501    match frag.layout {
502        Some(layout) => get_qualifier_from_layout(&layout),
503        None => "".to_string(),
504    }
505}
506
507fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
508    match layout {
509        FragmentLayout::ColMajor => "col",
510        FragmentLayout::RowMajor => "row",
511        FragmentLayout::_Dialect(..) => unreachable!(),
512    }
513    .to_string()
514}
515
516fn get_variable_regs_decl_constraints(
517    var: &Variable<CudaDialect<PtxWmmaCompiler>>,
518    output: bool,
519    reg_count: &mut u8,
520) -> (String, String) {
521    match var {
522        Variable::WmmaFragment { frag, .. } => {
523            let reg_total_count = get_fragment_register_total_count(frag);
524            let reg_decl = (0..reg_total_count)
525                .map(|_| format_reg_and_inc(reg_count))
526                .collect::<Vec<_>>()
527                .join(",");
528            let frag_elem = frag.elem;
529            let modifier = format!(
530                "{}{}",
531                if output { "=" } else { "" },
532                match frag_elem {
533                    Elem::F32 => "f",
534                    Elem::F64 => "d",
535                    _ => "r",
536                },
537            );
538            let constraints = (0..reg_total_count)
539                .map(|i| format!("\"{modifier}\"({var}[{i}])"))
540                .collect::<Vec<_>>()
541                .join(", ");
542            (reg_decl, constraints)
543        }
544        Variable::Constant(number, ..) => match number {
545            ConstantValue::UInt(val, ..) => (val.to_string(), "".to_string()),
546            _ => panic!("variable should be an unsigned integer"),
547        },
548        _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
549    }
550}
551
552fn format_reg_and_inc(count: &mut u8) -> String {
553    let res = format!("%{count}");
554    *count += 1;
555    res
556}
557
558fn as_ty_idx(var: impl Display, idx: impl Display, ty: impl Display) -> String {
559    format!("reinterpret_cast<{ty}*>({var})[{idx}]")
560}
561
562fn as_const_ty_idx(var: impl Display, idx: impl Display, ty: impl Display) -> String {
563    format!("reinterpret_cast<const {ty}*>({var})[{idx}]")
564}
565
566pub(super) fn compile_manual_mma<D: Dialect>(
567    f: &mut core::fmt::Formatter<'_>,
568    mma: ManualMma<D>,
569) -> std::fmt::Result {
570    let ManualMma {
571        shape,
572        frag_a,
573        frag_b,
574        frag_c,
575        frag_d,
576    } = mma;
577
578    let a_elem = frag_a.elem().unpacked();
579    let b_elem = frag_b.elem().unpacked();
580    let cd_elem = frag_c.elem().unpacked();
581
582    let ab_ty = match a_elem {
583        Elem::F32 => &format!("{}", Elem::<D>::F32),
584        _ => &format!("{}", Elem::<D>::U32),
585    };
586    let cd_ty = match cd_elem {
587        Elem::F32 => &format!("{}", Elem::<D>::F32),
588        _ => &format!("{}", Elem::<D>::U32),
589    };
590
591    let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
592    let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
593    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
594
595    let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
596    let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
597    let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
598
599    let frag_a = (0..a_regs).map(|i| as_const_ty_idx(frag_a, i, ab_ty));
600    let frag_b = (0..b_regs).map(|i| as_const_ty_idx(frag_b, i, ab_ty));
601    let frag_c = (0..cd_regs).map(|i| as_const_ty_idx(frag_c, i, cd_ty));
602    let frag_d = (0..cd_regs).map(|i| as_ty_idx(frag_d, i, cd_ty));
603
604    let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
605    write!(
606        f,
607        "__mma_m16n8k{}_{}_{}_{}({args});",
608        shape.k, a_elem, b_elem, cd_elem
609    )
610}
611
612pub(super) fn compile_scaled_mma<D: Dialect>(
613    f: &mut core::fmt::Formatter<'_>,
614    mma: ManualMma<D>,
615    scales_a: Variable<D>,
616    scales_b: Variable<D>,
617    scales_factor: u32,
618) -> std::fmt::Result {
619    let ManualMma {
620        shape,
621        frag_a,
622        frag_b,
623        frag_c,
624        frag_d,
625    } = mma;
626
627    let a_elem = frag_a.elem().unpacked();
628    let b_elem = frag_b.elem().unpacked();
629    let cd_elem = frag_c.elem().unpacked();
630
631    let ab_ty = &format!("{}", Elem::<D>::U32);
632    let cd_ty = &format!("{}", Elem::<D>::F32);
633
634    let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
635    let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
636    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
637
638    let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
639    let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
640    let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
641
642    let frag_a = (0..a_regs).map(|i| as_const_ty_idx(frag_a, i, ab_ty));
643    let frag_b = (0..b_regs).map(|i| as_const_ty_idx(frag_b, i, ab_ty));
644    let frag_c = (0..cd_regs).map(|i| as_const_ty_idx(frag_c, i, cd_ty));
645    let frag_d = (0..cd_regs).map(|i| as_ty_idx(frag_d, i, cd_ty));
646
647    let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
648    write!(
649        f,
650        "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
651        shape.k, a_elem, b_elem, cd_elem
652    )
653}
654
655pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
656    let mut result: SupportedMmaCombinations = vec![];
657    // Higher than WMMA because we only support the newest shapes. Other shapes would make things
658    // very complicated.
659    // Also only use f32 accumulators for now
660    if arch.get_version() >= 80 {
661        result.extend([
662            MmaConfig {
663                a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // a
664                b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // b
665                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), // cd
666                m: 16,
667                n: 8,
668                k: 16,
669            },
670            MmaConfig {
671                a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
672                b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
673                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
674                m: 16,
675                n: 8,
676                k: 16,
677            },
678            MmaConfig {
679                a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
680                b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
681                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
682                m: 16,
683                n: 8,
684                k: 8,
685            },
686            MmaConfig {
687                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
688                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
689                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
690                m: 16,
691                n: 8,
692                k: 32,
693            },
694            MmaConfig {
695                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
696                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
697                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
698                m: 16,
699                n: 8,
700                k: 32,
701            },
702            MmaConfig {
703                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
704                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
705                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
706                m: 16,
707                n: 8,
708                k: 32,
709            },
710            MmaConfig {
711                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
712                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
713                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
714                m: 16,
715                n: 8,
716                k: 32,
717            },
718            // TODO: u4/i4/b1, there's no types for them yet
719        ]);
720    }
721    if arch.get_version() >= 89 {
722        let f8f6f4_types = [
723            gpu::FloatKind::E4M3,
724            gpu::FloatKind::E5M2,
725            gpu::FloatKind::E3M2,
726            gpu::FloatKind::E2M3,
727            gpu::FloatKind::E2M1,
728        ];
729        let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
730        result.extend(combinations.map(|(t1, t2)| MmaConfig {
731            a_type: gpu::ElemType::Float(*t1).into(),
732            b_type: gpu::ElemType::Float(*t2).into(),
733            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
734            m: 16,
735            n: 8,
736            k: 32,
737        }));
738    }
739    // Warning: this likely does not follow the same layout pattern as those after 80
740    if arch.get_version() >= 70 && arch.get_version() < 80 {
741        result.push(MmaConfig {
742            a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(),
743            b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(),
744            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
745            m: 16,
746            n: 8,
747            k: 8,
748        });
749    }
750    result
751}
752
753pub(super) fn supported_scaled_mma_combinations(
754    arch: &CudaArchitecture,
755) -> SupportedScaledMmaCombinations {
756    let mut result: SupportedScaledMmaCombinations = vec![];
757    // sm_120f
758    if arch.get_version() >= 120 && arch.get_version() < 130 {
759        let f8f6f4_types = [
760            gpu::FloatKind::E4M3,
761            gpu::FloatKind::E5M2,
762            gpu::FloatKind::E3M2,
763            gpu::FloatKind::E2M3,
764            gpu::FloatKind::E2M1,
765        ];
766        let combinations = f8f6f4_types
767            .iter()
768            .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
769
770        result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
771            a_type: gpu::ElemType::Float(*t1).into(),
772            b_type: gpu::ElemType::Float(*t2).into(),
773            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
774            scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
775            m: 16,
776            n: 8,
777            k: 32,
778            scales_factor: 1,
779        }));
780
781        result.extend([
782            ScaledMmaConfig {
783                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
784                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
785                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
786                scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
787                m: 16,
788                n: 8,
789                k: 64,
790                scales_factor: 2,
791            },
792            // Sign of scales is ignored
793            ScaledMmaConfig {
794                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
795                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
796                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
797                scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
798                m: 16,
799                n: 8,
800                k: 64,
801                scales_factor: 4,
802            },
803        ]);
804    }
805    result
806}
807
808pub fn contiguous_elements_cuda(ident: MatrixIdent, matrix: Matrix) -> usize {
809    match ident {
810        MatrixIdent::A | MatrixIdent::B => 32 / matrix.storage.size_bits(),
811        MatrixIdent::Accumulator => 2,
812    }
813}