1use super::WMMA_MINIMUM_VERSION;
2use crate::{
3 Dialect,
4 cuda::{
5 CudaDialect,
6 arch::CudaArchitecture,
7 ptx::{comma_separated, ldmatrix_call, stmatrix_call},
8 },
9 shared::{
10 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
11 FragmentIdent, FragmentLayout, ManualMma, SupportedMmaCombinations,
12 SupportedScaledMmaCombinations, Variable, WmmaInstruction,
13 },
14};
15use cubecl_core::ir::{
16 self as gpu, ConstantValue, Matrix, MatrixIdent,
17 features::{MmaConfig, ScaledMmaConfig},
18};
19use itertools::Itertools;
20use std::fmt::Display;
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct PtxWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for PtxWmmaCompiler {
26 fn compile_wmma_includes(
27 f: &mut std::fmt::Formatter<'_>,
28 flags: &Flags<CudaDialect<Self>>,
29 ) -> std::fmt::Result {
30 if flags.elem_tf32 {
32 f.write_str("#include <mma.h>\n")?;
33 }
34 Ok(())
35 }
36
37 fn compile_wmma_fragment_declaration(
38 f: &mut std::fmt::Formatter<'_>,
39 var: &Variable<CudaDialect<Self>>,
40 ) -> std::fmt::Result {
41 let frag = match var {
42 Variable::WmmaFragment { frag, .. } => *frag,
43 _ => panic!("load instruction expects a WmmaFragment"),
44 };
45 let reg_count = get_fragment_register_total_count(&frag);
46 let ty = match frag.elem {
47 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::TF32 => "unsigned int",
48 Elem::F32 => "float",
49 Elem::F64 => "double",
50 _ => panic!("unsupported type"),
51 };
52 writeln!(f, "{ty} {var}[{reg_count}];")
53 }
54
55 fn compile_wmma_instruction(
56 f: &mut std::fmt::Formatter<'_>,
57 instruction: &WmmaInstruction<CudaDialect<Self>>,
58 ) -> std::fmt::Result {
59 match instruction {
60 WmmaInstruction::Fill { frag: var, value } => {
61 let frag = match var {
62 Variable::WmmaFragment { frag, .. } => *frag,
63 _ => panic!("variable should be WmmaFragment"),
64 };
65 let reg_count = get_fragment_register_total_count(&frag);
66 write!(
67 f,
68 "// fill
69for (uint i = 0; i < uint({reg_count}); ++i) {{
70 {var}[i] = {value};
71}}
72 "
73 )
74 }
75 WmmaInstruction::Load {
76 frag: var,
77 value,
78 offset,
79 stride,
80 layout,
81 } => {
82 let frag = match var {
83 Variable::WmmaFragment { frag, .. } => *frag,
84 _ => panic!("load instruction expects a WmmaFragment"),
85 };
86 let layout = if frag.layout.is_some() {
91 get_fragment_layout_qualifier(var)
92 } else if let Some(layout) = layout {
93 get_qualifier_from_layout(layout)
94 } else {
95 panic!("unknown matrix layout for wmma load instruction");
96 };
97 let ty = get_type_qualifier(value);
99 let matrix = match frag.ident {
100 FragmentIdent::A => "a",
101 FragmentIdent::B => "b",
102 FragmentIdent::Accumulator => "c",
103 FragmentIdent::_Dialect(_) => unreachable!(),
104 };
105 let value_ty = value.item();
106 let opcode = match frag.elem {
107 Elem::U8 | Elem::I8 | Elem::F16 | Elem::BF16 | Elem::F32 | Elem::TF32 => {
108 format!(
109 "wmma.load.{matrix}.sync.aligned.{layout}.m{}n{}k{}.{ty}",
110 frag.m, frag.n, frag.k,
111 )
112 }
113 other => panic!("{other} fragment type not supported"),
114 };
115 let mut reg_count = 0;
117 let (regs_decl, out_constraints) =
118 get_variable_regs_decl_constraints(var, true, &mut reg_count);
119 let buffer_reg = format_reg_and_inc(&mut reg_count);
120 let (stride_reg, stride_constraint) =
121 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
122 let tmp_ptr = Variable::tmp_ptr(value.item());
123 let tmp_ptr_left = tmp_ptr.fmt_left();
124 write!(
125 f,
126 r#"// load
127{tmp_ptr_left} = ({value_ty}*){value} + {offset};
128asm volatile(
129 "{opcode} "
130 "{{{regs_decl}}}, [{buffer_reg}], {stride_reg};\n"
131 : {out_constraints}
132 : "l"({tmp_ptr}){stride_constraint}
133);
134"#
135 )
136 }
137 WmmaInstruction::LdMatrix {
138 output,
139 buffer,
140 offset,
141 vector_size,
142 factor,
143 transpose,
144 } => f.write_str(&ldmatrix_call(
145 output,
146 buffer,
147 offset,
148 vector_size,
149 factor,
150 transpose,
151 )),
152 WmmaInstruction::StMatrix {
153 registers,
154 buffer,
155 offset,
156 vector_size,
157 factor,
158 transpose,
159 } => f.write_str(&stmatrix_call(
160 registers,
161 buffer,
162 offset,
163 vector_size,
164 factor,
165 transpose,
166 )),
167 WmmaInstruction::Execute {
168 frag_a: var_a,
169 frag_b: var_b,
170 frag_c: var_c,
171 frag_d: var_d,
172 ..
173 } => {
174 let frag_a = match var_a {
175 Variable::WmmaFragment { frag, .. } => *frag,
176 _ => panic!("variable should be WmmaFragment"),
177 };
178 let layout_a = get_fragment_layout_qualifier(var_a);
179 let layout_b = get_fragment_layout_qualifier(var_b);
180 let type_c = get_type_qualifier(var_c);
181 let type_d = get_type_qualifier(var_d);
182 let opcode = match var_a.elem() {
183 Elem::U8 | Elem::I8 | Elem::F16 | Elem::F32 => format!(
184 "wmma.mma.sync.aligned.m{}n{}k{}.{layout_a}.{layout_b}.{type_d}.{type_c}",
185 frag_a.m, frag_a.n, frag_a.k,
186 ),
187 Elem::BF16 => format!(
188 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.bf16.bf16.f32",
189 frag_a.m, frag_a.n, frag_a.k,
190 ),
191 Elem::TF32 => format!(
192 "wmma.mma.sync.aligned.{layout_a}.{layout_b}.m{}n{}k{}.f32.tf32.tf32.f32",
193 frag_a.m, frag_a.n, frag_a.k,
194 ),
195 other => panic!("{other} fragment type not supported"),
196 };
197 let mut reg_count = 0;
198 let (regs_decl_d, out_constraints_d) =
200 get_variable_regs_decl_constraints(var_d, true, &mut reg_count);
201 let (regs_decl_a, in_constraints_a) =
202 get_variable_regs_decl_constraints(var_a, false, &mut reg_count);
203 let (regs_decl_b, in_constraints_b) =
204 get_variable_regs_decl_constraints(var_b, false, &mut reg_count);
205 let (regs_decl_c, in_constraints_c) =
206 get_variable_regs_decl_constraints(var_c, false, &mut reg_count);
207 write!(
208 f,
209 r#"// execute
210asm volatile(
211 "{opcode} "
212 "{{{regs_decl_d}}}, "
213 "{{{regs_decl_a}}}, "
214 "{{{regs_decl_b}}}, "
215 "{{{regs_decl_c}}};\n"
216 : {out_constraints_d}
217 : {in_constraints_a}, {in_constraints_b}, {in_constraints_c}
218);
219"#
220 )
221 }
222 WmmaInstruction::ExecuteManual {
223 shape,
224 frag_a,
225 frag_b,
226 frag_c,
227 frag_d,
228 } => {
229 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
230 }
231 WmmaInstruction::ExecuteScaled {
232 shape,
233 frag_a,
234 frag_b,
235 frag_c,
236 frag_d,
237
238 scales_a,
239 scales_b,
240 scales_factor,
241 } => Self::compile_scaled_mma(
242 f,
243 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
244 *scales_a,
245 *scales_b,
246 *scales_factor,
247 ),
248 WmmaInstruction::Store {
249 output,
250 frag: var,
251 stride,
252 offset,
253 layout,
254 } => {
255 let frag_acc = match var {
256 Variable::WmmaFragment { frag, .. } => *frag,
257 _ => panic!("variable should be WmmaFragment"),
258 };
259 let layout = match layout {
261 FragmentLayout::ColMajor => "col",
262 FragmentLayout::RowMajor => "row",
263 FragmentLayout::_Dialect(..) => unreachable!(),
264 };
265 let opcode = match var.elem() {
266 Elem::F16 | Elem::BF16 => format!(
267 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f16",
271 frag_acc.m, frag_acc.n, frag_acc.k,
272 ),
273 Elem::TF32 | Elem::F32 => format!(
274 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.f32",
276 frag_acc.m, frag_acc.n, frag_acc.k,
277 ),
278 Elem::I32 => format!(
279 "wmma.store.d.sync.aligned.{layout}.m{}n{}k{}.s32",
281 frag_acc.m, frag_acc.n, frag_acc.k,
282 ),
283 other => panic!("{other} fragment type not supported"),
284 };
285 let mut reg_count = 0;
287 let buffer_reg = format_reg_and_inc(&mut reg_count);
288 let (stride_reg, stride_constraint) =
291 get_variable_regs_decl_constraints(stride, false, &mut reg_count);
292 let (regs_decl, in_constraints) =
294 get_variable_regs_decl_constraints(var, false, &mut reg_count);
295 let tmp_ptr = Variable::tmp_ptr(output.item());
296 let tmp_ptr_left = tmp_ptr.fmt_left();
297 write!(
298 f,
299 r#"// store
300{tmp_ptr_left} = {output} + {offset};
301asm volatile(
302 "{opcode} "
303 "[{buffer_reg}], {{{regs_decl}}}, {stride_reg};\n"
304 :
305 : "l"({tmp_ptr}),
306 {in_constraints}{stride_constraint}
307);
308"#
309 )
310 }
311 WmmaInstruction::Cast { input, output } => {
312 let frag = match input {
313 Variable::WmmaFragment { frag, .. } => *frag,
314 _ => panic!("variable should be WmmaFragment"),
315 };
316 let reg_count = get_fragment_register_total_count(&frag);
317 match output.elem() {
318 Elem::F16 => {
319 write!(
320 f,
321 "// cast
322for (int i = 0; i < {reg_count}; ++i) {{
323 __half h_lo = __float2half_rn({input}[2*i + 0]);
324 __half h_hi = __float2half_rn({input}[2*i + 1]);
325 __half2 h2 = __halves2half2(h_lo, h_hi);
326 {output}[i] = *reinterpret_cast<unsigned int*>(&h2);
327}}
328"
329 )
330 }
331 Elem::BF16 => {
332 write!(
333 f,
334 "// cast
335for (int i = 0; i < {reg_count}; ++i) {{
336 __nv_bfloat16 b_lo = __float2bfloat16({input}[2*i + 0]);
337 __nv_bfloat16 b_hi = __float2bfloat16({input}[2*i + 1]);
338 __nv_bfloat162 bf2 = __halves2bfloat162(b_lo, b_hi);
339 {output}[i] = *reinterpret_cast<unsigned int*>(&bf2);
340}}
341"
342 )
343 }
344 other => panic!("casting fragment to {other} not supported"),
345 }
346 }
347 }
348 }
349
350 fn compile_manual_mma(
351 f: &mut std::fmt::Formatter<'_>,
352 mma: ManualMma<CudaDialect<Self>>,
353 ) -> std::fmt::Result {
354 compile_manual_mma(f, mma)
355 }
356
357 fn compile_scaled_mma(
358 f: &mut std::fmt::Formatter<'_>,
359 mma: ManualMma<CudaDialect<Self>>,
360 scales_a: Variable<CudaDialect<Self>>,
361 scales_b: Variable<CudaDialect<Self>>,
362 scales_factor: u32,
363 ) -> std::fmt::Result {
364 compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
365 }
366
367 fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
368 let mut result: SupportedMmaCombinations = vec![];
369 if arch.get_version() >= WMMA_MINIMUM_VERSION {
370 let types = vec![
372 (
373 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
377 (
378 gpu::ElemType::Float(gpu::FloatKind::F16),
379 gpu::ElemType::Float(gpu::FloatKind::F16),
380 gpu::ElemType::Float(gpu::FloatKind::F32),
381 ),
382 (
383 gpu::ElemType::Float(gpu::FloatKind::BF16),
384 gpu::ElemType::Float(gpu::FloatKind::BF16),
385 gpu::ElemType::Float(gpu::FloatKind::F32),
386 ),
387 ];
388 let combinations: SupportedMmaCombinations = types
389 .into_iter()
390 .map(|(a, b, cd)| MmaConfig {
391 a_type: a.into(),
392 b_type: b.into(),
393 cd_type: cd.into(),
394 m: 16,
395 n: 16,
396 k: 16,
397 })
398 .collect();
399 result.extend(combinations);
400 if arch.get_version() >= 72 {
401 result.extend([
402 MmaConfig {
403 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
404 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
405 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
406 m: 16,
407 n: 16,
408 k: 16,
409 },
410 MmaConfig {
411 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
412 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
413 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
414 m: 16,
415 n: 16,
416 k: 16,
417 },
418 ]);
419 }
420 if arch.get_version() >= 80 {
421 result.push(MmaConfig {
422 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
423 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
424 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
425 m: 16,
426 n: 16,
427 k: 8,
428 });
429 }
430 }
431 result
432 }
433
434 fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
435 supported_mma_combinations(arch)
436 }
437
438 fn supported_scaled_mma_combinations(
439 arch: &CudaArchitecture,
440 ) -> SupportedScaledMmaCombinations {
441 supported_scaled_mma_combinations(arch)
442 }
443}
444
445fn get_fragment_register_total_count(frag: &Fragment<CudaDialect<PtxWmmaCompiler>>) -> u32 {
446 let Fragment {
447 ident,
448 m,
449 n,
450 k,
451 elem,
452 ..
453 } = frag;
454 let elements = match ident {
455 FragmentIdent::A => m * k,
456 FragmentIdent::B => k * n,
457 FragmentIdent::Accumulator => m * n,
458 _ => unreachable!(),
459 };
460 let bits_per_elem = elem.size_bits() as u32;
461 let lanes_per_reg = 32 / bits_per_elem;
463 let threads_per_frag = match ident {
467 FragmentIdent::Accumulator => 32,
468 FragmentIdent::A | FragmentIdent::B => {
469 if frag.elem == Elem::TF32 {
470 32
471 } else {
472 16
473 }
474 }
475 _ => unreachable!(),
476 };
477
478 elements / (lanes_per_reg * threads_per_frag)
479}
480
481fn get_type_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
482 match var.elem() {
483 Elem::U8 => "u8",
484 Elem::I8 => "s8",
485 Elem::F16 => "f16",
486 Elem::BF16 => "bf16",
487 Elem::F32 => "f32",
488 Elem::TF32 => "tf32",
489 Elem::I32 => "s32",
490 Elem::F64 => "f64",
491 _ => panic!("unsupported WMMA fragment type"),
492 }
493 .to_string()
494}
495
496fn get_fragment_layout_qualifier(var: &Variable<CudaDialect<PtxWmmaCompiler>>) -> String {
497 let frag = match var {
498 Variable::WmmaFragment { frag, .. } => *frag,
499 _ => panic!("variable should be WmmaFragment"),
500 };
501 match frag.layout {
502 Some(layout) => get_qualifier_from_layout(&layout),
503 None => "".to_string(),
504 }
505}
506
507fn get_qualifier_from_layout(layout: &FragmentLayout<CudaDialect<PtxWmmaCompiler>>) -> String {
508 match layout {
509 FragmentLayout::ColMajor => "col",
510 FragmentLayout::RowMajor => "row",
511 FragmentLayout::_Dialect(..) => unreachable!(),
512 }
513 .to_string()
514}
515
516fn get_variable_regs_decl_constraints(
517 var: &Variable<CudaDialect<PtxWmmaCompiler>>,
518 output: bool,
519 reg_count: &mut u8,
520) -> (String, String) {
521 match var {
522 Variable::WmmaFragment { frag, .. } => {
523 let reg_total_count = get_fragment_register_total_count(frag);
524 let reg_decl = (0..reg_total_count)
525 .map(|_| format_reg_and_inc(reg_count))
526 .collect::<Vec<_>>()
527 .join(",");
528 let frag_elem = frag.elem;
529 let modifier = format!(
530 "{}{}",
531 if output { "=" } else { "" },
532 match frag_elem {
533 Elem::F32 => "f",
534 Elem::F64 => "d",
535 _ => "r",
536 },
537 );
538 let constraints = (0..reg_total_count)
539 .map(|i| format!("\"{modifier}\"({var}[{i}])"))
540 .collect::<Vec<_>>()
541 .join(", ");
542 (reg_decl, constraints)
543 }
544 Variable::Constant(number, ..) => match number {
545 ConstantValue::UInt(val, ..) => (val.to_string(), "".to_string()),
546 _ => panic!("variable should be an unsigned integer"),
547 },
548 _ => (format_reg_and_inc(reg_count), format!(r#", "r"({var})"#)),
549 }
550}
551
552fn format_reg_and_inc(count: &mut u8) -> String {
553 let res = format!("%{count}");
554 *count += 1;
555 res
556}
557
558fn as_ty_idx(var: impl Display, idx: impl Display, ty: impl Display) -> String {
559 format!("reinterpret_cast<{ty}*>({var})[{idx}]")
560}
561
562fn as_const_ty_idx(var: impl Display, idx: impl Display, ty: impl Display) -> String {
563 format!("reinterpret_cast<const {ty}*>({var})[{idx}]")
564}
565
566pub(super) fn compile_manual_mma<D: Dialect>(
567 f: &mut core::fmt::Formatter<'_>,
568 mma: ManualMma<D>,
569) -> std::fmt::Result {
570 let ManualMma {
571 shape,
572 frag_a,
573 frag_b,
574 frag_c,
575 frag_d,
576 } = mma;
577
578 let a_elem = frag_a.elem().unpacked();
579 let b_elem = frag_b.elem().unpacked();
580 let cd_elem = frag_c.elem().unpacked();
581
582 let ab_ty = match a_elem {
583 Elem::F32 => &format!("{}", Elem::<D>::F32),
584 _ => &format!("{}", Elem::<D>::U32),
585 };
586 let cd_ty = match cd_elem {
587 Elem::F32 => &format!("{}", Elem::<D>::F32),
588 _ => &format!("{}", Elem::<D>::U32),
589 };
590
591 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
592 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
593 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
594
595 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
596 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
597 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
598
599 let frag_a = (0..a_regs).map(|i| as_const_ty_idx(frag_a, i, ab_ty));
600 let frag_b = (0..b_regs).map(|i| as_const_ty_idx(frag_b, i, ab_ty));
601 let frag_c = (0..cd_regs).map(|i| as_const_ty_idx(frag_c, i, cd_ty));
602 let frag_d = (0..cd_regs).map(|i| as_ty_idx(frag_d, i, cd_ty));
603
604 let args = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
605 write!(
606 f,
607 "__mma_m16n8k{}_{}_{}_{}({args});",
608 shape.k, a_elem, b_elem, cd_elem
609 )
610}
611
612pub(super) fn compile_scaled_mma<D: Dialect>(
613 f: &mut core::fmt::Formatter<'_>,
614 mma: ManualMma<D>,
615 scales_a: Variable<D>,
616 scales_b: Variable<D>,
617 scales_factor: u32,
618) -> std::fmt::Result {
619 let ManualMma {
620 shape,
621 frag_a,
622 frag_b,
623 frag_c,
624 frag_d,
625 } = mma;
626
627 let a_elem = frag_a.elem().unpacked();
628 let b_elem = frag_b.elem().unpacked();
629 let cd_elem = frag_c.elem().unpacked();
630
631 let ab_ty = &format!("{}", Elem::<D>::U32);
632 let cd_ty = &format!("{}", Elem::<D>::F32);
633
634 let a_elems = shape.num_elems(FragmentIdent::<D>::A) / 32;
635 let b_elems = shape.num_elems(FragmentIdent::<D>::B) / 32;
636 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
637
638 let a_regs = a_elems as usize / (32 / frag_a.elem().unpacked().size_bits());
639 let b_regs = b_elems as usize / (32 / frag_b.elem().unpacked().size_bits());
640 let cd_regs = cd_elems as usize / (32 / frag_c.elem().unpacked().size_bits());
641
642 let frag_a = (0..a_regs).map(|i| as_const_ty_idx(frag_a, i, ab_ty));
643 let frag_b = (0..b_regs).map(|i| as_const_ty_idx(frag_b, i, ab_ty));
644 let frag_c = (0..cd_regs).map(|i| as_const_ty_idx(frag_c, i, cd_ty));
645 let frag_d = (0..cd_regs).map(|i| as_ty_idx(frag_d, i, cd_ty));
646
647 let fragments = comma_separated(frag_a.chain(frag_b).chain(frag_c).chain(frag_d));
648 write!(
649 f,
650 "__mma_scaled_{scales_factor}x_m16n8k{}_{}_{}_{}({fragments}, reinterpret_cast<uint32&>({scales_a}), reinterpret_cast<uint32&>({scales_b}));",
651 shape.k, a_elem, b_elem, cd_elem
652 )
653}
654
655pub(super) fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
656 let mut result: SupportedMmaCombinations = vec![];
657 if arch.get_version() >= 80 {
661 result.extend([
662 MmaConfig {
663 a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(), cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(), m: 16,
667 n: 8,
668 k: 16,
669 },
670 MmaConfig {
671 a_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
672 b_type: gpu::ElemType::Float(gpu::FloatKind::BF16).into(),
673 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
674 m: 16,
675 n: 8,
676 k: 16,
677 },
678 MmaConfig {
679 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
680 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
681 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
682 m: 16,
683 n: 8,
684 k: 8,
685 },
686 MmaConfig {
687 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
688 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
689 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
690 m: 16,
691 n: 8,
692 k: 32,
693 },
694 MmaConfig {
695 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
696 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
697 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
698 m: 16,
699 n: 8,
700 k: 32,
701 },
702 MmaConfig {
703 a_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
704 b_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
705 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
706 m: 16,
707 n: 8,
708 k: 32,
709 },
710 MmaConfig {
711 a_type: gpu::ElemType::UInt(gpu::UIntKind::U8).into(),
712 b_type: gpu::ElemType::Int(gpu::IntKind::I8).into(),
713 cd_type: gpu::ElemType::Int(gpu::IntKind::I32).into(),
714 m: 16,
715 n: 8,
716 k: 32,
717 },
718 ]);
720 }
721 if arch.get_version() >= 89 {
722 let f8f6f4_types = [
723 gpu::FloatKind::E4M3,
724 gpu::FloatKind::E5M2,
725 gpu::FloatKind::E3M2,
726 gpu::FloatKind::E2M3,
727 gpu::FloatKind::E2M1,
728 ];
729 let combinations = f8f6f4_types.iter().cartesian_product(f8f6f4_types.iter());
730 result.extend(combinations.map(|(t1, t2)| MmaConfig {
731 a_type: gpu::ElemType::Float(*t1).into(),
732 b_type: gpu::ElemType::Float(*t2).into(),
733 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
734 m: 16,
735 n: 8,
736 k: 32,
737 }));
738 }
739 if arch.get_version() >= 70 && arch.get_version() < 80 {
741 result.push(MmaConfig {
742 a_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(),
743 b_type: gpu::ElemType::Float(gpu::FloatKind::F16).into(),
744 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
745 m: 16,
746 n: 8,
747 k: 8,
748 });
749 }
750 result
751}
752
753pub(super) fn supported_scaled_mma_combinations(
754 arch: &CudaArchitecture,
755) -> SupportedScaledMmaCombinations {
756 let mut result: SupportedScaledMmaCombinations = vec![];
757 if arch.get_version() >= 120 && arch.get_version() < 130 {
759 let f8f6f4_types = [
760 gpu::FloatKind::E4M3,
761 gpu::FloatKind::E5M2,
762 gpu::FloatKind::E3M2,
763 gpu::FloatKind::E2M3,
764 gpu::FloatKind::E2M1,
765 ];
766 let combinations = f8f6f4_types
767 .iter()
768 .flat_map(|t1| f8f6f4_types.iter().map(move |t2| (t1, t2)));
769
770 result.extend(combinations.map(|(t1, t2)| ScaledMmaConfig {
771 a_type: gpu::ElemType::Float(*t1).into(),
772 b_type: gpu::ElemType::Float(*t2).into(),
773 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
774 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
775 m: 16,
776 n: 8,
777 k: 32,
778 scales_factor: 1,
779 }));
780
781 result.extend([
782 ScaledMmaConfig {
783 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
784 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
785 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
786 scales_type: gpu::ElemType::Float(gpu::FloatKind::UE8M0).into(),
787 m: 16,
788 n: 8,
789 k: 64,
790 scales_factor: 2,
791 },
792 ScaledMmaConfig {
794 a_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
795 b_type: gpu::StorageType::Packed(gpu::ElemType::Float(gpu::FloatKind::E2M1), 2),
796 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
797 scales_type: gpu::ElemType::Float(gpu::FloatKind::E4M3).into(),
798 m: 16,
799 n: 8,
800 k: 64,
801 scales_factor: 4,
802 },
803 ]);
804 }
805 result
806}
807
808pub fn contiguous_elements_cuda(ident: MatrixIdent, matrix: Matrix) -> usize {
809 match ident {
810 MatrixIdent::A | MatrixIdent::B => 32 / matrix.storage.size_bits(),
811 MatrixIdent::Accumulator => 2,
812 }
813}