cubecl_cpp/hip/mma/
wmma_intrinsics_compiler.rs

1use std::fmt::Formatter;
2
3use crate::{
4    Dialect,
5    hip::{HipDialect, arch::AMDArchitecture},
6    shared::{
7        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8        FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
9        Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
10        wmma_api_base,
11    },
12};
13use cubecl_core::ir::{self as gpu, Matrix, MatrixIdent};
14use cubecl_runtime::MmaConfig;
15
16#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
17pub struct WmmaIntrinsicCompiler {}
18
19#[derive(new, Debug, Clone, PartialEq)]
20pub struct WmmaFill<D: Dialect> {
21    frag: Fragment<D>,
22}
23
24#[derive(new, Debug, Clone, PartialEq)]
25pub struct WmmaLoad<D: Dialect> {
26    frag: Fragment<D>,
27    layout: Option<FragmentLayout<D>>,
28}
29
30#[derive(new, Debug, Clone, PartialEq)]
31pub struct WmmaStore<D: Dialect> {
32    frag: Fragment<D>,
33    layout: FragmentLayout<D>,
34}
35
36#[derive(new, Debug, Clone, PartialEq)]
37pub struct WmmaExecute<D: Dialect> {
38    frag_a: Fragment<D>,
39    frag_b: Fragment<D>,
40    frag_c: Fragment<D>,
41    frag_d: Fragment<D>,
42}
43
44#[derive(new, Debug, Clone, PartialEq)]
45pub struct WmmaCast<D: Dialect> {
46    frag_input: Fragment<D>,
47    frag_output: Fragment<D>,
48}
49
50impl<D: Dialect> WmmaFill<D> {
51    pub fn fn_name(&self) -> String {
52        let layout = frag_layout_str(&self.frag.layout);
53        let ident = frag_ident_str(&self.frag.ident);
54        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
55        let elem = self.frag.elem;
56
57        format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
58    }
59
60    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61        let elem = self.frag.elem;
62        let frag = self.frag;
63        let name = self.fn_name();
64
65        write!(
66            f,
67            "
68// Fill the fragment.
69__device__ void {name}({frag}& frag, {elem} value) {{
70    #pragma unroll
71    for (uint i = 0; i < 8; ++i) {{
72      frag[i] = value;
73    }}
74}}
75        "
76        )
77    }
78}
79
80impl<D: Dialect> WmmaLoad<D> {
81    pub fn fn_name(&self) -> String {
82        let layout_frag = frag_layout_str(&self.frag.layout);
83        let layout = frag_layout_str(&self.layout);
84        let ident = frag_ident_str(&self.frag.ident);
85        let elem = self.frag.elem;
86        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
87
88        format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
89    }
90
91    /// Matrix A must be in column major layout (so fragments correspond to a row)
92    /// Matrices B, C and D must be in row major layout (so fragments correspond to a column)
93    ///
94    /// Each lane is a thread so each column get 8 VGPRs used to store fragments
95    /// Here is the layout for C and D matrices and how they map to registers
96    ///
97    /// Lane index   0      1      2      3      ...     13     14     15     ...     17     18     ...     30     31
98    /// --------------------------------------------------------------------------------------------------------------
99    /// VGPR0      | 1,1  | 1,2  | 1,3  | 1,4  | ...  | 1,13 | 1,14 | 1,15 | ...  | 2,1  | 2,2  | ...  | 2,15 | 2,16 |
100    /// --------------------------------------------------------------------------------------------------------------
101    /// VGPR1      | 3,1  | 3,2  | 3,3  | 3,4  | ...  | 3,13 | 3,14 | 3,15 | ...  | 4,1  | 4,2  | ...  | 4,15 | 4,16 |
102    /// --------------------------------------------------------------------------------------------------------------
103    /// VGPR2      | 5,1  | 5,2  | 5,3  | 5,4  | ...  | 5,13 | 5,14 | 5,15 | ...  | 6,1  | 6,2  | ...  | 6,15 | 6,16 |
104    /// --------------------------------------------------------------------------------------------------------------
105    /// VGPR3      | 7,1  | 7,2  | 7,3  | 7,4  | ...  | 7,13 | 7,14 | 7,15 | ...  | 8,1  | 8,2  | ...  | 8,15 | 8,16 |
106    /// --------------------------------------------------------------------------------------------------------------
107    /// VGPR4      | 9,1  | 9,2  | 9,3  | 9,4  | ...  | 9,13 | 9,14 | 9,15 | ...  | 10,1 | 10,2 | ...  | 10,15| 10,16|
108    /// --------------------------------------------------------------------------------------------------------------
109    /// VGPR5      | 11,1 | 11,2 | 11,3 | 11,4 | ...  | 11,13| 11,14| 11,15| ...  | 12,1 | 12,2 | ...  | 12,15| 12,16|
110    /// --------------------------------------------------------------------------------------------------------------
111    /// VGPR6      | 13,1 | 13,2 | 13,3 | 13,4 | ...  | 13,13| 13,14| 13,15| ...  | 14,1 | 14,2 | ...  | 14,15| 14,16|
112    /// --------------------------------------------------------------------------------------------------------------
113    /// VGPR7      | 15,1 | 15,2 | 15,3 | 15,4 | ...  | 15,13| 15,14| 15,15| ...  | 16,1 | 16,2 | ...  | 16,15| 16,16|
114    /// --------------------------------------------------------------------------------------------------------------
115    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116        let elem = self.frag.elem;
117        let frag = self.frag;
118        let name = self.fn_name();
119
120        let (index_body, length, step) = match frag.ident {
121            FragmentIdent::A | FragmentIdent::B => {
122                let length = 16;
123                let step = 1;
124                // fragment a and b are always in half precision and they don't require special attention
125                // to how they are stored in memory as matrix A and B are also in half precision
126                let index = if (frag.ident == FragmentIdent::A
127                    && frag.layout.unwrap() == FragmentLayout::ColMajor)
128                    || (frag.ident == FragmentIdent::B
129                        && frag.layout.unwrap() == FragmentLayout::RowMajor)
130                {
131                    "i * stride + wmmaLane".to_string()
132                } else {
133                    "i + wmmaLane * stride".to_string()
134                };
135                (index, length, step)
136            }
137            FragmentIdent::Accumulator => {
138                let length = 8;
139                let step = get_output_accumulator_index_step(&elem, &frag);
140                let index = match self.layout {
141                    Some(FragmentLayout::ColMajor) => {
142                        "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
143                    }
144                    Some(FragmentLayout::RowMajor) => {
145                        "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
146                    }
147                    _ => panic!(
148                        "cannot load data to an accumulator without knowing the layout of the data"
149                    ),
150                };
151                (index, length, step)
152            }
153            other => panic!("unknown matrix identifier {other}"),
154        };
155
156        write!(
157            f,
158            "
159// Load the fragment.
160__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
161    {WMMA_LANE_DEF}
162
163    #pragma unroll
164    for (uint i = 0; i < {length}; ++i) {{
165      const uint index = {index_body};
166      frag[i * {step}] = value_ptr[index];
167    }}
168}}
169        "
170        )
171    }
172}
173
174impl<D: Dialect> WmmaStore<D> {
175    pub fn fn_name(&self) -> String {
176        let layout_frag = frag_layout_str(&self.frag.layout);
177        let layout_option = Some(self.layout);
178        let layout = frag_layout_str(&layout_option);
179        let ident = frag_ident_str(&self.frag.ident);
180        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
181        let elem = self.frag.elem;
182
183        format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
184    }
185
186    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187        let elem = self.frag.elem;
188        let frag = self.frag;
189        let name = self.fn_name();
190        // frag holds a result column where threads 0-15 of the wavefront have the even rows and threads 16-31 the odd rows
191        // moreover, since we use OPSEL to false in the Execute instruction in f16 output format, the output elements are
192        // stored in even indexes (0, 2, 4, ...) (low 16-bits of the VGPR) in frag
193        let frag_idx = match elem {
194            Elem::F16 | Elem::BF16 => "elemIdx * 2",
195            Elem::F32 => "elemIdx",
196            other => {
197                panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
198            }
199        };
200        // FragmentLayout here represents the desired layout of the matrix C
201        let output_idx = match self.layout {
202            FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
203            FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
204            FragmentLayout::_Dialect(_) => String::new(),
205        };
206
207        write!(
208            f,
209            "
210// Store the fragment.
211__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
212    {WMMA_LANE_DEF}
213
214    #pragma unroll
215    for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
216      const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
217      output_ptr[{output_idx}] = frag[{frag_idx}];
218    }}
219}}
220        "
221        )
222    }
223}
224
225impl<D: Dialect> WmmaExecute<D> {
226    pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
227        let frag_a = Fragment {
228            ident: FragmentIdent::A,
229            m: shape.m,
230            n: shape.n,
231            k: shape.k,
232            elem: ab_elem,
233            layout: Some(FragmentLayout::ColMajor),
234        };
235        let frag_b = Fragment {
236            ident: FragmentIdent::B,
237            layout: Some(FragmentLayout::RowMajor),
238            ..frag_a
239        };
240        let frag_cd = Fragment {
241            ident: FragmentIdent::Accumulator,
242            elem: cd_elem,
243            ..frag_b
244        };
245        WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
246    }
247
248    pub fn fn_name(&self) -> String {
249        format!(
250            "wmma_execute_16x16x16_{}_{}",
251            self.frag_a.elem, self.frag_c.elem
252        )
253    }
254
255    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        let name = self.fn_name();
257        let ab_format = match self.frag_a.elem {
258            Elem::F32 => "f32",
259            Elem::BF16 => "bf16",
260            Elem::F16 => "f16",
261            _ => panic!(),
262        };
263        let (cd_format, opsel) = match self.frag_c.elem {
264            Elem::F32 => ("f32", ""),
265            Elem::BF16 => ("bf16", ", false"),
266            Elem::F16 => ("f16", ", false"),
267            _ => panic!(),
268        };
269        let warp_size = 32;
270        write!(
271            f,
272            "
273// Execute wmma.
274__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
275    frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
276}}
277        ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
278        )
279    }
280}
281
282impl<D: Dialect> WmmaCast<D> {
283    pub fn fn_name(&self) -> String {
284        let layout = frag_layout_str(&self.frag_input.layout);
285        let ident = frag_ident_str(&self.frag_input.ident);
286        let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
287        let elem = self.frag_input.elem;
288        let elem_out = self.frag_output.elem;
289
290        format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
291    }
292
293    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294        let input = self.frag_input;
295        let output = self.frag_output;
296        let name = self.fn_name();
297        let step = match output.ident {
298            FragmentIdent::Accumulator => {
299                get_output_accumulator_index_step(&self.frag_input.elem, &output)
300            }
301            _ => 1,
302        };
303
304        write!(
305            f,
306            "
307// Cast the fragment.
308__device__ void {name}({input}& input, {output}& output) {{
309    #pragma unroll
310    for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
311      output[elemIdx * {step}] = input[elemIdx];
312    }}
313}}
314        "
315        )
316    }
317}
318
319impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
320    fn compile_wmma_type_definitions(
321        f: &mut std::fmt::Formatter<'_>,
322        flags: &Flags,
323    ) -> std::fmt::Result {
324        if flags.elem_bf16 {
325            f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
326            f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
327        }
328        if flags.elem_f16 {
329            f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
330            f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
331        }
332        f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
333    }
334
335    fn compile_wmma_fragment_declaration(
336        f: &mut std::fmt::Formatter<'_>,
337        var: &crate::shared::Variable<HipDialect<Self>>,
338    ) -> std::fmt::Result {
339        wmma_api_base::compile_fragment_declaration(f, var)
340    }
341
342    fn compile_wmma_fragment(
343        f: &mut std::fmt::Formatter<'_>,
344        fragment: &Fragment<HipDialect<Self>>,
345    ) -> std::fmt::Result {
346        match fragment.ident {
347            FragmentIdent::A | FragmentIdent::B => match fragment.elem {
348                Elem::F16 => write!(f, "half16_t"),
349                Elem::BF16 => write!(f, "bhalf16_t"),
350                other => panic!("unsupported type {other} for {fragment}"),
351            },
352            FragmentIdent::Accumulator => match fragment.elem {
353                Elem::F16 => write!(f, "half16_t"),
354                Elem::BF16 => write!(f, "bhalf16_t"),
355                Elem::F32 => write!(f, "float8_t"),
356                other => panic!("unsupported type {other} for {fragment}"),
357            },
358            FragmentIdent::_Dialect(_) => Ok(()),
359        }
360    }
361
362    fn compile_wmma_instruction(
363        f: &mut std::fmt::Formatter<'_>,
364        instruction: &WmmaInstruction<HipDialect<Self>>,
365    ) -> std::fmt::Result {
366        match instruction {
367            WmmaInstruction::Fill { frag, value } => {
368                let extension = WmmaFill::new(match frag {
369                    Variable::WmmaFragment { frag, .. } => *frag,
370                    _ => panic!(),
371                });
372                let name = extension.fn_name();
373                writeln!(f, "{name}({frag}, {value});")
374            }
375            WmmaInstruction::Load {
376                frag,
377                value,
378                layout,
379                offset,
380                stride,
381            } => {
382                let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
383                let name = extension.fn_name();
384                let value_ptr = frag_as_ptr(f, value, offset);
385                writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
386            }
387            WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
388                f.write_str("#error LdMatrix & StMatrix are not supported on HIP\n")
389            }
390            WmmaInstruction::Execute {
391                frag_a,
392                frag_b,
393                frag_c,
394                frag_d,
395                warp_size,
396            } => {
397                if *warp_size != 32 {
398                    f.write_str(
399                        "#error Only warp size of 32 supported for Wmma::Execute on HIP\n",
400                    )?;
401                }
402
403                let extension = WmmaExecute::new(
404                    variable_to_frag(frag_a),
405                    variable_to_frag(frag_b),
406                    variable_to_frag(frag_c),
407                    variable_to_frag(frag_d),
408                );
409                let name = extension.fn_name();
410                writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
411            }
412            WmmaInstruction::ExecuteManual {
413                shape,
414                frag_a,
415                frag_b,
416                frag_c,
417                frag_d,
418            } => {
419                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
420            }
421            WmmaInstruction::ExecuteScaled {
422                shape,
423                frag_a,
424                frag_b,
425                frag_c,
426                frag_d,
427                scales_a,
428                scales_b,
429                scales_factor,
430            } => Self::compile_scaled_mma(
431                f,
432                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
433                *scales_a,
434                *scales_b,
435                *scales_factor,
436            ),
437            WmmaInstruction::Store {
438                output,
439                frag,
440                layout,
441                offset,
442                stride,
443            } => {
444                let extension = WmmaStore::new(variable_to_frag(frag), *layout);
445                let name = extension.fn_name();
446                let output_ptr = frag_as_ptr(f, output, offset);
447                writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
448            }
449            WmmaInstruction::Cast { input, output } => {
450                let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
451                let name = extension.fn_name();
452                writeln!(f, "{name}({input}, {output});")
453            }
454        }
455    }
456
457    fn compile_manual_mma(
458        f: &mut std::fmt::Formatter<'_>,
459        mma: ManualMma<HipDialect<Self>>,
460    ) -> std::fmt::Result {
461        compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
462    }
463
464    fn compile_scaled_mma(
465        f: &mut std::fmt::Formatter<'_>,
466        _mma: ManualMma<HipDialect<Self>>,
467        _scales_a: Variable<HipDialect<Self>>,
468        _scales_b: Variable<HipDialect<Self>>,
469        _scales_factor: u32,
470    ) -> std::fmt::Result {
471        f.write_str("#error scaled mma not supported in HIP\n")
472    }
473
474    fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
475        // Reference: https://gpuopen.com/learn/wmma_on_rdna3/
476        let mut result: SupportedMmaCombinations = vec![];
477        if arch.is_wmma_capable() {
478            // Types fully supported.
479            let types = vec![
480                (
481                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
482                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
483                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
484                ),
485                (
486                    gpu::ElemType::Float(gpu::FloatKind::F16),
487                    gpu::ElemType::Float(gpu::FloatKind::F16),
488                    gpu::ElemType::Float(gpu::FloatKind::F32),
489                ),
490                (
491                    gpu::ElemType::Float(gpu::FloatKind::BF16),
492                    gpu::ElemType::Float(gpu::FloatKind::BF16),
493                    gpu::ElemType::Float(gpu::FloatKind::F32),
494                ),
495            ];
496            let combinations: SupportedMmaCombinations = types
497                .into_iter()
498                .map(|(a, b, c)| MmaConfig {
499                    a_type: a.into(),
500                    b_type: b.into(),
501                    cd_type: c.into(),
502                    m: 16,
503                    n: 16,
504                    k: 16,
505                })
506                .collect();
507            result.extend(combinations);
508        }
509        result
510    }
511
512    fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
513        supported_mma_combinations(arch)
514    }
515}
516
517fn get_output_accumulator_index_step<D: Dialect>(
518    input_elem: &Elem<D>,
519    output: &Fragment<D>,
520) -> u32 {
521    // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either:
522    // - a vector of 8 float
523    // - a vector of 16 half
524    // Depending on the precision used for the input, the whole 32 bits per register will be used or
525    // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means
526    // that we only assign values to even indexes of the accumulator (0, 2, 4, ...)
527
528    assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
529
530    match input_elem {
531        Elem::F16 | Elem::BF16 | Elem::F32 => {
532            match output.elem {
533                // loading into accumulator of 16 half precision
534                Elem::F16 | Elem::BF16 => 2,
535                // loading into accumulator of 8 full precision
536                Elem::F32 => 1,
537                other => panic!("unsupported format {other} for {output}"),
538            }
539        }
540        other => panic!("unsupported format {other} for {input_elem}"),
541    }
542}
543
544pub(super) fn compile_manual_mma<D: Dialect>(
545    f: &mut std::fmt::Formatter<'_>,
546    shape: MmaShape<D>,
547    frag_a: &Variable<D>,
548    frag_b: &Variable<D>,
549    frag_c: &Variable<D>,
550    frag_d: &Variable<D>,
551) -> std::fmt::Result {
552    let extension = WmmaExecute::from_manual(shape, frag_a.elem(), frag_c.elem());
553
554    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
555
556    let frag_cd_step = 4usize.div_ceil(frag_c.elem().size());
557    let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
558
559    // Need to reconstruct the fragments from an array of lines to a single vector type.
560    // This requires double indexing over both the array index and the line index.
561    // Will generate something like
562    // `float8_t {arr[0].i_0, arr[0].i_1, arr[1].i_0, ...}`
563    let frag = |var: &Variable<D>, len: usize| {
564        let vec = var.item().vectorization;
565        let frag: Vec<_> = if vec > 1 {
566            (0..len)
567                .map(|i| format!("{var}[{}].i_{}", i / vec, i % vec))
568                .collect()
569        } else {
570            (0..len).map(|i| format!("{var}[{}]", i)).collect()
571        };
572        frag.join(", ")
573    };
574
575    let frag_a = frag(frag_a, 16);
576    let frag_b = frag(frag_b, 16);
577    // C matrix needs to be padded for f16, because it only uses the low bytes. The simplest way is
578    // to just replicate the same f16 in both halves of the register.
579    let frag_c = {
580        let vec = frag_c.item().vectorization;
581        let frag: Vec<_> = if vec > 1 {
582            (0..cd_elems as usize)
583                .flat_map(|i| {
584                    (0..frag_cd_step).map(move |_| format!("{frag_c}[{}].i_{}", i / vec, i % vec))
585                })
586                .collect()
587        } else {
588            (0..cd_elems as usize)
589                .flat_map(|i| (0..frag_cd_step).map(move |_| format!("{frag_c}[{}]", i)))
590                .collect()
591        };
592        frag.join(", ")
593    };
594
595    // Should optimize out
596    let name = extension.fn_name();
597
598    // Item is irrelevant
599    writeln!(f, "{} {frag_d_tmp} = {{}};", extension.frag_d)?;
600
601    writeln!(
602        f,
603        "{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
604        extension.frag_a, extension.frag_b, extension.frag_c
605    )?;
606
607    for i in 0..cd_elems as usize {
608        let vec = frag_d.item().vectorization;
609        if vec > 1 {
610            writeln!(
611                f,
612                "{frag_d}[{}].i_{} = {frag_d_tmp}[{i} * {frag_cd_step}];",
613                i / vec,
614                i % vec
615            )?;
616        } else {
617            writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i} * {frag_cd_step}];")?;
618        }
619    }
620
621    Ok(())
622}
623
624pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
625    // Correctness is wrong.
626    const ENABLED: bool = true;
627
628    if !ENABLED {
629        return Vec::new();
630    }
631
632    // Reference: https://gpuopen.com/learn/wmma_on_rdna3/
633    // Feel free to add more if additional intrinsics are supported for execute
634    let mut result: SupportedMmaCombinations = vec![];
635    if arch.is_wmma_capable() {
636        // Types fully supported.
637        let types = vec![
638            (
639                gpu::ElemType::Float(gpu::FloatKind::F16),
640                gpu::ElemType::Float(gpu::FloatKind::F32),
641            ),
642            (
643                gpu::ElemType::Float(gpu::FloatKind::BF16),
644                gpu::ElemType::Float(gpu::FloatKind::F32),
645            ),
646        ];
647        let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
648            a_type: ab_elem.into(),
649            b_type: ab_elem.into(),
650            cd_type: cd_elem.into(),
651            m: 16,
652            n: 16,
653            k: 16,
654        });
655        result.extend(combinations);
656    }
657    result
658}
659
660pub fn contiguous_elements_rdna3(ident: MatrixIdent, matrix: Matrix) -> u32 {
661    // Don't exceed swizzle atom and load width
662    let max_line_size = 16 / matrix.storage.size();
663    match ident {
664        MatrixIdent::A | MatrixIdent::B => 16.min(max_line_size) as u32,
665        MatrixIdent::Accumulator => 1,
666    }
667}
668
669// threads 0-15 and threads 16-31 of the wavefront hold the same fragments respectively
670// in other words fragments are duplicated
671// so lanes 0,16 / 1,17 / ... / 15, 31 are the same
672static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";