1use std::fmt::Display;
2
3use crate::{
4 Dialect,
5 cuda::{
6 CudaDialect,
7 arch::CudaArchitecture,
8 ptx::{comma_separated, ldmatrix_call, stmatrix_call},
9 },
10 shared::{
11 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
12 FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
13 SupportedScaledMmaCombinations, Variable, WmmaInstruction,
14 },
15};
16use cubecl_core::ir::{self as gpu, ConstantScalarValue, Matrix, MatrixIdent};
17use cubecl_runtime::{MmaConfig, ScaledMmaConfig};
18use itertools::Itertools;
19
20use super::WMMA_MINIMUM_VERSION;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
27 if flags.elem_tf32 {
29 f.write_str("#include <mma.h>\n")?;
30 }
31 Ok(())
32 }
33
34 fn compile_wmma_fragment_declaration(
35 f: &mut std::fmt::Formatter<'_>,
36 var: &Variable<CudaDialect<Self>>,
37 ) -> std::fmt::Result {
38 let frag = match var {
39 Variable::WmmaFragment { frag, .. } => *frag,
40 _ => panic!("load instruction expects a WmmaFragment"),
41 };
42 let reg_count = get_fragment_register_total_count(&frag);
43 let ty = match frag.elem {
44 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
45 Elem::F32 => "float",
46 Elem::F64 => "double",
47 _ => panic!("unsupported type"),
48 };
49 writeln!(f, "{ty} {var}[{reg_count}];")
50 }
51
52 fn compile_wmma_instruction(
53 f: &mut std::fmt::Formatter<'_>,
54 instruction: &WmmaInstruction<CudaDialect<Self>>,
55 ) -> std::fmt::Result {
56 match instruction {
57 WmmaInstruction::Fill { frag: var, value } => {
58 let frag = match var {
59 Variable::WmmaFragment { frag, .. } => *frag,
60 _ => panic!("variable should be WmmaFragment"),
61 };
62 let reg_count = get_fragment_register_total_count(&frag);
63 write!(
64 f,
65 "// fill
66for (uint i = 0; i < uint({reg_count}); ++i) {{
67 {var}[i] = {value};
68}}
69 "
70 )
71 }
72 WmmaInstruction::Load {
73 frag: var,
74 value,
75 offset,
76 stride,
77 layout,
78 } => {
79 let frag = match var {
80 Variable::WmmaFragment { frag, .. } => *frag,
81 _ => panic!("load instruction expects a WmmaFragment"),
82 };
83 let layout = if frag.layout.is_some() {
88 get_fragment_layout_qualifier(var)
89 } else if let Some(layout) = layout {
90 get_qualifier_from_layout(layout)
91 } else {
92 panic!("unknown matrix layout for wmma load instruction");
93 };
94 let ty = get_type_qualifier(value);
96 let matrix = match frag.ident {
97 FragmentIdent::A => "a",
98 FragmentIdent::B => "b",
99 FragmentIdent::Accumulator => "c",
100 FragmentIdent::_Dialect(_) => unreachable!(),
101 };
102 let value_ty = value.item();
103 let opcode = match frag.elem {
104 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
105 format!(
106 "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
107 frag.m, frag.n, frag.k,
108 )
109 }
110 other => panic!("{other} fragment type not supported"),
111 };
112 let mut reg_count = 0;
114 let (regs_decl, out_constraints) =
115 get_variable_regs_decl_constraints(var, true, &mut reg_count);
116 let buffer_reg = format_reg_and_inc(&mut reg_count);
117 let (stride_reg, stride_constraint) =
118 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
119 let tmp_ptr = Variable::tmp_ptr(value.item());
120 let tmp_ptr_left = tmp_ptr.fmt_left();
121 write!(
122 f,
123 r#"// load
124{tmp_ptr_left} = ({value_ty}*){value} + {offset};
125asm volatile(
126 "{opcode} "
127 "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
128 : {out_constraints}
129 : "l"({tmp_ptr}){stride_constraint}
130);
131"#
132 )
133 }
134 WmmaInstruction::LdMatrix {
135 output,
136 buffer,
137 offset,
138 line_size,
139 factor,
140 transpose,
141 } => f.write_str(&ldmatrix_call(
142 output, buffer, offset, line_size, factor, transpose,
143 )),
144 WmmaInstruction::StMatrix {
145 registers,
146 buffer,
147 offset,
148 line_size,
149 factor,
150 transpose,
151 } => f.write_str(&stmatrix_call(
152 registers, buffer, offset, line_size, factor, transpose,
153 )),
154 WmmaInstruction::Execute {
155 frag_a: var_a,
156 frag_b: var_b,
157 frag_c: var_c,
158 frag_d: var_d,
159 ..
160 } => {
161 let frag_a = match var_a {
162 Variable::WmmaFragment { frag, .. } => *frag,
163 _ => panic!("variable should be WmmaFragment"),
164 };
165 let layout_a = get_fragment_layout_qualifier(var_a);
166 let layout_b = get_fragment_layout_qualifier(var_b);
167 let type_c = get_type_qualifier(var_c);
168 let type_d = get_type_qualifier(var_d);
169 let opcode = match var_a.elem() {
170 Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
171 "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
172 frag_a.m, frag_a.n, frag_a.k,
173 ),
174 Elem::BF16 => format!(
175 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
176 frag_a.m, frag_a.n, frag_a.k,
177 ),
178 Elem::TF32 => format!(
179 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
180 frag_a.m, frag_a.n, frag_a.k,
181 ),
182 other => panic!("{other} fragment type not supported"),
183 };
184 let mut reg_count = 0;
185 let (regs_decl_d, out_constraints_d) =
187 get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
188 let (regs_decl_a, in_constraints_a) =
189 get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
190 let (regs_decl_b, in_constraints_b) =
191 get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
192 let (regs_decl_c, in_constraints_c) =
193 get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
194 write!(
195 f,
196 r#"// execute
197asm volatile(
198 "{opcode} "
199 "{{{regs_decl_d}}}, "
200 "{{{regs_decl_a}}}, "
201 "{{{regs_decl_b}}}, "
202 "{{{regs_decl_c}}};\n"
203 : {out_constraints_d}
204 : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
205);
206"#
207 )
208 }
209 WmmaInstruction::ExecuteManual {
210 shape,
211 frag_a,
212 frag_b,
213 frag_c,
214 frag_d,
215 } => {
216 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
217 }
218 WmmaInstruction::ExecuteScaled {
219 shape,
220 frag_a,
221 frag_b,
222 frag_c,
223 frag_d,
224
225 scales_a,
226 scales_b,
227 scales_factor,
228 } => Self::compile_scaled_mma(
229 f,
230 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
231 *scales_a,
232 *scales_b,
233 *scales_factor,
234 ),
235 WmmaInstruction::Store {
236 output,
237 frag: var,
238 stride,
239 offset,
240 layout,
241 } => {
242 let frag_acc = match var {
243 Variable::WmmaFragment { frag, .. } => *frag,
244 _ => panic!("variable should be WmmaFragment"),
245 };
246 let layout = match layout {
248 FragmentLayout::ColMajor => "col",
249 FragmentLayout::RowMajor => "row",
250 FragmentLayout::_Dialect(..) => unreachable!(),
251 };
252 let opcode = match var.elem() {
253 Elem::F16 | Elem::BF16 => format!(
254 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
258 frag_acc.m, frag_acc.n, frag_acc.k,
259 ),
260 Elem::TF32 | Elem::F32 => format!(
261 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
263 frag_acc.m, frag_acc.n, frag_acc.k,
264 ),
265 Elem::I32 => format!(
266 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
268 frag_acc.m, frag_acc.n, frag_acc.k,
269 ),
270 other => panic!("{other} fragment type not supported"),
271 };
272 let mut reg_count = 0;
274 let buffer_reg = format_reg_and_inc(&mut reg_count);
275 let (stride_reg, stride_constraint) =
278 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
279 let (regs_decl, in_constraints) =
281 get_variable_regs_decl_constraints(var, false, &mut reg_count);
282 let tmp_ptr = Variable::tmp_ptr(output.item());
283 let tmp_ptr_left = tmp_ptr.fmt_left();
284 write!(
285 f,
286 r#"// store
287{tmp_ptr_left} = {output} + {offset};
288asm volatile(
289 "{opcode} "
290 "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
291 :
292 : "l"({tmp_ptr}),
293 {in_constraints}{stride_constraint}
294);
295"#
296 )
297 }
298 WmmaInstruction::Cast { input, output } => {
299 let frag = match input {
300 Variable::WmmaFragment { frag, .. } => *frag,
301 _ => panic!("variable should be WmmaFragment"),
302 };
303 let reg_count = get_fragment_register_total_count(&frag);
304 match output.elem() {
305 Elem::F16 => {
306 write!(
307 f,
308 "// cast
309for (int i = 0; i < {reg_count}; ++i) {{
310 __half h_lo = __float2half_rn({input}[2*i + 0]);
311 __half h_hi = __float2half_rn({input}[2*i + 1]);
312 __half2 h2 = __halves2half2(h_lo, h_hi);
313 {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
314}}
315"
316 )
317 }
318 Elem::BF16 => {
319 write!(
320 f,
321 "// cast
322for (int i = 0; i < {reg_count}; ++i) {{
323 __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
324 __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
325 __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
326 {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
327}}
328"
329 )
330 }
331 other => panic!("casting fragment to {other} not supported"),
332 }
333 }
334 }
335 }
336
337 fn compile_manual_mma(
338 f: &mut std::fmt::Formatter<'_>,
339 mma: ManualMma<CudaDialect<Self>>,
340 ) -> std::fmt::Result {
341 compile_manual_mma(f, mma)
342 }
343
344 fn compile_scaled_mma(
345 f: &mut std::fmt::Formatter<'_>,
346 mma: ManualMma<CudaDialect<Self>>,
347 scales_a: Variable<CudaDialect<Self>>,
348 scales_b: Variable<CudaDialect<Self>>,
349 scales_factor: u32,
350 ) -> std::fmt::Result {
351 compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
352 }
353
354 fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
355 let mut result: SupportedMmaCombinations = vec![];
356 if arch.get_version() >= WMMA_MINIMUM_VERSION {
357 let types = vec![
359 (
360 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
364 (
365 gpu::ElemType::Float(gpu::FloatKind::F16),
366 gpu::ElemType::Float(gpu::FloatKind::F16),
367 gpu::ElemType::Float(gpu::FloatKind::F32),
368 ),
369 (
370 gpu::ElemType::Float(gpu::FloatKind::BF16),
371 gpu::ElemType::Float(gpu::FloatKind::BF16),
372 gpu::ElemType::Float(gpu::FloatKind::F32),
373 ),
374 ];
375 let combinations: SupportedMmaCombinations = types
376 .into_iter()
377 .map(|(a, b, cd)| MmaConfig {
378 a_type: a.into(),
379 b_type: b.into(),
380 cd_type: cd.into(),
381 m: 16,
382 n: 16,
383 k: 16,
384 })
385 .collect();
386 result.extend(combinations);
387 if arch.get_version() >= 72 {
388 result.extend([
389 MmaConfig {
390 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
391 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
392 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
393 m: 16,
394 n: 16,
395 k: 16,
396 },
397 MmaConfig {
398 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
399 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
400 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
401 m: 16,
402 n: 16,
403 k: 16,
404 },
405 ]);
406 }
407 if arch.get_version() >= 80 {
408 result.push(MmaConfig {
409 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
410 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
411 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
412 m: 16,
413 n: 16,
414 k: 8,
415 });
416 }
417 }
418 result
419 }
420
421 fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
422 supported_mma_combinations(arch)
423 }
424
425 fn supported_scaled_mma_combinations(
426 arch: &CudaArchitecture,
427 ) -> SupportedScaledMmaCombinations {
428 supported_scaled_mma_combinations(arch)
429 }
430}
431
432fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
433 let Fragment {
434 ident,
435 m,
436 n,
437 k,
438 elem,
439 ..
440 } = frag;
441 let elements = match ident {
442 FragmentIdent::A => m * k,
443 FragmentIdent::B => k * n,
444 FragmentIdent::Accumulator => m * n,
445 _ => unreachable!(),
446 };
447 let bits_per_elem = elem.size_bits() as u32;
448 let lanes_per_reg = 32 / bits_per_elem;
450 let threads_per_frag = match ident {
454 FragmentIdent::Accumulator => 32,
455 FragmentIdent::A | FragmentIdent::B => {
456 if frag.elem == Elem::TF32 {
457 32
458 } else {
459 16
460 }
461 }
462 _ => unreachable!(),
463 };
464
465 elements / (lanes_per_reg * threads_per_frag)
466}
467
468fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
469 match var.elem() {
470 Elem::U8 => "u8",
471 Elem::I8 => "s8",
472 Elem::F16 => "f16",
473 Elem::BF16 => "bf16",
474 Elem::F32 => "f32",
475 Elem::TF32 => "tf32",
476 Elem::I32 => "s32",
477 Elem::F64 => "f64",
478 _ => panic!("unsupported WMMA fragment type"),
479 }
480 .to_string()
481}
482
483fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
484 let frag = match var {
485 Variable::WmmaFragment { frag, .. } => *frag,
486 _ => panic!("variable should be WmmaFragment"),
487 };
488 match frag.layout {
489 Some(layout) => get_qualifier_from_layout(&layout),
490 None => "".to_string(),
491 }
492}
493
494fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
495 match layout {
496 FragmentLayout::ColMajor => "col",
497 FragmentLayout::RowMajor => "row",
498 FragmentLayout::_Dialect(..) => unreachable!(),
499 }
500 .to_string()
501}
502
503fn get_variable_regs_decl_constraints(
504 var: &Variable<CudaDialect<PtxWmmaCompiler>>,
505 output: bool,
506 reg_count: &mut u8,
507) -> (String, String) {
508 match var {
509 Variable::WmmaFragment { frag, .. } => {
510 let reg_total_count = get_fragment_register_total_count(frag);
511 let reg_decl = (0..reg_total_count)
512 .map(|_| format_reg_and_inc(reg_count))
513 .collect::<Vec<_>>()
514 .join(",");
515 let frag_elem = frag.elem;
516 let modifier = format!(
517 "{}{}",
518 if output { "=" } else { "" },
519 match frag_elem {
520 Elem::F32 => "f",
521 Elem::F64 => "d",
522 _ => "r",
523 },
524 );
525 let constraints = (0..reg_total_count)
526 .map(|i| format!("\"{modifier}\"({var}[{i}])"))
527 .collect::<Vec<_>>()
528 .join(", ");
529 (reg_decl, constraints)
530 }
531 Variable::ConstantScalar(number, ..) => match number {
532 ConstantScalarValue::UInt(val, ..) => (val.to_string(), "".to_string()),
533 _ => panic!("variable should be an unsigned integer"),
534 },
535 _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
536 }
537}
538
539fn format_reg_and_inc(count: &mut u8) -> String {
540 let res = format!("%{count}");
541 *count += 1;
542 res
543}
544
545fn as_ty(var: impl Display, ty: impl Display) -> String {
546 format!("reinterpret_cast<{ty}&>({var})")
547}
548
549fn as_const_ty(var: impl Display, ty: impl Display) -> String {
550 format!("reinterpret_cast<const {ty}&>({var})")
551}
552
553pub(super) fn compile_manual_mma<D: Dialect>(
554 f: &mut core::fmt::Formatter<'_>,
555 mma: ManualMma<D>,
556) -> std::fmt::Result {
557 let ManualMma {
558 shape,
559 frag_a,
560 frag_b,
561 frag_c,
562 frag_d,
563 } = mma;
564
565 let a_elem = frag_a.elem().unpacked();
566 let b_elem = frag_b.elem().unpacked();
567 let cd_elem = frag_c.elem().unpacked();
568
569 let ab_ty = match a_elem {
570 Elem::F32 => &format!("{}", Elem::<D>::F32),
571 _ => &format!("{}", Elem::<D>::U32),
572 };
573 let cd_ty = match cd_elem {
574 Elem::F32 => &format!("{}", Elem::<D>::F32),
575 _ => &format!("{}", Elem::<D>::U32),
576 };
577
578 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
579 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
580 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
581
582 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
583 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
584 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
585
586 let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
587 let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
588
589 let frag_c = match cd_elem.size() {
592 4 | 8 => (0..cd_regs)
593 .map(|i| as_ty(format!("{frag_c}[{}].i_{}", i / 2, i % 2), cd_ty))
594 .collect::<Vec<_>>(),
595 2 => (0..cd_regs)
596 .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
597 .collect::<Vec<_>>(),
598 other => panic!("Found unhandled accumulator elem size {other}"),
599 };
600 let frag_d = match cd_elem.size() {
601 4 | 8 => (0..cd_regs)
602 .map(|i| as_ty(format!("{frag_d}[{}].i_{}", i / 2, i % 2), cd_ty))
603 .collect::<Vec<_>>(),
604 2 => (0..cd_regs)
605 .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
606 .collect::<Vec<_>>(),
607 other => panic!("Found unhandled accumulator elem size {other}"),
608 };
609 let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
610 write!(
611 f,
612 "__mma_m16n8k{}_{}_{}_{}({args});",
613 shape.k, a_elem, b_elem, cd_elem
614 )
615}
616
617pub(super) fn compile_scaled_mma<D: Dialect>(
618 f: &mut core::fmt::Formatter<'_>,
619 mma: ManualMma<D>,
620 scales_a: Variable<D>,
621 scales_b: Variable<D>,
622 scales_factor: u32,
623) -> std::fmt::Result {
624 let ManualMma {
625 shape,
626 frag_a,
627 frag_b,
628 frag_c,
629 frag_d,
630 } = mma;
631
632 let a_elem = frag_a.elem().unpacked();
633 let b_elem = frag_b.elem().unpacked();
634 let cd_elem = frag_c.elem().unpacked();
635
636 let ab_ty = &format!("{}", Elem::<D>::U32);
637 let cd_ty = &format!("{}", Elem::<D>::F32);
638
639 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
640 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
641 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
642
643 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
644 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
645 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
646
647 let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
648 let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
649 let frag_c = (0..cd_regs).map(|i| as_const_ty(format!("{frag_c}[{i}]"), cd_ty));
650 let frag_d = (0..cd_regs).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
651 let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
652 write!(
653 f,
654 "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
655 shape.k, a_elem, b_elem, cd_elem
656 )
657}
658
659pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
660 let mut result: SupportedMmaCombinations = vec![];
661 if arch.get_version() >= 80 {
665 result.extend([
666 MmaConfig {
667 a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), m: 16,
671 n: 8,
672 k: 16,
673 },
674 MmaConfig {
675 a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
676 b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
677 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
678 m: 16,
679 n: 8,
680 k: 16,
681 },
682 MmaConfig {
683 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
684 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
685 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
686 m: 16,
687 n: 8,
688 k: 8,
689 },
690 MmaConfig {
691 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
692 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
693 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
694 m: 16,
695 n: 8,
696 k: 32,
697 },
698 MmaConfig {
699 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
700 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
701 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
702 m: 16,
703 n: 8,
704 k: 32,
705 },
706 MmaConfig {
707 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
708 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
709 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
710 m: 16,
711 n: 8,
712 k: 32,
713 },
714 MmaConfig {
715 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
716 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
717 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
718 m: 16,
719 n: 8,
720 k: 32,
721 },
722 ]);
724 }
725 if arch.get_version() >= 89 {
726 let f8f6f4_types = [
727 gpu::FloatKind::E4M3,
728 gpu::FloatKind::E5M2,
729 gpu::FloatKind::E3M2,
730 gpu::FloatKind::E2M3,
731 gpu::FloatKind::E2M1,
732 ];
733 let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
734 result.extend(combinations.map(|(t1, t2)| MmaConfig {
735 a_type: gpu::ElemType::Float(*t1).into(),
736 b_type: gpu::ElemType::Float(*t2).into(),
737 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
738 m: 16,
739 n: 8,
740 k: 32,
741 }));
742 }
743 result
744}
745
746pub(super) fn supported_scaled_mma_combinations(
747 arch: &CudaArchitecture,
748) -> SupportedScaledMmaCombinations {
749 let mut result: SupportedScaledMmaCombinations = vec![];
750 if arch.get_version() >= 120 && arch.get_version() < 130 {
752 let f8f6f4_types = [
753 gpu::FloatKind::E4M3,
754 gpu::FloatKind::E5M2,
755 gpu::FloatKind::E3M2,
756 gpu::FloatKind::E2M3,
757 gpu::FloatKind::E2M1,
758 ];
759 let combinations = f8f6f4_types
760 .iter()
761 .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
762
763 result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
764 a_type: gpu::ElemType::Float(*t1).into(),
765 b_type: gpu::ElemType::Float(*t2).into(),
766 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
767 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
768 m: 16,
769 n: 8,
770 k: 32,
771 scales_factor: 1,
772 }));
773
774 result.extend([
775 ScaledMmaConfig {
776 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
777 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
778 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
779 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
780 m: 16,
781 n: 8,
782 k: 64,
783 scales_factor: 2,
784 },
785 ScaledMmaConfig {
787 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
788 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
789 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
790 scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
791 m: 16,
792 n: 8,
793 k: 64,
794 scales_factor: 4,
795 },
796 ]);
797 }
798 result
799}
800
801pub fn contiguous_elements_cuda(ident: MatrixIdent, matrix: Matrix) -> u32 {
802 match ident {
803 MatrixIdent::A | MatrixIdent::B => (32 / matrix.storage.size_bits()) as u32,
804 MatrixIdent::Accumulator => 2,
805 }
806}