1use crate::{
2 hip::{
3 HipDialect,
4 arch::AMDArchitecture,
5 mma::{compile_manual_mma, supported_mma_combinations},
6 },
7 shared::{
8 DialectWmmaCompiler, Flags, Fragment, FragmentIdent, FragmentLayout, ManualMma,
9 SupportedMmaCombinations, Variable, WmmaInstruction, wmma_api_base,
10 },
11};
12use cubecl_core::ir::{self as gpu};
13use cubecl_runtime::MmaConfig;
14
15const ROCWMMA_NAMESPACE: &str = "rocwmma";
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
18pub struct RocWmmaCompiler {}
19
20impl DialectWmmaCompiler<HipDialect<Self>> for RocWmmaCompiler {
21 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
22 f.write_str("#include <rocwmma/rocwmma.hpp>\n")
23 }
24
25 fn compile_wmma_type_definitions(
26 f: &mut std::fmt::Formatter<'_>,
27 flags: &Flags,
28 ) -> std::fmt::Result {
29 if flags.elem_bf16 {
31 f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
32 f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
33 }
34 if flags.elem_f16 {
35 f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
36 f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
37 }
38 f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
39 }
40
41 fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 Ok(())
43 }
44
45 fn compile_wmma_fragment_declaration(
46 f: &mut std::fmt::Formatter<'_>,
47 var: &crate::shared::Variable<HipDialect<Self>>,
48 ) -> std::fmt::Result {
49 wmma_api_base::compile_fragment_declaration(f, var)
50 }
51
52 fn compile_wwma_fragment_ident(
53 f: &mut std::fmt::Formatter<'_>,
54 ident: &FragmentIdent<HipDialect<Self>>,
55 ) -> std::fmt::Result {
56 wmma_api_base::compile_fragment_ident(f, ROCWMMA_NAMESPACE, ident)
57 }
58
59 fn compile_wmma_fragment_layout(
60 f: &mut std::fmt::Formatter<'_>,
61 layout: &FragmentLayout<HipDialect<Self>>,
62 ) -> std::fmt::Result {
63 wmma_api_base::compile_fragment_layout(f, ROCWMMA_NAMESPACE, layout)
64 }
65
66 fn compile_wmma_fragment(
67 f: &mut std::fmt::Formatter<'_>,
68 fragment: &Fragment<HipDialect<Self>>,
69 ) -> std::fmt::Result {
70 wmma_api_base::compile_fragment(f, ROCWMMA_NAMESPACE, fragment)
71 }
72
73 fn compile_wmma_instruction(
74 f: &mut std::fmt::Formatter<'_>,
75 instruction: &WmmaInstruction<HipDialect<Self>>,
76 ) -> std::fmt::Result {
77 wmma_api_base::compile_instruction(f, ROCWMMA_NAMESPACE, instruction)
78 }
79
80 fn compile_manual_mma(
81 f: &mut std::fmt::Formatter<'_>,
82 mma: ManualMma<HipDialect<Self>>,
83 ) -> std::fmt::Result {
84 compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
85 }
86
87 fn compile_scaled_mma(
88 f: &mut std::fmt::Formatter<'_>,
89 _mma: ManualMma<HipDialect<Self>>,
90 _scales_a: Variable<HipDialect<Self>>,
91 _scales_b: Variable<HipDialect<Self>>,
92 _scales_factor: u32,
93 ) -> std::fmt::Result {
94 f.write_str("#error Scaled MMA not supported on HIP\n")
95 }
96
97 fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
98 let combinations = match arch {
99 AMDArchitecture::GFX10 | AMDArchitecture::GFX11 => {
100 let tdims = vec![(16, 16, 16), (16, 16, 32)];
103 let types = vec![
104 (
105 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32), ),
109 (
110 gpu::ElemType::Float(gpu::FloatKind::F16),
111 gpu::ElemType::Float(gpu::FloatKind::F16),
112 gpu::ElemType::Float(gpu::FloatKind::F32),
113 ),
114 (
115 gpu::ElemType::Float(gpu::FloatKind::F16),
116 gpu::ElemType::Float(gpu::FloatKind::F16),
117 gpu::ElemType::Float(gpu::FloatKind::F16),
118 ),
119 (
120 gpu::ElemType::Float(gpu::FloatKind::BF16),
121 gpu::ElemType::Float(gpu::FloatKind::F32),
122 gpu::ElemType::Float(gpu::FloatKind::F32),
123 ),
124 (
125 gpu::ElemType::Float(gpu::FloatKind::BF16),
126 gpu::ElemType::Float(gpu::FloatKind::BF16),
127 gpu::ElemType::Float(gpu::FloatKind::F32),
128 ),
129 (
130 gpu::ElemType::Float(gpu::FloatKind::BF16),
131 gpu::ElemType::Float(gpu::FloatKind::BF16),
132 gpu::ElemType::Float(gpu::FloatKind::BF16),
133 ),
134 ];
135 types.into_iter().map(|it| (it, tdims.clone())).collect()
136 }
137 AMDArchitecture::GFX908 => {
138 vec![
139 (
140 (
141 gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32),
144 ), vec![
146 (16, 16, 4),
148 (16, 16, 8),
149 (16, 16, 16),
150 (16, 16, 32),
151 (32, 32, 2),
152 (32, 32, 4),
153 (32, 32, 8),
154 (32, 32, 16),
155 (32, 32, 32),
156 ],
157 ),
158 (
159 (
160 gpu::ElemType::Float(gpu::FloatKind::F16),
161 gpu::ElemType::Float(gpu::FloatKind::F32),
162 gpu::ElemType::Float(gpu::FloatKind::F32),
163 ),
164 vec![
165 (16, 16, 16),
166 (16, 16, 32),
167 (32, 32, 8),
168 (32, 32, 16),
169 (32, 32, 32),
170 ],
171 ),
172 (
173 (
174 gpu::ElemType::Float(gpu::FloatKind::F16),
175 gpu::ElemType::Float(gpu::FloatKind::F16),
176 gpu::ElemType::Float(gpu::FloatKind::F32),
177 ),
178 vec![
179 (16, 16, 16),
180 (16, 16, 32),
181 (32, 32, 8),
182 (32, 32, 16),
183 (32, 32, 32),
184 ],
185 ),
186 (
187 (
188 gpu::ElemType::Float(gpu::FloatKind::F16),
189 gpu::ElemType::Float(gpu::FloatKind::F16),
190 gpu::ElemType::Float(gpu::FloatKind::F16),
191 ),
192 vec![
193 (16, 16, 16),
194 (16, 16, 32),
195 (32, 32, 8),
196 (32, 32, 16),
197 (32, 32, 32),
198 ],
199 ),
200 (
201 (
202 gpu::ElemType::Float(gpu::FloatKind::BF16),
203 gpu::ElemType::Float(gpu::FloatKind::F32),
204 gpu::ElemType::Float(gpu::FloatKind::F32),
205 ),
206 vec![
207 (16, 16, 8),
208 (16, 16, 16),
209 (16, 16, 32),
210 (32, 32, 4),
211 (32, 32, 8),
212 (32, 32, 16),
213 (32, 32, 32),
214 ],
215 ),
216 (
217 (
218 gpu::ElemType::Float(gpu::FloatKind::BF16),
219 gpu::ElemType::Float(gpu::FloatKind::BF16),
220 gpu::ElemType::Float(gpu::FloatKind::F32),
221 ),
222 vec![
223 (16, 16, 8),
224 (16, 16, 16),
225 (16, 16, 32),
226 (32, 32, 4),
227 (32, 32, 8),
228 (32, 32, 16),
229 (32, 32, 32),
230 ],
231 ),
232 (
233 (
234 gpu::ElemType::Float(gpu::FloatKind::BF16),
235 gpu::ElemType::Float(gpu::FloatKind::BF16),
236 gpu::ElemType::Float(gpu::FloatKind::BF16),
237 ),
238 vec![
239 (16, 16, 8),
240 (16, 16, 16),
241 (16, 16, 32),
242 (32, 32, 4),
243 (32, 32, 8),
244 (32, 32, 16),
245 (32, 32, 32),
246 ],
247 ),
248 ]
249 }
250 AMDArchitecture::GFX90A | AMDArchitecture::GFX94 => {
251 vec![
252 (
253 (
254 gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32),
257 ), vec![
259 (16, 16, 4),
261 (16, 16, 8),
262 (16, 16, 16),
263 (16, 16, 32),
264 (32, 32, 2),
265 (32, 32, 4),
266 (32, 32, 8),
267 (32, 32, 16),
268 (32, 32, 32),
269 ],
270 ),
271 (
272 (
273 gpu::ElemType::Float(gpu::FloatKind::F16),
274 gpu::ElemType::Float(gpu::FloatKind::F32),
275 gpu::ElemType::Float(gpu::FloatKind::F32),
276 ),
277 vec![
278 (16, 16, 16),
279 (16, 16, 32),
280 (32, 32, 8),
281 (32, 32, 16),
282 (32, 32, 32),
283 ],
284 ),
285 (
286 (
287 gpu::ElemType::Float(gpu::FloatKind::F16),
288 gpu::ElemType::Float(gpu::FloatKind::F16),
289 gpu::ElemType::Float(gpu::FloatKind::F32),
290 ),
291 vec![
292 (16, 16, 16),
293 (16, 16, 32),
294 (32, 32, 8),
295 (32, 32, 16),
296 (32, 32, 32),
297 ],
298 ),
299 (
300 (
301 gpu::ElemType::Float(gpu::FloatKind::F16),
302 gpu::ElemType::Float(gpu::FloatKind::F16),
303 gpu::ElemType::Float(gpu::FloatKind::F16),
304 ),
305 vec![
306 (16, 16, 16),
307 (16, 16, 32),
308 (32, 32, 8),
309 (32, 32, 16),
310 (32, 32, 32),
311 ],
312 ),
313 (
314 (
315 gpu::ElemType::Float(gpu::FloatKind::BF16),
316 gpu::ElemType::Float(gpu::FloatKind::F32),
317 gpu::ElemType::Float(gpu::FloatKind::F32),
318 ),
319 vec![
320 (16, 16, 16),
321 (16, 16, 32),
322 (32, 32, 8),
323 (32, 32, 16),
324 (32, 32, 32),
325 ],
326 ),
327 (
328 (
329 gpu::ElemType::Float(gpu::FloatKind::BF16),
330 gpu::ElemType::Float(gpu::FloatKind::BF16),
331 gpu::ElemType::Float(gpu::FloatKind::F32),
332 ),
333 vec![
334 (16, 16, 16),
335 (16, 16, 32),
336 (32, 32, 8),
337 (32, 32, 16),
338 (32, 32, 32),
339 ],
340 ),
341 (
342 (
343 gpu::ElemType::Float(gpu::FloatKind::BF16),
344 gpu::ElemType::Float(gpu::FloatKind::BF16),
345 gpu::ElemType::Float(gpu::FloatKind::BF16),
346 ),
347 vec![
348 (16, 16, 16),
349 (16, 16, 32),
350 (32, 32, 8),
351 (32, 32, 16),
352 (32, 32, 32),
353 ],
354 ),
355 ]
356 }
357 AMDArchitecture::Other => vec![],
358 };
359 combinations
360 .into_iter()
361 .flat_map(|(ty, sizes)| sizes.into_iter().map(move |size| (ty, size)))
362 .map(|((i, o, c), (m, n, k))| MmaConfig {
363 a_type: i.into(),
364 b_type: o.into(),
365 cd_type: c.into(),
366 m,
367 n,
368 k,
369 })
370 .collect()
371 }
372
373 fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
374 supported_mma_combinations(arch)
375 }
376}