cubecl_cpp/cuda/mma/
ptx_wmma_compiler.rs

1use std::fmt::Display;
2
3use crate::{
4    Dialect,
5    cuda::{
6        CudaDialect,
7        arch::CudaArchitecture,
8        ptx::{comma_separated, ldmatrix_call, stmatrix_call},
9    },
10    shared::{
11        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
12        FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
13        SupportedScaledMmaCombinations, Variable, WmmaInstruction,
14    },
15};
16use cubecl_core::ir::{self as gpu, ConstantScalarValue, Matrix, MatrixIdent};
17use cubecl_runtime::{MmaConfig, ScaledMmaConfig};
18use itertools::Itertools;
19
20use super::WMMA_MINIMUM_VERSION;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
27        // We need mma header for conversion
28        if flags.elem_tf32 {
29            f.write_str("#include <mma.h>\n")?;
30        }
31        Ok(())
32    }
33
34    fn compile_wmma_fragment_declaration(
35        f: &mut std::fmt::Formatter<'_>,
36        var: &Variable<CudaDialect<Self>>,
37    ) -> std::fmt::Result {
38        let frag = match var {
39            Variable::WmmaFragment { frag, .. } => *frag,
40            _ => panic!("load instruction expects a WmmaFragment"),
41        };
42        let reg_count = get_fragment_register_total_count(&frag);
43        let ty = match frag.elem {
44            Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
45            Elem::F32 => "float",
46            Elem::F64 => "double",
47            _ => panic!("unsupported type"),
48        };
49        writeln!(f, "{ty} {var}[{reg_count}];")
50    }
51
52    fn compile_wmma_instruction(
53        f: &mut std::fmt::Formatter<'_>,
54        instruction: &WmmaInstruction<CudaDialect<Self>>,
55    ) -> std::fmt::Result {
56        match instruction {
57            WmmaInstruction::Fill { frag: var, value } => {
58                let frag = match var {
59                    Variable::WmmaFragment { frag, .. } => *frag,
60                    _ => panic!("variable should be WmmaFragment"),
61                };
62                let reg_count = get_fragment_register_total_count(&frag);
63                write!(
64                    f,
65                    "// fill
66for (uint i = 0; i < uint({reg_count}); ++i) {{
67  {var}[i] = {value};
68}}
69 "
70                )
71            }
72            WmmaInstruction::Load {
73                frag: var,
74                value,
75                offset,
76                stride,
77                layout,
78            } => {
79                let frag = match var {
80                    Variable::WmmaFragment { frag, .. } => *frag,
81                    _ => panic!("load instruction expects a WmmaFragment"),
82                };
83                // Important note: the current frontend has been designed around
84                // CUDA wmma which is not optimal in the case of PTX wmma and mma
85                // We choose here to use the layout defined in the fragment first,
86                // if it is unknown and we look into the layout passed to the instruction.
87                let layout = if frag.layout.is_some() {
88                    get_fragment_layout_qualifier(var)
89                } else if let Some(layout) = layout {
90                    get_qualifier_from_layout(layout)
91                } else {
92                    panic!("unknown matrix layout for wmma load instruction");
93                };
94                // instruction qualifiers
95                let ty = get_type_qualifier(value);
96                let matrix = match frag.ident {
97                    FragmentIdent::A => "a",
98                    FragmentIdent::B => "b",
99                    FragmentIdent::Accumulator => "c",
100                    FragmentIdent::_Dialect(_) => unreachable!(),
101                };
102                let value_ty = value.item();
103                let opcode = match frag.elem {
104                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
105                        format!(
106                            "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
107                            frag.m, frag.n, frag.k,
108                        )
109                    }
110                    other => panic!("{other} fragment type not supported"),
111                };
112                // constraints
113                let mut reg_count = 0;
114                let (regs_decl, out_constraints) =
115                    get_variable_regs_decl_constraints(var, true, &mut reg_count);
116                let buffer_reg = format_reg_and_inc(&mut reg_count);
117                let (stride_reg, stride_constraint) =
118                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
119                let tmp_ptr = Variable::tmp_ptr(value.item());
120                let tmp_ptr_left = tmp_ptr.fmt_left();
121                write!(
122                    f,
123                    r#"// load
124{tmp_ptr_left} = ({value_ty}*){value} + {offset};
125asm volatile(
126    "{opcode} "
127    "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
128    : {out_constraints}
129    : "l"({tmp_ptr}){stride_constraint}
130);
131"#
132                )
133            }
134            WmmaInstruction::LdMatrix {
135                output,
136                buffer,
137                offset,
138                line_size,
139                factor,
140                transpose,
141            } => f.write_str(&ldmatrix_call(
142                output, buffer, offset, line_size, factor, transpose,
143            )),
144            WmmaInstruction::StMatrix {
145                registers,
146                buffer,
147                offset,
148                line_size,
149                factor,
150                transpose,
151            } => f.write_str(&stmatrix_call(
152                registers, buffer, offset, line_size, factor, transpose,
153            )),
154            WmmaInstruction::Execute {
155                frag_a: var_a,
156                frag_b: var_b,
157                frag_c: var_c,
158                frag_d: var_d,
159                ..
160            } => {
161                let frag_a = match var_a {
162                    Variable::WmmaFragment { frag, .. } => *frag,
163                    _ => panic!("variable should be WmmaFragment"),
164                };
165                let layout_a = get_fragment_layout_qualifier(var_a);
166                let layout_b = get_fragment_layout_qualifier(var_b);
167                let type_c = get_type_qualifier(var_c);
168                let type_d = get_type_qualifier(var_d);
169                let opcode = match var_a.elem() {
170                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
171                        "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
172                        frag_a.m, frag_a.n, frag_a.k,
173                    ),
174                    Elem::BF16 => format!(
175                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
176                        frag_a.m, frag_a.n, frag_a.k,
177                    ),
178                    Elem::TF32 => format!(
179                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
180                        frag_a.m, frag_a.n, frag_a.k,
181                    ),
182                    other => panic!("{other} fragment type not supported"),
183                };
184                let mut reg_count = 0;
185                // order matters, declare the registers in the same order as the intrinsic
186                let (regs_decl_d, out_constraints_d) =
187                    get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
188                let (regs_decl_a, in_constraints_a) =
189                    get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
190                let (regs_decl_b, in_constraints_b) =
191                    get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
192                let (regs_decl_c, in_constraints_c) =
193                    get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
194                write!(
195                    f,
196                    r#"// execute
197asm volatile(
198    "{opcode} "
199    "{{{regs_decl_d}}}, "
200    "{{{regs_decl_a}}}, "
201    "{{{regs_decl_b}}}, "
202    "{{{regs_decl_c}}};\n"
203    : {out_constraints_d}
204    : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
205);
206"#
207                )
208            }
209            WmmaInstruction::ExecuteManual {
210                shape,
211                frag_a,
212                frag_b,
213                frag_c,
214                frag_d,
215            } => {
216                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
217            }
218            WmmaInstruction::ExecuteScaled {
219                shape,
220                frag_a,
221                frag_b,
222                frag_c,
223                frag_d,
224
225                scales_a,
226                scales_b,
227                scales_factor,
228            } => Self::compile_scaled_mma(
229                f,
230                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
231                *scales_a,
232                *scales_b,
233                *scales_factor,
234            ),
235            WmmaInstruction::Store {
236                output,
237                frag: var,
238                stride,
239                offset,
240                layout,
241            } => {
242                let frag_acc = match var {
243                    Variable::WmmaFragment { frag, .. } => *frag,
244                    _ => panic!("variable should be WmmaFragment"),
245                };
246                // instruction qualifiers
247                let layout = match layout {
248                    FragmentLayout::ColMajor => "col",
249                    FragmentLayout::RowMajor => "row",
250                    FragmentLayout::_Dialect(..) => unreachable!(),
251                };
252                let opcode = match var.elem() {
253                    Elem::F16 | Elem::BF16 => format!(
254                        // hack because wmma.store does not support bf16
255                        // f16 should still work correctly for bf16 as long
256                        // as the input registers are in correct format
257                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
258                        frag_acc.m, frag_acc.n, frag_acc.k,
259                    ),
260                    Elem::TF32 | Elem::F32 => format!(
261                        // same hack for tf32
262                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
263                        frag_acc.m, frag_acc.n, frag_acc.k,
264                    ),
265                    Elem::I32 => format!(
266                        // same hack for tf32
267                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
268                        frag_acc.m, frag_acc.n, frag_acc.k,
269                    ),
270                    other => panic!("{other} fragment type not supported"),
271                };
272                // constraints
273                let mut reg_count = 0;
274                let buffer_reg = format_reg_and_inc(&mut reg_count);
275                // offset and stride can be passed as local const or as const scalar
276                // we need to handle both cases correctly in the asm.
277                let (stride_reg, stride_constraint) =
278                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
279                // we start at 2 because of the buffer address calculation
280                let (regs_decl, in_constraints) =
281                    get_variable_regs_decl_constraints(var, false, &mut reg_count);
282                let tmp_ptr = Variable::tmp_ptr(output.item());
283                let tmp_ptr_left = tmp_ptr.fmt_left();
284                write!(
285                    f,
286                    r#"// store
287{tmp_ptr_left} = {output} + {offset};
288asm volatile(
289    "{opcode} "
290    "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
291    :
292    : "l"({tmp_ptr}),
293      {in_constraints}{stride_constraint}
294);
295"#
296                )
297            }
298            WmmaInstruction::Cast { input, output } => {
299                let frag = match input {
300                    Variable::WmmaFragment { frag, .. } => *frag,
301                    _ => panic!("variable should be WmmaFragment"),
302                };
303                let reg_count = get_fragment_register_total_count(&frag);
304                match output.elem() {
305                    Elem::F16 => {
306                        write!(
307                            f,
308                            "// cast
309for (int i = 0; i < {reg_count}; ++i) {{
310    __half h_lo = __float2half_rn({input}[2*i + 0]);
311    __half h_hi = __float2half_rn({input}[2*i + 1]);
312    __half2 h2 = __halves2half2(h_lo, h_hi);
313    {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
314}}
315"
316                        )
317                    }
318                    Elem::BF16 => {
319                        write!(
320                            f,
321                            "// cast
322for (int i = 0; i < {reg_count}; ++i) {{
323    __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
324    __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
325    __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
326    {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
327}}
328"
329                        )
330                    }
331                    other => panic!("casting fragment to {other} not supported"),
332                }
333            }
334        }
335    }
336
337    fn compile_manual_mma(
338        f: &mut std::fmt::Formatter<'_>,
339        mma: ManualMma<CudaDialect<Self>>,
340    ) -> std::fmt::Result {
341        compile_manual_mma(f, mma)
342    }
343
344    fn compile_scaled_mma(
345        f: &mut std::fmt::Formatter<'_>,
346        mma: ManualMma<CudaDialect<Self>>,
347        scales_a: Variable<CudaDialect<Self>>,
348        scales_b: Variable<CudaDialect<Self>>,
349        scales_factor: u32,
350    ) -> std::fmt::Result {
351        compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
352    }
353
354    fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
355        let mut result: SupportedMmaCombinations = vec![];
356        if arch.get_version() >= WMMA_MINIMUM_VERSION {
357            // Types fully supported.
358            let types = vec![
359                (
360                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
361                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
362                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
363                ),
364                (
365                    gpu::ElemType::Float(gpu::FloatKind::F16),
366                    gpu::ElemType::Float(gpu::FloatKind::F16),
367                    gpu::ElemType::Float(gpu::FloatKind::F32),
368                ),
369                (
370                    gpu::ElemType::Float(gpu::FloatKind::BF16),
371                    gpu::ElemType::Float(gpu::FloatKind::BF16),
372                    gpu::ElemType::Float(gpu::FloatKind::F32),
373                ),
374            ];
375            let combinations: SupportedMmaCombinations = types
376                .into_iter()
377                .map(|(a, b, cd)| MmaConfig {
378                    a_type: a.into(),
379                    b_type: b.into(),
380                    cd_type: cd.into(),
381                    m: 16,
382                    n: 16,
383                    k: 16,
384                })
385                .collect();
386            result.extend(combinations);
387            if arch.get_version() >= 72 {
388                result.extend([
389                    MmaConfig {
390                        a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
391                        b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
392                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
393                        m: 16,
394                        n: 16,
395                        k: 16,
396                    },
397                    MmaConfig {
398                        a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
399                        b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
400                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
401                        m: 16,
402                        n: 16,
403                        k: 16,
404                    },
405                ]);
406            }
407            if arch.get_version() >= 80 {
408                result.push(MmaConfig {
409                    a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
410                    b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
411                    cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
412                    m: 16,
413                    n: 16,
414                    k: 8,
415                });
416            }
417        }
418        result
419    }
420
421    fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
422        supported_mma_combinations(arch)
423    }
424
425    fn supported_scaled_mma_combinations(
426        arch: &CudaArchitecture,
427    ) -> SupportedScaledMmaCombinations {
428        supported_scaled_mma_combinations(arch)
429    }
430}
431
432fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
433    let Fragment {
434        ident,
435        m,
436        n,
437        k,
438        elem,
439        ..
440    } = frag;
441    let elements = match ident {
442        FragmentIdent::A => m * k,
443        FragmentIdent::B => k * n,
444        FragmentIdent::Accumulator => m * n,
445        _ => unreachable!(),
446    };
447    let bits_per_elem = elem.size_bits() as u32;
448    // TODO: retrieve the warp size from the compiler CompilationOptions
449    let lanes_per_reg = 32 / bits_per_elem;
450    // choose threads-per-frag:
451    // - accumulators always use 32 lanes
452    // - A/B use 16 lanes _except_ TF32 (k=8) which also uses 32 lanes
453    let threads_per_frag = match ident {
454        FragmentIdent::Accumulator => 32,
455        FragmentIdent::A | FragmentIdent::B => {
456            if frag.elem == Elem::TF32 {
457                32
458            } else {
459                16
460            }
461        }
462        _ => unreachable!(),
463    };
464
465    elements / (lanes_per_reg * threads_per_frag)
466}
467
468fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
469    match var.elem() {
470        Elem::U8 => "u8",
471        Elem::I8 => "s8",
472        Elem::F16 => "f16",
473        Elem::BF16 => "bf16",
474        Elem::F32 => "f32",
475        Elem::TF32 => "tf32",
476        Elem::I32 => "s32",
477        Elem::F64 => "f64",
478        _ => panic!("unsupported WMMA fragment type"),
479    }
480    .to_string()
481}
482
483fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
484    let frag = match var {
485        Variable::WmmaFragment { frag, .. } => *frag,
486        _ => panic!("variable should be WmmaFragment"),
487    };
488    match frag.layout {
489        Some(layout) => get_qualifier_from_layout(&layout),
490        None => "".to_string(),
491    }
492}
493
494fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
495    match layout {
496        FragmentLayout::ColMajor => "col",
497        FragmentLayout::RowMajor => "row",
498        FragmentLayout::_Dialect(..) => unreachable!(),
499    }
500    .to_string()
501}
502
503fn get_variable_regs_decl_constraints(
504    var: &Variable<CudaDialect<PtxWmmaCompiler>>,
505    output: bool,
506    reg_count: &mut u8,
507) -> (String, String) {
508    match var {
509        Variable::WmmaFragment { frag, .. } => {
510            let reg_total_count = get_fragment_register_total_count(frag);
511            let reg_decl = (0..reg_total_count)
512                .map(|_| format_reg_and_inc(reg_count))
513                .collect::<Vec<_>>()
514                .join(",");
515            let frag_elem = frag.elem;
516            let modifier = format!(
517                "{}{}",
518                if output { "=" } else { "" },
519                match frag_elem {
520                    Elem::F32 => "f",
521                    Elem::F64 => "d",
522                    _ => "r",
523                },
524            );
525            let constraints = (0..reg_total_count)
526                .map(|i| format!("\"{modifier}\"({var}[{i}])"))
527                .collect::<Vec<_>>()
528                .join(", ");
529            (reg_decl, constraints)
530        }
531        Variable::ConstantScalar(number, ..) => match number {
532            ConstantScalarValue::UInt(val, ..) => (val.to_string(), "".to_string()),
533            _ => panic!("variable should be an unsigned integer"),
534        },
535        _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
536    }
537}
538
539fn format_reg_and_inc(count: &mut u8) -> String {
540    let res = format!("%{count}");
541    *count += 1;
542    res
543}
544
545fn as_ty(var: impl Display, ty: impl Display) -> String {
546    format!("reinterpret_cast<{ty}&>({var})")
547}
548
549fn as_const_ty(var: impl Display, ty: impl Display) -> String {
550    format!("reinterpret_cast<const {ty}&>({var})")
551}
552
553pub(super) fn compile_manual_mma<D: Dialect>(
554    f: &mut core::fmt::Formatter<'_>,
555    mma: ManualMma<D>,
556) -> std::fmt::Result {
557    let ManualMma {
558        shape,
559        frag_a,
560        frag_b,
561        frag_c,
562        frag_d,
563    } = mma;
564
565    let a_elem = frag_a.elem().unpacked();
566    let b_elem = frag_b.elem().unpacked();
567    let cd_elem = frag_c.elem().unpacked();
568
569    let ab_ty = match a_elem {
570        Elem::F32 => &format!("{}", Elem::<D>::F32),
571        _ => &format!("{}", Elem::<D>::U32),
572    };
573    let cd_ty = match cd_elem {
574        Elem::F32 => &format!("{}", Elem::<D>::F32),
575        _ => &format!("{}", Elem::<D>::U32),
576    };
577
578    let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
579    let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
580    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
581
582    let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
583    let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
584    let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
585
586    let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
587    let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
588
589    // C and D fragments are always vectorized for optimal stores and casts, but are taken as separate
590    // registers for f32 so we need to unpack it. f16 is taken in packed registers, so use as is.
591    let frag_c = match cd_elem.size() {
592        4 | 8 => (0..cd_regs)
593            .map(|i| as_ty(format!("{frag_c}[{}].i_{}", i / 2, i % 2), cd_ty))
594            .collect::<Vec<_>>(),
595        2 => (0..cd_regs)
596            .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
597            .collect::<Vec<_>>(),
598        other => panic!("Found unhandled accumulator elem size {other}"),
599    };
600    let frag_d = match cd_elem.size() {
601        4 | 8 => (0..cd_regs)
602            .map(|i| as_ty(format!("{frag_d}[{}].i_{}", i / 2, i % 2), cd_ty))
603            .collect::<Vec<_>>(),
604        2 => (0..cd_regs)
605            .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
606            .collect::<Vec<_>>(),
607        other => panic!("Found unhandled accumulator elem size {other}"),
608    };
609    let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
610    write!(
611        f,
612        "__mma_m16n8k{}_{}_{}_{}({args});",
613        shape.k, a_elem, b_elem, cd_elem
614    )
615}
616
617pub(super) fn compile_scaled_mma<D: Dialect>(
618    f: &mut core::fmt::Formatter<'_>,
619    mma: ManualMma<D>,
620    scales_a: Variable<D>,
621    scales_b: Variable<D>,
622    scales_factor: u32,
623) -> std::fmt::Result {
624    let ManualMma {
625        shape,
626        frag_a,
627        frag_b,
628        frag_c,
629        frag_d,
630    } = mma;
631
632    let a_elem = frag_a.elem().unpacked();
633    let b_elem = frag_b.elem().unpacked();
634    let cd_elem = frag_c.elem().unpacked();
635
636    let ab_ty = &format!("{}", Elem::<D>::U32);
637    let cd_ty = &format!("{}", Elem::<D>::F32);
638
639    let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
640    let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
641    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
642
643    let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
644    let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
645    let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
646
647    let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
648    let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
649    let frag_c = (0..cd_regs).map(|i| as_const_ty(format!("{frag_c}[{i}]"), cd_ty));
650    let frag_d = (0..cd_regs).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
651    let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
652    write!(
653        f,
654        "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
655        shape.k, a_elem, b_elem, cd_elem
656    )
657}
658
659pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
660    let mut result: SupportedMmaCombinations = vec![];
661    // Higher than WMMA because we only support the newest shapes. Other shapes would make things
662    // very complicated.
663    // Also only use f32 accumulators for now
664    if arch.get_version() >= 80 {
665        result.extend([
666            MmaConfig {
667                a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // a
668                b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // b
669                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), // cd
670                m: 16,
671                n: 8,
672                k: 16,
673            },
674            MmaConfig {
675                a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
676                b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
677                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
678                m: 16,
679                n: 8,
680                k: 16,
681            },
682            MmaConfig {
683                a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
684                b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
685                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
686                m: 16,
687                n: 8,
688                k: 8,
689            },
690            MmaConfig {
691                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
692                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
693                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
694                m: 16,
695                n: 8,
696                k: 32,
697            },
698            MmaConfig {
699                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
700                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
701                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
702                m: 16,
703                n: 8,
704                k: 32,
705            },
706            MmaConfig {
707                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
708                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
709                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
710                m: 16,
711                n: 8,
712                k: 32,
713            },
714            MmaConfig {
715                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
716                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
717                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
718                m: 16,
719                n: 8,
720                k: 32,
721            },
722            // TODO: u4/i4/b1, there's no types for them yet
723        ]);
724    }
725    if arch.get_version() >= 89 {
726        let f8f6f4_types = [
727            gpu::FloatKind::E4M3,
728            gpu::FloatKind::E5M2,
729            gpu::FloatKind::E3M2,
730            gpu::FloatKind::E2M3,
731            gpu::FloatKind::E2M1,
732        ];
733        let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
734        result.extend(combinations.map(|(t1, t2)| MmaConfig {
735            a_type: gpu::ElemType::Float(*t1).into(),
736            b_type: gpu::ElemType::Float(*t2).into(),
737            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
738            m: 16,
739            n: 8,
740            k: 32,
741        }));
742    }
743    result
744}
745
746pub(super) fn supported_scaled_mma_combinations(
747    arch: &CudaArchitecture,
748) -> SupportedScaledMmaCombinations {
749    let mut result: SupportedScaledMmaCombinations = vec![];
750    // sm_120f
751    if arch.get_version() >= 120 && arch.get_version() < 130 {
752        let f8f6f4_types = [
753            gpu::FloatKind::E4M3,
754            gpu::FloatKind::E5M2,
755            gpu::FloatKind::E3M2,
756            gpu::FloatKind::E2M3,
757            gpu::FloatKind::E2M1,
758        ];
759        let combinations = f8f6f4_types
760            .iter()
761            .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
762
763        result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
764            a_type: gpu::ElemType::Float(*t1).into(),
765            b_type: gpu::ElemType::Float(*t2).into(),
766            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
767            scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
768            m: 16,
769            n: 8,
770            k: 32,
771            scales_factor: 1,
772        }));
773
774        result.extend([
775            ScaledMmaConfig {
776                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
777                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
778                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
779                scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
780                m: 16,
781                n: 8,
782                k: 64,
783                scales_factor: 2,
784            },
785            // Sign of scales is ignored
786            ScaledMmaConfig {
787                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
788                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
789                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
790                scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
791                m: 16,
792                n: 8,
793                k: 64,
794                scales_factor: 4,
795            },
796        ]);
797    }
798    result
799}
800
801pub fn contiguous_elements_cuda(ident: MatrixIdent, matrix: Matrix) -> u32 {
802    match ident {
803        MatrixIdent::A | MatrixIdent::B => (32 / matrix.storage.size_bits()) as u32,
804        MatrixIdent::Accumulator => 2,
805    }
806}