1use std::fmt::Display;
2
3use crate::{
4 Dialect,
5 cuda::{CudaDialect, arch::CudaArchitecture, ptx::comma_separated},
6 shared::{
7 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8 FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
9 SupportedScaledMmaCombinations, Variable, WmmaInstruction,
10 },
11};
12use cubecl_core::ir::{self as gpu, ConstantScalarValue};
13use cubecl_runtime::{MmaConfig, ScaledMmaConfig};
14use itertools::Itertools;
15
16use super::WMMA_MINIMUM_VERSION;
17
18#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
19pub struct PtxWmmaCompiler {}
20
21impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
22 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
23 if flags.elem_tf32 {
25 f.write_str("#include <mma.h>\n")?;
26 }
27 Ok(())
28 }
29
30 fn compile_wmma_fragment_declaration(
31 f: &mut std::fmt::Formatter<'_>,
32 var: &Variable<CudaDialect<Self>>,
33 ) -> std::fmt::Result {
34 let frag = match var {
35 Variable::WmmaFragment { frag, .. } => *frag,
36 _ => panic!("load instruction expects a WmmaFragment"),
37 };
38 let reg_count = get_fragment_register_total_count(&frag);
39 let ty = match frag.elem {
40 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
41 Elem::F32 => "float",
42 Elem::F64 => "double",
43 _ => panic!("unsupported type"),
44 };
45 writeln!(f, "{ty} {var}[{reg_count}];")
46 }
47
48 fn compile_wmma_instruction(
49 f: &mut std::fmt::Formatter<'_>,
50 instruction: &WmmaInstruction<CudaDialect<Self>>,
51 ) -> std::fmt::Result {
52 match instruction {
53 WmmaInstruction::Fill { frag: var, value } => {
54 let frag = match var {
55 Variable::WmmaFragment { frag, .. } => *frag,
56 _ => panic!("variable should be WmmaFragment"),
57 };
58 let reg_count = get_fragment_register_total_count(&frag);
59 write!(
60 f,
61 "// fill
62for (uint i = 0; i < uint({reg_count}); ++i) {{
63 {var}[i] = {value};
64}}
65 "
66 )
67 }
68 WmmaInstruction::Load {
69 frag: var,
70 value,
71 offset,
72 stride,
73 layout,
74 } => {
75 let frag = match var {
76 Variable::WmmaFragment { frag, .. } => *frag,
77 _ => panic!("load instruction expects a WmmaFragment"),
78 };
79 let layout = if frag.layout.is_some() {
84 get_fragment_layout_qualifier(var)
85 } else if let Some(layout) = layout {
86 get_qualifier_from_layout(layout)
87 } else {
88 panic!("unknown matrix layout for wmma load instruction");
89 };
90 let ty = get_type_qualifier(value);
92 let matrix = match frag.ident {
93 FragmentIdent::A => "a",
94 FragmentIdent::B => "b",
95 FragmentIdent::Accumulator => "c",
96 FragmentIdent::_Dialect(_) => unreachable!(),
97 };
98 let value_ty = value.item();
99 let opcode = match frag.elem {
100 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
101 format!(
102 "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
103 frag.m, frag.n, frag.k,
104 )
105 }
106 other => panic!("{other} fragment type not supported"),
107 };
108 let mut reg_count = 0;
110 let (regs_decl, out_constraints) =
111 get_variable_regs_decl_constraints(var, true, &mut reg_count);
112 let buffer_reg = format_reg_and_inc(&mut reg_count);
113 let (stride_reg, stride_constraint) =
114 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
115 let tmp_ptr = Variable::tmp_ptr(value.item());
116 let tmp_ptr_left = tmp_ptr.fmt_left();
117 write!(
118 f,
119 r#"// load
120{tmp_ptr_left} = ({value_ty}*){value} + {offset};
121asm volatile(
122 "{opcode} "
123 "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
124 : {out_constraints}
125 : "l"({tmp_ptr}){stride_constraint}
126);
127"#
128 )
129 }
130 WmmaInstruction::Execute {
131 frag_a: var_a,
132 frag_b: var_b,
133 frag_c: var_c,
134 frag_d: var_d,
135 ..
136 } => {
137 let frag_a = match var_a {
138 Variable::WmmaFragment { frag, .. } => *frag,
139 _ => panic!("variable should be WmmaFragment"),
140 };
141 let layout_a = get_fragment_layout_qualifier(var_a);
142 let layout_b = get_fragment_layout_qualifier(var_b);
143 let type_c = get_type_qualifier(var_c);
144 let type_d = get_type_qualifier(var_d);
145 let opcode = match var_a.elem() {
146 Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
147 "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
148 frag_a.m, frag_a.n, frag_a.k,
149 ),
150 Elem::BF16 => format!(
151 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
152 frag_a.m, frag_a.n, frag_a.k,
153 ),
154 Elem::TF32 => format!(
155 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
156 frag_a.m, frag_a.n, frag_a.k,
157 ),
158 other => panic!("{other} fragment type not supported"),
159 };
160 let mut reg_count = 0;
161 let (regs_decl_d, out_constraints_d) =
163 get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
164 let (regs_decl_a, in_constraints_a) =
165 get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
166 let (regs_decl_b, in_constraints_b) =
167 get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
168 let (regs_decl_c, in_constraints_c) =
169 get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
170 write!(
171 f,
172 r#"// execute
173asm volatile(
174 "{opcode} "
175 "{{{regs_decl_d}}}, "
176 "{{{regs_decl_a}}}, "
177 "{{{regs_decl_b}}}, "
178 "{{{regs_decl_c}}};\n"
179 : {out_constraints_d}
180 : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
181);
182"#
183 )
184 }
185 WmmaInstruction::ExecuteManual {
186 shape,
187 frag_a,
188 frag_b,
189 frag_c,
190 frag_d,
191 } => {
192 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
193 }
194 WmmaInstruction::ExecuteScaled {
195 shape,
196 frag_a,
197 frag_b,
198 frag_c,
199 frag_d,
200
201 scales_a,
202 scales_b,
203 scales_factor,
204 } => Self::compile_scaled_mma(
205 f,
206 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
207 *scales_a,
208 *scales_b,
209 *scales_factor,
210 ),
211 WmmaInstruction::Store {
212 output,
213 frag: var,
214 stride,
215 offset,
216 layout,
217 } => {
218 let frag_acc = match var {
219 Variable::WmmaFragment { frag, .. } => *frag,
220 _ => panic!("variable should be WmmaFragment"),
221 };
222 let layout = match layout {
224 FragmentLayout::ColMajor => "col",
225 FragmentLayout::RowMajor => "row",
226 FragmentLayout::_Dialect(..) => unreachable!(),
227 };
228 let opcode = match var.elem() {
229 Elem::F16 | Elem::BF16 => format!(
230 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
234 frag_acc.m, frag_acc.n, frag_acc.k,
235 ),
236 Elem::TF32 | Elem::F32 => format!(
237 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
239 frag_acc.m, frag_acc.n, frag_acc.k,
240 ),
241 Elem::I32 => format!(
242 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
244 frag_acc.m, frag_acc.n, frag_acc.k,
245 ),
246 other => panic!("{other} fragment type not supported"),
247 };
248 let mut reg_count = 0;
250 let buffer_reg = format_reg_and_inc(&mut reg_count);
251 let (stride_reg, stride_constraint) =
254 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
255 let (regs_decl, in_constraints) =
257 get_variable_regs_decl_constraints(var, false, &mut reg_count);
258 let tmp_ptr = Variable::tmp_ptr(output.item());
259 let tmp_ptr_left = tmp_ptr.fmt_left();
260 write!(
261 f,
262 r#"// store
263{tmp_ptr_left} = {output} + {offset};
264asm volatile(
265 "{opcode} "
266 "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
267 :
268 : "l"({tmp_ptr}),
269 {in_constraints}{stride_constraint}
270);
271"#
272 )
273 }
274 WmmaInstruction::Cast { input, output } => {
275 let frag = match input {
276 Variable::WmmaFragment { frag, .. } => *frag,
277 _ => panic!("variable should be WmmaFragment"),
278 };
279 let reg_count = get_fragment_register_total_count(&frag);
280 match output.elem() {
281 Elem::F16 => {
282 write!(
283 f,
284 "// cast
285for (int i = 0; i < {reg_count}; ++i) {{
286 __half h_lo = __float2half_rn({input}[2*i + 0]);
287 __half h_hi = __float2half_rn({input}[2*i + 1]);
288 __half2 h2 = __halves2half2(h_lo, h_hi);
289 {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
290}}
291"
292 )
293 }
294 Elem::BF16 => {
295 write!(
296 f,
297 "// cast
298for (int i = 0; i < {reg_count}; ++i) {{
299 __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
300 __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
301 __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
302 {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
303}}
304"
305 )
306 }
307 other => panic!("casting fragment to {other} not supported"),
308 }
309 }
310 }
311 }
312
313 fn compile_manual_mma(
314 f: &mut std::fmt::Formatter<'_>,
315 mma: ManualMma<CudaDialect<Self>>,
316 ) -> std::fmt::Result {
317 compile_manual_mma(f, mma)
318 }
319
320 fn compile_scaled_mma(
321 f: &mut std::fmt::Formatter<'_>,
322 mma: ManualMma<CudaDialect<Self>>,
323 scales_a: Variable<CudaDialect<Self>>,
324 scales_b: Variable<CudaDialect<Self>>,
325 scales_factor: u32,
326 ) -> std::fmt::Result {
327 compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
328 }
329
330 fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
331 let mut result: SupportedMmaCombinations = vec![];
332 if arch.get_version() >= WMMA_MINIMUM_VERSION {
333 let types = vec![
335 (
336 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
340 (
341 gpu::ElemType::Float(gpu::FloatKind::F16),
342 gpu::ElemType::Float(gpu::FloatKind::F16),
343 gpu::ElemType::Float(gpu::FloatKind::F32),
344 ),
345 (
346 gpu::ElemType::Float(gpu::FloatKind::BF16),
347 gpu::ElemType::Float(gpu::FloatKind::BF16),
348 gpu::ElemType::Float(gpu::FloatKind::F32),
349 ),
350 ];
351 let combinations: SupportedMmaCombinations = types
352 .into_iter()
353 .map(|(a, b, cd)| MmaConfig {
354 a_type: a.into(),
355 b_type: b.into(),
356 cd_type: cd.into(),
357 m: 16,
358 n: 16,
359 k: 16,
360 })
361 .collect();
362 result.extend(combinations);
363 if arch.get_version() >= 72 {
364 result.extend([
365 MmaConfig {
366 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
367 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
368 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
369 m: 16,
370 n: 16,
371 k: 16,
372 },
373 MmaConfig {
374 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
375 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
376 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
377 m: 16,
378 n: 16,
379 k: 16,
380 },
381 ]);
382 }
383 if arch.get_version() >= 80 {
384 result.push(MmaConfig {
385 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
386 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
387 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
388 m: 16,
389 n: 16,
390 k: 8,
391 });
392 }
393 }
394 result
395 }
396
397 fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
398 supported_mma_combinations(arch)
399 }
400
401 fn supported_scaled_mma_combinations(
402 arch: &CudaArchitecture,
403 ) -> SupportedScaledMmaCombinations {
404 supported_scaled_mma_combinations(arch)
405 }
406}
407
408fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
409 let Fragment {
410 ident,
411 m,
412 n,
413 k,
414 elem,
415 ..
416 } = frag;
417 let elements = match ident {
418 FragmentIdent::A => m * k,
419 FragmentIdent::B => k * n,
420 FragmentIdent::Accumulator => m * n,
421 _ => unreachable!(),
422 };
423 let bits_per_elem = elem.size_bits() as u32;
424 let lanes_per_reg = 32 / bits_per_elem;
426 let threads_per_frag = match ident {
430 FragmentIdent::Accumulator => 32,
431 FragmentIdent::A | FragmentIdent::B => {
432 if frag.elem == Elem::TF32 {
433 32
434 } else {
435 16
436 }
437 }
438 _ => unreachable!(),
439 };
440
441 elements / (lanes_per_reg * threads_per_frag)
442}
443
444fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
445 match var.elem() {
446 Elem::U8 => "u8",
447 Elem::I8 => "s8",
448 Elem::F16 => "f16",
449 Elem::BF16 => "bf16",
450 Elem::F32 => "f32",
451 Elem::TF32 => "tf32",
452 Elem::I32 => "s32",
453 Elem::F64 => "f64",
454 _ => panic!("unsupported WMMA fragment type"),
455 }
456 .to_string()
457}
458
459fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
460 let frag = match var {
461 Variable::WmmaFragment { frag, .. } => *frag,
462 _ => panic!("variable should be WmmaFragment"),
463 };
464 match frag.layout {
465 Some(layout) => get_qualifier_from_layout(&layout),
466 None => "".to_string(),
467 }
468}
469
470fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
471 match layout {
472 FragmentLayout::ColMajor => "col",
473 FragmentLayout::RowMajor => "row",
474 FragmentLayout::_Dialect(..) => unreachable!(),
475 }
476 .to_string()
477}
478
479fn get_variable_regs_decl_constraints(
480 var: &Variable<CudaDialect<PtxWmmaCompiler>>,
481 output: bool,
482 reg_count: &mut u8,
483) -> (String, String) {
484 match var {
485 Variable::WmmaFragment { frag, .. } => {
486 let reg_total_count = get_fragment_register_total_count(frag);
487 let reg_decl = (0..reg_total_count)
488 .map(|_| format_reg_and_inc(reg_count))
489 .collect::<Vec<_>>()
490 .join(",");
491 let frag_elem = frag.elem;
492 let modifier = format!(
493 "{}{}",
494 if output { "=" } else { "" },
495 match frag_elem {
496 Elem::F32 => "f",
497 Elem::F64 => "d",
498 _ => "r",
499 },
500 );
501 let constraints = (0..reg_total_count)
502 .map(|i| format!("\"{modifier}\"({var}[{i}])"))
503 .collect::<Vec<_>>()
504 .join(", ");
505 (reg_decl, constraints)
506 }
507 Variable::ConstantScalar(number, ..) => match number {
508 ConstantScalarValue::UInt(val, ..) => (val.to_string(), "".to_string()),
509 _ => panic!("variable should be an unsigned integer"),
510 },
511 _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
512 }
513}
514
515fn format_reg_and_inc(count: &mut u8) -> String {
516 let res = format!("%{count}");
517 *count += 1;
518 res
519}
520
521fn as_ty(var: impl Display, ty: impl Display) -> String {
522 format!("reinterpret_cast<{ty}&>({var})")
523}
524
525fn as_const_ty(var: impl Display, ty: impl Display) -> String {
526 format!("reinterpret_cast<const {ty}&>({var})")
527}
528
529pub(super) fn compile_manual_mma<D: Dialect>(
530 f: &mut core::fmt::Formatter<'_>,
531 mma: ManualMma<D>,
532) -> std::fmt::Result {
533 let ManualMma {
534 shape,
535 frag_a,
536 frag_b,
537 frag_c,
538 frag_d,
539 } = mma;
540
541 let a_elem = frag_a[0].elem().unpacked();
542 let b_elem = frag_b[0].elem().unpacked();
543 let cd_elem = frag_c[0].elem().unpacked();
544
545 let ab_ty = match a_elem {
546 Elem::F32 => &format!("{}", Elem::<D>::F32),
547 _ => &format!("{}", Elem::<D>::U32),
548 };
549 let cd_ty = match cd_elem {
550 Elem::F32 => &format!("{}", Elem::<D>::F32),
551 _ => &format!("{}", Elem::<D>::U32),
552 };
553
554 let acc_elems = frag_c.len();
555 let frag_ab = frag_a.iter().chain(frag_b).map(|it| as_const_ty(it, ab_ty));
556 let frag_c = frag_c.iter().map(|it| as_const_ty(it, cd_ty));
557 let frag_d = (0..acc_elems).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
558 let args = comma_separated(frag_ab.chain(frag_c).chain(frag_d));
559 write!(
560 f,
561 "__mma_m16n8k{}_{}_{}_{}({args});",
562 shape.k, a_elem, b_elem, cd_elem
563 )
564}
565
566pub(super) fn compile_scaled_mma<D: Dialect>(
567 f: &mut core::fmt::Formatter<'_>,
568 mma: ManualMma<D>,
569 scales_a: Variable<D>,
570 scales_b: Variable<D>,
571 scales_factor: u32,
572) -> std::fmt::Result {
573 let ManualMma {
574 shape,
575 frag_a,
576 frag_b,
577 frag_c,
578 frag_d,
579 } = mma;
580
581 let a_elem = frag_a[0].elem().unpacked();
582 let b_elem = frag_b[0].elem().unpacked();
583 let cd_elem = frag_c[0].elem().unpacked();
584 let ab_ty = &format!("{}", Elem::<D>::U32);
585 let cd_ty = &format!("{}", Elem::<D>::F32);
586 let acc_elems = frag_c.len();
587 let frag_ab = frag_a.iter().chain(frag_b).map(|it| as_const_ty(it, ab_ty));
588 let frag_c = frag_c.iter().map(|it| as_const_ty(it, cd_ty));
589 let frag_d = (0..acc_elems).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
590 let fragments = comma_separated(frag_ab.chain(frag_c).chain(frag_d));
591 write!(
592 f,
593 "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
594 shape.k, a_elem, b_elem, cd_elem
595 )
596}
597
598pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
599 let mut result: SupportedMmaCombinations = vec![];
600 if arch.get_version() >= 80 {
604 result.extend([
605 MmaConfig {
606 a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), m: 16,
610 n: 8,
611 k: 16,
612 },
613 MmaConfig {
614 a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
615 b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
616 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
617 m: 16,
618 n: 8,
619 k: 16,
620 },
621 MmaConfig {
622 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
623 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
624 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
625 m: 16,
626 n: 8,
627 k: 8,
628 },
629 MmaConfig {
630 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
631 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
632 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
633 m: 16,
634 n: 8,
635 k: 32,
636 },
637 MmaConfig {
638 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
639 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
640 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
641 m: 16,
642 n: 8,
643 k: 32,
644 },
645 MmaConfig {
646 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
647 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
648 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
649 m: 16,
650 n: 8,
651 k: 32,
652 },
653 MmaConfig {
654 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
655 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
656 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
657 m: 16,
658 n: 8,
659 k: 32,
660 },
661 ]);
663 }
664 if arch.get_version() >= 89 {
665 let f8f6f4_types = [
666 gpu::FloatKind::E4M3,
667 gpu::FloatKind::E5M2,
668 gpu::FloatKind::E3M2,
669 gpu::FloatKind::E2M3,
670 gpu::FloatKind::E2M1,
671 ];
672 let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
673 result.extend(combinations.map(|(t1, t2)| MmaConfig {
674 a_type: gpu::ElemType::Float(*t1).into(),
675 b_type: gpu::ElemType::Float(*t2).into(),
676 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
677 m: 16,
678 n: 8,
679 k: 32,
680 }));
681 }
682 result
683}
684
685pub(super) fn supported_scaled_mma_combinations(
686 arch: &CudaArchitecture,
687) -> SupportedScaledMmaCombinations {
688 let mut result: SupportedScaledMmaCombinations = vec![];
689 if arch.get_version() >= 120 && arch.get_version() < 130 {
691 let f8f6f4_types = [
692 gpu::FloatKind::E4M3,
693 gpu::FloatKind::E5M2,
694 gpu::FloatKind::E3M2,
695 gpu::FloatKind::E2M3,
696 gpu::FloatKind::E2M1,
697 ];
698 let combinations = f8f6f4_types
699 .iter()
700 .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
701
702 result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
703 a_type: gpu::ElemType::Float(*t1).into(),
704 b_type: gpu::ElemType::Float(*t2).into(),
705 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
706 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
707 m: 16,
708 n: 8,
709 k: 32,
710 scales_factor: 1,
711 }));
712
713 result.extend([
714 ScaledMmaConfig {
715 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
716 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
717 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
718 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
719 m: 16,
720 n: 8,
721 k: 64,
722 scales_factor: 2,
723 },
724 ScaledMmaConfig {
726 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
727 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
728 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
729 scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
730 m: 16,
731 n: 8,
732 k: 64,
733 scales_factor: 4,
734 },
735 ]);
736 }
737 result
738}