1use std::fmt::Display;
2
3use crate::{
4 Dialect,
5 cuda::{
6 CudaDialect,
7 arch::CudaArchitecture,
8 ptx::{comma_separated, ldmatrix_call},
9 },
10 shared::{
11 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
12 FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
13 SupportedScaledMmaCombinations, Variable, WmmaInstruction,
14 },
15};
16use cubecl_core::ir::{self as gpu, ConstantScalarValue};
17use cubecl_runtime::{MmaConfig, ScaledMmaConfig};
18use itertools::Itertools;
19
20use super::WMMA_MINIMUM_VERSION;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
27 if flags.elem_tf32 {
29 f.write_str("#include <mma.h>\n")?;
30 }
31 Ok(())
32 }
33
34 fn compile_wmma_fragment_declaration(
35 f: &mut std::fmt::Formatter<'_>,
36 var: &Variable<CudaDialect<Self>>,
37 ) -> std::fmt::Result {
38 let frag = match var {
39 Variable::WmmaFragment { frag, .. } => *frag,
40 _ => panic!("load instruction expects a WmmaFragment"),
41 };
42 let reg_count = get_fragment_register_total_count(&frag);
43 let ty = match frag.elem {
44 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
45 Elem::F32 => "float",
46 Elem::F64 => "double",
47 _ => panic!("unsupported type"),
48 };
49 writeln!(f, "{ty} {var}[{reg_count}];")
50 }
51
52 fn compile_wmma_instruction(
53 f: &mut std::fmt::Formatter<'_>,
54 instruction: &WmmaInstruction<CudaDialect<Self>>,
55 ) -> std::fmt::Result {
56 match instruction {
57 WmmaInstruction::Fill { frag: var, value } => {
58 let frag = match var {
59 Variable::WmmaFragment { frag, .. } => *frag,
60 _ => panic!("variable should be WmmaFragment"),
61 };
62 let reg_count = get_fragment_register_total_count(&frag);
63 write!(
64 f,
65 "// fill
66for (uint i = 0; i < uint({reg_count}); ++i) {{
67 {var}[i] = {value};
68}}
69 "
70 )
71 }
72 WmmaInstruction::Load {
73 frag: var,
74 value,
75 offset,
76 stride,
77 layout,
78 } => {
79 let frag = match var {
80 Variable::WmmaFragment { frag, .. } => *frag,
81 _ => panic!("load instruction expects a WmmaFragment"),
82 };
83 let layout = if frag.layout.is_some() {
88 get_fragment_layout_qualifier(var)
89 } else if let Some(layout) = layout {
90 get_qualifier_from_layout(layout)
91 } else {
92 panic!("unknown matrix layout for wmma load instruction");
93 };
94 let ty = get_type_qualifier(value);
96 let matrix = match frag.ident {
97 FragmentIdent::A => "a",
98 FragmentIdent::B => "b",
99 FragmentIdent::Accumulator => "c",
100 FragmentIdent::_Dialect(_) => unreachable!(),
101 };
102 let value_ty = value.item();
103 let opcode = match frag.elem {
104 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
105 format!(
106 "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
107 frag.m, frag.n, frag.k,
108 )
109 }
110 other => panic!("{other} fragment type not supported"),
111 };
112 let mut reg_count = 0;
114 let (regs_decl, out_constraints) =
115 get_variable_regs_decl_constraints(var, true, &mut reg_count);
116 let buffer_reg = format_reg_and_inc(&mut reg_count);
117 let (stride_reg, stride_constraint) =
118 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
119 let tmp_ptr = Variable::tmp_ptr(value.item());
120 let tmp_ptr_left = tmp_ptr.fmt_left();
121 write!(
122 f,
123 r#"// load
124{tmp_ptr_left} = ({value_ty}*){value} + {offset};
125asm volatile(
126 "{opcode} "
127 "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
128 : {out_constraints}
129 : "l"({tmp_ptr}){stride_constraint}
130);
131"#
132 )
133 }
134 WmmaInstruction::LdMatrix {
135 output,
136 buffer,
137 offset,
138 line_size,
139 factor,
140 transpose,
141 } => f.write_str(&ldmatrix_call(
142 output, buffer, offset, line_size, factor, transpose,
143 )),
144 WmmaInstruction::Execute {
145 frag_a: var_a,
146 frag_b: var_b,
147 frag_c: var_c,
148 frag_d: var_d,
149 ..
150 } => {
151 let frag_a = match var_a {
152 Variable::WmmaFragment { frag, .. } => *frag,
153 _ => panic!("variable should be WmmaFragment"),
154 };
155 let layout_a = get_fragment_layout_qualifier(var_a);
156 let layout_b = get_fragment_layout_qualifier(var_b);
157 let type_c = get_type_qualifier(var_c);
158 let type_d = get_type_qualifier(var_d);
159 let opcode = match var_a.elem() {
160 Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
161 "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
162 frag_a.m, frag_a.n, frag_a.k,
163 ),
164 Elem::BF16 => format!(
165 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
166 frag_a.m, frag_a.n, frag_a.k,
167 ),
168 Elem::TF32 => format!(
169 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
170 frag_a.m, frag_a.n, frag_a.k,
171 ),
172 other => panic!("{other} fragment type not supported"),
173 };
174 let mut reg_count = 0;
175 let (regs_decl_d, out_constraints_d) =
177 get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
178 let (regs_decl_a, in_constraints_a) =
179 get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
180 let (regs_decl_b, in_constraints_b) =
181 get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
182 let (regs_decl_c, in_constraints_c) =
183 get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
184 write!(
185 f,
186 r#"// execute
187asm volatile(
188 "{opcode} "
189 "{{{regs_decl_d}}}, "
190 "{{{regs_decl_a}}}, "
191 "{{{regs_decl_b}}}, "
192 "{{{regs_decl_c}}};\n"
193 : {out_constraints_d}
194 : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
195);
196"#
197 )
198 }
199 WmmaInstruction::ExecuteManual {
200 shape,
201 frag_a,
202 frag_b,
203 frag_c,
204 frag_d,
205 } => {
206 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
207 }
208 WmmaInstruction::ExecuteScaled {
209 shape,
210 frag_a,
211 frag_b,
212 frag_c,
213 frag_d,
214
215 scales_a,
216 scales_b,
217 scales_factor,
218 } => Self::compile_scaled_mma(
219 f,
220 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
221 *scales_a,
222 *scales_b,
223 *scales_factor,
224 ),
225 WmmaInstruction::Store {
226 output,
227 frag: var,
228 stride,
229 offset,
230 layout,
231 } => {
232 let frag_acc = match var {
233 Variable::WmmaFragment { frag, .. } => *frag,
234 _ => panic!("variable should be WmmaFragment"),
235 };
236 let layout = match layout {
238 FragmentLayout::ColMajor => "col",
239 FragmentLayout::RowMajor => "row",
240 FragmentLayout::_Dialect(..) => unreachable!(),
241 };
242 let opcode = match var.elem() {
243 Elem::F16 | Elem::BF16 => format!(
244 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
248 frag_acc.m, frag_acc.n, frag_acc.k,
249 ),
250 Elem::TF32 | Elem::F32 => format!(
251 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
253 frag_acc.m, frag_acc.n, frag_acc.k,
254 ),
255 Elem::I32 => format!(
256 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
258 frag_acc.m, frag_acc.n, frag_acc.k,
259 ),
260 other => panic!("{other} fragment type not supported"),
261 };
262 let mut reg_count = 0;
264 let buffer_reg = format_reg_and_inc(&mut reg_count);
265 let (stride_reg, stride_constraint) =
268 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
269 let (regs_decl, in_constraints) =
271 get_variable_regs_decl_constraints(var, false, &mut reg_count);
272 let tmp_ptr = Variable::tmp_ptr(output.item());
273 let tmp_ptr_left = tmp_ptr.fmt_left();
274 write!(
275 f,
276 r#"// store
277{tmp_ptr_left} = {output} + {offset};
278asm volatile(
279 "{opcode} "
280 "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
281 :
282 : "l"({tmp_ptr}),
283 {in_constraints}{stride_constraint}
284);
285"#
286 )
287 }
288 WmmaInstruction::Cast { input, output } => {
289 let frag = match input {
290 Variable::WmmaFragment { frag, .. } => *frag,
291 _ => panic!("variable should be WmmaFragment"),
292 };
293 let reg_count = get_fragment_register_total_count(&frag);
294 match output.elem() {
295 Elem::F16 => {
296 write!(
297 f,
298 "// cast
299for (int i = 0; i < {reg_count}; ++i) {{
300 __half h_lo = __float2half_rn({input}[2*i + 0]);
301 __half h_hi = __float2half_rn({input}[2*i + 1]);
302 __half2 h2 = __halves2half2(h_lo, h_hi);
303 {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
304}}
305"
306 )
307 }
308 Elem::BF16 => {
309 write!(
310 f,
311 "// cast
312for (int i = 0; i < {reg_count}; ++i) {{
313 __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
314 __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
315 __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
316 {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
317}}
318"
319 )
320 }
321 other => panic!("casting fragment to {other} not supported"),
322 }
323 }
324 }
325 }
326
327 fn compile_manual_mma(
328 f: &mut std::fmt::Formatter<'_>,
329 mma: ManualMma<CudaDialect<Self>>,
330 ) -> std::fmt::Result {
331 compile_manual_mma(f, mma)
332 }
333
334 fn compile_scaled_mma(
335 f: &mut std::fmt::Formatter<'_>,
336 mma: ManualMma<CudaDialect<Self>>,
337 scales_a: Variable<CudaDialect<Self>>,
338 scales_b: Variable<CudaDialect<Self>>,
339 scales_factor: u32,
340 ) -> std::fmt::Result {
341 compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
342 }
343
344 fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
345 let mut result: SupportedMmaCombinations = vec![];
346 if arch.get_version() >= WMMA_MINIMUM_VERSION {
347 let types = vec![
349 (
350 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
354 (
355 gpu::ElemType::Float(gpu::FloatKind::F16),
356 gpu::ElemType::Float(gpu::FloatKind::F16),
357 gpu::ElemType::Float(gpu::FloatKind::F32),
358 ),
359 (
360 gpu::ElemType::Float(gpu::FloatKind::BF16),
361 gpu::ElemType::Float(gpu::FloatKind::BF16),
362 gpu::ElemType::Float(gpu::FloatKind::F32),
363 ),
364 ];
365 let combinations: SupportedMmaCombinations = types
366 .into_iter()
367 .map(|(a, b, cd)| MmaConfig {
368 a_type: a.into(),
369 b_type: b.into(),
370 cd_type: cd.into(),
371 m: 16,
372 n: 16,
373 k: 16,
374 })
375 .collect();
376 result.extend(combinations);
377 if arch.get_version() >= 72 {
378 result.extend([
379 MmaConfig {
380 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
381 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
382 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
383 m: 16,
384 n: 16,
385 k: 16,
386 },
387 MmaConfig {
388 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
389 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
390 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
391 m: 16,
392 n: 16,
393 k: 16,
394 },
395 ]);
396 }
397 if arch.get_version() >= 80 {
398 result.push(MmaConfig {
399 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
400 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
401 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
402 m: 16,
403 n: 16,
404 k: 8,
405 });
406 }
407 }
408 result
409 }
410
411 fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
412 supported_mma_combinations(arch)
413 }
414
415 fn supported_scaled_mma_combinations(
416 arch: &CudaArchitecture,
417 ) -> SupportedScaledMmaCombinations {
418 supported_scaled_mma_combinations(arch)
419 }
420}
421
422fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
423 let Fragment {
424 ident,
425 m,
426 n,
427 k,
428 elem,
429 ..
430 } = frag;
431 let elements = match ident {
432 FragmentIdent::A => m * k,
433 FragmentIdent::B => k * n,
434 FragmentIdent::Accumulator => m * n,
435 _ => unreachable!(),
436 };
437 let bits_per_elem = elem.size_bits() as u32;
438 let lanes_per_reg = 32 / bits_per_elem;
440 let threads_per_frag = match ident {
444 FragmentIdent::Accumulator => 32,
445 FragmentIdent::A | FragmentIdent::B => {
446 if frag.elem == Elem::TF32 {
447 32
448 } else {
449 16
450 }
451 }
452 _ => unreachable!(),
453 };
454
455 elements / (lanes_per_reg * threads_per_frag)
456}
457
458fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
459 match var.elem() {
460 Elem::U8 => "u8",
461 Elem::I8 => "s8",
462 Elem::F16 => "f16",
463 Elem::BF16 => "bf16",
464 Elem::F32 => "f32",
465 Elem::TF32 => "tf32",
466 Elem::I32 => "s32",
467 Elem::F64 => "f64",
468 _ => panic!("unsupported WMMA fragment type"),
469 }
470 .to_string()
471}
472
473fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
474 let frag = match var {
475 Variable::WmmaFragment { frag, .. } => *frag,
476 _ => panic!("variable should be WmmaFragment"),
477 };
478 match frag.layout {
479 Some(layout) => get_qualifier_from_layout(&layout),
480 None => "".to_string(),
481 }
482}
483
484fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
485 match layout {
486 FragmentLayout::ColMajor => "col",
487 FragmentLayout::RowMajor => "row",
488 FragmentLayout::_Dialect(..) => unreachable!(),
489 }
490 .to_string()
491}
492
493fn get_variable_regs_decl_constraints(
494 var: &Variable<CudaDialect<PtxWmmaCompiler>>,
495 output: bool,
496 reg_count: &mut u8,
497) -> (String, String) {
498 match var {
499 Variable::WmmaFragment { frag, .. } => {
500 let reg_total_count = get_fragment_register_total_count(frag);
501 let reg_decl = (0..reg_total_count)
502 .map(|_| format_reg_and_inc(reg_count))
503 .collect::<Vec<_>>()
504 .join(",");
505 let frag_elem = frag.elem;
506 let modifier = format!(
507 "{}{}",
508 if output { "=" } else { "" },
509 match frag_elem {
510 Elem::F32 => "f",
511 Elem::F64 => "d",
512 _ => "r",
513 },
514 );
515 let constraints = (0..reg_total_count)
516 .map(|i| format!("\"{modifier}\"({var}[{i}])"))
517 .collect::<Vec<_>>()
518 .join(", ");
519 (reg_decl, constraints)
520 }
521 Variable::ConstantScalar(number, ..) => match number {
522 ConstantScalarValue::UInt(val, ..) => (val.to_string(), "".to_string()),
523 _ => panic!("variable should be an unsigned integer"),
524 },
525 _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
526 }
527}
528
529fn format_reg_and_inc(count: &mut u8) -> String {
530 let res = format!("%{count}");
531 *count += 1;
532 res
533}
534
535fn as_ty(var: impl Display, ty: impl Display) -> String {
536 format!("reinterpret_cast<{ty}&>({var})")
537}
538
539fn as_const_ty(var: impl Display, ty: impl Display) -> String {
540 format!("reinterpret_cast<const {ty}&>({var})")
541}
542
543pub(super) fn compile_manual_mma<D: Dialect>(
544 f: &mut core::fmt::Formatter<'_>,
545 mma: ManualMma<D>,
546) -> std::fmt::Result {
547 let ManualMma {
548 shape,
549 frag_a,
550 frag_b,
551 frag_c,
552 frag_d,
553 } = mma;
554
555 let a_elem = frag_a.elem().unpacked();
556 let b_elem = frag_b.elem().unpacked();
557 let cd_elem = frag_c.elem().unpacked();
558
559 let ab_ty = match a_elem {
560 Elem::F32 => &format!("{}", Elem::<D>::F32),
561 _ => &format!("{}", Elem::<D>::U32),
562 };
563 let cd_ty = match cd_elem {
564 Elem::F32 => &format!("{}", Elem::<D>::F32),
565 _ => &format!("{}", Elem::<D>::U32),
566 };
567
568 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
569 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
570 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
571
572 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
573 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
574 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
575
576 let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
577 let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
578 let frag_c = (0..cd_regs).map(|i| as_const_ty(format!("{frag_c}[{i}]"), cd_ty));
579 let frag_d = (0..cd_regs).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
580 let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
581 write!(
582 f,
583 "__mma_m16n8k{}_{}_{}_{}({args});",
584 shape.k, a_elem, b_elem, cd_elem
585 )
586}
587
588pub(super) fn compile_scaled_mma<D: Dialect>(
589 f: &mut core::fmt::Formatter<'_>,
590 mma: ManualMma<D>,
591 scales_a: Variable<D>,
592 scales_b: Variable<D>,
593 scales_factor: u32,
594) -> std::fmt::Result {
595 let ManualMma {
596 shape,
597 frag_a,
598 frag_b,
599 frag_c,
600 frag_d,
601 } = mma;
602
603 let a_elem = frag_a.elem().unpacked();
604 let b_elem = frag_b.elem().unpacked();
605 let cd_elem = frag_c.elem().unpacked();
606
607 let ab_ty = &format!("{}", Elem::<D>::U32);
608 let cd_ty = &format!("{}", Elem::<D>::F32);
609
610 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
611 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
612 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
613
614 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
615 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
616 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
617
618 let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
619 let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
620 let frag_c = (0..cd_regs).map(|i| as_const_ty(format!("{frag_c}[{i}]"), cd_ty));
621 let frag_d = (0..cd_regs).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
622 let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
623 write!(
624 f,
625 "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
626 shape.k, a_elem, b_elem, cd_elem
627 )
628}
629
630pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
631 let mut result: SupportedMmaCombinations = vec![];
632 if arch.get_version() >= 80 {
636 result.extend([
637 MmaConfig {
638 a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), m: 16,
642 n: 8,
643 k: 16,
644 },
645 MmaConfig {
646 a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
647 b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
648 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
649 m: 16,
650 n: 8,
651 k: 16,
652 },
653 MmaConfig {
654 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
655 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
656 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
657 m: 16,
658 n: 8,
659 k: 8,
660 },
661 MmaConfig {
662 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
663 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
664 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
665 m: 16,
666 n: 8,
667 k: 32,
668 },
669 MmaConfig {
670 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
671 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
672 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
673 m: 16,
674 n: 8,
675 k: 32,
676 },
677 MmaConfig {
678 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
679 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
680 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
681 m: 16,
682 n: 8,
683 k: 32,
684 },
685 MmaConfig {
686 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
687 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
688 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
689 m: 16,
690 n: 8,
691 k: 32,
692 },
693 ]);
695 }
696 if arch.get_version() >= 89 {
697 let f8f6f4_types = [
698 gpu::FloatKind::E4M3,
699 gpu::FloatKind::E5M2,
700 gpu::FloatKind::E3M2,
701 gpu::FloatKind::E2M3,
702 gpu::FloatKind::E2M1,
703 ];
704 let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
705 result.extend(combinations.map(|(t1, t2)| MmaConfig {
706 a_type: gpu::ElemType::Float(*t1).into(),
707 b_type: gpu::ElemType::Float(*t2).into(),
708 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
709 m: 16,
710 n: 8,
711 k: 32,
712 }));
713 }
714 result
715}
716
717pub(super) fn supported_scaled_mma_combinations(
718 arch: &CudaArchitecture,
719) -> SupportedScaledMmaCombinations {
720 let mut result: SupportedScaledMmaCombinations = vec![];
721 if arch.get_version() >= 120 && arch.get_version() < 130 {
723 let f8f6f4_types = [
724 gpu::FloatKind::E4M3,
725 gpu::FloatKind::E5M2,
726 gpu::FloatKind::E3M2,
727 gpu::FloatKind::E2M3,
728 gpu::FloatKind::E2M1,
729 ];
730 let combinations = f8f6f4_types
731 .iter()
732 .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
733
734 result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
735 a_type: gpu::ElemType::Float(*t1).into(),
736 b_type: gpu::ElemType::Float(*t2).into(),
737 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
738 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
739 m: 16,
740 n: 8,
741 k: 32,
742 scales_factor: 1,
743 }));
744
745 result.extend([
746 ScaledMmaConfig {
747 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
748 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
749 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
750 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
751 m: 16,
752 n: 8,
753 k: 64,
754 scales_factor: 2,
755 },
756 ScaledMmaConfig {
758 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
759 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
760 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
761 scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
762 m: 16,
763 n: 8,
764 k: 64,
765 scales_factor: 4,
766 },
767 ]);
768 }
769 result
770}