cubecl_cpp/hip/mma/
wmma_intrinsics_compiler.rs

1use std::fmt::Formatter;
2
3use crate::{
4    Dialect,
5    hip::{HipDialect, arch::AMDArchitecture},
6    shared::{
7        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8        FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
9        Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
10        wmma_api_base,
11    },
12};
13use cubecl_core::ir::{self as gpu, Matrix, MatrixIdent, features::MmaConfig};
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
16pub struct WmmaIntrinsicCompiler {}
17
18#[derive(new, Debug, Clone, PartialEq)]
19pub struct WmmaFill<D: Dialect> {
20    frag: Fragment<D>,
21}
22
23#[derive(new, Debug, Clone, PartialEq)]
24pub struct WmmaLoad<D: Dialect> {
25    frag: Fragment<D>,
26    layout: Option<FragmentLayout<D>>,
27}
28
29#[derive(new, Debug, Clone, PartialEq)]
30pub struct WmmaStore<D: Dialect> {
31    frag: Fragment<D>,
32    layout: FragmentLayout<D>,
33}
34
35#[derive(new, Debug, Clone, PartialEq)]
36pub struct WmmaExecute<D: Dialect> {
37    frag_a: Fragment<D>,
38    frag_b: Fragment<D>,
39    frag_c: Fragment<D>,
40    frag_d: Fragment<D>,
41}
42
43#[derive(new, Debug, Clone, PartialEq)]
44pub struct WmmaCast<D: Dialect> {
45    frag_input: Fragment<D>,
46    frag_output: Fragment<D>,
47}
48
49impl<D: Dialect> WmmaFill<D> {
50    pub fn fn_name(&self) -> String {
51        let layout = frag_layout_str(&self.frag.layout);
52        let ident = frag_ident_str(&self.frag.ident);
53        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
54        let elem = self.frag.elem;
55
56        format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
57    }
58
59    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60        let elem = self.frag.elem;
61        let frag = self.frag;
62        let name = self.fn_name();
63
64        write!(
65            f,
66            "
67// Fill the fragment.
68__device__ void {name}({frag}& frag, {elem} value) {{
69    #pragma unroll
70    for (uint i = 0; i < 8; ++i) {{
71      frag[i] = value;
72    }}
73}}
74        "
75        )
76    }
77}
78
79impl<D: Dialect> WmmaLoad<D> {
80    pub fn fn_name(&self) -> String {
81        let layout_frag = frag_layout_str(&self.frag.layout);
82        let layout = frag_layout_str(&self.layout);
83        let ident = frag_ident_str(&self.frag.ident);
84        let elem = self.frag.elem;
85        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
86
87        format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
88    }
89
90    /// Matrix A must be in column major layout (so fragments correspond to a row)
91    /// Matrices B, C and D must be in row major layout (so fragments correspond to a column)
92    ///
93    /// Each lane is a thread so each column get 8 VGPRs used to store fragments
94    /// Here is the layout for C and D matrices and how they map to registers
95    ///
96    /// Lane index   0      1      2      3      ...     13     14     15     ...     17     18     ...     30     31
97    /// --------------------------------------------------------------------------------------------------------------
98    /// VGPR0      | 1,1  | 1,2  | 1,3  | 1,4  | ...  | 1,13 | 1,14 | 1,15 | ...  | 2,1  | 2,2  | ...  | 2,15 | 2,16 |
99    /// --------------------------------------------------------------------------------------------------------------
100    /// VGPR1      | 3,1  | 3,2  | 3,3  | 3,4  | ...  | 3,13 | 3,14 | 3,15 | ...  | 4,1  | 4,2  | ...  | 4,15 | 4,16 |
101    /// --------------------------------------------------------------------------------------------------------------
102    /// VGPR2      | 5,1  | 5,2  | 5,3  | 5,4  | ...  | 5,13 | 5,14 | 5,15 | ...  | 6,1  | 6,2  | ...  | 6,15 | 6,16 |
103    /// --------------------------------------------------------------------------------------------------------------
104    /// VGPR3      | 7,1  | 7,2  | 7,3  | 7,4  | ...  | 7,13 | 7,14 | 7,15 | ...  | 8,1  | 8,2  | ...  | 8,15 | 8,16 |
105    /// --------------------------------------------------------------------------------------------------------------
106    /// VGPR4      | 9,1  | 9,2  | 9,3  | 9,4  | ...  | 9,13 | 9,14 | 9,15 | ...  | 10,1 | 10,2 | ...  | 10,15| 10,16|
107    /// --------------------------------------------------------------------------------------------------------------
108    /// VGPR5      | 11,1 | 11,2 | 11,3 | 11,4 | ...  | 11,13| 11,14| 11,15| ...  | 12,1 | 12,2 | ...  | 12,15| 12,16|
109    /// --------------------------------------------------------------------------------------------------------------
110    /// VGPR6      | 13,1 | 13,2 | 13,3 | 13,4 | ...  | 13,13| 13,14| 13,15| ...  | 14,1 | 14,2 | ...  | 14,15| 14,16|
111    /// --------------------------------------------------------------------------------------------------------------
112    /// VGPR7      | 15,1 | 15,2 | 15,3 | 15,4 | ...  | 15,13| 15,14| 15,15| ...  | 16,1 | 16,2 | ...  | 16,15| 16,16|
113    /// --------------------------------------------------------------------------------------------------------------
114    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
115        let elem = self.frag.elem;
116        let frag = self.frag;
117        let name = self.fn_name();
118
119        let (index_body, length, step) = match frag.ident {
120            FragmentIdent::A | FragmentIdent::B => {
121                let length = 16;
122                let step = 1;
123                // fragment a and b are always in half precision and they don't require special attention
124                // to how they are stored in memory as matrix A and B are also in half precision
125                let index = if (frag.ident == FragmentIdent::A
126                    && frag.layout.unwrap() == FragmentLayout::ColMajor)
127                    || (frag.ident == FragmentIdent::B
128                        && frag.layout.unwrap() == FragmentLayout::RowMajor)
129                {
130                    "i * stride + wmmaLane".to_string()
131                } else {
132                    "i + wmmaLane * stride".to_string()
133                };
134                (index, length, step)
135            }
136            FragmentIdent::Accumulator => {
137                let length = 8;
138                let step = get_output_accumulator_index_step(&elem, &frag);
139                let index = match self.layout {
140                    Some(FragmentLayout::ColMajor) => {
141                        "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
142                    }
143                    Some(FragmentLayout::RowMajor) => {
144                        "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
145                    }
146                    _ => panic!(
147                        "cannot load data to an accumulator without knowing the layout of the data"
148                    ),
149                };
150                (index, length, step)
151            }
152            other => panic!("unknown matrix identifier {other}"),
153        };
154
155        write!(
156            f,
157            "
158// Load the fragment.
159__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
160    {WMMA_LANE_DEF}
161
162    #pragma unroll
163    for (uint i = 0; i < {length}; ++i) {{
164      const uint index = {index_body};
165      frag[i * {step}] = value_ptr[index];
166    }}
167}}
168        "
169        )
170    }
171}
172
173impl<D: Dialect> WmmaStore<D> {
174    pub fn fn_name(&self) -> String {
175        let layout_frag = frag_layout_str(&self.frag.layout);
176        let layout_option = Some(self.layout);
177        let layout = frag_layout_str(&layout_option);
178        let ident = frag_ident_str(&self.frag.ident);
179        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
180        let elem = self.frag.elem;
181
182        format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
183    }
184
185    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
186        let elem = self.frag.elem;
187        let frag = self.frag;
188        let name = self.fn_name();
189        // frag holds a result column where threads 0-15 of the wavefront have the even rows and threads 16-31 the odd rows
190        // moreover, since we use OPSEL to false in the Execute instruction in f16 output format, the output elements are
191        // stored in even indexes (0, 2, 4, ...) (low 16-bits of the VGPR) in frag
192        let frag_idx = match elem {
193            Elem::F16 | Elem::BF16 => "elemIdx * 2",
194            Elem::F32 => "elemIdx",
195            other => {
196                panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
197            }
198        };
199        // FragmentLayout here represents the desired layout of the matrix C
200        let output_idx = match self.layout {
201            FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
202            FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
203            FragmentLayout::_Dialect(_) => String::new(),
204        };
205
206        write!(
207            f,
208            "
209// Store the fragment.
210__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
211    {WMMA_LANE_DEF}
212
213    #pragma unroll
214    for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
215      const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
216      output_ptr[{output_idx}] = frag[{frag_idx}];
217    }}
218}}
219        "
220        )
221    }
222}
223
224impl<D: Dialect> WmmaExecute<D> {
225    pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
226        let frag_a = Fragment {
227            ident: FragmentIdent::A,
228            m: shape.m,
229            n: shape.n,
230            k: shape.k,
231            elem: ab_elem,
232            layout: Some(FragmentLayout::ColMajor),
233        };
234        let frag_b = Fragment {
235            ident: FragmentIdent::B,
236            layout: Some(FragmentLayout::RowMajor),
237            ..frag_a
238        };
239        let frag_cd = Fragment {
240            ident: FragmentIdent::Accumulator,
241            elem: cd_elem,
242            ..frag_b
243        };
244        WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
245    }
246
247    pub fn fn_name(&self) -> String {
248        format!(
249            "wmma_execute_16x16x16_{}_{}",
250            self.frag_a.elem, self.frag_c.elem
251        )
252    }
253
254    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
255        let name = self.fn_name();
256        let ab_format = match self.frag_a.elem {
257            Elem::F32 => "f32",
258            Elem::BF16 => "bf16",
259            Elem::F16 => "f16",
260            _ => panic!(),
261        };
262        let (cd_format, opsel) = match self.frag_c.elem {
263            Elem::F32 => ("f32", ""),
264            Elem::BF16 => ("bf16", ", false"),
265            Elem::F16 => ("f16", ", false"),
266            _ => panic!(),
267        };
268        let warp_size = 32;
269        write!(
270            f,
271            "
272// Execute wmma.
273__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
274    frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
275}}
276        ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
277        )
278    }
279}
280
281impl<D: Dialect> WmmaCast<D> {
282    pub fn fn_name(&self) -> String {
283        let layout = frag_layout_str(&self.frag_input.layout);
284        let ident = frag_ident_str(&self.frag_input.ident);
285        let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
286        let elem = self.frag_input.elem;
287        let elem_out = self.frag_output.elem;
288
289        format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
290    }
291
292    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293        let input = self.frag_input;
294        let output = self.frag_output;
295        let name = self.fn_name();
296        let step = match output.ident {
297            FragmentIdent::Accumulator => {
298                get_output_accumulator_index_step(&self.frag_input.elem, &output)
299            }
300            _ => 1,
301        };
302
303        write!(
304            f,
305            "
306// Cast the fragment.
307__device__ void {name}({input}& input, {output}& output) {{
308    #pragma unroll
309    for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
310      output[elemIdx * {step}] = input[elemIdx];
311    }}
312}}
313        "
314        )
315    }
316}
317
318impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
319    fn compile_wmma_type_definitions(
320        f: &mut std::fmt::Formatter<'_>,
321        flags: &Flags<HipDialect<Self>>,
322    ) -> std::fmt::Result {
323        if flags.elem_bf16 {
324            f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
325            f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
326        }
327        if flags.elem_f16 {
328            f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
329            f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
330        }
331        f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
332    }
333
334    fn compile_wmma_fragment_declaration(
335        f: &mut std::fmt::Formatter<'_>,
336        var: &crate::shared::Variable<HipDialect<Self>>,
337    ) -> std::fmt::Result {
338        wmma_api_base::compile_fragment_declaration(f, var)
339    }
340
341    fn compile_wmma_fragment(
342        f: &mut std::fmt::Formatter<'_>,
343        fragment: &Fragment<HipDialect<Self>>,
344    ) -> std::fmt::Result {
345        match fragment.ident {
346            FragmentIdent::A | FragmentIdent::B => match fragment.elem {
347                Elem::F16 => write!(f, "half16_t"),
348                Elem::BF16 => write!(f, "bhalf16_t"),
349                other => panic!("unsupported type {other} for {fragment}"),
350            },
351            FragmentIdent::Accumulator => match fragment.elem {
352                Elem::F16 => write!(f, "half16_t"),
353                Elem::BF16 => write!(f, "bhalf16_t"),
354                Elem::F32 => write!(f, "float8_t"),
355                other => panic!("unsupported type {other} for {fragment}"),
356            },
357            FragmentIdent::_Dialect(_) => Ok(()),
358        }
359    }
360
361    fn compile_wmma_instruction(
362        f: &mut std::fmt::Formatter<'_>,
363        instruction: &WmmaInstruction<HipDialect<Self>>,
364    ) -> std::fmt::Result {
365        match instruction {
366            WmmaInstruction::Fill { frag, value } => {
367                let extension = WmmaFill::new(match frag {
368                    Variable::WmmaFragment { frag, .. } => *frag,
369                    _ => panic!(),
370                });
371                let name = extension.fn_name();
372                writeln!(f, "{name}({frag}, {value});")
373            }
374            WmmaInstruction::Load {
375                frag,
376                value,
377                layout,
378                offset,
379                stride,
380            } => {
381                let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
382                let name = extension.fn_name();
383                let value_ptr = frag_as_ptr(f, value, offset);
384                writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
385            }
386            WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
387                f.write_str("#error LdMatrix & StMatrix are not supported on HIP\n")
388            }
389            WmmaInstruction::Execute {
390                frag_a,
391                frag_b,
392                frag_c,
393                frag_d,
394                warp_size,
395            } => {
396                if *warp_size != 32 {
397                    f.write_str(
398                        "#error Only warp size of 32 supported for Wmma::Execute on HIP\n",
399                    )?;
400                }
401
402                let extension = WmmaExecute::new(
403                    variable_to_frag(frag_a),
404                    variable_to_frag(frag_b),
405                    variable_to_frag(frag_c),
406                    variable_to_frag(frag_d),
407                );
408                let name = extension.fn_name();
409                writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
410            }
411            WmmaInstruction::ExecuteManual {
412                shape,
413                frag_a,
414                frag_b,
415                frag_c,
416                frag_d,
417            } => {
418                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
419            }
420            WmmaInstruction::ExecuteScaled {
421                shape,
422                frag_a,
423                frag_b,
424                frag_c,
425                frag_d,
426                scales_a,
427                scales_b,
428                scales_factor,
429            } => Self::compile_scaled_mma(
430                f,
431                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
432                *scales_a,
433                *scales_b,
434                *scales_factor,
435            ),
436            WmmaInstruction::Store {
437                output,
438                frag,
439                layout,
440                offset,
441                stride,
442            } => {
443                let extension = WmmaStore::new(variable_to_frag(frag), *layout);
444                let name = extension.fn_name();
445                let output_ptr = frag_as_ptr(f, output, offset);
446                writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
447            }
448            WmmaInstruction::Cast { input, output } => {
449                let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
450                let name = extension.fn_name();
451                writeln!(f, "{name}({input}, {output});")
452            }
453        }
454    }
455
456    fn compile_manual_mma(
457        f: &mut std::fmt::Formatter<'_>,
458        mma: ManualMma<HipDialect<Self>>,
459    ) -> std::fmt::Result {
460        compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
461    }
462
463    fn compile_scaled_mma(
464        f: &mut std::fmt::Formatter<'_>,
465        _mma: ManualMma<HipDialect<Self>>,
466        _scales_a: Variable<HipDialect<Self>>,
467        _scales_b: Variable<HipDialect<Self>>,
468        _scales_factor: u32,
469    ) -> std::fmt::Result {
470        f.write_str("#error scaled mma not supported in HIP\n")
471    }
472
473    fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
474        // Reference: https://gpuopen.com/learn/wmma_on_rdna3/
475        let mut result: SupportedMmaCombinations = vec![];
476        if arch.is_wmma_capable() {
477            // Types fully supported.
478            let types = vec![
479                (
480                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
481                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
482                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
483                ),
484                (
485                    gpu::ElemType::Float(gpu::FloatKind::F16),
486                    gpu::ElemType::Float(gpu::FloatKind::F16),
487                    gpu::ElemType::Float(gpu::FloatKind::F32),
488                ),
489                (
490                    gpu::ElemType::Float(gpu::FloatKind::BF16),
491                    gpu::ElemType::Float(gpu::FloatKind::BF16),
492                    gpu::ElemType::Float(gpu::FloatKind::F32),
493                ),
494            ];
495            let combinations: SupportedMmaCombinations = types
496                .into_iter()
497                .map(|(a, b, c)| MmaConfig {
498                    a_type: a.into(),
499                    b_type: b.into(),
500                    cd_type: c.into(),
501                    m: 16,
502                    n: 16,
503                    k: 16,
504                })
505                .collect();
506            result.extend(combinations);
507        }
508        result
509    }
510
511    fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
512        supported_mma_combinations(arch)
513    }
514}
515
516fn get_output_accumulator_index_step<D: Dialect>(
517    input_elem: &Elem<D>,
518    output: &Fragment<D>,
519) -> u32 {
520    // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either:
521    // - a vector of 8 float
522    // - a vector of 16 half
523    // Depending on the precision used for the input, the whole 32 bits per register will be used or
524    // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means
525    // that we only assign values to even indexes of the accumulator (0, 2, 4, ...)
526
527    assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
528
529    match input_elem {
530        Elem::F16 | Elem::BF16 | Elem::F32 => {
531            match output.elem {
532                // loading into accumulator of 16 half precision
533                Elem::F16 | Elem::BF16 => 2,
534                // loading into accumulator of 8 full precision
535                Elem::F32 => 1,
536                other => panic!("unsupported format {other} for {output}"),
537            }
538        }
539        other => panic!("unsupported format {other} for {input_elem}"),
540    }
541}
542
543pub(super) fn compile_manual_mma<D: Dialect>(
544    f: &mut std::fmt::Formatter<'_>,
545    shape: MmaShape<D>,
546    frag_a: &Variable<D>,
547    frag_b: &Variable<D>,
548    frag_c: &Variable<D>,
549    frag_d: &Variable<D>,
550) -> std::fmt::Result {
551    let extension = WmmaExecute::from_manual(shape, frag_a.elem(), frag_c.elem());
552
553    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
554
555    let frag_cd_step = 4usize.div_ceil(frag_c.elem().size());
556    let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
557
558    // Need to reconstruct the fragments from an array of lines to a single vector type.
559    // This requires double indexing over both the array index and the line index.
560    // Will generate something like
561    // `float8_t {arr[0].i_0, arr[0].i_1, arr[1].i_0, ...}`
562    let frag = |var: &Variable<D>, len: usize| {
563        let vec = var.item().vectorization;
564        let frag: Vec<_> = if vec > 1 {
565            (0..len)
566                .map(|i| format!("{var}[{}].i_{}", i / vec, i % vec))
567                .collect()
568        } else {
569            (0..len).map(|i| format!("{var}[{}]", i)).collect()
570        };
571        frag.join(", ")
572    };
573
574    let frag_a = frag(frag_a, 16);
575    let frag_b = frag(frag_b, 16);
576    // C matrix needs to be padded for f16, because it only uses the low bytes. The simplest way is
577    // to just replicate the same f16 in both halves of the register.
578    let frag_c = {
579        let vec = frag_c.item().vectorization;
580        let frag: Vec<_> = if vec > 1 {
581            (0..cd_elems as usize)
582                .flat_map(|i| {
583                    (0..frag_cd_step).map(move |_| format!("{frag_c}[{}].i_{}", i / vec, i % vec))
584                })
585                .collect()
586        } else {
587            (0..cd_elems as usize)
588                .flat_map(|i| (0..frag_cd_step).map(move |_| format!("{frag_c}[{}]", i)))
589                .collect()
590        };
591        frag.join(", ")
592    };
593
594    // Should optimize out
595    let name = extension.fn_name();
596
597    // Item is irrelevant
598    writeln!(f, "{} {frag_d_tmp} = {{}};", extension.frag_d)?;
599
600    writeln!(
601        f,
602        "{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
603        extension.frag_a, extension.frag_b, extension.frag_c
604    )?;
605
606    for i in 0..cd_elems as usize {
607        let vec = frag_d.item().vectorization;
608        if vec > 1 {
609            writeln!(
610                f,
611                "{frag_d}[{}].i_{} = {frag_d_tmp}[{i} * {frag_cd_step}];",
612                i / vec,
613                i % vec
614            )?;
615        } else {
616            writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i} * {frag_cd_step}];")?;
617        }
618    }
619
620    Ok(())
621}
622
623pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
624    // Correctness is wrong.
625    const ENABLED: bool = true;
626
627    if !ENABLED {
628        return Vec::new();
629    }
630
631    // Reference: https://gpuopen.com/learn/wmma_on_rdna3/
632    // Feel free to add more if additional intrinsics are supported for execute
633    let mut result: SupportedMmaCombinations = vec![];
634    if arch.is_wmma_capable() {
635        // Types fully supported.
636        let types = vec![
637            (
638                gpu::ElemType::Float(gpu::FloatKind::F16),
639                gpu::ElemType::Float(gpu::FloatKind::F32),
640            ),
641            (
642                gpu::ElemType::Float(gpu::FloatKind::BF16),
643                gpu::ElemType::Float(gpu::FloatKind::F32),
644            ),
645        ];
646        let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
647            a_type: ab_elem.into(),
648            b_type: ab_elem.into(),
649            cd_type: cd_elem.into(),
650            m: 16,
651            n: 16,
652            k: 16,
653        });
654        result.extend(combinations);
655    }
656    result
657}
658
659pub fn contiguous_elements_rdna3(ident: MatrixIdent, matrix: Matrix) -> usize {
660    // Don't exceed swizzle atom and load width
661    let max_line_size = 16 / matrix.storage.size();
662    match ident {
663        MatrixIdent::A | MatrixIdent::B => 16.min(max_line_size),
664        MatrixIdent::Accumulator => 1,
665    }
666}
667
668// threads 0-15 and threads 16-31 of the wavefront hold the same fragments respectively
669// in other words fragments are duplicated
670// so lanes 0,16 / 1,17 / ... / 15, 31 are the same
671static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";