cubecl_cpp/hip/mma/
rocwmma_compiler.rs

1use crate::{
2    hip::{
3        HipDialect,
4        arch::AMDArchitecture,
5        mma::{compile_manual_mma, supported_mma_combinations},
6    },
7    shared::{
8        DialectWmmaCompiler, Flags, Fragment, FragmentIdent, FragmentLayout, ManualMma,
9        SupportedMmaCombinations, Variable, WmmaInstruction, wmma_api_base,
10    },
11};
12use cubecl_core::ir::{self as gpu};
13use cubecl_runtime::MmaConfig;
14
15const ROCWMMA_NAMESPACE: &str = "rocwmma";
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
18pub struct RocWmmaCompiler {}
19
20impl DialectWmmaCompiler<HipDialect<Self>> for RocWmmaCompiler {
21    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
22        f.write_str("#include <rocwmma/rocwmma.hpp>\n")
23    }
24
25    fn compile_wmma_type_definitions(
26        f: &mut std::fmt::Formatter<'_>,
27        flags: &Flags,
28    ) -> std::fmt::Result {
29        // For manual MMA, maybe add a flag for this at some point
30        if flags.elem_bf16 {
31            f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
32            f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
33        }
34        if flags.elem_f16 {
35            f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
36            f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
37        }
38        f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
39    }
40
41    fn compile_wmma_local_variables(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        Ok(())
43    }
44
45    fn compile_wmma_fragment_declaration(
46        f: &mut std::fmt::Formatter<'_>,
47        var: &crate::shared::Variable<HipDialect<Self>>,
48    ) -> std::fmt::Result {
49        wmma_api_base::compile_fragment_declaration(f, var)
50    }
51
52    fn compile_wwma_fragment_ident(
53        f: &mut std::fmt::Formatter<'_>,
54        ident: &FragmentIdent<HipDialect<Self>>,
55    ) -> std::fmt::Result {
56        wmma_api_base::compile_fragment_ident(f, ROCWMMA_NAMESPACE, ident)
57    }
58
59    fn compile_wmma_fragment_layout(
60        f: &mut std::fmt::Formatter<'_>,
61        layout: &FragmentLayout<HipDialect<Self>>,
62    ) -> std::fmt::Result {
63        wmma_api_base::compile_fragment_layout(f, ROCWMMA_NAMESPACE, layout)
64    }
65
66    fn compile_wmma_fragment(
67        f: &mut std::fmt::Formatter<'_>,
68        fragment: &Fragment<HipDialect<Self>>,
69    ) -> std::fmt::Result {
70        wmma_api_base::compile_fragment(f, ROCWMMA_NAMESPACE, fragment)
71    }
72
73    fn compile_wmma_instruction(
74        f: &mut std::fmt::Formatter<'_>,
75        instruction: &WmmaInstruction<HipDialect<Self>>,
76    ) -> std::fmt::Result {
77        wmma_api_base::compile_instruction(f, ROCWMMA_NAMESPACE, instruction)
78    }
79
80    fn compile_manual_mma(
81        f: &mut std::fmt::Formatter<'_>,
82        mma: ManualMma<HipDialect<Self>>,
83    ) -> std::fmt::Result {
84        compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
85    }
86
87    fn compile_scaled_mma(
88        _f: &mut std::fmt::Formatter<'_>,
89        _mma: ManualMma<HipDialect<Self>>,
90        _scales_a: Variable<HipDialect<Self>>,
91        _scales_b: Variable<HipDialect<Self>>,
92        _scales_factor: u32,
93    ) -> std::fmt::Result {
94        unimplemented!("Scaled MMA not supported in HIP")
95    }
96
97    fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
98        let combinations = match arch {
99            AMDArchitecture::GFX10 | AMDArchitecture::GFX11 => {
100                // For gfx11 the supported tile dimensions are always the same
101                //                                   m   n   k
102                let tdims = vec![(16, 16, 16), (16, 16, 32)];
103                let types = vec![
104                    (
105                        gpu::ElemType::Float(gpu::FloatKind::F16), // m / i
106                        gpu::ElemType::Float(gpu::FloatKind::F32), // n / o
107                        gpu::ElemType::Float(gpu::FloatKind::F32), // k / c
108                    ),
109                    (
110                        gpu::ElemType::Float(gpu::FloatKind::F16),
111                        gpu::ElemType::Float(gpu::FloatKind::F16),
112                        gpu::ElemType::Float(gpu::FloatKind::F32),
113                    ),
114                    (
115                        gpu::ElemType::Float(gpu::FloatKind::F16),
116                        gpu::ElemType::Float(gpu::FloatKind::F16),
117                        gpu::ElemType::Float(gpu::FloatKind::F16),
118                    ),
119                    (
120                        gpu::ElemType::Float(gpu::FloatKind::BF16),
121                        gpu::ElemType::Float(gpu::FloatKind::F32),
122                        gpu::ElemType::Float(gpu::FloatKind::F32),
123                    ),
124                    (
125                        gpu::ElemType::Float(gpu::FloatKind::BF16),
126                        gpu::ElemType::Float(gpu::FloatKind::BF16),
127                        gpu::ElemType::Float(gpu::FloatKind::F32),
128                    ),
129                    (
130                        gpu::ElemType::Float(gpu::FloatKind::BF16),
131                        gpu::ElemType::Float(gpu::FloatKind::BF16),
132                        gpu::ElemType::Float(gpu::FloatKind::BF16),
133                    ),
134                ];
135                types.into_iter().map(|it| (it, tdims.clone())).collect()
136            }
137            AMDArchitecture::GFX908 => {
138                vec![
139                    (
140                        (
141                            gpu::ElemType::Float(gpu::FloatKind::F32), // m / i
142                            gpu::ElemType::Float(gpu::FloatKind::F32), // n / o
143                            gpu::ElemType::Float(gpu::FloatKind::F32),
144                        ), // k / c
145                        vec![
146                            //m  n   k
147                            (16, 16, 4),
148                            (16, 16, 8),
149                            (16, 16, 16),
150                            (16, 16, 32),
151                            (32, 32, 2),
152                            (32, 32, 4),
153                            (32, 32, 8),
154                            (32, 32, 16),
155                            (32, 32, 32),
156                        ],
157                    ),
158                    (
159                        (
160                            gpu::ElemType::Float(gpu::FloatKind::F16),
161                            gpu::ElemType::Float(gpu::FloatKind::F32),
162                            gpu::ElemType::Float(gpu::FloatKind::F32),
163                        ),
164                        vec![
165                            (16, 16, 16),
166                            (16, 16, 32),
167                            (32, 32, 8),
168                            (32, 32, 16),
169                            (32, 32, 32),
170                        ],
171                    ),
172                    (
173                        (
174                            gpu::ElemType::Float(gpu::FloatKind::F16),
175                            gpu::ElemType::Float(gpu::FloatKind::F16),
176                            gpu::ElemType::Float(gpu::FloatKind::F32),
177                        ),
178                        vec![
179                            (16, 16, 16),
180                            (16, 16, 32),
181                            (32, 32, 8),
182                            (32, 32, 16),
183                            (32, 32, 32),
184                        ],
185                    ),
186                    (
187                        (
188                            gpu::ElemType::Float(gpu::FloatKind::F16),
189                            gpu::ElemType::Float(gpu::FloatKind::F16),
190                            gpu::ElemType::Float(gpu::FloatKind::F16),
191                        ),
192                        vec![
193                            (16, 16, 16),
194                            (16, 16, 32),
195                            (32, 32, 8),
196                            (32, 32, 16),
197                            (32, 32, 32),
198                        ],
199                    ),
200                    (
201                        (
202                            gpu::ElemType::Float(gpu::FloatKind::BF16),
203                            gpu::ElemType::Float(gpu::FloatKind::F32),
204                            gpu::ElemType::Float(gpu::FloatKind::F32),
205                        ),
206                        vec![
207                            (16, 16, 8),
208                            (16, 16, 16),
209                            (16, 16, 32),
210                            (32, 32, 4),
211                            (32, 32, 8),
212                            (32, 32, 16),
213                            (32, 32, 32),
214                        ],
215                    ),
216                    (
217                        (
218                            gpu::ElemType::Float(gpu::FloatKind::BF16),
219                            gpu::ElemType::Float(gpu::FloatKind::BF16),
220                            gpu::ElemType::Float(gpu::FloatKind::F32),
221                        ),
222                        vec![
223                            (16, 16, 8),
224                            (16, 16, 16),
225                            (16, 16, 32),
226                            (32, 32, 4),
227                            (32, 32, 8),
228                            (32, 32, 16),
229                            (32, 32, 32),
230                        ],
231                    ),
232                    (
233                        (
234                            gpu::ElemType::Float(gpu::FloatKind::BF16),
235                            gpu::ElemType::Float(gpu::FloatKind::BF16),
236                            gpu::ElemType::Float(gpu::FloatKind::BF16),
237                        ),
238                        vec![
239                            (16, 16, 8),
240                            (16, 16, 16),
241                            (16, 16, 32),
242                            (32, 32, 4),
243                            (32, 32, 8),
244                            (32, 32, 16),
245                            (32, 32, 32),
246                        ],
247                    ),
248                ]
249            }
250            AMDArchitecture::GFX90A | AMDArchitecture::GFX94 => {
251                vec![
252                    (
253                        (
254                            gpu::ElemType::Float(gpu::FloatKind::F32), // m / i
255                            gpu::ElemType::Float(gpu::FloatKind::F32), // n / o
256                            gpu::ElemType::Float(gpu::FloatKind::F32),
257                        ), // k / c
258                        vec![
259                            //m  n   k
260                            (16, 16, 4),
261                            (16, 16, 8),
262                            (16, 16, 16),
263                            (16, 16, 32),
264                            (32, 32, 2),
265                            (32, 32, 4),
266                            (32, 32, 8),
267                            (32, 32, 16),
268                            (32, 32, 32),
269                        ],
270                    ),
271                    (
272                        (
273                            gpu::ElemType::Float(gpu::FloatKind::F16),
274                            gpu::ElemType::Float(gpu::FloatKind::F32),
275                            gpu::ElemType::Float(gpu::FloatKind::F32),
276                        ),
277                        vec![
278                            (16, 16, 16),
279                            (16, 16, 32),
280                            (32, 32, 8),
281                            (32, 32, 16),
282                            (32, 32, 32),
283                        ],
284                    ),
285                    (
286                        (
287                            gpu::ElemType::Float(gpu::FloatKind::F16),
288                            gpu::ElemType::Float(gpu::FloatKind::F16),
289                            gpu::ElemType::Float(gpu::FloatKind::F32),
290                        ),
291                        vec![
292                            (16, 16, 16),
293                            (16, 16, 32),
294                            (32, 32, 8),
295                            (32, 32, 16),
296                            (32, 32, 32),
297                        ],
298                    ),
299                    (
300                        (
301                            gpu::ElemType::Float(gpu::FloatKind::F16),
302                            gpu::ElemType::Float(gpu::FloatKind::F16),
303                            gpu::ElemType::Float(gpu::FloatKind::F16),
304                        ),
305                        vec![
306                            (16, 16, 16),
307                            (16, 16, 32),
308                            (32, 32, 8),
309                            (32, 32, 16),
310                            (32, 32, 32),
311                        ],
312                    ),
313                    (
314                        (
315                            gpu::ElemType::Float(gpu::FloatKind::BF16),
316                            gpu::ElemType::Float(gpu::FloatKind::F32),
317                            gpu::ElemType::Float(gpu::FloatKind::F32),
318                        ),
319                        vec![
320                            (16, 16, 16),
321                            (16, 16, 32),
322                            (32, 32, 8),
323                            (32, 32, 16),
324                            (32, 32, 32),
325                        ],
326                    ),
327                    (
328                        (
329                            gpu::ElemType::Float(gpu::FloatKind::BF16),
330                            gpu::ElemType::Float(gpu::FloatKind::BF16),
331                            gpu::ElemType::Float(gpu::FloatKind::F32),
332                        ),
333                        vec![
334                            (16, 16, 16),
335                            (16, 16, 32),
336                            (32, 32, 8),
337                            (32, 32, 16),
338                            (32, 32, 32),
339                        ],
340                    ),
341                    (
342                        (
343                            gpu::ElemType::Float(gpu::FloatKind::BF16),
344                            gpu::ElemType::Float(gpu::FloatKind::BF16),
345                            gpu::ElemType::Float(gpu::FloatKind::BF16),
346                        ),
347                        vec![
348                            (16, 16, 16),
349                            (16, 16, 32),
350                            (32, 32, 8),
351                            (32, 32, 16),
352                            (32, 32, 32),
353                        ],
354                    ),
355                ]
356            }
357            AMDArchitecture::Other => vec![],
358        };
359        combinations
360            .into_iter()
361            .flat_map(|(ty, sizes)| sizes.into_iter().map(move |size| (ty, size)))
362            .map(|((i, o, c), (m, n, k))| MmaConfig {
363                a_type: i.into(),
364                b_type: o.into(),
365                cd_type: c.into(),
366                m,
367                n,
368                k,
369            })
370            .collect()
371    }
372
373    fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
374        supported_mma_combinations(arch)
375    }
376}