1use crate::{
2 cuda::{
3 CudaDialect,
4 arch::CudaArchitecture,
5 mma::{
6 compile_manual_mma, compile_scaled_mma, supported_mma_combinations,
7 supported_scaled_mma_combinations,
8 },
9 },
10 shared::{
11 Architecture, DialectWmmaCompiler, Flags, Fragment, FragmentIdent, FragmentLayout,
12 ManualMma, SupportedMmaCombinations, SupportedScaledMmaCombinations, Variable,
13 WmmaInstruction, wmma_api_base,
14 },
15};
16use cubecl_core::ir::{self as gpu};
17use cubecl_runtime::MmaConfig;
18use itertools::Itertools;
19
20use super::{WMMA_MINIMUM_VERSION, WMMA_NAMESPACE};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for CudaWmmaCompiler {
26 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
27 f.write_str("#include <mma.h>\n")
28 }
29
30 fn compile_wmma_fragment_declaration(
31 f: &mut std::fmt::Formatter<'_>,
32 var: &crate::shared::Variable<CudaDialect<Self>>,
33 ) -> std::fmt::Result {
34 wmma_api_base::compile_fragment_declaration(f, var)
35 }
36
37 fn compile_wwma_fragment_ident(
38 f: &mut std::fmt::Formatter<'_>,
39 ident: &FragmentIdent<CudaDialect<Self>>,
40 ) -> std::fmt::Result {
41 wmma_api_base::compile_fragment_ident(f, WMMA_NAMESPACE, ident)
42 }
43
44 fn compile_wmma_fragment_layout(
45 f: &mut std::fmt::Formatter<'_>,
46 layout: &FragmentLayout<CudaDialect<Self>>,
47 ) -> std::fmt::Result {
48 wmma_api_base::compile_fragment_layout(f, WMMA_NAMESPACE, layout)
49 }
50
51 fn compile_wmma_fragment(
52 f: &mut std::fmt::Formatter<'_>,
53 fragment: &Fragment<CudaDialect<Self>>,
54 ) -> std::fmt::Result {
55 wmma_api_base::compile_fragment(f, WMMA_NAMESPACE, fragment)
56 }
57
58 fn compile_wmma_instruction(
59 f: &mut std::fmt::Formatter<'_>,
60 instruction: &WmmaInstruction<CudaDialect<Self>>,
61 ) -> std::fmt::Result {
62 wmma_api_base::compile_instruction(f, WMMA_NAMESPACE, instruction)
63 }
64
65 fn compile_manual_mma(
66 f: &mut std::fmt::Formatter<'_>,
67 mma: ManualMma<CudaDialect<Self>>,
68 ) -> std::fmt::Result {
69 compile_manual_mma(f, mma)
70 }
71
72 fn compile_scaled_mma(
73 f: &mut std::fmt::Formatter<'_>,
74 mma: ManualMma<CudaDialect<Self>>,
75 scales_a: Variable<CudaDialect<Self>>,
76 scales_b: Variable<CudaDialect<Self>>,
77 scales_factor: u32,
78 ) -> std::fmt::Result {
79 compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
80 }
81
82 fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
83 let mut result: SupportedMmaCombinations = vec![];
84 if arch.get_version() >= WMMA_MINIMUM_VERSION {
85 let tdims = vec![(16, 16, 16), (32, 8, 16), (8, 32, 16)];
86 let types = vec![
88 (
89 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
93 (
94 gpu::ElemType::Float(gpu::FloatKind::F16),
95 gpu::ElemType::Float(gpu::FloatKind::F16),
96 gpu::ElemType::Float(gpu::FloatKind::F32),
97 ),
98 (
99 gpu::ElemType::Float(gpu::FloatKind::BF16),
100 gpu::ElemType::Float(gpu::FloatKind::BF16),
101 gpu::ElemType::Float(gpu::FloatKind::F32),
102 ),
103 (
104 gpu::ElemType::Int(gpu::IntKind::I8),
105 gpu::ElemType::Int(gpu::IntKind::I8),
106 gpu::ElemType::Int(gpu::IntKind::I32),
107 ),
108 (
109 gpu::ElemType::UInt(gpu::UIntKind::U8),
110 gpu::ElemType::UInt(gpu::UIntKind::U8),
111 gpu::ElemType::Int(gpu::IntKind::I32),
112 ),
113 ];
114 let combinations: SupportedMmaCombinations = types
115 .into_iter()
116 .cartesian_product(tdims)
117 .map(|((a, b, c), (m, n, k))| MmaConfig {
118 a_type: a.into(),
119 b_type: b.into(),
120 cd_type: c.into(),
121 m,
122 n,
123 k,
124 })
125 .collect();
126 result.extend(combinations);
127 if arch.get_version() >= 80 {
128 result.push(MmaConfig {
129 a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
130 b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
131 cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
132 m: 16,
133 n: 16,
134 k: 8,
135 });
136 }
137 }
138 result
139 }
140
141 fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
142 supported_mma_combinations(arch)
143 }
144
145 fn supported_scaled_mma_combinations(
146 arch: &CudaArchitecture,
147 ) -> SupportedScaledMmaCombinations {
148 supported_scaled_mma_combinations(arch)
149 }
150}