1use super::{WMMA_MINIMUM_VERSION, WMMA_NAMESPACE};
2use crate::{
3 cuda::{
4 CudaDialect,
5 arch::CudaArchitecture,
6 mma::{
7 compile_manual_mma, compile_scaled_mma, supported_mma_combinations,
8 supported_scaled_mma_combinations,
9 },
10 },
11 shared::{
12 Architecture, DialectWmmaCompiler, Flags, Fragment, FragmentIdent, FragmentLayout,
13 ManualMma, SupportedMmaCombinations, SupportedScaledMmaCombinations, Variable,
14 WmmaInstruction, wmma_api_base,
15 },
16};
17use cubecl_core::ir::{self as gpu, features::MmaConfig};
18use itertools::Itertools;
19
20#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
21pub struct CudaWmmaCompiler {}
22
23impl DialectWmmaCompiler<CudaDialect<Self>> for CudaWmmaCompiler {
24 fn compile_wmma_includes(
25 f: &mut std::fmt::Formatter<'_>,
26 _flags: &Flags<CudaDialect<Self>>,
27 ) -> std::fmt::Result {
28 f.write_str("#include <mma.h>\n")
29 }
30
31 fn compile_wmma_fragment_declaration(
32 f: &mut std::fmt::Formatter<'_>,
33 var: &crate::shared::Variable<CudaDialect<Self>>,
34 ) -> std::fmt::Result {
35 wmma_api_base::compile_fragment_declaration(f, var)
36 }
37
38 fn compile_wwma_fragment_ident(
39 f: &mut std::fmt::Formatter<'_>,
40 ident: &FragmentIdent<CudaDialect<Self>>,
41 ) -> std::fmt::Result {
42 wmma_api_base::compile_fragment_ident(f, WMMA_NAMESPACE, ident)
43 }
44
45 fn compile_wmma_fragment_layout(
46 f: &mut std::fmt::Formatter<'_>,
47 layout: &FragmentLayout<CudaDialect<Self>>,
48 ) -> std::fmt::Result {
49 wmma_api_base::compile_fragment_layout(f, WMMA_NAMESPACE, layout)
50 }
51
52 fn compile_wmma_fragment(
53 f: &mut std::fmt::Formatter<'_>,
54 fragment: &Fragment<CudaDialect<Self>>,
55 ) -> std::fmt::Result {
56 wmma_api_base::compile_fragment(f, WMMA_NAMESPACE, fragment)
57 }
58
59 fn compile_wmma_instruction(
60 f: &mut std::fmt::Formatter<'_>,
61 instruction: &WmmaInstruction<CudaDialect<Self>>,
62 ) -> std::fmt::Result {
63 wmma_api_base::compile_instruction(f, WMMA_NAMESPACE, instruction)
64 }
65
66 fn compile_manual_mma(
67 f: &mut std::fmt::Formatter<'_>,
68 mma: ManualMma<CudaDialect<Self>>,
69 ) -> std::fmt::Result {
70 compile_manual_mma(f, mma)
71 }
72
73 fn compile_scaled_mma(
74 f: &mut std::fmt::Formatter<'_>,
75 mma: ManualMma<CudaDialect<Self>>,
76 scales_a: Variable<CudaDialect<Self>>,
77 scales_b: Variable<CudaDialect<Self>>,
78 scales_factor: u32,
79 ) -> std::fmt::Result {
80 compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
81 }
82
83 fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
84 let mut result: SupportedMmaCombinations = vec![];
85 if arch.get_version() >= WMMA_MINIMUM_VERSION {
86 let tdims = vec![(16, 16, 16), (32, 8, 16), (8, 32, 16)];
87 let types = vec![
89 (
90 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
94 (
95 gpu::ElemType::Float(gpu::FloatKind::F16),
96 gpu::ElemType::Float(gpu::FloatKind::F16),
97 gpu::ElemType::Float(gpu::FloatKind::F32),
98 ),
99 (
100 gpu::ElemType::Float(gpu::FloatKind::BF16),
101 gpu::ElemType::Float(gpu::FloatKind::BF16),
102 gpu::ElemType::Float(gpu::FloatKind::F32),
103 ),
104 (
105 gpu::ElemType::Int(gpu::IntKind::I8),
106 gpu::ElemType::Int(gpu::IntKind::I8),
107 gpu::ElemType::Int(gpu::IntKind::I32),
108 ),
109 (
110 gpu::ElemType::UInt(gpu::UIntKind::U8),
111 gpu::ElemType::UInt(gpu::UIntKind::U8),
112 gpu::ElemType::Int(gpu::IntKind::I32),
113 ),
114 ];
115 let combinations: SupportedMmaCombinations = types
116 .into_iter()
117 .cartesian_product(tdims)
118 .map(|((a, b, c), (m, n, k))| MmaConfig {
119 a_type: a.into(),
120 b_type: b.into(),
121 cd_type: c.into(),
122 m,
123 n,
124 k,
125 })
126 .collect();
127 result.extend(combinations);
128 if arch.get_version() >= 80 {
129 result.push(MmaConfig {
130 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
131 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
132 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
133 m: 16,
134 n: 16,
135 k: 8,
136 });
137 }
138 }
139 result
140 }
141
142 fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
143 supported_mma_combinations(arch)
144 }
145
146 fn supported_scaled_mma_combinations(
147 arch: &CudaArchitecture,
148 ) -> SupportedScaledMmaCombinations {
149 supported_scaled_mma_combinations(arch)
150 }
151}