cubecl_cpp/hip/mma/
wmma_intrinsics_compiler.rs

1use std::fmt::Formatter;
2
3use crate::{
4    Dialect,
5    hip::{HipDialect, arch::AMDArchitecture},
6    shared::{
7        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8        FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
9        Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
10        wmma_api_base,
11    },
12};
13use cubecl_core::ir::{self as gpu, Matrix, MatrixIdent};
14use cubecl_runtime::MmaConfig;
15
16#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
17pub struct WmmaIntrinsicCompiler {}
18
19#[derive(new, Debug, Clone, PartialEq)]
20pub struct WmmaFill<D: Dialect> {
21    frag: Fragment<D>,
22}
23
24#[derive(new, Debug, Clone, PartialEq)]
25pub struct WmmaLoad<D: Dialect> {
26    frag: Fragment<D>,
27    layout: Option<FragmentLayout<D>>,
28}
29
30#[derive(new, Debug, Clone, PartialEq)]
31pub struct WmmaStore<D: Dialect> {
32    frag: Fragment<D>,
33    layout: FragmentLayout<D>,
34}
35
36#[derive(new, Debug, Clone, PartialEq)]
37pub struct WmmaExecute<D: Dialect> {
38    frag_a: Fragment<D>,
39    frag_b: Fragment<D>,
40    frag_c: Fragment<D>,
41    frag_d: Fragment<D>,
42}
43
44#[derive(new, Debug, Clone, PartialEq)]
45pub struct WmmaCast<D: Dialect> {
46    frag_input: Fragment<D>,
47    frag_output: Fragment<D>,
48}
49
50impl<D: Dialect> WmmaFill<D> {
51    pub fn fn_name(&self) -> String {
52        let layout = frag_layout_str(&self.frag.layout);
53        let ident = frag_ident_str(&self.frag.ident);
54        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
55        let elem = self.frag.elem;
56
57        format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
58    }
59
60    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61        let elem = self.frag.elem;
62        let frag = self.frag;
63        let name = self.fn_name();
64
65        write!(
66            f,
67            "
68// Fill the fragment.
69__device__ void {name}({frag}& frag, {elem} value) {{
70    #pragma unroll
71    for (uint i = 0; i < 8; ++i) {{
72      frag[i] = value;
73    }}
74}}
75        "
76        )
77    }
78}
79
80impl<D: Dialect> WmmaLoad<D> {
81    pub fn fn_name(&self) -> String {
82        let layout_frag = frag_layout_str(&self.frag.layout);
83        let layout = frag_layout_str(&self.layout);
84        let ident = frag_ident_str(&self.frag.ident);
85        let elem = self.frag.elem;
86        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
87
88        format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
89    }
90
91    /// Matrix A must be in column major layout (so fragments correspond to a row)
92    /// Matrices B, C and D must be in row major layout (so fragments correspond to a column)
93    ///
94    /// Each lane is a thread so each column get 8 VGPRs used to store fragments
95    /// Here is the layout for C and D matrices and how they map to registers
96    ///
97    /// Lane index   0      1      2      3      ...     13     14     15     ...     17     18     ...     30     31
98    /// --------------------------------------------------------------------------------------------------------------
99    /// VGPR0      | 1,1  | 1,2  | 1,3  | 1,4  | ...  | 1,13 | 1,14 | 1,15 | ...  | 2,1  | 2,2  | ...  | 2,15 | 2,16 |
100    /// --------------------------------------------------------------------------------------------------------------
101    /// VGPR1      | 3,1  | 3,2  | 3,3  | 3,4  | ...  | 3,13 | 3,14 | 3,15 | ...  | 4,1  | 4,2  | ...  | 4,15 | 4,16 |
102    /// --------------------------------------------------------------------------------------------------------------
103    /// VGPR2      | 5,1  | 5,2  | 5,3  | 5,4  | ...  | 5,13 | 5,14 | 5,15 | ...  | 6,1  | 6,2  | ...  | 6,15 | 6,16 |
104    /// --------------------------------------------------------------------------------------------------------------
105    /// VGPR3      | 7,1  | 7,2  | 7,3  | 7,4  | ...  | 7,13 | 7,14 | 7,15 | ...  | 8,1  | 8,2  | ...  | 8,15 | 8,16 |
106    /// --------------------------------------------------------------------------------------------------------------
107    /// VGPR4      | 9,1  | 9,2  | 9,3  | 9,4  | ...  | 9,13 | 9,14 | 9,15 | ...  | 10,1 | 10,2 | ...  | 10,15| 10,16|
108    /// --------------------------------------------------------------------------------------------------------------
109    /// VGPR5      | 11,1 | 11,2 | 11,3 | 11,4 | ...  | 11,13| 11,14| 11,15| ...  | 12,1 | 12,2 | ...  | 12,15| 12,16|
110    /// --------------------------------------------------------------------------------------------------------------
111    /// VGPR6      | 13,1 | 13,2 | 13,3 | 13,4 | ...  | 13,13| 13,14| 13,15| ...  | 14,1 | 14,2 | ...  | 14,15| 14,16|
112    /// --------------------------------------------------------------------------------------------------------------
113    /// VGPR7      | 15,1 | 15,2 | 15,3 | 15,4 | ...  | 15,13| 15,14| 15,15| ...  | 16,1 | 16,2 | ...  | 16,15| 16,16|
114    /// --------------------------------------------------------------------------------------------------------------
115    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116        let elem = self.frag.elem;
117        let frag = self.frag;
118        let name = self.fn_name();
119
120        let (index_body, length, step) = match frag.ident {
121            FragmentIdent::A | FragmentIdent::B => {
122                let length = 16;
123                let step = 1;
124                // fragment a and b are always in half precision and they don't require special attention
125                // to how they are stored in memory as matrix A and B are also in half precision
126                let index = if (frag.ident == FragmentIdent::A
127                    && frag.layout.unwrap() == FragmentLayout::ColMajor)
128                    || (frag.ident == FragmentIdent::B
129                        && frag.layout.unwrap() == FragmentLayout::RowMajor)
130                {
131                    "i * stride + wmmaLane".to_string()
132                } else {
133                    "i + wmmaLane * stride".to_string()
134                };
135                (index, length, step)
136            }
137            FragmentIdent::Accumulator => {
138                let length = 8;
139                let step = get_output_accumulator_index_step(&elem, &frag);
140                let index = match self.layout {
141                    Some(FragmentLayout::ColMajor) => {
142                        "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
143                    }
144                    Some(FragmentLayout::RowMajor) => {
145                        "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
146                    }
147                    _ => panic!(
148                        "cannot load data to an accumulator without knowing the layout of the data"
149                    ),
150                };
151                (index, length, step)
152            }
153            other => panic!("unknown matrix identifier {other}"),
154        };
155
156        write!(
157            f,
158            "
159// Load the fragment.
160__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
161    {WMMA_LANE_DEF}
162
163    #pragma unroll
164    for (uint i = 0; i < {length}; ++i) {{
165      const uint index = {index_body};
166      frag[i * {step}] = value_ptr[index];
167    }}
168}}
169        "
170        )
171    }
172}
173
174impl<D: Dialect> WmmaStore<D> {
175    pub fn fn_name(&self) -> String {
176        let layout_frag = frag_layout_str(&self.frag.layout);
177        let layout_option = Some(self.layout);
178        let layout = frag_layout_str(&layout_option);
179        let ident = frag_ident_str(&self.frag.ident);
180        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
181        let elem = self.frag.elem;
182
183        format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
184    }
185
186    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187        let elem = self.frag.elem;
188        let frag = self.frag;
189        let name = self.fn_name();
190        // frag holds a result column where threads 0-15 of the wavefront have the even rows and threads 16-31 the odd rows
191        // moreover, since we use OPSEL to false in the Execute instruction in f16 output format, the output elements are
192        // stored in even indexes (0, 2, 4, ...) (low 16-bits of the VGPR) in frag
193        let frag_idx = match elem {
194            Elem::F16 | Elem::BF16 => "elemIdx * 2",
195            Elem::F32 => "elemIdx",
196            other => {
197                panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
198            }
199        };
200        // FragmentLayout here represents the desired layout of the matrix C
201        let output_idx = match self.layout {
202            FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
203            FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
204            FragmentLayout::_Dialect(_) => String::new(),
205        };
206
207        write!(
208            f,
209            "
210// Store the fragment.
211__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
212    {WMMA_LANE_DEF}
213
214    #pragma unroll
215    for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
216      const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
217      output_ptr[{output_idx}] = frag[{frag_idx}];
218    }}
219}}
220        "
221        )
222    }
223}
224
225impl<D: Dialect> WmmaExecute<D> {
226    pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
227        let frag_a = Fragment {
228            ident: FragmentIdent::A,
229            m: shape.m,
230            n: shape.n,
231            k: shape.k,
232            elem: ab_elem,
233            layout: Some(FragmentLayout::ColMajor),
234        };
235        let frag_b = Fragment {
236            ident: FragmentIdent::B,
237            layout: Some(FragmentLayout::RowMajor),
238            ..frag_a
239        };
240        let frag_cd = Fragment {
241            ident: FragmentIdent::Accumulator,
242            elem: cd_elem,
243            ..frag_b
244        };
245        WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
246    }
247
248    pub fn fn_name(&self) -> String {
249        format!(
250            "wmma_execute_16x16x16_{}_{}",
251            self.frag_a.elem, self.frag_c.elem
252        )
253    }
254
255    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        let name = self.fn_name();
257        let ab_format = match self.frag_a.elem {
258            Elem::F32 => "f32",
259            Elem::BF16 => "bf16",
260            Elem::F16 => "f16",
261            _ => panic!(),
262        };
263        let (cd_format, opsel) = match self.frag_c.elem {
264            Elem::F32 => ("f32", ""),
265            Elem::BF16 => ("bf16", ", false"),
266            Elem::F16 => ("f16", ", false"),
267            _ => panic!(),
268        };
269        let warp_size = 32;
270        write!(
271            f,
272            "
273// Execute wmma.
274__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
275    frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
276}}
277        ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
278        )
279    }
280}
281
282impl<D: Dialect> WmmaCast<D> {
283    pub fn fn_name(&self) -> String {
284        let layout = frag_layout_str(&self.frag_input.layout);
285        let ident = frag_ident_str(&self.frag_input.ident);
286        let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
287        let elem = self.frag_input.elem;
288        let elem_out = self.frag_output.elem;
289
290        format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
291    }
292
293    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294        let input = self.frag_input;
295        let output = self.frag_output;
296        let name = self.fn_name();
297        let step = match output.ident {
298            FragmentIdent::Accumulator => {
299                get_output_accumulator_index_step(&self.frag_input.elem, &output)
300            }
301            _ => 1,
302        };
303
304        write!(
305            f,
306            "
307// Cast the fragment.
308__device__ void {name}({input}& input, {output}& output) {{
309    #pragma unroll
310    for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
311      output[elemIdx * {step}] = input[elemIdx];
312    }}
313}}
314        "
315        )
316    }
317}
318
319impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
320    fn compile_wmma_type_definitions(
321        f: &mut std::fmt::Formatter<'_>,
322        flags: &Flags,
323    ) -> std::fmt::Result {
324        if flags.elem_bf16 {
325            f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
326            f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
327        }
328        if flags.elem_f16 {
329            f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
330            f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
331        }
332        f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
333    }
334
335    fn compile_wmma_fragment_declaration(
336        f: &mut std::fmt::Formatter<'_>,
337        var: &crate::shared::Variable<HipDialect<Self>>,
338    ) -> std::fmt::Result {
339        wmma_api_base::compile_fragment_declaration(f, var)
340    }
341
342    fn compile_wmma_fragment(
343        f: &mut std::fmt::Formatter<'_>,
344        fragment: &Fragment<HipDialect<Self>>,
345    ) -> std::fmt::Result {
346        match fragment.ident {
347            FragmentIdent::A | FragmentIdent::B => match fragment.elem {
348                Elem::F16 => write!(f, "half16_t"),
349                Elem::BF16 => write!(f, "bhalf16_t"),
350                other => panic!("unsupported type {other} for {fragment}"),
351            },
352            FragmentIdent::Accumulator => match fragment.elem {
353                Elem::F16 => write!(f, "half16_t"),
354                Elem::BF16 => write!(f, "bhalf16_t"),
355                Elem::F32 => write!(f, "float8_t"),
356                other => panic!("unsupported type {other} for {fragment}"),
357            },
358            FragmentIdent::_Dialect(_) => Ok(()),
359        }
360    }
361
362    fn compile_wmma_instruction(
363        f: &mut std::fmt::Formatter<'_>,
364        instruction: &WmmaInstruction<HipDialect<Self>>,
365    ) -> std::fmt::Result {
366        match instruction {
367            WmmaInstruction::Fill { frag, value } => {
368                let extension = WmmaFill::new(match frag {
369                    Variable::WmmaFragment { frag, .. } => *frag,
370                    _ => panic!(),
371                });
372                let name = extension.fn_name();
373                writeln!(f, "{name}({frag}, {value});")
374            }
375            WmmaInstruction::Load {
376                frag,
377                value,
378                layout,
379                offset,
380                stride,
381            } => {
382                let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
383                let name = extension.fn_name();
384                let value_ptr = frag_as_ptr(f, value, offset);
385                writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
386            }
387            WmmaInstruction::LdMatrix { .. } => {
388                unimplemented!("Not supported in HIP")
389            }
390            WmmaInstruction::Execute {
391                frag_a,
392                frag_b,
393                frag_c,
394                frag_d,
395                warp_size,
396            } => {
397                assert_eq!(*warp_size, 32, "Only warp size of 32 supported");
398
399                let extension = WmmaExecute::new(
400                    variable_to_frag(frag_a),
401                    variable_to_frag(frag_b),
402                    variable_to_frag(frag_c),
403                    variable_to_frag(frag_d),
404                );
405                let name = extension.fn_name();
406                writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
407            }
408            WmmaInstruction::ExecuteManual {
409                shape,
410                frag_a,
411                frag_b,
412                frag_c,
413                frag_d,
414            } => {
415                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
416            }
417            WmmaInstruction::ExecuteScaled {
418                shape,
419                frag_a,
420                frag_b,
421                frag_c,
422                frag_d,
423                scales_a,
424                scales_b,
425                scales_factor,
426            } => Self::compile_scaled_mma(
427                f,
428                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
429                *scales_a,
430                *scales_b,
431                *scales_factor,
432            ),
433            WmmaInstruction::Store {
434                output,
435                frag,
436                layout,
437                offset,
438                stride,
439            } => {
440                let extension = WmmaStore::new(variable_to_frag(frag), *layout);
441                let name = extension.fn_name();
442                let output_ptr = frag_as_ptr(f, output, offset);
443                writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
444            }
445            WmmaInstruction::Cast { input, output } => {
446                let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
447                let name = extension.fn_name();
448                writeln!(f, "{name}({input}, {output});")
449            }
450        }
451    }
452
453    fn compile_manual_mma(
454        f: &mut std::fmt::Formatter<'_>,
455        mma: ManualMma<HipDialect<Self>>,
456    ) -> std::fmt::Result {
457        compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
458    }
459
460    fn compile_scaled_mma(
461        _f: &mut std::fmt::Formatter<'_>,
462        _mma: ManualMma<HipDialect<Self>>,
463        _scales_a: Variable<HipDialect<Self>>,
464        _scales_b: Variable<HipDialect<Self>>,
465        _scales_factor: u32,
466    ) -> std::fmt::Result {
467        unimplemented!("Not supported in HIP")
468    }
469
470    fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
471        // Reference: https://gpuopen.com/learn/wmma_on_rdna3/
472        let mut result: SupportedMmaCombinations = vec![];
473        if arch.is_wmma_capable() {
474            // Types fully supported.
475            let types = vec![
476                (
477                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
478                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
479                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
480                ),
481                (
482                    gpu::ElemType::Float(gpu::FloatKind::F16),
483                    gpu::ElemType::Float(gpu::FloatKind::F16),
484                    gpu::ElemType::Float(gpu::FloatKind::F32),
485                ),
486                (
487                    gpu::ElemType::Float(gpu::FloatKind::BF16),
488                    gpu::ElemType::Float(gpu::FloatKind::BF16),
489                    gpu::ElemType::Float(gpu::FloatKind::F32),
490                ),
491            ];
492            let combinations: SupportedMmaCombinations = types
493                .into_iter()
494                .map(|(a, b, c)| MmaConfig {
495                    a_type: a.into(),
496                    b_type: b.into(),
497                    cd_type: c.into(),
498                    m: 16,
499                    n: 16,
500                    k: 16,
501                })
502                .collect();
503            result.extend(combinations);
504        }
505        result
506    }
507
508    fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
509        supported_mma_combinations(arch)
510    }
511}
512
513fn get_output_accumulator_index_step<D: Dialect>(
514    input_elem: &Elem<D>,
515    output: &Fragment<D>,
516) -> u32 {
517    // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either:
518    // - a vector of 8 float
519    // - a vector of 16 half
520    // Depending on the precision used for the input, the whole 32 bits per register will be used or
521    // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means
522    // that we only assign values to even indexes of the accumulator (0, 2, 4, ...)
523
524    assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
525
526    match input_elem {
527        Elem::F16 | Elem::BF16 | Elem::F32 => {
528            match output.elem {
529                // loading into accumulator of 16 half precision
530                Elem::F16 | Elem::BF16 => 2,
531                // loading into accumulator of 8 full precision
532                Elem::F32 => 1,
533                other => panic!("unsupported format {other} for {output}"),
534            }
535        }
536        other => panic!("unsupported format {other} for {input_elem}"),
537    }
538}
539
540pub(super) fn compile_manual_mma<D: Dialect>(
541    f: &mut std::fmt::Formatter<'_>,
542    shape: MmaShape<D>,
543    frag_a: &Variable<D>,
544    frag_b: &Variable<D>,
545    frag_c: &Variable<D>,
546    frag_d: &Variable<D>,
547) -> std::fmt::Result {
548    let extension = WmmaExecute::from_manual(shape, frag_a.elem(), frag_c.elem());
549
550    let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
551
552    let frag_cd_step = 4usize.div_ceil(frag_c.elem().size());
553    let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
554
555    // Need to reconstruct the fragments from an array of lines to a single vector type.
556    // This requires double indexing over both the array index and the line index.
557    // Will generate something like
558    // `float8_t {arr[0].i_0, arr[0].i_1, arr[1].i_0, ...}`
559    let frag = |var: &Variable<D>, len: usize| {
560        let vec = var.item().vectorization;
561        let frag: Vec<_> = if vec > 1 {
562            (0..len)
563                .map(|i| format!("{var}[{}].i_{}", i / vec, i % vec))
564                .collect()
565        } else {
566            (0..len).map(|i| format!("{var}[{}]", i)).collect()
567        };
568        frag.join(", ")
569    };
570
571    let frag_a = frag(frag_a, 16);
572    let frag_b = frag(frag_b, 16);
573    // C matrix needs to be padded for f16, because it only uses the low bytes. The simplest way is
574    // to just replicate the same f16 in both halves of the register.
575    let frag_c = {
576        let vec = frag_c.item().vectorization;
577        let frag: Vec<_> = if vec > 1 {
578            (0..cd_elems as usize)
579                .flat_map(|i| {
580                    (0..frag_cd_step).map(move |_| format!("{frag_c}[{}].i_{}", i / vec, i % vec))
581                })
582                .collect()
583        } else {
584            (0..cd_elems as usize)
585                .flat_map(|i| (0..frag_cd_step).map(move |_| format!("{frag_c}[{}]", i)))
586                .collect()
587        };
588        frag.join(", ")
589    };
590
591    // Should optimize out
592    let name = extension.fn_name();
593
594    // Item is irrelevant
595    writeln!(f, "{} {frag_d_tmp} = {{}};", extension.frag_d)?;
596
597    writeln!(
598        f,
599        "{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
600        extension.frag_a, extension.frag_b, extension.frag_c
601    )?;
602
603    for i in 0..cd_elems as usize {
604        let vec = frag_d.item().vectorization;
605        if vec > 1 {
606            writeln!(
607                f,
608                "{frag_d}[{}].i_{} = {frag_d_tmp}[{i} * {frag_cd_step}];",
609                i / vec,
610                i % vec
611            )?;
612        } else {
613            writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i} * {frag_cd_step}];")?;
614        }
615    }
616
617    Ok(())
618}
619
620pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
621    // Correctness is wrong.
622    const ENABLED: bool = true;
623
624    if !ENABLED {
625        return Vec::new();
626    }
627
628    // Reference: https://gpuopen.com/learn/wmma_on_rdna3/
629    // Feel free to add more if additional intrinsics are supported for execute
630    let mut result: SupportedMmaCombinations = vec![];
631    if arch.is_wmma_capable() {
632        // Types fully supported.
633        let types = vec![
634            (
635                gpu::ElemType::Float(gpu::FloatKind::F16),
636                gpu::ElemType::Float(gpu::FloatKind::F32),
637            ),
638            (
639                gpu::ElemType::Float(gpu::FloatKind::BF16),
640                gpu::ElemType::Float(gpu::FloatKind::F32),
641            ),
642        ];
643        let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
644            a_type: ab_elem.into(),
645            b_type: ab_elem.into(),
646            cd_type: cd_elem.into(),
647            m: 16,
648            n: 16,
649            k: 16,
650        });
651        result.extend(combinations);
652    }
653    result
654}
655
656pub fn contiguous_elements_rdna3(ident: MatrixIdent, matrix: Matrix) -> u32 {
657    // Don't exceed swizzle atom and load width
658    let max_line_size = 16 / matrix.storage.size();
659    match ident {
660        MatrixIdent::A | MatrixIdent::B => 16.min(max_line_size) as u32,
661        MatrixIdent::Accumulator => 1,
662    }
663}
664
665// threads 0-15 and threads 16-31 of the wavefront hold the same fragments respectively
666// in other words fragments are duplicated
667// so lanes 0,16 / 1,17 / ... / 15, 31 are the same
668static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";