cubecl_cpp/cuda/mma/
ptx_wmma_compiler.rs

1use std::fmt::Display;
2
3use crate::{
4    Dialect,
5    cuda::{CudaDialect, arch::CudaArchitecture, ptx::comma_separated},
6    shared::{
7        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8        FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
9        SupportedScaledMmaCombinations, Variable, WmmaInstruction,
10    },
11};
12use cubecl_core::ir::{self as gpu, ConstantScalarValue};
13use cubecl_runtime::{MmaConfig, ScaledMmaConfig};
14use itertools::Itertools;
15
16use super::WMMA_MINIMUM_VERSION;
17
18#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
19pub struct PtxWmmaCompiler {}
20
21impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
22    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
23        // We need mma header for conversion
24        if flags.elem_tf32 {
25            f.write_str("#include <mma.h>\n")?;
26        }
27        Ok(())
28    }
29
30    fn compile_wmma_fragment_declaration(
31        f: &mut std::fmt::Formatter<'_>,
32        var: &Variable<CudaDialect<Self>>,
33    ) -> std::fmt::Result {
34        let frag = match var {
35            Variable::WmmaFragment { frag, .. } => *frag,
36            _ => panic!("load instruction expects a WmmaFragment"),
37        };
38        let reg_count = get_fragment_register_total_count(&frag);
39        let ty = match frag.elem {
40            Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
41            Elem::F32 => "float",
42            Elem::F64 => "double",
43            _ => panic!("unsupported type"),
44        };
45        writeln!(f, "{ty} {var}[{reg_count}];")
46    }
47
48    fn compile_wmma_instruction(
49        f: &mut std::fmt::Formatter<'_>,
50        instruction: &WmmaInstruction<CudaDialect<Self>>,
51    ) -> std::fmt::Result {
52        match instruction {
53            WmmaInstruction::Fill { frag: var, value } => {
54                let frag = match var {
55                    Variable::WmmaFragment { frag, .. } => *frag,
56                    _ => panic!("variable should be WmmaFragment"),
57                };
58                let reg_count = get_fragment_register_total_count(&frag);
59                write!(
60                    f,
61                    "// fill
62for (uint i = 0; i < uint({reg_count}); ++i) {{
63  {var}[i] = {value};
64}}
65 "
66                )
67            }
68            WmmaInstruction::Load {
69                frag: var,
70                value,
71                offset,
72                stride,
73                layout,
74            } => {
75                let frag = match var {
76                    Variable::WmmaFragment { frag, .. } => *frag,
77                    _ => panic!("load instruction expects a WmmaFragment"),
78                };
79                // Important note: the current frontend has been designed around
80                // CUDA wmma which is not optimal in the case of PTX wmma and mma
81                // We choose here to use the layout defined in the fragment first,
82                // if it is unknown and we look into the layout passed to the instruction.
83                let layout = if frag.layout.is_some() {
84                    get_fragment_layout_qualifier(var)
85                } else if let Some(layout) = layout {
86                    get_qualifier_from_layout(layout)
87                } else {
88                    panic!("unknown matrix layout for wmma load instruction");
89                };
90                // instruction qualifiers
91                let ty = get_type_qualifier(value);
92                let matrix = match frag.ident {
93                    FragmentIdent::A => "a",
94                    FragmentIdent::B => "b",
95                    FragmentIdent::Accumulator => "c",
96                    FragmentIdent::_Dialect(_) => unreachable!(),
97                };
98                let value_ty = value.item();
99                let opcode = match frag.elem {
100                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
101                        format!(
102                            "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
103                            frag.m, frag.n, frag.k,
104                        )
105                    }
106                    other => panic!("{other} fragment type not supported"),
107                };
108                // constraints
109                let mut reg_count = 0;
110                let (regs_decl, out_constraints) =
111                    get_variable_regs_decl_constraints(var, true, &mut reg_count);
112                let buffer_reg = format_reg_and_inc(&mut reg_count);
113                let (stride_reg, stride_constraint) =
114                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
115                let tmp_ptr = Variable::tmp_ptr(value.item());
116                let tmp_ptr_left = tmp_ptr.fmt_left();
117                write!(
118                    f,
119                    r#"// load
120{tmp_ptr_left} = ({value_ty}*){value} + {offset};
121asm volatile(
122    "{opcode} "
123    "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
124    : {out_constraints}
125    : "l"({tmp_ptr}){stride_constraint}
126);
127"#
128                )
129            }
130            WmmaInstruction::Execute {
131                frag_a: var_a,
132                frag_b: var_b,
133                frag_c: var_c,
134                frag_d: var_d,
135                ..
136            } => {
137                let frag_a = match var_a {
138                    Variable::WmmaFragment { frag, .. } => *frag,
139                    _ => panic!("variable should be WmmaFragment"),
140                };
141                let layout_a = get_fragment_layout_qualifier(var_a);
142                let layout_b = get_fragment_layout_qualifier(var_b);
143                let type_c = get_type_qualifier(var_c);
144                let type_d = get_type_qualifier(var_d);
145                let opcode = match var_a.elem() {
146                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
147                        "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
148                        frag_a.m, frag_a.n, frag_a.k,
149                    ),
150                    Elem::BF16 => format!(
151                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
152                        frag_a.m, frag_a.n, frag_a.k,
153                    ),
154                    Elem::TF32 => format!(
155                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
156                        frag_a.m, frag_a.n, frag_a.k,
157                    ),
158                    other => panic!("{other} fragment type not supported"),
159                };
160                let mut reg_count = 0;
161                // order matters, declare the registers in the same order as the intrinsic
162                let (regs_decl_d, out_constraints_d) =
163                    get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
164                let (regs_decl_a, in_constraints_a) =
165                    get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
166                let (regs_decl_b, in_constraints_b) =
167                    get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
168                let (regs_decl_c, in_constraints_c) =
169                    get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
170                write!(
171                    f,
172                    r#"// execute
173asm volatile(
174    "{opcode} "
175    "{{{regs_decl_d}}}, "
176    "{{{regs_decl_a}}}, "
177    "{{{regs_decl_b}}}, "
178    "{{{regs_decl_c}}};\n"
179    : {out_constraints_d}
180    : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
181);
182"#
183                )
184            }
185            WmmaInstruction::ExecuteManual {
186                shape,
187                frag_a,
188                frag_b,
189                frag_c,
190                frag_d,
191            } => {
192                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
193            }
194            WmmaInstruction::ExecuteScaled {
195                shape,
196                frag_a,
197                frag_b,
198                frag_c,
199                frag_d,
200
201                scales_a,
202                scales_b,
203                scales_factor,
204            } => Self::compile_scaled_mma(
205                f,
206                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
207                *scales_a,
208                *scales_b,
209                *scales_factor,
210            ),
211            WmmaInstruction::Store {
212                output,
213                frag: var,
214                stride,
215                offset,
216                layout,
217            } => {
218                let frag_acc = match var {
219                    Variable::WmmaFragment { frag, .. } => *frag,
220                    _ => panic!("variable should be WmmaFragment"),
221                };
222                // instruction qualifiers
223                let layout = match layout {
224                    FragmentLayout::ColMajor => "col",
225                    FragmentLayout::RowMajor => "row",
226                    FragmentLayout::_Dialect(..) => unreachable!(),
227                };
228                let opcode = match var.elem() {
229                    Elem::F16 | Elem::BF16 => format!(
230                        // hack because wmma.store does not support bf16
231                        // f16 should still work correctly for bf16 as long
232                        // as the input registers are in correct format
233                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
234                        frag_acc.m, frag_acc.n, frag_acc.k,
235                    ),
236                    Elem::TF32 | Elem::F32 => format!(
237                        // same hack for tf32
238                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
239                        frag_acc.m, frag_acc.n, frag_acc.k,
240                    ),
241                    Elem::I32 => format!(
242                        // same hack for tf32
243                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
244                        frag_acc.m, frag_acc.n, frag_acc.k,
245                    ),
246                    other => panic!("{other} fragment type not supported"),
247                };
248                // constraints
249                let mut reg_count = 0;
250                let buffer_reg = format_reg_and_inc(&mut reg_count);
251                // offset and stride can be passed as local const or as const scalar
252                // we need to handle both cases correctly in the asm.
253                let (stride_reg, stride_constraint) =
254                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
255                // we start at 2 because of the buffer address calculation
256                let (regs_decl, in_constraints) =
257                    get_variable_regs_decl_constraints(var, false, &mut reg_count);
258                let tmp_ptr = Variable::tmp_ptr(output.item());
259                let tmp_ptr_left = tmp_ptr.fmt_left();
260                write!(
261                    f,
262                    r#"// store
263{tmp_ptr_left} = {output} + {offset};
264asm volatile(
265    "{opcode} "
266    "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
267    :
268    : "l"({tmp_ptr}),
269      {in_constraints}{stride_constraint}
270);
271"#
272                )
273            }
274            WmmaInstruction::Cast { input, output } => {
275                let frag = match input {
276                    Variable::WmmaFragment { frag, .. } => *frag,
277                    _ => panic!("variable should be WmmaFragment"),
278                };
279                let reg_count = get_fragment_register_total_count(&frag);
280                match output.elem() {
281                    Elem::F16 => {
282                        write!(
283                            f,
284                            "// cast
285for (int i = 0; i < {reg_count}; ++i) {{
286    __half h_lo = __float2half_rn({input}[2*i + 0]);
287    __half h_hi = __float2half_rn({input}[2*i + 1]);
288    __half2 h2 = __halves2half2(h_lo, h_hi);
289    {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
290}}
291"
292                        )
293                    }
294                    Elem::BF16 => {
295                        write!(
296                            f,
297                            "// cast
298for (int i = 0; i < {reg_count}; ++i) {{
299    __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
300    __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
301    __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
302    {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
303}}
304"
305                        )
306                    }
307                    other => panic!("casting fragment to {other} not supported"),
308                }
309            }
310        }
311    }
312
313    fn compile_manual_mma(
314        f: &mut std::fmt::Formatter<'_>,
315        mma: ManualMma<CudaDialect<Self>>,
316    ) -> std::fmt::Result {
317        compile_manual_mma(f, mma)
318    }
319
320    fn compile_scaled_mma(
321        f: &mut std::fmt::Formatter<'_>,
322        mma: ManualMma<CudaDialect<Self>>,
323        scales_a: Variable<CudaDialect<Self>>,
324        scales_b: Variable<CudaDialect<Self>>,
325        scales_factor: u32,
326    ) -> std::fmt::Result {
327        compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
328    }
329
330    fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
331        let mut result: SupportedMmaCombinations = vec![];
332        if arch.get_version() >= WMMA_MINIMUM_VERSION {
333            // Types fully supported.
334            let types = vec![
335                (
336                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
337                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
338                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
339                ),
340                (
341                    gpu::ElemType::Float(gpu::FloatKind::F16),
342                    gpu::ElemType::Float(gpu::FloatKind::F16),
343                    gpu::ElemType::Float(gpu::FloatKind::F32),
344                ),
345                (
346                    gpu::ElemType::Float(gpu::FloatKind::BF16),
347                    gpu::ElemType::Float(gpu::FloatKind::BF16),
348                    gpu::ElemType::Float(gpu::FloatKind::F32),
349                ),
350            ];
351            let combinations: SupportedMmaCombinations = types
352                .into_iter()
353                .map(|(a, b, cd)| MmaConfig {
354                    a_type: a.into(),
355                    b_type: b.into(),
356                    cd_type: cd.into(),
357                    m: 16,
358                    n: 16,
359                    k: 16,
360                })
361                .collect();
362            result.extend(combinations);
363            if arch.get_version() >= 72 {
364                result.extend([
365                    MmaConfig {
366                        a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
367                        b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
368                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
369                        m: 16,
370                        n: 16,
371                        k: 16,
372                    },
373                    MmaConfig {
374                        a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
375                        b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
376                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
377                        m: 16,
378                        n: 16,
379                        k: 16,
380                    },
381                ]);
382            }
383            if arch.get_version() >= 80 {
384                result.push(MmaConfig {
385                    a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
386                    b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
387                    cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
388                    m: 16,
389                    n: 16,
390                    k: 8,
391                });
392            }
393        }
394        result
395    }
396
397    fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
398        supported_mma_combinations(arch)
399    }
400
401    fn supported_scaled_mma_combinations(
402        arch: &CudaArchitecture,
403    ) -> SupportedScaledMmaCombinations {
404        supported_scaled_mma_combinations(arch)
405    }
406}
407
408fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
409    let Fragment {
410        ident,
411        m,
412        n,
413        k,
414        elem,
415        ..
416    } = frag;
417    let elements = match ident {
418        FragmentIdent::A => m * k,
419        FragmentIdent::B => k * n,
420        FragmentIdent::Accumulator => m * n,
421        _ => unreachable!(),
422    };
423    let bits_per_elem = elem.size_bits() as u32;
424    // TODO: retrieve the warp size from the compiler CompilationOptions
425    let lanes_per_reg = 32 / bits_per_elem;
426    // choose threads-per-frag:
427    // - accumulators always use 32 lanes
428    // - A/B use 16 lanes _except_ TF32 (k=8) which also uses 32 lanes
429    let threads_per_frag = match ident {
430        FragmentIdent::Accumulator => 32,
431        FragmentIdent::A | FragmentIdent::B => {
432            if frag.elem == Elem::TF32 {
433                32
434            } else {
435                16
436            }
437        }
438        _ => unreachable!(),
439    };
440
441    elements / (lanes_per_reg * threads_per_frag)
442}
443
444fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
445    match var.elem() {
446        Elem::U8 => "u8",
447        Elem::I8 => "s8",
448        Elem::F16 => "f16",
449        Elem::BF16 => "bf16",
450        Elem::F32 => "f32",
451        Elem::TF32 => "tf32",
452        Elem::I32 => "s32",
453        Elem::F64 => "f64",
454        _ => panic!("unsupported WMMA fragment type"),
455    }
456    .to_string()
457}
458
459fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
460    let frag = match var {
461        Variable::WmmaFragment { frag, .. } => *frag,
462        _ => panic!("variable should be WmmaFragment"),
463    };
464    match frag.layout {
465        Some(layout) => get_qualifier_from_layout(&layout),
466        None => "".to_string(),
467    }
468}
469
470fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
471    match layout {
472        FragmentLayout::ColMajor => "col",
473        FragmentLayout::RowMajor => "row",
474        FragmentLayout::_Dialect(..) => unreachable!(),
475    }
476    .to_string()
477}
478
479fn get_variable_regs_decl_constraints(
480    var: &Variable<CudaDialect<PtxWmmaCompiler>>,
481    output: bool,
482    reg_count: &mut u8,
483) -> (String, String) {
484    match var {
485        Variable::WmmaFragment { frag, .. } => {
486            let reg_total_count = get_fragment_register_total_count(frag);
487            let reg_decl = (0..reg_total_count)
488                .map(|_| format_reg_and_inc(reg_count))
489                .collect::<Vec<_>>()
490                .join(",");
491            let frag_elem = frag.elem;
492            let modifier = format!(
493                "{}{}",
494                if output { "=" } else { "" },
495                match frag_elem {
496                    Elem::F32 => "f",
497                    Elem::F64 => "d",
498                    _ => "r",
499                },
500            );
501            let constraints = (0..reg_total_count)
502                .map(|i| format!("\"{modifier}\"({var}[{i}])"))
503                .collect::<Vec<_>>()
504                .join(", ");
505            (reg_decl, constraints)
506        }
507        Variable::ConstantScalar(number, ..) => match number {
508            ConstantScalarValue::UInt(val, ..) => (val.to_string(), "".to_string()),
509            _ => panic!("variable should be an unsigned integer"),
510        },
511        _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
512    }
513}
514
515fn format_reg_and_inc(count: &mut u8) -> String {
516    let res = format!("%{count}");
517    *count += 1;
518    res
519}
520
521fn as_ty(var: impl Display, ty: impl Display) -> String {
522    format!("reinterpret_cast<{ty}&>({var})")
523}
524
525fn as_const_ty(var: impl Display, ty: impl Display) -> String {
526    format!("reinterpret_cast<const {ty}&>({var})")
527}
528
529pub(super) fn compile_manual_mma<D: Dialect>(
530    f: &mut core::fmt::Formatter<'_>,
531    mma: ManualMma<D>,
532) -> std::fmt::Result {
533    let ManualMma {
534        shape,
535        frag_a,
536        frag_b,
537        frag_c,
538        frag_d,
539    } = mma;
540
541    let a_elem = frag_a[0].elem().unpacked();
542    let b_elem = frag_b[0].elem().unpacked();
543    let cd_elem = frag_c[0].elem().unpacked();
544
545    let ab_ty = match a_elem {
546        Elem::F32 => &format!("{}", Elem::<D>::F32),
547        _ => &format!("{}", Elem::<D>::U32),
548    };
549    let cd_ty = match cd_elem {
550        Elem::F32 => &format!("{}", Elem::<D>::F32),
551        _ => &format!("{}", Elem::<D>::U32),
552    };
553
554    let acc_elems = frag_c.len();
555    let frag_ab = frag_a.iter().chain(frag_b).map(|it| as_const_ty(it, ab_ty));
556    let frag_c = frag_c.iter().map(|it| as_const_ty(it, cd_ty));
557    let frag_d = (0..acc_elems).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
558    let args = comma_separated(frag_ab.chain(frag_c).chain(frag_d));
559    write!(
560        f,
561        "__mma_m16n8k{}_{}_{}_{}({args});",
562        shape.k, a_elem, b_elem, cd_elem
563    )
564}
565
566pub(super) fn compile_scaled_mma<D: Dialect>(
567    f: &mut core::fmt::Formatter<'_>,
568    mma: ManualMma<D>,
569    scales_a: Variable<D>,
570    scales_b: Variable<D>,
571    scales_factor: u32,
572) -> std::fmt::Result {
573    let ManualMma {
574        shape,
575        frag_a,
576        frag_b,
577        frag_c,
578        frag_d,
579    } = mma;
580
581    let a_elem = frag_a[0].elem().unpacked();
582    let b_elem = frag_b[0].elem().unpacked();
583    let cd_elem = frag_c[0].elem().unpacked();
584    let ab_ty = &format!("{}", Elem::<D>::U32);
585    let cd_ty = &format!("{}", Elem::<D>::F32);
586    let acc_elems = frag_c.len();
587    let frag_ab = frag_a.iter().chain(frag_b).map(|it| as_const_ty(it, ab_ty));
588    let frag_c = frag_c.iter().map(|it| as_const_ty(it, cd_ty));
589    let frag_d = (0..acc_elems).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
590    let fragments = comma_separated(frag_ab.chain(frag_c).chain(frag_d));
591    write!(
592        f,
593        "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
594        shape.k, a_elem, b_elem, cd_elem
595    )
596}
597
598pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
599    let mut result: SupportedMmaCombinations = vec![];
600    // Higher than WMMA because we only support the newest shapes. Other shapes would make things
601    // very complicated.
602    // Also only use f32 accumulators for now
603    if arch.get_version() >= 80 {
604        result.extend([
605            MmaConfig {
606                a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // a
607                b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // b
608                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), // cd
609                m: 16,
610                n: 8,
611                k: 16,
612            },
613            MmaConfig {
614                a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
615                b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
616                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
617                m: 16,
618                n: 8,
619                k: 16,
620            },
621            MmaConfig {
622                a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
623                b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
624                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
625                m: 16,
626                n: 8,
627                k: 8,
628            },
629            MmaConfig {
630                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
631                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
632                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
633                m: 16,
634                n: 8,
635                k: 32,
636            },
637            MmaConfig {
638                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
639                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
640                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
641                m: 16,
642                n: 8,
643                k: 32,
644            },
645            MmaConfig {
646                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
647                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
648                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
649                m: 16,
650                n: 8,
651                k: 32,
652            },
653            MmaConfig {
654                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
655                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
656                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
657                m: 16,
658                n: 8,
659                k: 32,
660            },
661            // TODO: u4/i4/b1, there's no types for them yet
662        ]);
663    }
664    if arch.get_version() >= 89 {
665        let f8f6f4_types = [
666            gpu::FloatKind::E4M3,
667            gpu::FloatKind::E5M2,
668            gpu::FloatKind::E3M2,
669            gpu::FloatKind::E2M3,
670            gpu::FloatKind::E2M1,
671        ];
672        let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
673        result.extend(combinations.map(|(t1, t2)| MmaConfig {
674            a_type: gpu::ElemType::Float(*t1).into(),
675            b_type: gpu::ElemType::Float(*t2).into(),
676            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
677            m: 16,
678            n: 8,
679            k: 32,
680        }));
681    }
682    result
683}
684
685pub(super) fn supported_scaled_mma_combinations(
686    arch: &CudaArchitecture,
687) -> SupportedScaledMmaCombinations {
688    let mut result: SupportedScaledMmaCombinations = vec![];
689    // sm_120f
690    if arch.get_version() >= 120 && arch.get_version() < 130 {
691        let f8f6f4_types = [
692            gpu::FloatKind::E4M3,
693            gpu::FloatKind::E5M2,
694            gpu::FloatKind::E3M2,
695            gpu::FloatKind::E2M3,
696            gpu::FloatKind::E2M1,
697        ];
698        let combinations = f8f6f4_types
699            .iter()
700            .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
701
702        result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
703            a_type: gpu::ElemType::Float(*t1).into(),
704            b_type: gpu::ElemType::Float(*t2).into(),
705            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
706            scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
707            m: 16,
708            n: 8,
709            k: 32,
710            scales_factor: 1,
711        }));
712
713        result.extend([
714            ScaledMmaConfig {
715                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
716                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
717                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
718                scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
719                m: 16,
720                n: 8,
721                k: 64,
722                scales_factor: 2,
723            },
724            // Sign of scales is ignored
725            ScaledMmaConfig {
726                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
727                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
728                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
729                scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
730                m: 16,
731                n: 8,
732                k: 64,
733                scales_factor: 4,
734            },
735        ]);
736    }
737    result
738}