cubecl_cpp/hip/mma/
rocwmma_compiler.rs

1use crate::{
2    hip::{
3        HipDialect,
4        arch::AMDArchitecture,
5        mma::{compile_manual_mma, supported_mma_combinations},
6    },
7    shared::{
8        DialectWmmaCompiler, Flags, Fragment, FragmentIdent, FragmentLayout, ManualMma,
9        SupportedMmaCombinations, Variable, WmmaInstruction, wmma_api_base,
10    },
11};
12use cubecl_core::ir::{self as gpu, features::MmaConfig};
13
14const ROCWMMA_NAMESPACE: &str = "rocwmma";
15
16#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
17pub struct RocWmmaCompiler {}
18
19impl DialectWmmaCompiler<HipDialect<Self>> for RocWmmaCompiler {
20    fn compile_wmma_includes(
21        f: &mut std::fmt::Formatter<'_>,
22        _flags: &Flags<HipDialect<Self>>,
23    ) -> std::fmt::Result {
24        f.write_str("#include <rocwmma/rocwmma.hpp>\n")
25    }
26
27    fn compile_wmma_type_definitions(
28        f: &mut std::fmt::Formatter<'_>,
29        flags: &Flags<HipDialect<Self>>,
30    ) -> std::fmt::Result {
31        // For manual MMA, maybe add a flag for this at some point
32        if flags.elem_bf16 {
33            f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
34            f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
35        }
36        if flags.elem_f16 {
37            f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
38            f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
39        }
40        f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
41    }
42
43    fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        Ok(())
45    }
46
47    fn compile_wmma_fragment_declaration(
48        f: &mut std::fmt::Formatter<'_>,
49        var: &crate::shared::Variable<HipDialect<Self>>,
50    ) -> std::fmt::Result {
51        wmma_api_base::compile_fragment_declaration(f, var)
52    }
53
54    fn compile_wwma_fragment_ident(
55        f: &mut std::fmt::Formatter<'_>,
56        ident: &FragmentIdent<HipDialect<Self>>,
57    ) -> std::fmt::Result {
58        wmma_api_base::compile_fragment_ident(f, ROCWMMA_NAMESPACE, ident)
59    }
60
61    fn compile_wmma_fragment_layout(
62        f: &mut std::fmt::Formatter<'_>,
63        layout: &FragmentLayout<HipDialect<Self>>,
64    ) -> std::fmt::Result {
65        wmma_api_base::compile_fragment_layout(f, ROCWMMA_NAMESPACE, layout)
66    }
67
68    fn compile_wmma_fragment(
69        f: &mut std::fmt::Formatter<'_>,
70        fragment: &Fragment<HipDialect<Self>>,
71    ) -> std::fmt::Result {
72        wmma_api_base::compile_fragment(f, ROCWMMA_NAMESPACE, fragment)
73    }
74
75    fn compile_wmma_instruction(
76        f: &mut std::fmt::Formatter<'_>,
77        instruction: &WmmaInstruction<HipDialect<Self>>,
78    ) -> std::fmt::Result {
79        wmma_api_base::compile_instruction(f, ROCWMMA_NAMESPACE, instruction)
80    }
81
82    fn compile_manual_mma(
83        f: &mut std::fmt::Formatter<'_>,
84        mma: ManualMma<HipDialect<Self>>,
85    ) -> std::fmt::Result {
86        compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
87    }
88
89    fn compile_scaled_mma(
90        f: &mut std::fmt::Formatter<'_>,
91        _mma: ManualMma<HipDialect<Self>>,
92        _scales_a: Variable<HipDialect<Self>>,
93        _scales_b: Variable<HipDialect<Self>>,
94        _scales_factor: u32,
95    ) -> std::fmt::Result {
96        f.write_str("#error Scaled MMA not supported on HIP\n")
97    }
98
99    fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
100        let combinations = match arch {
101            AMDArchitecture::GFX12 => {
102                // Group types by their tile dimensions for readability
103                let tdims_16_16_32 = vec![(16, 16, 32)];
104                let types_16_16_32 = vec![
105                    (
106                        gpu::ElemType::Float(gpu::FloatKind::E5M2), // bfloat8_t / bf8
107                        gpu::ElemType::Float(gpu::FloatKind::F32),
108                        gpu::ElemType::Float(gpu::FloatKind::F32),
109                    ),
110                    (
111                        gpu::ElemType::Float(gpu::FloatKind::E4M3), // float8_t / f8
112                        gpu::ElemType::Float(gpu::FloatKind::F32),
113                        gpu::ElemType::Float(gpu::FloatKind::F32),
114                    ),
115                ];
116
117                let tdims_16_16_16 = vec![(16, 16, 16)];
118                let types_16_16_16 = vec![
119                    (
120                        gpu::ElemType::Int(gpu::IntKind::I8),
121                        gpu::ElemType::Int(gpu::IntKind::I32),
122                        gpu::ElemType::Int(gpu::IntKind::I32),
123                    ),
124                    (
125                        gpu::ElemType::Int(gpu::IntKind::I8),
126                        gpu::ElemType::Int(gpu::IntKind::I8),
127                        gpu::ElemType::Int(gpu::IntKind::I32),
128                    ),
129                    (
130                        gpu::ElemType::Float(gpu::FloatKind::F16),
131                        gpu::ElemType::Float(gpu::FloatKind::F32),
132                        gpu::ElemType::Float(gpu::FloatKind::F32),
133                    ),
134                    (
135                        gpu::ElemType::Float(gpu::FloatKind::F16),
136                        gpu::ElemType::Float(gpu::FloatKind::F16),
137                        gpu::ElemType::Float(gpu::FloatKind::F32),
138                    ),
139                    (
140                        gpu::ElemType::Float(gpu::FloatKind::F16),
141                        gpu::ElemType::Float(gpu::FloatKind::F16),
142                        gpu::ElemType::Float(gpu::FloatKind::F16),
143                    ),
144                    (
145                        gpu::ElemType::Float(gpu::FloatKind::BF16),
146                        gpu::ElemType::Float(gpu::FloatKind::F32),
147                        gpu::ElemType::Float(gpu::FloatKind::F32),
148                    ),
149                    (
150                        gpu::ElemType::Float(gpu::FloatKind::BF16),
151                        gpu::ElemType::Float(gpu::FloatKind::BF16),
152                        gpu::ElemType::Float(gpu::FloatKind::F32),
153                    ),
154                    (
155                        gpu::ElemType::Float(gpu::FloatKind::BF16),
156                        gpu::ElemType::Float(gpu::FloatKind::BF16),
157                        gpu::ElemType::Float(gpu::FloatKind::BF16),
158                    ),
159                ];
160
161                // Combine all type-dimension pairs
162                types_16_16_32
163                    .into_iter()
164                    .map(|it| (it, tdims_16_16_32.clone()))
165                    .chain(
166                        types_16_16_16
167                            .into_iter()
168                            .map(|it| (it, tdims_16_16_16.clone())),
169                    )
170                    .collect()
171            }
172            AMDArchitecture::GFX10 | AMDArchitecture::GFX11 => {
173                // For gfx11 the supported tile dimensions are always the same
174                //                                   m   n   k
175                let tdims = vec![(16, 16, 16), (16, 16, 32)];
176                let types = vec![
177                    (
178                        gpu::ElemType::Float(gpu::FloatKind::F16), // m / i
179                        gpu::ElemType::Float(gpu::FloatKind::F32), // n / o
180                        gpu::ElemType::Float(gpu::FloatKind::F32), // k / c
181                    ),
182                    (
183                        gpu::ElemType::Float(gpu::FloatKind::F16),
184                        gpu::ElemType::Float(gpu::FloatKind::F16),
185                        gpu::ElemType::Float(gpu::FloatKind::F32),
186                    ),
187                    (
188                        gpu::ElemType::Float(gpu::FloatKind::F16),
189                        gpu::ElemType::Float(gpu::FloatKind::F16),
190                        gpu::ElemType::Float(gpu::FloatKind::F16),
191                    ),
192                    (
193                        gpu::ElemType::Float(gpu::FloatKind::BF16),
194                        gpu::ElemType::Float(gpu::FloatKind::F32),
195                        gpu::ElemType::Float(gpu::FloatKind::F32),
196                    ),
197                    (
198                        gpu::ElemType::Float(gpu::FloatKind::BF16),
199                        gpu::ElemType::Float(gpu::FloatKind::BF16),
200                        gpu::ElemType::Float(gpu::FloatKind::F32),
201                    ),
202                    (
203                        gpu::ElemType::Float(gpu::FloatKind::BF16),
204                        gpu::ElemType::Float(gpu::FloatKind::BF16),
205                        gpu::ElemType::Float(gpu::FloatKind::BF16),
206                    ),
207                ];
208                types.into_iter().map(|it| (it, tdims.clone())).collect()
209            }
210            AMDArchitecture::GFX908 => {
211                vec![
212                    (
213                        (
214                            gpu::ElemType::Float(gpu::FloatKind::F32), // m / i
215                            gpu::ElemType::Float(gpu::FloatKind::F32), // n / o
216                            gpu::ElemType::Float(gpu::FloatKind::F32),
217                        ), // k / c
218                        vec![
219                            //m  n   k
220                            (16, 16, 4),
221                            (16, 16, 8),
222                            (16, 16, 16),
223                            (16, 16, 32),
224                            (32, 32, 2),
225                            (32, 32, 4),
226                            (32, 32, 8),
227                            (32, 32, 16),
228                            (32, 32, 32),
229                        ],
230                    ),
231                    (
232                        (
233                            gpu::ElemType::Float(gpu::FloatKind::F16),
234                            gpu::ElemType::Float(gpu::FloatKind::F32),
235                            gpu::ElemType::Float(gpu::FloatKind::F32),
236                        ),
237                        vec![
238                            (16, 16, 16),
239                            (16, 16, 32),
240                            (32, 32, 8),
241                            (32, 32, 16),
242                            (32, 32, 32),
243                        ],
244                    ),
245                    (
246                        (
247                            gpu::ElemType::Float(gpu::FloatKind::F16),
248                            gpu::ElemType::Float(gpu::FloatKind::F16),
249                            gpu::ElemType::Float(gpu::FloatKind::F32),
250                        ),
251                        vec![
252                            (16, 16, 16),
253                            (16, 16, 32),
254                            (32, 32, 8),
255                            (32, 32, 16),
256                            (32, 32, 32),
257                        ],
258                    ),
259                    (
260                        (
261                            gpu::ElemType::Float(gpu::FloatKind::F16),
262                            gpu::ElemType::Float(gpu::FloatKind::F16),
263                            gpu::ElemType::Float(gpu::FloatKind::F16),
264                        ),
265                        vec![
266                            (16, 16, 16),
267                            (16, 16, 32),
268                            (32, 32, 8),
269                            (32, 32, 16),
270                            (32, 32, 32),
271                        ],
272                    ),
273                    (
274                        (
275                            gpu::ElemType::Float(gpu::FloatKind::BF16),
276                            gpu::ElemType::Float(gpu::FloatKind::F32),
277                            gpu::ElemType::Float(gpu::FloatKind::F32),
278                        ),
279                        vec![
280                            (16, 16, 8),
281                            (16, 16, 16),
282                            (16, 16, 32),
283                            (32, 32, 4),
284                            (32, 32, 8),
285                            (32, 32, 16),
286                            (32, 32, 32),
287                        ],
288                    ),
289                    (
290                        (
291                            gpu::ElemType::Float(gpu::FloatKind::BF16),
292                            gpu::ElemType::Float(gpu::FloatKind::BF16),
293                            gpu::ElemType::Float(gpu::FloatKind::F32),
294                        ),
295                        vec![
296                            (16, 16, 8),
297                            (16, 16, 16),
298                            (16, 16, 32),
299                            (32, 32, 4),
300                            (32, 32, 8),
301                            (32, 32, 16),
302                            (32, 32, 32),
303                        ],
304                    ),
305                    (
306                        (
307                            gpu::ElemType::Float(gpu::FloatKind::BF16),
308                            gpu::ElemType::Float(gpu::FloatKind::BF16),
309                            gpu::ElemType::Float(gpu::FloatKind::BF16),
310                        ),
311                        vec![
312                            (16, 16, 8),
313                            (16, 16, 16),
314                            (16, 16, 32),
315                            (32, 32, 4),
316                            (32, 32, 8),
317                            (32, 32, 16),
318                            (32, 32, 32),
319                        ],
320                    ),
321                ]
322            }
323            AMDArchitecture::GFX90A | AMDArchitecture::GFX94 => {
324                vec![
325                    (
326                        (
327                            gpu::ElemType::Float(gpu::FloatKind::F32), // m / i
328                            gpu::ElemType::Float(gpu::FloatKind::F32), // n / o
329                            gpu::ElemType::Float(gpu::FloatKind::F32),
330                        ), // k / c
331                        vec![
332                            //m  n   k
333                            (16, 16, 4),
334                            (16, 16, 8),
335                            (16, 16, 16),
336                            (16, 16, 32),
337                            (32, 32, 2),
338                            (32, 32, 4),
339                            (32, 32, 8),
340                            (32, 32, 16),
341                            (32, 32, 32),
342                        ],
343                    ),
344                    (
345                        (
346                            gpu::ElemType::Float(gpu::FloatKind::F16),
347                            gpu::ElemType::Float(gpu::FloatKind::F32),
348                            gpu::ElemType::Float(gpu::FloatKind::F32),
349                        ),
350                        vec![
351                            (16, 16, 16),
352                            (16, 16, 32),
353                            (32, 32, 8),
354                            (32, 32, 16),
355                            (32, 32, 32),
356                        ],
357                    ),
358                    (
359                        (
360                            gpu::ElemType::Float(gpu::FloatKind::F16),
361                            gpu::ElemType::Float(gpu::FloatKind::F16),
362                            gpu::ElemType::Float(gpu::FloatKind::F32),
363                        ),
364                        vec![
365                            (16, 16, 16),
366                            (16, 16, 32),
367                            (32, 32, 8),
368                            (32, 32, 16),
369                            (32, 32, 32),
370                        ],
371                    ),
372                    (
373                        (
374                            gpu::ElemType::Float(gpu::FloatKind::F16),
375                            gpu::ElemType::Float(gpu::FloatKind::F16),
376                            gpu::ElemType::Float(gpu::FloatKind::F16),
377                        ),
378                        vec![
379                            (16, 16, 16),
380                            (16, 16, 32),
381                            (32, 32, 8),
382                            (32, 32, 16),
383                            (32, 32, 32),
384                        ],
385                    ),
386                    (
387                        (
388                            gpu::ElemType::Float(gpu::FloatKind::BF16),
389                            gpu::ElemType::Float(gpu::FloatKind::F32),
390                            gpu::ElemType::Float(gpu::FloatKind::F32),
391                        ),
392                        vec![
393                            (16, 16, 16),
394                            (16, 16, 32),
395                            (32, 32, 8),
396                            (32, 32, 16),
397                            (32, 32, 32),
398                        ],
399                    ),
400                    (
401                        (
402                            gpu::ElemType::Float(gpu::FloatKind::BF16),
403                            gpu::ElemType::Float(gpu::FloatKind::BF16),
404                            gpu::ElemType::Float(gpu::FloatKind::F32),
405                        ),
406                        vec![
407                            (16, 16, 16),
408                            (16, 16, 32),
409                            (32, 32, 8),
410                            (32, 32, 16),
411                            (32, 32, 32),
412                        ],
413                    ),
414                    (
415                        (
416                            gpu::ElemType::Float(gpu::FloatKind::BF16),
417                            gpu::ElemType::Float(gpu::FloatKind::BF16),
418                            gpu::ElemType::Float(gpu::FloatKind::BF16),
419                        ),
420                        vec![
421                            (16, 16, 16),
422                            (16, 16, 32),
423                            (32, 32, 8),
424                            (32, 32, 16),
425                            (32, 32, 32),
426                        ],
427                    ),
428                ]
429            }
430            AMDArchitecture::Other => vec![],
431        };
432        combinations
433            .into_iter()
434            .flat_map(|(ty, sizes)| sizes.into_iter().map(move |size| (ty, size)))
435            .map(|((i, o, c), (m, n, k))| MmaConfig {
436                a_type: i.into(),
437                b_type: o.into(),
438                cd_type: c.into(),
439                m,
440                n,
441                k,
442            })
443            .collect()
444    }
445
446    fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
447        supported_mma_combinations(arch)
448    }
449}