1use crate::{
2 hip::{
3 HipDialect,
4 arch::AMDArchitecture,
5 mma::{compile_manual_mma, supported_mma_combinations},
6 },
7 shared::{
8 DialectWmmaCompiler, Flags, Fragment, FragmentIdent, FragmentLayout, ManualMma,
9 SupportedMmaCombinations, Variable, WmmaInstruction, wmma_api_base,
10 },
11};
12use cubecl_core::ir::{self as gpu, features::MmaConfig};
13
14const ROCWMMA_NAMESPACE: &str = "rocwmma";
15
16#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
17pub struct RocWmmaCompiler {}
18
19impl DialectWmmaCompiler<HipDialect<Self>> for RocWmmaCompiler {
20 fn compile_wmma_includes(
21 f: &mut std::fmt::Formatter<'_>,
22 _flags: &Flags<HipDialect<Self>>,
23 ) -> std::fmt::Result {
24 f.write_str("#include <rocwmma/rocwmma.hpp>\n")
25 }
26
27 fn compile_wmma_type_definitions(
28 f: &mut std::fmt::Formatter<'_>,
29 flags: &Flags<HipDialect<Self>>,
30 ) -> std::fmt::Result {
31 if flags.elem_bf16 {
33 f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
34 f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
35 }
36 if flags.elem_f16 {
37 f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
38 f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
39 }
40 f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
41 }
42
43 fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 Ok(())
45 }
46
47 fn compile_wmma_fragment_declaration(
48 f: &mut std::fmt::Formatter<'_>,
49 var: &crate::shared::Variable<HipDialect<Self>>,
50 ) -> std::fmt::Result {
51 wmma_api_base::compile_fragment_declaration(f, var)
52 }
53
54 fn compile_wwma_fragment_ident(
55 f: &mut std::fmt::Formatter<'_>,
56 ident: &FragmentIdent<HipDialect<Self>>,
57 ) -> std::fmt::Result {
58 wmma_api_base::compile_fragment_ident(f, ROCWMMA_NAMESPACE, ident)
59 }
60
61 fn compile_wmma_fragment_layout(
62 f: &mut std::fmt::Formatter<'_>,
63 layout: &FragmentLayout<HipDialect<Self>>,
64 ) -> std::fmt::Result {
65 wmma_api_base::compile_fragment_layout(f, ROCWMMA_NAMESPACE, layout)
66 }
67
68 fn compile_wmma_fragment(
69 f: &mut std::fmt::Formatter<'_>,
70 fragment: &Fragment<HipDialect<Self>>,
71 ) -> std::fmt::Result {
72 wmma_api_base::compile_fragment(f, ROCWMMA_NAMESPACE, fragment)
73 }
74
75 fn compile_wmma_instruction(
76 f: &mut std::fmt::Formatter<'_>,
77 instruction: &WmmaInstruction<HipDialect<Self>>,
78 ) -> std::fmt::Result {
79 wmma_api_base::compile_instruction(f, ROCWMMA_NAMESPACE, instruction)
80 }
81
82 fn compile_manual_mma(
83 f: &mut std::fmt::Formatter<'_>,
84 mma: ManualMma<HipDialect<Self>>,
85 ) -> std::fmt::Result {
86 compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
87 }
88
89 fn compile_scaled_mma(
90 f: &mut std::fmt::Formatter<'_>,
91 _mma: ManualMma<HipDialect<Self>>,
92 _scales_a: Variable<HipDialect<Self>>,
93 _scales_b: Variable<HipDialect<Self>>,
94 _scales_factor: u32,
95 ) -> std::fmt::Result {
96 f.write_str("#error Scaled MMA not supported on HIP\n")
97 }
98
99 fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
100 let combinations = match arch {
101 AMDArchitecture::GFX12 => {
102 let tdims_16_16_32 = vec![(16, 16, 32)];
104 let types_16_16_32 = vec![
105 (
106 gpu::ElemType::Float(gpu::FloatKind::E5M2), gpu::ElemType::Float(gpu::FloatKind::F32),
108 gpu::ElemType::Float(gpu::FloatKind::F32),
109 ),
110 (
111 gpu::ElemType::Float(gpu::FloatKind::E4M3), gpu::ElemType::Float(gpu::FloatKind::F32),
113 gpu::ElemType::Float(gpu::FloatKind::F32),
114 ),
115 ];
116
117 let tdims_16_16_16 = vec![(16, 16, 16)];
118 let types_16_16_16 = vec![
119 (
120 gpu::ElemType::Int(gpu::IntKind::I8),
121 gpu::ElemType::Int(gpu::IntKind::I32),
122 gpu::ElemType::Int(gpu::IntKind::I32),
123 ),
124 (
125 gpu::ElemType::Int(gpu::IntKind::I8),
126 gpu::ElemType::Int(gpu::IntKind::I8),
127 gpu::ElemType::Int(gpu::IntKind::I32),
128 ),
129 (
130 gpu::ElemType::Float(gpu::FloatKind::F16),
131 gpu::ElemType::Float(gpu::FloatKind::F32),
132 gpu::ElemType::Float(gpu::FloatKind::F32),
133 ),
134 (
135 gpu::ElemType::Float(gpu::FloatKind::F16),
136 gpu::ElemType::Float(gpu::FloatKind::F16),
137 gpu::ElemType::Float(gpu::FloatKind::F32),
138 ),
139 (
140 gpu::ElemType::Float(gpu::FloatKind::F16),
141 gpu::ElemType::Float(gpu::FloatKind::F16),
142 gpu::ElemType::Float(gpu::FloatKind::F16),
143 ),
144 (
145 gpu::ElemType::Float(gpu::FloatKind::BF16),
146 gpu::ElemType::Float(gpu::FloatKind::F32),
147 gpu::ElemType::Float(gpu::FloatKind::F32),
148 ),
149 (
150 gpu::ElemType::Float(gpu::FloatKind::BF16),
151 gpu::ElemType::Float(gpu::FloatKind::BF16),
152 gpu::ElemType::Float(gpu::FloatKind::F32),
153 ),
154 (
155 gpu::ElemType::Float(gpu::FloatKind::BF16),
156 gpu::ElemType::Float(gpu::FloatKind::BF16),
157 gpu::ElemType::Float(gpu::FloatKind::BF16),
158 ),
159 ];
160
161 types_16_16_32
163 .into_iter()
164 .map(|it| (it, tdims_16_16_32.clone()))
165 .chain(
166 types_16_16_16
167 .into_iter()
168 .map(|it| (it, tdims_16_16_16.clone())),
169 )
170 .collect()
171 }
172 AMDArchitecture::GFX10 | AMDArchitecture::GFX11 => {
173 let tdims = vec![(16, 16, 16), (16, 16, 32)];
176 let types = vec![
177 (
178 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32), ),
182 (
183 gpu::ElemType::Float(gpu::FloatKind::F16),
184 gpu::ElemType::Float(gpu::FloatKind::F16),
185 gpu::ElemType::Float(gpu::FloatKind::F32),
186 ),
187 (
188 gpu::ElemType::Float(gpu::FloatKind::F16),
189 gpu::ElemType::Float(gpu::FloatKind::F16),
190 gpu::ElemType::Float(gpu::FloatKind::F16),
191 ),
192 (
193 gpu::ElemType::Float(gpu::FloatKind::BF16),
194 gpu::ElemType::Float(gpu::FloatKind::F32),
195 gpu::ElemType::Float(gpu::FloatKind::F32),
196 ),
197 (
198 gpu::ElemType::Float(gpu::FloatKind::BF16),
199 gpu::ElemType::Float(gpu::FloatKind::BF16),
200 gpu::ElemType::Float(gpu::FloatKind::F32),
201 ),
202 (
203 gpu::ElemType::Float(gpu::FloatKind::BF16),
204 gpu::ElemType::Float(gpu::FloatKind::BF16),
205 gpu::ElemType::Float(gpu::FloatKind::BF16),
206 ),
207 ];
208 types.into_iter().map(|it| (it, tdims.clone())).collect()
209 }
210 AMDArchitecture::GFX908 => {
211 vec![
212 (
213 (
214 gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32),
217 ), vec![
219 (16, 16, 4),
221 (16, 16, 8),
222 (16, 16, 16),
223 (16, 16, 32),
224 (32, 32, 2),
225 (32, 32, 4),
226 (32, 32, 8),
227 (32, 32, 16),
228 (32, 32, 32),
229 ],
230 ),
231 (
232 (
233 gpu::ElemType::Float(gpu::FloatKind::F16),
234 gpu::ElemType::Float(gpu::FloatKind::F32),
235 gpu::ElemType::Float(gpu::FloatKind::F32),
236 ),
237 vec![
238 (16, 16, 16),
239 (16, 16, 32),
240 (32, 32, 8),
241 (32, 32, 16),
242 (32, 32, 32),
243 ],
244 ),
245 (
246 (
247 gpu::ElemType::Float(gpu::FloatKind::F16),
248 gpu::ElemType::Float(gpu::FloatKind::F16),
249 gpu::ElemType::Float(gpu::FloatKind::F32),
250 ),
251 vec![
252 (16, 16, 16),
253 (16, 16, 32),
254 (32, 32, 8),
255 (32, 32, 16),
256 (32, 32, 32),
257 ],
258 ),
259 (
260 (
261 gpu::ElemType::Float(gpu::FloatKind::F16),
262 gpu::ElemType::Float(gpu::FloatKind::F16),
263 gpu::ElemType::Float(gpu::FloatKind::F16),
264 ),
265 vec![
266 (16, 16, 16),
267 (16, 16, 32),
268 (32, 32, 8),
269 (32, 32, 16),
270 (32, 32, 32),
271 ],
272 ),
273 (
274 (
275 gpu::ElemType::Float(gpu::FloatKind::BF16),
276 gpu::ElemType::Float(gpu::FloatKind::F32),
277 gpu::ElemType::Float(gpu::FloatKind::F32),
278 ),
279 vec![
280 (16, 16, 8),
281 (16, 16, 16),
282 (16, 16, 32),
283 (32, 32, 4),
284 (32, 32, 8),
285 (32, 32, 16),
286 (32, 32, 32),
287 ],
288 ),
289 (
290 (
291 gpu::ElemType::Float(gpu::FloatKind::BF16),
292 gpu::ElemType::Float(gpu::FloatKind::BF16),
293 gpu::ElemType::Float(gpu::FloatKind::F32),
294 ),
295 vec![
296 (16, 16, 8),
297 (16, 16, 16),
298 (16, 16, 32),
299 (32, 32, 4),
300 (32, 32, 8),
301 (32, 32, 16),
302 (32, 32, 32),
303 ],
304 ),
305 (
306 (
307 gpu::ElemType::Float(gpu::FloatKind::BF16),
308 gpu::ElemType::Float(gpu::FloatKind::BF16),
309 gpu::ElemType::Float(gpu::FloatKind::BF16),
310 ),
311 vec![
312 (16, 16, 8),
313 (16, 16, 16),
314 (16, 16, 32),
315 (32, 32, 4),
316 (32, 32, 8),
317 (32, 32, 16),
318 (32, 32, 32),
319 ],
320 ),
321 ]
322 }
323 AMDArchitecture::GFX90A | AMDArchitecture::GFX94 => {
324 vec![
325 (
326 (
327 gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32), gpu::ElemType::Float(gpu::FloatKind::F32),
330 ), vec![
332 (16, 16, 4),
334 (16, 16, 8),
335 (16, 16, 16),
336 (16, 16, 32),
337 (32, 32, 2),
338 (32, 32, 4),
339 (32, 32, 8),
340 (32, 32, 16),
341 (32, 32, 32),
342 ],
343 ),
344 (
345 (
346 gpu::ElemType::Float(gpu::FloatKind::F16),
347 gpu::ElemType::Float(gpu::FloatKind::F32),
348 gpu::ElemType::Float(gpu::FloatKind::F32),
349 ),
350 vec![
351 (16, 16, 16),
352 (16, 16, 32),
353 (32, 32, 8),
354 (32, 32, 16),
355 (32, 32, 32),
356 ],
357 ),
358 (
359 (
360 gpu::ElemType::Float(gpu::FloatKind::F16),
361 gpu::ElemType::Float(gpu::FloatKind::F16),
362 gpu::ElemType::Float(gpu::FloatKind::F32),
363 ),
364 vec![
365 (16, 16, 16),
366 (16, 16, 32),
367 (32, 32, 8),
368 (32, 32, 16),
369 (32, 32, 32),
370 ],
371 ),
372 (
373 (
374 gpu::ElemType::Float(gpu::FloatKind::F16),
375 gpu::ElemType::Float(gpu::FloatKind::F16),
376 gpu::ElemType::Float(gpu::FloatKind::F16),
377 ),
378 vec![
379 (16, 16, 16),
380 (16, 16, 32),
381 (32, 32, 8),
382 (32, 32, 16),
383 (32, 32, 32),
384 ],
385 ),
386 (
387 (
388 gpu::ElemType::Float(gpu::FloatKind::BF16),
389 gpu::ElemType::Float(gpu::FloatKind::F32),
390 gpu::ElemType::Float(gpu::FloatKind::F32),
391 ),
392 vec![
393 (16, 16, 16),
394 (16, 16, 32),
395 (32, 32, 8),
396 (32, 32, 16),
397 (32, 32, 32),
398 ],
399 ),
400 (
401 (
402 gpu::ElemType::Float(gpu::FloatKind::BF16),
403 gpu::ElemType::Float(gpu::FloatKind::BF16),
404 gpu::ElemType::Float(gpu::FloatKind::F32),
405 ),
406 vec![
407 (16, 16, 16),
408 (16, 16, 32),
409 (32, 32, 8),
410 (32, 32, 16),
411 (32, 32, 32),
412 ],
413 ),
414 (
415 (
416 gpu::ElemType::Float(gpu::FloatKind::BF16),
417 gpu::ElemType::Float(gpu::FloatKind::BF16),
418 gpu::ElemType::Float(gpu::FloatKind::BF16),
419 ),
420 vec![
421 (16, 16, 16),
422 (16, 16, 32),
423 (32, 32, 8),
424 (32, 32, 16),
425 (32, 32, 32),
426 ],
427 ),
428 ]
429 }
430 AMDArchitecture::Other => vec![],
431 };
432 combinations
433 .into_iter()
434 .flat_map(|(ty, sizes)| sizes.into_iter().map(move |size| (ty, size)))
435 .map(|((i, o, c), (m, n, k))| MmaConfig {
436 a_type: i.into(),
437 b_type: o.into(),
438 cd_type: c.into(),
439 m,
440 n,
441 k,
442 })
443 .collect()
444 }
445
446 fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
447 supported_mma_combinations(arch)
448 }
449}