1use super::WMMA_MINIMUM_VERSION;
2use crate::{
3 Dialect,
4 cuda::{
5 CudaDialect,
6 arch::CudaArchitecture,
7 ptx::{comma_separated, ldmatrix_call, stmatrix_call},
8 },
9 shared::{
10 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
11 FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
12 SupportedScaledMmaCombinations, Variable, WmmaInstruction,
13 },
14};
15use cubecl_core::ir::{
16 self as gpu, ConstantValue, Matrix, MatrixIdent,
17 features::{MmaConfig, ScaledMmaConfig},
18};
19use itertools::Itertools;
20use std::fmt::Display;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26 fn compile_wmma_includes(
27 f: &mut std::fmt::Formatter<'_>,
28 flags: &Flags<CudaDialect<Self>>,
29 ) -> std::fmt::Result {
30 if flags.elem_tf32 {
32 f.write_str("#include <mma.h>\n")?;
33 }
34 Ok(())
35 }
36
37 fn compile_wmma_fragment_declaration(
38 f: &mut std::fmt::Formatter<'_>,
39 var: &Variable<CudaDialect<Self>>,
40 ) -> std::fmt::Result {
41 let frag = match var {
42 Variable::WmmaFragment { frag, .. } => *frag,
43 _ => panic!("load instruction expects a WmmaFragment"),
44 };
45 let reg_count = get_fragment_register_total_count(&frag);
46 let ty = match frag.elem {
47 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
48 Elem::F32 => "float",
49 Elem::F64 => "double",
50 _ => panic!("unsupported type"),
51 };
52 writeln!(f, "{ty} {var}[{reg_count}];")
53 }
54
55 fn compile_wmma_instruction(
56 f: &mut std::fmt::Formatter<'_>,
57 instruction: &WmmaInstruction<CudaDialect<Self>>,
58 ) -> std::fmt::Result {
59 match instruction {
60 WmmaInstruction::Fill { frag: var, value } => {
61 let frag = match var {
62 Variable::WmmaFragment { frag, .. } => *frag,
63 _ => panic!("variable should be WmmaFragment"),
64 };
65 let reg_count = get_fragment_register_total_count(&frag);
66 write!(
67 f,
68 "// fill
69for (uint i = 0; i < uint({reg_count}); ++i) {{
70 {var}[i] = {value};
71}}
72 "
73 )
74 }
75 WmmaInstruction::Load {
76 frag: var,
77 value,
78 offset,
79 stride,
80 layout,
81 } => {
82 let frag = match var {
83 Variable::WmmaFragment { frag, .. } => *frag,
84 _ => panic!("load instruction expects a WmmaFragment"),
85 };
86 let layout = if frag.layout.is_some() {
91 get_fragment_layout_qualifier(var)
92 } else if let Some(layout) = layout {
93 get_qualifier_from_layout(layout)
94 } else {
95 panic!("unknown matrix layout for wmma load instruction");
96 };
97 let ty = get_type_qualifier(value);
99 let matrix = match frag.ident {
100 FragmentIdent::A => "a",
101 FragmentIdent::B => "b",
102 FragmentIdent::Accumulator => "c",
103 FragmentIdent::_Dialect(_) => unreachable!(),
104 };
105 let value_ty = value.item();
106 let opcode = match frag.elem {
107 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
108 format!(
109 "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
110 frag.m, frag.n, frag.k,
111 )
112 }
113 other => panic!("{other} fragment type not supported"),
114 };
115 let mut reg_count = 0;
117 let (regs_decl, out_constraints) =
118 get_variable_regs_decl_constraints(var, true, &mut reg_count);
119 let buffer_reg = format_reg_and_inc(&mut reg_count);
120 let (stride_reg, stride_constraint) =
121 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
122 let tmp_ptr = Variable::tmp_ptr(value.item());
123 let tmp_ptr_left = tmp_ptr.fmt_left();
124 write!(
125 f,
126 r#"// load
127{tmp_ptr_left} = ({value_ty}*){value} + {offset};
128asm volatile(
129 "{opcode} "
130 "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
131 : {out_constraints}
132 : "l"({tmp_ptr}){stride_constraint}
133);
134"#
135 )
136 }
137 WmmaInstruction::LdMatrix {
138 output,
139 buffer,
140 offset,
141 line_size,
142 factor,
143 transpose,
144 } => f.write_str(&ldmatrix_call(
145 output, buffer, offset, line_size, factor, transpose,
146 )),
147 WmmaInstruction::StMatrix {
148 registers,
149 buffer,
150 offset,
151 line_size,
152 factor,
153 transpose,
154 } => f.write_str(&stmatrix_call(
155 registers, buffer, offset, line_size, factor, transpose,
156 )),
157 WmmaInstruction::Execute {
158 frag_a: var_a,
159 frag_b: var_b,
160 frag_c: var_c,
161 frag_d: var_d,
162 ..
163 } => {
164 let frag_a = match var_a {
165 Variable::WmmaFragment { frag, .. } => *frag,
166 _ => panic!("variable should be WmmaFragment"),
167 };
168 let layout_a = get_fragment_layout_qualifier(var_a);
169 let layout_b = get_fragment_layout_qualifier(var_b);
170 let type_c = get_type_qualifier(var_c);
171 let type_d = get_type_qualifier(var_d);
172 let opcode = match var_a.elem() {
173 Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
174 "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
175 frag_a.m, frag_a.n, frag_a.k,
176 ),
177 Elem::BF16 => format!(
178 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
179 frag_a.m, frag_a.n, frag_a.k,
180 ),
181 Elem::TF32 => format!(
182 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
183 frag_a.m, frag_a.n, frag_a.k,
184 ),
185 other => panic!("{other} fragment type not supported"),
186 };
187 let mut reg_count = 0;
188 let (regs_decl_d, out_constraints_d) =
190 get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
191 let (regs_decl_a, in_constraints_a) =
192 get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
193 let (regs_decl_b, in_constraints_b) =
194 get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
195 let (regs_decl_c, in_constraints_c) =
196 get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
197 write!(
198 f,
199 r#"// execute
200asm volatile(
201 "{opcode} "
202 "{{{regs_decl_d}}}, "
203 "{{{regs_decl_a}}}, "
204 "{{{regs_decl_b}}}, "
205 "{{{regs_decl_c}}};\n"
206 : {out_constraints_d}
207 : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
208);
209"#
210 )
211 }
212 WmmaInstruction::ExecuteManual {
213 shape,
214 frag_a,
215 frag_b,
216 frag_c,
217 frag_d,
218 } => {
219 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
220 }
221 WmmaInstruction::ExecuteScaled {
222 shape,
223 frag_a,
224 frag_b,
225 frag_c,
226 frag_d,
227
228 scales_a,
229 scales_b,
230 scales_factor,
231 } => Self::compile_scaled_mma(
232 f,
233 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
234 *scales_a,
235 *scales_b,
236 *scales_factor,
237 ),
238 WmmaInstruction::Store {
239 output,
240 frag: var,
241 stride,
242 offset,
243 layout,
244 } => {
245 let frag_acc = match var {
246 Variable::WmmaFragment { frag, .. } => *frag,
247 _ => panic!("variable should be WmmaFragment"),
248 };
249 let layout = match layout {
251 FragmentLayout::ColMajor => "col",
252 FragmentLayout::RowMajor => "row",
253 FragmentLayout::_Dialect(..) => unreachable!(),
254 };
255 let opcode = match var.elem() {
256 Elem::F16 | Elem::BF16 => format!(
257 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
261 frag_acc.m, frag_acc.n, frag_acc.k,
262 ),
263 Elem::TF32 | Elem::F32 => format!(
264 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
266 frag_acc.m, frag_acc.n, frag_acc.k,
267 ),
268 Elem::I32 => format!(
269 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
271 frag_acc.m, frag_acc.n, frag_acc.k,
272 ),
273 other => panic!("{other} fragment type not supported"),
274 };
275 let mut reg_count = 0;
277 let buffer_reg = format_reg_and_inc(&mut reg_count);
278 let (stride_reg, stride_constraint) =
281 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
282 let (regs_decl, in_constraints) =
284 get_variable_regs_decl_constraints(var, false, &mut reg_count);
285 let tmp_ptr = Variable::tmp_ptr(output.item());
286 let tmp_ptr_left = tmp_ptr.fmt_left();
287 write!(
288 f,
289 r#"// store
290{tmp_ptr_left} = {output} + {offset};
291asm volatile(
292 "{opcode} "
293 "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
294 :
295 : "l"({tmp_ptr}),
296 {in_constraints}{stride_constraint}
297);
298"#
299 )
300 }
301 WmmaInstruction::Cast { input, output } => {
302 let frag = match input {
303 Variable::WmmaFragment { frag, .. } => *frag,
304 _ => panic!("variable should be WmmaFragment"),
305 };
306 let reg_count = get_fragment_register_total_count(&frag);
307 match output.elem() {
308 Elem::F16 => {
309 write!(
310 f,
311 "// cast
312for (int i = 0; i < {reg_count}; ++i) {{
313 __half h_lo = __float2half_rn({input}[2*i + 0]);
314 __half h_hi = __float2half_rn({input}[2*i + 1]);
315 __half2 h2 = __halves2half2(h_lo, h_hi);
316 {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
317}}
318"
319 )
320 }
321 Elem::BF16 => {
322 write!(
323 f,
324 "// cast
325for (int i = 0; i < {reg_count}; ++i) {{
326 __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
327 __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
328 __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
329 {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
330}}
331"
332 )
333 }
334 other => panic!("casting fragment to {other} not supported"),
335 }
336 }
337 }
338 }
339
340 fn compile_manual_mma(
341 f: &mut std::fmt::Formatter<'_>,
342 mma: ManualMma<CudaDialect<Self>>,
343 ) -> std::fmt::Result {
344 compile_manual_mma(f, mma)
345 }
346
347 fn compile_scaled_mma(
348 f: &mut std::fmt::Formatter<'_>,
349 mma: ManualMma<CudaDialect<Self>>,
350 scales_a: Variable<CudaDialect<Self>>,
351 scales_b: Variable<CudaDialect<Self>>,
352 scales_factor: u32,
353 ) -> std::fmt::Result {
354 compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
355 }
356
357 fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
358 let mut result: SupportedMmaCombinations = vec![];
359 if arch.get_version() >= WMMA_MINIMUM_VERSION {
360 let types = vec![
362 (
363 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
367 (
368 gpu::ElemType::Float(gpu::FloatKind::F16),
369 gpu::ElemType::Float(gpu::FloatKind::F16),
370 gpu::ElemType::Float(gpu::FloatKind::F32),
371 ),
372 (
373 gpu::ElemType::Float(gpu::FloatKind::BF16),
374 gpu::ElemType::Float(gpu::FloatKind::BF16),
375 gpu::ElemType::Float(gpu::FloatKind::F32),
376 ),
377 ];
378 let combinations: SupportedMmaCombinations = types
379 .into_iter()
380 .map(|(a, b, cd)| MmaConfig {
381 a_type: a.into(),
382 b_type: b.into(),
383 cd_type: cd.into(),
384 m: 16,
385 n: 16,
386 k: 16,
387 })
388 .collect();
389 result.extend(combinations);
390 if arch.get_version() >= 72 {
391 result.extend([
392 MmaConfig {
393 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
394 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
395 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
396 m: 16,
397 n: 16,
398 k: 16,
399 },
400 MmaConfig {
401 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
402 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
403 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
404 m: 16,
405 n: 16,
406 k: 16,
407 },
408 ]);
409 }
410 if arch.get_version() >= 80 {
411 result.push(MmaConfig {
412 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
413 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
414 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
415 m: 16,
416 n: 16,
417 k: 8,
418 });
419 }
420 }
421 result
422 }
423
424 fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
425 supported_mma_combinations(arch)
426 }
427
428 fn supported_scaled_mma_combinations(
429 arch: &CudaArchitecture,
430 ) -> SupportedScaledMmaCombinations {
431 supported_scaled_mma_combinations(arch)
432 }
433}
434
435fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
436 let Fragment {
437 ident,
438 m,
439 n,
440 k,
441 elem,
442 ..
443 } = frag;
444 let elements = match ident {
445 FragmentIdent::A => m * k,
446 FragmentIdent::B => k * n,
447 FragmentIdent::Accumulator => m * n,
448 _ => unreachable!(),
449 };
450 let bits_per_elem = elem.size_bits() as u32;
451 let lanes_per_reg = 32 / bits_per_elem;
453 let threads_per_frag = match ident {
457 FragmentIdent::Accumulator => 32,
458 FragmentIdent::A | FragmentIdent::B => {
459 if frag.elem == Elem::TF32 {
460 32
461 } else {
462 16
463 }
464 }
465 _ => unreachable!(),
466 };
467
468 elements / (lanes_per_reg * threads_per_frag)
469}
470
471fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
472 match var.elem() {
473 Elem::U8 => "u8",
474 Elem::I8 => "s8",
475 Elem::F16 => "f16",
476 Elem::BF16 => "bf16",
477 Elem::F32 => "f32",
478 Elem::TF32 => "tf32",
479 Elem::I32 => "s32",
480 Elem::F64 => "f64",
481 _ => panic!("unsupported WMMA fragment type"),
482 }
483 .to_string()
484}
485
486fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
487 let frag = match var {
488 Variable::WmmaFragment { frag, .. } => *frag,
489 _ => panic!("variable should be WmmaFragment"),
490 };
491 match frag.layout {
492 Some(layout) => get_qualifier_from_layout(&layout),
493 None => "".to_string(),
494 }
495}
496
497fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
498 match layout {
499 FragmentLayout::ColMajor => "col",
500 FragmentLayout::RowMajor => "row",
501 FragmentLayout::_Dialect(..) => unreachable!(),
502 }
503 .to_string()
504}
505
506fn get_variable_regs_decl_constraints(
507 var: &Variable<CudaDialect<PtxWmmaCompiler>>,
508 output: bool,
509 reg_count: &mut u8,
510) -> (String, String) {
511 match var {
512 Variable::WmmaFragment { frag, .. } => {
513 let reg_total_count = get_fragment_register_total_count(frag);
514 let reg_decl = (0..reg_total_count)
515 .map(|_| format_reg_and_inc(reg_count))
516 .collect::<Vec<_>>()
517 .join(",");
518 let frag_elem = frag.elem;
519 let modifier = format!(
520 "{}{}",
521 if output { "=" } else { "" },
522 match frag_elem {
523 Elem::F32 => "f",
524 Elem::F64 => "d",
525 _ => "r",
526 },
527 );
528 let constraints = (0..reg_total_count)
529 .map(|i| format!("\"{modifier}\"({var}[{i}])"))
530 .collect::<Vec<_>>()
531 .join(", ");
532 (reg_decl, constraints)
533 }
534 Variable::Constant(number, ..) => match number {
535 ConstantValue::UInt(val, ..) => (val.to_string(), "".to_string()),
536 _ => panic!("variable should be an unsigned integer"),
537 },
538 _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
539 }
540}
541
542fn format_reg_and_inc(count: &mut u8) -> String {
543 let res = format!("%{count}");
544 *count += 1;
545 res
546}
547
548fn as_ty(var: impl Display, ty: impl Display) -> String {
549 format!("reinterpret_cast<{ty}&>({var})")
550}
551
552fn as_const_ty(var: impl Display, ty: impl Display) -> String {
553 format!("reinterpret_cast<const {ty}&>({var})")
554}
555
556pub(super) fn compile_manual_mma<D: Dialect>(
557 f: &mut core::fmt::Formatter<'_>,
558 mma: ManualMma<D>,
559) -> std::fmt::Result {
560 let ManualMma {
561 shape,
562 frag_a,
563 frag_b,
564 frag_c,
565 frag_d,
566 } = mma;
567
568 let a_elem = frag_a.elem().unpacked();
569 let b_elem = frag_b.elem().unpacked();
570 let cd_elem = frag_c.elem().unpacked();
571
572 let ab_ty = match a_elem {
573 Elem::F32 => &format!("{}", Elem::<D>::F32),
574 _ => &format!("{}", Elem::<D>::U32),
575 };
576 let cd_ty = match cd_elem {
577 Elem::F32 => &format!("{}", Elem::<D>::F32),
578 _ => &format!("{}", Elem::<D>::U32),
579 };
580
581 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
582 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
583 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
584
585 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
586 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
587 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
588
589 let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
590 let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
591
592 let frag_c = match cd_elem.size() {
595 4 | 8 => (0..cd_regs)
596 .map(|i| as_ty(format!("{frag_c}[{}].i_{}", i / 2, i % 2), cd_ty))
597 .collect::<Vec<_>>(),
598 2 => (0..cd_regs)
599 .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
600 .collect::<Vec<_>>(),
601 other => panic!("Found unhandled accumulator elem size {other}"),
602 };
603 let frag_d = match cd_elem.size() {
604 4 | 8 => (0..cd_regs)
605 .map(|i| as_ty(format!("{frag_d}[{}].i_{}", i / 2, i % 2), cd_ty))
606 .collect::<Vec<_>>(),
607 2 => (0..cd_regs)
608 .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
609 .collect::<Vec<_>>(),
610 other => panic!("Found unhandled accumulator elem size {other}"),
611 };
612 let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
613 write!(
614 f,
615 "__mma_m16n8k{}_{}_{}_{}({args});",
616 shape.k, a_elem, b_elem, cd_elem
617 )
618}
619
620pub(super) fn compile_scaled_mma<D: Dialect>(
621 f: &mut core::fmt::Formatter<'_>,
622 mma: ManualMma<D>,
623 scales_a: Variable<D>,
624 scales_b: Variable<D>,
625 scales_factor: u32,
626) -> std::fmt::Result {
627 let ManualMma {
628 shape,
629 frag_a,
630 frag_b,
631 frag_c,
632 frag_d,
633 } = mma;
634
635 let a_elem = frag_a.elem().unpacked();
636 let b_elem = frag_b.elem().unpacked();
637 let cd_elem = frag_c.elem().unpacked();
638
639 let ab_ty = &format!("{}", Elem::<D>::U32);
640 let cd_ty = &format!("{}", Elem::<D>::F32);
641
642 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
643 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
644 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
645
646 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
647 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
648 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
649
650 let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
651 let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
652 let frag_c = (0..cd_regs).map(|i| as_const_ty(format!("{frag_c}[{i}]"), cd_ty));
653 let frag_d = (0..cd_regs).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
654 let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
655 write!(
656 f,
657 "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
658 shape.k, a_elem, b_elem, cd_elem
659 )
660}
661
662pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
663 let mut result: SupportedMmaCombinations = vec![];
664 if arch.get_version() >= 80 {
668 result.extend([
669 MmaConfig {
670 a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), m: 16,
674 n: 8,
675 k: 16,
676 },
677 MmaConfig {
678 a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
679 b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
680 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
681 m: 16,
682 n: 8,
683 k: 16,
684 },
685 MmaConfig {
686 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
687 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
688 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
689 m: 16,
690 n: 8,
691 k: 8,
692 },
693 MmaConfig {
694 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
695 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
696 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
697 m: 16,
698 n: 8,
699 k: 32,
700 },
701 MmaConfig {
702 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
703 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
704 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
705 m: 16,
706 n: 8,
707 k: 32,
708 },
709 MmaConfig {
710 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
711 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
712 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
713 m: 16,
714 n: 8,
715 k: 32,
716 },
717 MmaConfig {
718 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
719 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
720 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
721 m: 16,
722 n: 8,
723 k: 32,
724 },
725 ]);
727 }
728 if arch.get_version() >= 89 {
729 let f8f6f4_types = [
730 gpu::FloatKind::E4M3,
731 gpu::FloatKind::E5M2,
732 gpu::FloatKind::E3M2,
733 gpu::FloatKind::E2M3,
734 gpu::FloatKind::E2M1,
735 ];
736 let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
737 result.extend(combinations.map(|(t1, t2)| MmaConfig {
738 a_type: gpu::ElemType::Float(*t1).into(),
739 b_type: gpu::ElemType::Float(*t2).into(),
740 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
741 m: 16,
742 n: 8,
743 k: 32,
744 }));
745 }
746 result
747}
748
749pub(super) fn supported_scaled_mma_combinations(
750 arch: &CudaArchitecture,
751) -> SupportedScaledMmaCombinations {
752 let mut result: SupportedScaledMmaCombinations = vec![];
753 if arch.get_version() >= 120 && arch.get_version() < 130 {
755 let f8f6f4_types = [
756 gpu::FloatKind::E4M3,
757 gpu::FloatKind::E5M2,
758 gpu::FloatKind::E3M2,
759 gpu::FloatKind::E2M3,
760 gpu::FloatKind::E2M1,
761 ];
762 let combinations = f8f6f4_types
763 .iter()
764 .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
765
766 result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
767 a_type: gpu::ElemType::Float(*t1).into(),
768 b_type: gpu::ElemType::Float(*t2).into(),
769 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
770 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
771 m: 16,
772 n: 8,
773 k: 32,
774 scales_factor: 1,
775 }));
776
777 result.extend([
778 ScaledMmaConfig {
779 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
780 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
781 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
782 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
783 m: 16,
784 n: 8,
785 k: 64,
786 scales_factor: 2,
787 },
788 ScaledMmaConfig {
790 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
791 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
792 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
793 scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
794 m: 16,
795 n: 8,
796 k: 64,
797 scales_factor: 4,
798 },
799 ]);
800 }
801 result
802}
803
804pub fn contiguous_elements_cuda(ident: MatrixIdent, matrix: Matrix) -> usize {
805 match ident {
806 MatrixIdent::A | MatrixIdent::B => 32 / matrix.storage.size_bits(),
807 MatrixIdent::Accumulator => 2,
808 }
809}