1use std::fmt::Display;
2
3use crate::{
4 Dialect,
5 cuda::{
6 CudaDialect,
7 arch::CudaArchitecture,
8 ptx::{comma_separated, ldmatrix_call},
9 },
10 shared::{
11 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
12 FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
13 SupportedScaledMmaCombinations, Variable, WmmaInstruction,
14 },
15};
16use cubecl_core::ir::{self as gpu, ConstantScalarValue, Matrix, MatrixIdent};
17use cubecl_runtime::{MmaConfig, ScaledMmaConfig};
18use itertools::Itertools;
19
20use super::WMMA_MINIMUM_VERSION;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
27 if flags.elem_tf32 {
29 f.write_str("#include <mma.h>\n")?;
30 }
31 Ok(())
32 }
33
34 fn compile_wmma_fragment_declaration(
35 f: &mut std::fmt::Formatter<'_>,
36 var: &Variable<CudaDialect<Self>>,
37 ) -> std::fmt::Result {
38 let frag = match var {
39 Variable::WmmaFragment { frag, .. } => *frag,
40 _ => panic!("load instruction expects a WmmaFragment"),
41 };
42 let reg_count = get_fragment_register_total_count(&frag);
43 let ty = match frag.elem {
44 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
45 Elem::F32 => "float",
46 Elem::F64 => "double",
47 _ => panic!("unsupported type"),
48 };
49 writeln!(f, "{ty} {var}[{reg_count}];")
50 }
51
52 fn compile_wmma_instruction(
53 f: &mut std::fmt::Formatter<'_>,
54 instruction: &WmmaInstruction<CudaDialect<Self>>,
55 ) -> std::fmt::Result {
56 match instruction {
57 WmmaInstruction::Fill { frag: var, value } => {
58 let frag = match var {
59 Variable::WmmaFragment { frag, .. } => *frag,
60 _ => panic!("variable should be WmmaFragment"),
61 };
62 let reg_count = get_fragment_register_total_count(&frag);
63 write!(
64 f,
65 "// fill
66for (uint i = 0; i < uint({reg_count}); ++i) {{
67 {var}[i] = {value};
68}}
69 "
70 )
71 }
72 WmmaInstruction::Load {
73 frag: var,
74 value,
75 offset,
76 stride,
77 layout,
78 } => {
79 let frag = match var {
80 Variable::WmmaFragment { frag, .. } => *frag,
81 _ => panic!("load instruction expects a WmmaFragment"),
82 };
83 let layout = if frag.layout.is_some() {
88 get_fragment_layout_qualifier(var)
89 } else if let Some(layout) = layout {
90 get_qualifier_from_layout(layout)
91 } else {
92 panic!("unknown matrix layout for wmma load instruction");
93 };
94 let ty = get_type_qualifier(value);
96 let matrix = match frag.ident {
97 FragmentIdent::A => "a",
98 FragmentIdent::B => "b",
99 FragmentIdent::Accumulator => "c",
100 FragmentIdent::_Dialect(_) => unreachable!(),
101 };
102 let value_ty = value.item();
103 let opcode = match frag.elem {
104 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
105 format!(
106 "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
107 frag.m, frag.n, frag.k,
108 )
109 }
110 other => panic!("{other} fragment type not supported"),
111 };
112 let mut reg_count = 0;
114 let (regs_decl, out_constraints) =
115 get_variable_regs_decl_constraints(var, true, &mut reg_count);
116 let buffer_reg = format_reg_and_inc(&mut reg_count);
117 let (stride_reg, stride_constraint) =
118 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
119 let tmp_ptr = Variable::tmp_ptr(value.item());
120 let tmp_ptr_left = tmp_ptr.fmt_left();
121 write!(
122 f,
123 r#"// load
124{tmp_ptr_left} = ({value_ty}*){value} + {offset};
125asm volatile(
126 "{opcode} "
127 "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
128 : {out_constraints}
129 : "l"({tmp_ptr}){stride_constraint}
130);
131"#
132 )
133 }
134 WmmaInstruction::LdMatrix {
135 output,
136 buffer,
137 offset,
138 line_size,
139 factor,
140 transpose,
141 } => f.write_str(&ldmatrix_call(
142 output, buffer, offset, line_size, factor, transpose,
143 )),
144 WmmaInstruction::Execute {
145 frag_a: var_a,
146 frag_b: var_b,
147 frag_c: var_c,
148 frag_d: var_d,
149 ..
150 } => {
151 let frag_a = match var_a {
152 Variable::WmmaFragment { frag, .. } => *frag,
153 _ => panic!("variable should be WmmaFragment"),
154 };
155 let layout_a = get_fragment_layout_qualifier(var_a);
156 let layout_b = get_fragment_layout_qualifier(var_b);
157 let type_c = get_type_qualifier(var_c);
158 let type_d = get_type_qualifier(var_d);
159 let opcode = match var_a.elem() {
160 Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
161 "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
162 frag_a.m, frag_a.n, frag_a.k,
163 ),
164 Elem::BF16 => format!(
165 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
166 frag_a.m, frag_a.n, frag_a.k,
167 ),
168 Elem::TF32 => format!(
169 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
170 frag_a.m, frag_a.n, frag_a.k,
171 ),
172 other => panic!("{other} fragment type not supported"),
173 };
174 let mut reg_count = 0;
175 let (regs_decl_d, out_constraints_d) =
177 get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
178 let (regs_decl_a, in_constraints_a) =
179 get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
180 let (regs_decl_b, in_constraints_b) =
181 get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
182 let (regs_decl_c, in_constraints_c) =
183 get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
184 write!(
185 f,
186 r#"// execute
187asm volatile(
188 "{opcode} "
189 "{{{regs_decl_d}}}, "
190 "{{{regs_decl_a}}}, "
191 "{{{regs_decl_b}}}, "
192 "{{{regs_decl_c}}};\n"
193 : {out_constraints_d}
194 : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
195);
196"#
197 )
198 }
199 WmmaInstruction::ExecuteManual {
200 shape,
201 frag_a,
202 frag_b,
203 frag_c,
204 frag_d,
205 } => {
206 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
207 }
208 WmmaInstruction::ExecuteScaled {
209 shape,
210 frag_a,
211 frag_b,
212 frag_c,
213 frag_d,
214
215 scales_a,
216 scales_b,
217 scales_factor,
218 } => Self::compile_scaled_mma(
219 f,
220 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
221 *scales_a,
222 *scales_b,
223 *scales_factor,
224 ),
225 WmmaInstruction::Store {
226 output,
227 frag: var,
228 stride,
229 offset,
230 layout,
231 } => {
232 let frag_acc = match var {
233 Variable::WmmaFragment { frag, .. } => *frag,
234 _ => panic!("variable should be WmmaFragment"),
235 };
236 let layout = match layout {
238 FragmentLayout::ColMajor => "col",
239 FragmentLayout::RowMajor => "row",
240 FragmentLayout::_Dialect(..) => unreachable!(),
241 };
242 let opcode = match var.elem() {
243 Elem::F16 | Elem::BF16 => format!(
244 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
248 frag_acc.m, frag_acc.n, frag_acc.k,
249 ),
250 Elem::TF32 | Elem::F32 => format!(
251 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
253 frag_acc.m, frag_acc.n, frag_acc.k,
254 ),
255 Elem::I32 => format!(
256 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
258 frag_acc.m, frag_acc.n, frag_acc.k,
259 ),
260 other => panic!("{other} fragment type not supported"),
261 };
262 let mut reg_count = 0;
264 let buffer_reg = format_reg_and_inc(&mut reg_count);
265 let (stride_reg, stride_constraint) =
268 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
269 let (regs_decl, in_constraints) =
271 get_variable_regs_decl_constraints(var, false, &mut reg_count);
272 let tmp_ptr = Variable::tmp_ptr(output.item());
273 let tmp_ptr_left = tmp_ptr.fmt_left();
274 write!(
275 f,
276 r#"// store
277{tmp_ptr_left} = {output} + {offset};
278asm volatile(
279 "{opcode} "
280 "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
281 :
282 : "l"({tmp_ptr}),
283 {in_constraints}{stride_constraint}
284);
285"#
286 )
287 }
288 WmmaInstruction::Cast { input, output } => {
289 let frag = match input {
290 Variable::WmmaFragment { frag, .. } => *frag,
291 _ => panic!("variable should be WmmaFragment"),
292 };
293 let reg_count = get_fragment_register_total_count(&frag);
294 match output.elem() {
295 Elem::F16 => {
296 write!(
297 f,
298 "// cast
299for (int i = 0; i < {reg_count}; ++i) {{
300 __half h_lo = __float2half_rn({input}[2*i + 0]);
301 __half h_hi = __float2half_rn({input}[2*i + 1]);
302 __half2 h2 = __halves2half2(h_lo, h_hi);
303 {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
304}}
305"
306 )
307 }
308 Elem::BF16 => {
309 write!(
310 f,
311 "// cast
312for (int i = 0; i < {reg_count}; ++i) {{
313 __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
314 __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
315 __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
316 {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
317}}
318"
319 )
320 }
321 other => panic!("casting fragment to {other} not supported"),
322 }
323 }
324 }
325 }
326
327 fn compile_manual_mma(
328 f: &mut std::fmt::Formatter<'_>,
329 mma: ManualMma<CudaDialect<Self>>,
330 ) -> std::fmt::Result {
331 compile_manual_mma(f, mma)
332 }
333
334 fn compile_scaled_mma(
335 f: &mut std::fmt::Formatter<'_>,
336 mma: ManualMma<CudaDialect<Self>>,
337 scales_a: Variable<CudaDialect<Self>>,
338 scales_b: Variable<CudaDialect<Self>>,
339 scales_factor: u32,
340 ) -> std::fmt::Result {
341 compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
342 }
343
344 fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
345 let mut result: SupportedMmaCombinations = vec![];
346 if arch.get_version() >= WMMA_MINIMUM_VERSION {
347 let types = vec![
349 (
350 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
354 (
355 gpu::ElemType::Float(gpu::FloatKind::F16),
356 gpu::ElemType::Float(gpu::FloatKind::F16),
357 gpu::ElemType::Float(gpu::FloatKind::F32),
358 ),
359 (
360 gpu::ElemType::Float(gpu::FloatKind::BF16),
361 gpu::ElemType::Float(gpu::FloatKind::BF16),
362 gpu::ElemType::Float(gpu::FloatKind::F32),
363 ),
364 ];
365 let combinations: SupportedMmaCombinations = types
366 .into_iter()
367 .map(|(a, b, cd)| MmaConfig {
368 a_type: a.into(),
369 b_type: b.into(),
370 cd_type: cd.into(),
371 m: 16,
372 n: 16,
373 k: 16,
374 })
375 .collect();
376 result.extend(combinations);
377 if arch.get_version() >= 72 {
378 result.extend([
379 MmaConfig {
380 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
381 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
382 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
383 m: 16,
384 n: 16,
385 k: 16,
386 },
387 MmaConfig {
388 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
389 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
390 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
391 m: 16,
392 n: 16,
393 k: 16,
394 },
395 ]);
396 }
397 if arch.get_version() >= 80 {
398 result.push(MmaConfig {
399 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
400 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
401 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
402 m: 16,
403 n: 16,
404 k: 8,
405 });
406 }
407 }
408 result
409 }
410
411 fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
412 supported_mma_combinations(arch)
413 }
414
415 fn supported_scaled_mma_combinations(
416 arch: &CudaArchitecture,
417 ) -> SupportedScaledMmaCombinations {
418 supported_scaled_mma_combinations(arch)
419 }
420}
421
422fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
423 let Fragment {
424 ident,
425 m,
426 n,
427 k,
428 elem,
429 ..
430 } = frag;
431 let elements = match ident {
432 FragmentIdent::A => m * k,
433 FragmentIdent::B => k * n,
434 FragmentIdent::Accumulator => m * n,
435 _ => unreachable!(),
436 };
437 let bits_per_elem = elem.size_bits() as u32;
438 let lanes_per_reg = 32 / bits_per_elem;
440 let threads_per_frag = match ident {
444 FragmentIdent::Accumulator => 32,
445 FragmentIdent::A | FragmentIdent::B => {
446 if frag.elem == Elem::TF32 {
447 32
448 } else {
449 16
450 }
451 }
452 _ => unreachable!(),
453 };
454
455 elements / (lanes_per_reg * threads_per_frag)
456}
457
458fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
459 match var.elem() {
460 Elem::U8 => "u8",
461 Elem::I8 => "s8",
462 Elem::F16 => "f16",
463 Elem::BF16 => "bf16",
464 Elem::F32 => "f32",
465 Elem::TF32 => "tf32",
466 Elem::I32 => "s32",
467 Elem::F64 => "f64",
468 _ => panic!("unsupported WMMA fragment type"),
469 }
470 .to_string()
471}
472
473fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
474 let frag = match var {
475 Variable::WmmaFragment { frag, .. } => *frag,
476 _ => panic!("variable should be WmmaFragment"),
477 };
478 match frag.layout {
479 Some(layout) => get_qualifier_from_layout(&layout),
480 None => "".to_string(),
481 }
482}
483
484fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
485 match layout {
486 FragmentLayout::ColMajor => "col",
487 FragmentLayout::RowMajor => "row",
488 FragmentLayout::_Dialect(..) => unreachable!(),
489 }
490 .to_string()
491}
492
493fn get_variable_regs_decl_constraints(
494 var: &Variable<CudaDialect<PtxWmmaCompiler>>,
495 output: bool,
496 reg_count: &mut u8,
497) -> (String, String) {
498 match var {
499 Variable::WmmaFragment { frag, .. } => {
500 let reg_total_count = get_fragment_register_total_count(frag);
501 let reg_decl = (0..reg_total_count)
502 .map(|_| format_reg_and_inc(reg_count))
503 .collect::<Vec<_>>()
504 .join(",");
505 let frag_elem = frag.elem;
506 let modifier = format!(
507 "{}{}",
508 if output { "=" } else { "" },
509 match frag_elem {
510 Elem::F32 => "f",
511 Elem::F64 => "d",
512 _ => "r",
513 },
514 );
515 let constraints = (0..reg_total_count)
516 .map(|i| format!("\"{modifier}\"({var}[{i}])"))
517 .collect::<Vec<_>>()
518 .join(", ");
519 (reg_decl, constraints)
520 }
521 Variable::ConstantScalar(number, ..) => match number {
522 ConstantScalarValue::UInt(val, ..) => (val.to_string(), "".to_string()),
523 _ => panic!("variable should be an unsigned integer"),
524 },
525 _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
526 }
527}
528
529fn format_reg_and_inc(count: &mut u8) -> String {
530 let res = format!("%{count}");
531 *count += 1;
532 res
533}
534
535fn as_ty(var: impl Display, ty: impl Display) -> String {
536 format!("reinterpret_cast<{ty}&>({var})")
537}
538
539fn as_const_ty(var: impl Display, ty: impl Display) -> String {
540 format!("reinterpret_cast<const {ty}&>({var})")
541}
542
543pub(super) fn compile_manual_mma<D: Dialect>(
544 f: &mut core::fmt::Formatter<'_>,
545 mma: ManualMma<D>,
546) -> std::fmt::Result {
547 let ManualMma {
548 shape,
549 frag_a,
550 frag_b,
551 frag_c,
552 frag_d,
553 } = mma;
554
555 let a_elem = frag_a.elem().unpacked();
556 let b_elem = frag_b.elem().unpacked();
557 let cd_elem = frag_c.elem().unpacked();
558
559 let ab_ty = match a_elem {
560 Elem::F32 => &format!("{}", Elem::<D>::F32),
561 _ => &format!("{}", Elem::<D>::U32),
562 };
563 let cd_ty = match cd_elem {
564 Elem::F32 => &format!("{}", Elem::<D>::F32),
565 _ => &format!("{}", Elem::<D>::U32),
566 };
567
568 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
569 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
570 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
571
572 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
573 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
574 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
575
576 let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
577 let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
578
579 let frag_c = match cd_elem.size() {
582 4 | 8 => (0..cd_regs)
583 .map(|i| as_ty(format!("{frag_c}[{}].i_{}", i / 2, i % 2), cd_ty))
584 .collect::<Vec<_>>(),
585 2 => (0..cd_regs)
586 .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
587 .collect::<Vec<_>>(),
588 other => panic!("Found unhandled accumulator elem size {other}"),
589 };
590 let frag_d = match cd_elem.size() {
591 4 | 8 => (0..cd_regs)
592 .map(|i| as_ty(format!("{frag_d}[{}].i_{}", i / 2, i % 2), cd_ty))
593 .collect::<Vec<_>>(),
594 2 => (0..cd_regs)
595 .map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty))
596 .collect::<Vec<_>>(),
597 other => panic!("Found unhandled accumulator elem size {other}"),
598 };
599 let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
600 write!(
601 f,
602 "__mma_m16n8k{}_{}_{}_{}({args});",
603 shape.k, a_elem, b_elem, cd_elem
604 )
605}
606
607pub(super) fn compile_scaled_mma<D: Dialect>(
608 f: &mut core::fmt::Formatter<'_>,
609 mma: ManualMma<D>,
610 scales_a: Variable<D>,
611 scales_b: Variable<D>,
612 scales_factor: u32,
613) -> std::fmt::Result {
614 let ManualMma {
615 shape,
616 frag_a,
617 frag_b,
618 frag_c,
619 frag_d,
620 } = mma;
621
622 let a_elem = frag_a.elem().unpacked();
623 let b_elem = frag_b.elem().unpacked();
624 let cd_elem = frag_c.elem().unpacked();
625
626 let ab_ty = &format!("{}", Elem::<D>::U32);
627 let cd_ty = &format!("{}", Elem::<D>::F32);
628
629 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
630 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
631 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
632
633 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
634 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
635 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
636
637 let frag_a = (0..a_regs).map(|i| as_const_ty(format!("{frag_a}[{i}]"), ab_ty));
638 let frag_b = (0..b_regs).map(|i| as_const_ty(format!("{frag_b}[{i}]"), ab_ty));
639 let frag_c = (0..cd_regs).map(|i| as_const_ty(format!("{frag_c}[{i}]"), cd_ty));
640 let frag_d = (0..cd_regs).map(|i| as_ty(format!("{frag_d}[{i}]"), cd_ty));
641 let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
642 write!(
643 f,
644 "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
645 shape.k, a_elem, b_elem, cd_elem
646 )
647}
648
649pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
650 let mut result: SupportedMmaCombinations = vec![];
651 if arch.get_version() >= 80 {
655 result.extend([
656 MmaConfig {
657 a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), m: 16,
661 n: 8,
662 k: 16,
663 },
664 MmaConfig {
665 a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
666 b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
667 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
668 m: 16,
669 n: 8,
670 k: 16,
671 },
672 MmaConfig {
673 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
674 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
675 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
676 m: 16,
677 n: 8,
678 k: 8,
679 },
680 MmaConfig {
681 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
682 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
683 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
684 m: 16,
685 n: 8,
686 k: 32,
687 },
688 MmaConfig {
689 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
690 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
691 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
692 m: 16,
693 n: 8,
694 k: 32,
695 },
696 MmaConfig {
697 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
698 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
699 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
700 m: 16,
701 n: 8,
702 k: 32,
703 },
704 MmaConfig {
705 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
706 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
707 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
708 m: 16,
709 n: 8,
710 k: 32,
711 },
712 ]);
714 }
715 if arch.get_version() >= 89 {
716 let f8f6f4_types = [
717 gpu::FloatKind::E4M3,
718 gpu::FloatKind::E5M2,
719 gpu::FloatKind::E3M2,
720 gpu::FloatKind::E2M3,
721 gpu::FloatKind::E2M1,
722 ];
723 let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
724 result.extend(combinations.map(|(t1, t2)| MmaConfig {
725 a_type: gpu::ElemType::Float(*t1).into(),
726 b_type: gpu::ElemType::Float(*t2).into(),
727 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
728 m: 16,
729 n: 8,
730 k: 32,
731 }));
732 }
733 result
734}
735
736pub(super) fn supported_scaled_mma_combinations(
737 arch: &CudaArchitecture,
738) -> SupportedScaledMmaCombinations {
739 let mut result: SupportedScaledMmaCombinations = vec![];
740 if arch.get_version() >= 120 && arch.get_version() < 130 {
742 let f8f6f4_types = [
743 gpu::FloatKind::E4M3,
744 gpu::FloatKind::E5M2,
745 gpu::FloatKind::E3M2,
746 gpu::FloatKind::E2M3,
747 gpu::FloatKind::E2M1,
748 ];
749 let combinations = f8f6f4_types
750 .iter()
751 .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
752
753 result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
754 a_type: gpu::ElemType::Float(*t1).into(),
755 b_type: gpu::ElemType::Float(*t2).into(),
756 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
757 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
758 m: 16,
759 n: 8,
760 k: 32,
761 scales_factor: 1,
762 }));
763
764 result.extend([
765 ScaledMmaConfig {
766 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
767 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
768 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
769 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
770 m: 16,
771 n: 8,
772 k: 64,
773 scales_factor: 2,
774 },
775 ScaledMmaConfig {
777 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
778 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
779 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
780 scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
781 m: 16,
782 n: 8,
783 k: 64,
784 scales_factor: 4,
785 },
786 ]);
787 }
788 result
789}
790
791pub fn contiguous_elements_cuda(ident: MatrixIdent, matrix: Matrix) -> u32 {
792 match ident {
793 MatrixIdent::A | MatrixIdent::B => (32 / matrix.storage.size_bits()) as u32,
794 MatrixIdent::Accumulator => 2,
795 }
796}