cubecl_cpp/cuda/mma/
ptx_wmma_compiler.rs

1use std::fmt::Display;
2
3use crate::{
4    Dialect,
5    cuda::{
6        CudaDialect,
7        arch::CudaArchitecture,
8        ptx::{comma_separated, ldmatrix_call},
9    },
10    shared::{
11        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
12        FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
13        SupportedScaledMmaCombinations, Variable, WmmaInstruction,
14    },
15};
16use cubecl_core::ir::{self as gpu, ConstantScalarValue, Matrix, MatrixIdent};
17use cubecl_runtime::{MmaConfig, ScaledMmaConfig};
18use itertools::Itertools;
19
20use super::WMMA_MINIMUM_VERSION;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
27        // We need mma header for conversion
28        if flags.elem_tf32 {
29            f.write_str("#include <mma.h>\n")?;
30        }
31        Ok(())
32    }
33
34    fn compile_wmma_fragment_declaration(
35        f: &mut std::fmt::Formatter<'_>,
36        var: &Variable<CudaDialect<Self>>,
37    ) -> std::fmt::Result {
38        let frag = match var {
39            Variable::WmmaFragment { frag, .. } => *frag,
40            _ => panic!("load instruction expects a WmmaFragment"),
41        };
42        let reg_count = get_fragment_register_total_count(&frag);
43        let ty = match frag.elem {
44            Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
45            Elem::F32 => "float",
46            Elem::F64 => "double",
47            _ => panic!("unsupported type"),
48        };
49        writeln!(f, "{ty} {var}[{reg_count}];")
50    }
51
52    fn compile_wmma_instruction(
53        f: &mut std::fmt::Formatter<'_>,
54        instruction: &WmmaInstruction<CudaDialect<Self>>,
55    ) -> std::fmt::Result {
56        match instruction {
57            WmmaInstruction::Fill { frag: var, value } => {
58                let frag = match var {
59                    Variable::WmmaFragment { frag, .. } => *frag,
60                    _ => panic!("variable should be WmmaFragment"),
61                };
62                let reg_count = get_fragment_register_total_count(&frag);
63                write!(
64                    f,
65                    "// fill
66for (uint i = 0; i < uint({reg_count}); ++i) {{
67  {var}[i] = {value};
68}}
69 "
70                )
71            }
72            WmmaInstruction::Load {
73                frag: var,
74                value,
75                offset,
76                stride,
77                layout,
78            } => {
79                let frag = match var {
80                    Variable::WmmaFragment { frag, .. } => *frag,
81                    _ => panic!("load instruction expects a WmmaFragment"),
82                };
83                // Important note: the current frontend has been designed around
84                // CUDA wmma which is not optimal in the case of PTX wmma and mma
85                // We choose here to use the layout defined in the fragment first,
86                // if it is unknown and we look into the layout passed to the instruction.
87                let layout = if frag.layout.is_some() {
88                    get_fragment_layout_qualifier(var)
89                } else if let Some(layout) = layout {
90                    get_qualifier_from_layout(layout)
91                } else {
92                    panic!("unknown matrix layout for wmma load instruction");
93                };
94                // instruction qualifiers
95                let ty = get_type_qualifier(value);
96                let matrix = match frag.ident {
97                    FragmentIdent::A => "a",
98                    FragmentIdent::B => "b",
99                    FragmentIdent::Accumulator => "c",
100                    FragmentIdent::_Dialect(_) => unreachable!(),
101                };
102                let value_ty = value.item();
103                let opcode = match frag.elem {
104                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
105                        format!(
106                            "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
107                            frag.m, frag.n, frag.k,
108                        )
109                    }
110                    other => panic!("{other} fragment type not supported"),
111                };
112                // constraints
113                let mut reg_count = 0;
114                let (regs_decl, out_constraints) =
115                    get_variable_regs_decl_constraints(var, true, &mut reg_count);
116                let buffer_reg = format_reg_and_inc(&mut reg_count);
117                let (stride_reg, stride_constraint) =
118                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
119                let tmp_ptr = Variable::tmp_ptr(value.item());
120                let tmp_ptr_left = tmp_ptr.fmt_left();
121                write!(
122                    f,
123                    r#"// load
124{tmp_ptr_left} = ({value_ty}*){value} + {offset};
125asm volatile(
126    "{opcode} "
127    "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
128    : {out_constraints}
129    : "l"({tmp_ptr}){stride_constraint}
130);
131"#
132                )
133            }
134            WmmaInstruction::LdMatrix {
135                output,
136                buffer,
137                offset,
138                line_size,
139                factor,
140                transpose,
141            } => f.write_str(&ldmatrix_call(
142                output, buffer, offset, line_size, factor, transpose,
143            )),
144            WmmaInstruction::Execute {
145                frag_a: var_a,
146                frag_b: var_b,
147                frag_c: var_c,
148                frag_d: var_d,
149                ..
150            } => {
151                let frag_a = match var_a {
152                    Variable::WmmaFragment { frag, .. } => *frag,
153                    _ => panic!("variable should be WmmaFragment"),
154                };
155                let layout_a = get_fragment_layout_qualifier(var_a);
156                let layout_b = get_fragment_layout_qualifier(var_b);
157                let type_c = get_type_qualifier(var_c);
158                let type_d = get_type_qualifier(var_d);
159                let opcode = match var_a.elem() {
160                    Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
161                        "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
162                        frag_a.m, frag_a.n, frag_a.k,
163                    ),
164                    Elem::BF16 => format!(
165                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
166                        frag_a.m, frag_a.n, frag_a.k,
167                    ),
168                    Elem::TF32 => format!(
169                        "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
170                        frag_a.m, frag_a.n, frag_a.k,
171                    ),
172                    other => panic!("{other} fragment type not supported"),
173                };
174                let mut reg_count = 0;
175                // order matters, declare the registers in the same order as the intrinsic
176                let (regs_decl_d, out_constraints_d) =
177                    get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
178                let (regs_decl_a, in_constraints_a) =
179                    get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
180                let (regs_decl_b, in_constraints_b) =
181                    get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
182                let (regs_decl_c, in_constraints_c) =
183                    get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
184                write!(
185                    f,
186                    r#"// execute
187asm volatile(
188    "{opcode} "
189    "{{{regs_decl_d}}}, "
190    "{{{regs_decl_a}}}, "
191    "{{{regs_decl_b}}}, "
192    "{{{regs_decl_c}}};\n"
193    : {out_constraints_d}
194    : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
195);
196"#
197                )
198            }
199            WmmaInstruction::ExecuteManual {
200                shape,
201                frag_a,
202                frag_b,
203                frag_c,
204                frag_d,
205            } => {
206                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
207            }
208            WmmaInstruction::ExecuteScaled {
209                shape,
210                frag_a,
211                frag_b,
212                frag_c,
213                frag_d,
214
215                scales_a,
216                scales_b,
217                scales_factor,
218            } => Self::compile_scaled_mma(
219                f,
220                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
221                *scales_a,
222                *scales_b,
223                *scales_factor,
224            ),
225            WmmaInstruction::Store {
226                output,
227                frag: var,
228                stride,
229                offset,
230                layout,
231            } => {
232                let frag_acc = match var {
233                    Variable::WmmaFragment { frag, .. } => *frag,
234                    _ => panic!("variable should be WmmaFragment"),
235                };
236                // instruction qualifiers
237                let layout = match layout {
238                    FragmentLayout::ColMajor => "col",
239                    FragmentLayout::RowMajor => "row",
240                    FragmentLayout::_Dialect(..) => unreachable!(),
241                };
242                let opcode = match var.elem() {
243                    Elem::F16 | Elem::BF16 => format!(
244                        // hack because wmma.store does not support bf16
245                        // f16 should still work correctly for bf16 as long
246                        // as the input registers are in correct format
247                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
248                        frag_acc.m, frag_acc.n, frag_acc.k,
249                    ),
250                    Elem::TF32 | Elem::F32 => format!(
251                        // same hack for tf32
252                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
253                        frag_acc.m, frag_acc.n, frag_acc.k,
254                    ),
255                    Elem::I32 => format!(
256                        // same hack for tf32
257                        "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
258                        frag_acc.m, frag_acc.n, frag_acc.k,
259                    ),
260                    other => panic!("{other} fragment type not supported"),
261                };
262                // constraints
263                let mut reg_count = 0;
264                let buffer_reg = format_reg_and_inc(&mut reg_count);
265                // offset and stride can be passed as local const or as const scalar
266                // we need to handle both cases correctly in the asm.
267                let (stride_reg, stride_constraint) =
268                    get_variable_regs_decl_constraints(stride, false, &mut reg_count);
269                // we start at 2 because of the buffer address calculation
270                let (regs_decl, in_constraints) =
271                    get_variable_regs_decl_constraints(var, false, &mut reg_count);
272                let tmp_ptr = Variable::tmp_ptr(output.item());
273                let tmp_ptr_left = tmp_ptr.fmt_left();
274                write!(
275                    f,
276                    r#"// store
277{tmp_ptr_left} = {output} + {offset};
278asm volatile(
279    "{opcode} "
280    "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
281    :
282    : "l"({tmp_ptr}),
283      {in_constraints}{stride_constraint}
284);
285"#
286                )
287            }
288            WmmaInstruction::Cast { input, output } => {
289                let frag = match input {
290                    Variable::WmmaFragment { frag, .. } => *frag,
291                    _ => panic!("variable should be WmmaFragment"),
292                };
293                let reg_count = get_fragment_register_total_count(&frag);
294                match output.elem() {
295                    Elem::F16 => {
296                        write!(
297                            f,
298                            "// cast
299for (int i = 0; i < {reg_count}; ++i) {{
300    __half h_lo = __float2half_rn({input}[2*i + 0]);
301    __half h_hi = __float2half_rn({input}[2*i + 1]);
302    __half2 h2 = __halves2half2(h_lo, h_hi);
303    {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
304}}
305"
306                        )
307                    }
308                    Elem::BF16 => {
309                        write!(
310                            f,
311                            "// cast
312for (int i = 0; i < {reg_count}; ++i) {{
313    __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
314    __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
315    __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
316    {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
317}}
318"
319                        )
320                    }
321                    other => panic!("casting fragment to {other} not supported"),
322                }
323            }
324        }
325    }
326
327    fn compile_manual_mma(
328        f: &mut std::fmt::Formatter<'_>,
329        mma: ManualMma<CudaDialect<Self>>,
330    ) -> std::fmt::Result {
331        compile_manual_mma(f, mma)
332    }
333
334    fn compile_scaled_mma(
335        f: &mut std::fmt::Formatter<'_>,
336        mma: ManualMma<CudaDialect<Self>>,
337        scales_a: Variable<CudaDialect<Self>>,
338        scales_b: Variable<CudaDialect<Self>>,
339        scales_factor: u32,
340    ) -> std::fmt::Result {
341        compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
342    }
343
344    fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
345        let mut result: SupportedMmaCombinations = vec![];
346        if arch.get_version() >= WMMA_MINIMUM_VERSION {
347            // Types fully supported.
348            let types = vec![
349                (
350                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
351                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
352                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
353                ),
354                (
355                    gpu::ElemType::Float(gpu::FloatKind::F16),
356                    gpu::ElemType::Float(gpu::FloatKind::F16),
357                    gpu::ElemType::Float(gpu::FloatKind::F32),
358                ),
359                (
360                    gpu::ElemType::Float(gpu::FloatKind::BF16),
361                    gpu::ElemType::Float(gpu::FloatKind::BF16),
362                    gpu::ElemType::Float(gpu::FloatKind::F32),
363                ),
364            ];
365            let combinations: SupportedMmaCombinations = types
366                .into_iter()
367                .map(|(a, b, cd)| MmaConfig {
368                    a_type: a.into(),
369                    b_type: b.into(),
370                    cd_type: cd.into(),
371                    m: 16,
372                    n: 16,
373                    k: 16,
374                })
375                .collect();
376            result.extend(combinations);
377            if arch.get_version() >= 72 {
378                result.extend([
379                    MmaConfig {
380                        a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
381                        b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
382                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
383                        m: 16,
384                        n: 16,
385                        k: 16,
386                    },
387                    MmaConfig {
388                        a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
389                        b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
390                        cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
391                        m: 16,
392                        n: 16,
393                        k: 16,
394                    },
395                ]);
396            }
397            if arch.get_version() >= 80 {
398                result.push(MmaConfig {
399                    a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
400                    b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
401                    cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
402                    m: 16,
403                    n: 16,
404                    k: 8,
405                });
406            }
407        }
408        result
409    }
410
411    fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
412        supported_mma_combinations(arch)
413    }
414
415    fn supported_scaled_mma_combinations(
416        arch: &CudaArchitecture,
417    ) -> SupportedScaledMmaCombinations {
418        supported_scaled_mma_combinations(arch)
419    }
420}
421
422fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
423    let Fragment {
424        ident,
425        m,
426        n,
427        k,
428        elem,
429        ..
430    } = frag;
431    let elements = match ident {
432        FragmentIdent::A => m * k,
433        FragmentIdent::B => k * n,
434        FragmentIdent::Accumulator => m * n,
435        _ => unreachable!(),
436    };
437    let bits_per_elem = elem.size_bits() as u32;
438    // TODO: retrieve the warp size from the compiler CompilationOptions
439    let lanes_per_reg = 32 / bits_per_elem;
440    // choose threads-per-frag:
441    // - accumulators always use 32 lanes
442    // - A/B use 16 lanes _except_ TF32 (k=8) which also uses 32 lanes
443    let threads_per_frag = match ident {
444        FragmentIdent::Accumulator => 32,
445        FragmentIdent::A | FragmentIdent::B => {
446            if frag.elem == Elem::TF32 {
447                32
448            } else {
449                16
450            }
451        }
452        _ => unreachable!(),
453    };
454
455    elements / (lanes_per_reg * threads_per_frag)
456}
457
458fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
459    match var.elem() {
460        Elem::U8 => "u8",
461        Elem::I8 => "s8",
462        Elem::F16 => "f16",
463        Elem::BF16 => "bf16",
464        Elem::F32 => "f32",
465        Elem::TF32 => "tf32",
466        Elem::I32 => "s32",
467        Elem::F64 => "f64",
468        _ => panic!("unsupported WMMA fragment type"),
469    }
470    .to_string()
471}
472
473fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
474    let frag = match var {
475        Variable::WmmaFragment { frag, .. } => *frag,
476        _ => panic!("variable should be WmmaFragment"),
477    };
478    match frag.layout {
479        Some(layout) => get_qualifier_from_layout(&layout),
480        None => "".to_string(),
481    }
482}
483
484fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
485    match layout {
486        FragmentLayout::ColMajor => "col",
487        FragmentLayout::RowMajor => "row",
488        FragmentLayout::_Dialect(..) => unreachable!(),
489    }
490    .to_string()
491}
492
493fn get_variable_regs_decl_constraints(
494    var: &Variable<CudaDialect<PtxWmmaCompiler>>,
495    output: bool,
496    reg_count: &mut u8,
497) -> (String, String) {
498    match var {
499        Variable::WmmaFragment { frag, .. } => {
500            let reg_total_count = get_fragment_register_total_count(frag);
501            let reg_decl = (0..reg_total_count)
502                .map(|_| format_reg_and_inc(reg_count))
503                .collect::<Vec<_>>()
504                .join(",");
505            let frag_elem = frag.elem;
506            let modifier = format!(
507                "{}{}",
508                if output { "=" } else { "" },
509                match frag_elem {
510                    Elem::F32 => "f",
511                    Elem::F64 => "d",
512                    _ => "r",
513                },
514            );
515            let constraints = (0..reg_total_count)
516                .map(|i| format!("\"{modifier}\"({var}[{i}])"))
517                .collect::<Vec<_>>()
518                .join(", ");
519            (reg_decl, constraints)
520        }
521        Variable::ConstantScalar(number, ..) => match number {
522            ConstantScalarValue::UInt(val, ..) => (val.to_string(), "".to_string()),
523            _ => panic!("variable should be an unsigned integer"),
524        },
525        _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
526    }
527}
528
529fn format_reg_and_inc(count: &mut u8) -> String {
530    let res = format!("%{count}");
531    *count += 1;
532    res
533}
534
535fn as_ty(var: impl Display, ty: impl Display) -> String {
536    format!("reinterpret_cast<{ty}&>({var})")
537}
538
539fn as_const_ty(var: impl Display, ty: impl Display) -> String {
540    format!("reinterpret_cast<const {ty}&>({var})")
541}
542
543pub(super) fn compile_manual_mma<D: Dialect>(
544    f: &mut core::fmt::Formatter<'_>,
545    mma: ManualMma<D>,
546) -> std::fmt::Result {
547    let ManualMma {
548        shape,
549        frag_a,
550        frag_b,
551        frag_c,
552        frag_d,
553    } = mma;
554
555    let a_elem = frag_a.elem().unpacked();
556    let b_elem = frag_b.elem().unpacked();
557    let cd_elem = frag_c.elem().unpacked();
558
559    let ab_ty = match a_elem {
560        Elem::F32 => &format!("{}", Elem::<D>::F32),
561        _ => &format!("{}", Elem::<D>::U32),
562    };
563    let cd_ty = match cd_elem {
564        Elem::F32 => &format!("{}", Elem::<D>::F32),
565        _ => &format!("{}", Elem::<D>::U32),
566    };
567
568    let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
569    let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
570    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
571
572    let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
573    let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
574    let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
575
576    let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
577    let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
578
579    // C and D fragments are always vectorized for optimal stores and casts, but are taken as separate
580    // registers for f32 so we need to unpack it. f16 is taken in packed registers, so use as is.
581    let frag_c = match cd_elem.size() {
582        4 | 8 => (0..cd_regs)
583            .map(|i| as_ty(format!("{frag_c}[{}].i_{}", i / 2, i % 2), cd_ty))
584            .collect::<Vec<_>>(),
585        2 => (0..cd_regs)
586            .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
587            .collect::<Vec<_>>(),
588        other => panic!("Found unhandled accumulator elem size {other}"),
589    };
590    let frag_d = match cd_elem.size() {
591        4 | 8 => (0..cd_regs)
592            .map(|i| as_ty(format!("{frag_d}[{}].i_{}", i / 2, i % 2), cd_ty))
593            .collect::<Vec<_>>(),
594        2 => (0..cd_regs)
595            .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
596            .collect::<Vec<_>>(),
597        other => panic!("Found unhandled accumulator elem size {other}"),
598    };
599    let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
600    write!(
601        f,
602        "__mma_m16n8k{}_{}_{}_{}({args});",
603        shape.k, a_elem, b_elem, cd_elem
604    )
605}
606
607pub(super) fn compile_scaled_mma<D: Dialect>(
608    f: &mut core::fmt::Formatter<'_>,
609    mma: ManualMma<D>,
610    scales_a: Variable<D>,
611    scales_b: Variable<D>,
612    scales_factor: u32,
613) -> std::fmt::Result {
614    let ManualMma {
615        shape,
616        frag_a,
617        frag_b,
618        frag_c,
619        frag_d,
620    } = mma;
621
622    let a_elem = frag_a.elem().unpacked();
623    let b_elem = frag_b.elem().unpacked();
624    let cd_elem = frag_c.elem().unpacked();
625
626    let ab_ty = &format!("{}", Elem::<D>::U32);
627    let cd_ty = &format!("{}", Elem::<D>::F32);
628
629    let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
630    let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
631    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
632
633    let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
634    let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
635    let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
636
637    let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
638    let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
639    let frag_c = (0..cd_regs).map(|i| as_const_ty(format!("{frag_c}[{i}]"), cd_ty));
640    let frag_d = (0..cd_regs).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
641    let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
642    write!(
643        f,
644        "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
645        shape.k, a_elem, b_elem, cd_elem
646    )
647}
648
649pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
650    let mut result: SupportedMmaCombinations = vec![];
651    // Higher than WMMA because we only support the newest shapes. Other shapes would make things
652    // very complicated.
653    // Also only use f32 accumulators for now
654    if arch.get_version() >= 80 {
655        result.extend([
656            MmaConfig {
657                a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // a
658                b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), // b
659                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), // cd
660                m: 16,
661                n: 8,
662                k: 16,
663            },
664            MmaConfig {
665                a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
666                b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
667                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
668                m: 16,
669                n: 8,
670                k: 16,
671            },
672            MmaConfig {
673                a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
674                b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
675                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
676                m: 16,
677                n: 8,
678                k: 8,
679            },
680            MmaConfig {
681                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
682                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
683                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
684                m: 16,
685                n: 8,
686                k: 32,
687            },
688            MmaConfig {
689                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
690                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
691                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
692                m: 16,
693                n: 8,
694                k: 32,
695            },
696            MmaConfig {
697                a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
698                b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
699                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
700                m: 16,
701                n: 8,
702                k: 32,
703            },
704            MmaConfig {
705                a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
706                b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
707                cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
708                m: 16,
709                n: 8,
710                k: 32,
711            },
712            // TODO: u4/i4/b1, there's no types for them yet
713        ]);
714    }
715    if arch.get_version() >= 89 {
716        let f8f6f4_types = [
717            gpu::FloatKind::E4M3,
718            gpu::FloatKind::E5M2,
719            gpu::FloatKind::E3M2,
720            gpu::FloatKind::E2M3,
721            gpu::FloatKind::E2M1,
722        ];
723        let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
724        result.extend(combinations.map(|(t1, t2)| MmaConfig {
725            a_type: gpu::ElemType::Float(*t1).into(),
726            b_type: gpu::ElemType::Float(*t2).into(),
727            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
728            m: 16,
729            n: 8,
730            k: 32,
731        }));
732    }
733    result
734}
735
736pub(super) fn supported_scaled_mma_combinations(
737    arch: &CudaArchitecture,
738) -> SupportedScaledMmaCombinations {
739    let mut result: SupportedScaledMmaCombinations = vec![];
740    // sm_120f
741    if arch.get_version() >= 120 && arch.get_version() < 130 {
742        let f8f6f4_types = [
743            gpu::FloatKind::E4M3,
744            gpu::FloatKind::E5M2,
745            gpu::FloatKind::E3M2,
746            gpu::FloatKind::E2M3,
747            gpu::FloatKind::E2M1,
748        ];
749        let combinations = f8f6f4_types
750            .iter()
751            .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
752
753        result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
754            a_type: gpu::ElemType::Float(*t1).into(),
755            b_type: gpu::ElemType::Float(*t2).into(),
756            cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
757            scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
758            m: 16,
759            n: 8,
760            k: 32,
761            scales_factor: 1,
762        }));
763
764        result.extend([
765            ScaledMmaConfig {
766                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
767                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
768                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
769                scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
770                m: 16,
771                n: 8,
772                k: 64,
773                scales_factor: 2,
774            },
775            // Sign of scales is ignored
776            ScaledMmaConfig {
777                a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
778                b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
779                cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
780                scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
781                m: 16,
782                n: 8,
783                k: 64,
784                scales_factor: 4,
785            },
786        ]);
787    }
788    result
789}
790
791pub fn contiguous_elements_cuda(ident: MatrixIdent, matrix: Matrix) -> u32 {
792    match ident {
793        MatrixIdent::A | MatrixIdent::B => (32 / matrix.storage.size_bits()) as u32,
794        MatrixIdent::Accumulator => 2,
795    }
796}