cubecl_cpp/hip/mma/
wmma_intrinsics_compiler.rs

1use std::fmt::Formatter;
2
3use crate::{
4    Dialect,
5    cuda::ptx::comma_separated,
6    hip::{HipDialect, arch::AMDArchitecture},
7    shared::{
8        Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
9        FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
10        Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
11        wmma_api_base,
12    },
13};
14use cubecl_core::ir::{self as gpu};
15use cubecl_runtime::MmaConfig;
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
18pub struct WmmaIntrinsicCompiler {}
19
20#[derive(new, Debug, Clone, PartialEq)]
21pub struct WmmaFill<D: Dialect> {
22    frag: Fragment<D>,
23}
24
25#[derive(new, Debug, Clone, PartialEq)]
26pub struct WmmaLoad<D: Dialect> {
27    frag: Fragment<D>,
28    layout: Option<FragmentLayout<D>>,
29}
30
31#[derive(new, Debug, Clone, PartialEq)]
32pub struct WmmaStore<D: Dialect> {
33    frag: Fragment<D>,
34    layout: FragmentLayout<D>,
35}
36
37#[derive(new, Debug, Clone, PartialEq)]
38pub struct WmmaExecute<D: Dialect> {
39    frag_a: Fragment<D>,
40    frag_b: Fragment<D>,
41    frag_c: Fragment<D>,
42    frag_d: Fragment<D>,
43}
44
45#[derive(new, Debug, Clone, PartialEq)]
46pub struct WmmaCast<D: Dialect> {
47    frag_input: Fragment<D>,
48    frag_output: Fragment<D>,
49}
50
51impl<D: Dialect> WmmaFill<D> {
52    pub fn fn_name(&self) -> String {
53        let layout = frag_layout_str(&self.frag.layout);
54        let ident = frag_ident_str(&self.frag.ident);
55        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
56        let elem = self.frag.elem;
57
58        format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
59    }
60
61    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62        let elem = self.frag.elem;
63        let frag = self.frag;
64        let name = self.fn_name();
65
66        write!(
67            f,
68            "
69// Fill the fragment.
70__device__ void {name}({frag}& frag, {elem} value) {{
71    #pragma unroll
72    for (uint i = 0; i < 8; ++i) {{
73      frag[i] = value;
74    }}
75}}
76        "
77        )
78    }
79}
80
81impl<D: Dialect> WmmaLoad<D> {
82    pub fn fn_name(&self) -> String {
83        let layout_frag = frag_layout_str(&self.frag.layout);
84        let layout = frag_layout_str(&self.layout);
85        let ident = frag_ident_str(&self.frag.ident);
86        let elem = self.frag.elem;
87        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
88
89        format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
90    }
91
92    /// Matrix A must be in column major layout (so fragments correspond to a row)
93    /// Matrices B, C and D must be in row major layout (so fragments correspond to a column)
94    ///
95    /// Each lane is a thread so each column get 8 VGPRs used to store fragments
96    /// Here is the layout for C and D matrices and how they map to registers
97    ///
98    /// Lane index   0      1      2      3      ...     13     14     15     ...     17     18     ...     30     31
99    /// --------------------------------------------------------------------------------------------------------------
100    /// VGPR0      | 1,1  | 1,2  | 1,3  | 1,4  | ...  | 1,13 | 1,14 | 1,15 | ...  | 2,1  | 2,2  | ...  | 2,15 | 2,16 |
101    /// --------------------------------------------------------------------------------------------------------------
102    /// VGPR1      | 3,1  | 3,2  | 3,3  | 3,4  | ...  | 3,13 | 3,14 | 3,15 | ...  | 4,1  | 4,2  | ...  | 4,15 | 4,16 |
103    /// --------------------------------------------------------------------------------------------------------------
104    /// VGPR2      | 5,1  | 5,2  | 5,3  | 5,4  | ...  | 5,13 | 5,14 | 5,15 | ...  | 6,1  | 6,2  | ...  | 6,15 | 6,16 |
105    /// --------------------------------------------------------------------------------------------------------------
106    /// VGPR3      | 7,1  | 7,2  | 7,3  | 7,4  | ...  | 7,13 | 7,14 | 7,15 | ...  | 8,1  | 8,2  | ...  | 8,15 | 8,16 |
107    /// --------------------------------------------------------------------------------------------------------------
108    /// VGPR4      | 9,1  | 9,2  | 9,3  | 9,4  | ...  | 9,13 | 9,14 | 9,15 | ...  | 10,1 | 10,2 | ...  | 10,15| 10,16|
109    /// --------------------------------------------------------------------------------------------------------------
110    /// VGPR5      | 11,1 | 11,2 | 11,3 | 11,4 | ...  | 11,13| 11,14| 11,15| ...  | 12,1 | 12,2 | ...  | 12,15| 12,16|
111    /// --------------------------------------------------------------------------------------------------------------
112    /// VGPR6      | 13,1 | 13,2 | 13,3 | 13,4 | ...  | 13,13| 13,14| 13,15| ...  | 14,1 | 14,2 | ...  | 14,15| 14,16|
113    /// --------------------------------------------------------------------------------------------------------------
114    /// VGPR7      | 15,1 | 15,2 | 15,3 | 15,4 | ...  | 15,13| 15,14| 15,15| ...  | 16,1 | 16,2 | ...  | 16,15| 16,16|
115    /// --------------------------------------------------------------------------------------------------------------
116    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
117        let elem = self.frag.elem;
118        let frag = self.frag;
119        let name = self.fn_name();
120
121        let (index_body, length, step) = match frag.ident {
122            FragmentIdent::A | FragmentIdent::B => {
123                let length = 16;
124                let step = 1;
125                // fragment a and b are always in half precision and they don't require special attention
126                // to how they are stored in memory as matrix A and B are also in half precision
127                let index = if (frag.ident == FragmentIdent::A
128                    && frag.layout.unwrap() == FragmentLayout::ColMajor)
129                    || (frag.ident == FragmentIdent::B
130                        && frag.layout.unwrap() == FragmentLayout::RowMajor)
131                {
132                    "i * stride + wmmaLane".to_string()
133                } else {
134                    "i + wmmaLane * stride".to_string()
135                };
136                (index, length, step)
137            }
138            FragmentIdent::Accumulator => {
139                let length = 8;
140                let step = get_output_accumulator_index_step(&elem, &frag);
141                let index = match self.layout {
142                    Some(FragmentLayout::ColMajor) => {
143                        "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
144                    }
145                    Some(FragmentLayout::RowMajor) => {
146                        "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
147                    }
148                    _ => panic!(
149                        "cannot load data to an accumulator without knowing the layout of the data"
150                    ),
151                };
152                (index, length, step)
153            }
154            other => panic!("unknown matrix identifier {other}"),
155        };
156
157        write!(
158            f,
159            "
160// Load the fragment.
161__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
162    {WMMA_LANE_DEF}
163
164    #pragma unroll
165    for (uint i = 0; i < {length}; ++i) {{
166      const uint index = {index_body};
167      frag[i * {step}] = value_ptr[index];
168    }}
169}}
170        "
171        )
172    }
173}
174
175impl<D: Dialect> WmmaStore<D> {
176    pub fn fn_name(&self) -> String {
177        let layout_frag = frag_layout_str(&self.frag.layout);
178        let layout_option = Some(self.layout);
179        let layout = frag_layout_str(&layout_option);
180        let ident = frag_ident_str(&self.frag.ident);
181        let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
182        let elem = self.frag.elem;
183
184        format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
185    }
186
187    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
188        let elem = self.frag.elem;
189        let frag = self.frag;
190        let name = self.fn_name();
191        // frag holds a result column where threads 0-15 of the wavefront have the even rows and threads 16-31 the odd rows
192        // moreover, since we use OPSEL to false in the Execute instruction in f16 output format, the output elements are
193        // stored in even indexes (0, 2, 4, ...) (low 16-bits of the VGPR) in frag
194        let frag_idx = match elem {
195            Elem::F16 | Elem::BF16 => "elemIdx * 2",
196            Elem::F32 => "elemIdx",
197            other => {
198                panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
199            }
200        };
201        // FragmentLayout here represents the desired layout of the matrix C
202        let output_idx = match self.layout {
203            FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
204            FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
205            FragmentLayout::_Dialect(_) => String::new(),
206        };
207
208        write!(
209            f,
210            "
211// Store the fragment.
212__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
213    {WMMA_LANE_DEF}
214
215    #pragma unroll
216    for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
217      const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
218      output_ptr[{output_idx}] = frag[{frag_idx}];
219    }}
220}}
221        "
222        )
223    }
224}
225
226impl<D: Dialect> WmmaExecute<D> {
227    pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
228        let frag_a = Fragment {
229            ident: FragmentIdent::A,
230            m: shape.m,
231            n: shape.n,
232            k: shape.k,
233            elem: ab_elem,
234            layout: Some(FragmentLayout::ColMajor),
235        };
236        let frag_b = Fragment {
237            ident: FragmentIdent::B,
238            layout: Some(FragmentLayout::RowMajor),
239            ..frag_a
240        };
241        let frag_cd = Fragment {
242            ident: FragmentIdent::Accumulator,
243            elem: cd_elem,
244            ..frag_b
245        };
246        WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
247    }
248
249    pub fn fn_name(&self) -> String {
250        format!(
251            "wmma_execute_16x16x16_{}_{}",
252            self.frag_a.elem, self.frag_c.elem
253        )
254    }
255
256    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
257        let name = self.fn_name();
258        let ab_format = match self.frag_a.elem {
259            Elem::F32 => "f32",
260            Elem::BF16 => "bf16",
261            Elem::F16 => "f16",
262            _ => panic!(),
263        };
264        let (cd_format, opsel) = match self.frag_c.elem {
265            Elem::F32 => ("f32", ""),
266            Elem::BF16 => ("bf16", ", false"),
267            Elem::F16 => ("f16", ", false"),
268            _ => panic!(),
269        };
270        let warp_size = 32;
271        write!(
272            f,
273            "
274// Execute wmma.
275__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
276    frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
277}}
278        ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
279        )
280    }
281}
282
283impl<D: Dialect> WmmaCast<D> {
284    pub fn fn_name(&self) -> String {
285        let layout = frag_layout_str(&self.frag_input.layout);
286        let ident = frag_ident_str(&self.frag_input.ident);
287        let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
288        let elem = self.frag_input.elem;
289        let elem_out = self.frag_output.elem;
290
291        format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
292    }
293
294    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
295        let input = self.frag_input;
296        let output = self.frag_output;
297        let name = self.fn_name();
298        let step = match output.ident {
299            FragmentIdent::Accumulator => {
300                get_output_accumulator_index_step(&self.frag_input.elem, &output)
301            }
302            _ => 1,
303        };
304
305        write!(
306            f,
307            "
308// Cast the fragment.
309__device__ void {name}({input}& input, {output}& output) {{
310    #pragma unroll
311    for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
312      output[elemIdx * {step}] = input[elemIdx];
313    }}
314}}
315        "
316        )
317    }
318}
319
320impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
321    fn compile_wmma_type_definitions(
322        f: &mut std::fmt::Formatter<'_>,
323        flags: &Flags,
324    ) -> std::fmt::Result {
325        if flags.elem_bf16 {
326            f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
327            f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
328        }
329        if flags.elem_f16 {
330            f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
331            f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
332        }
333        f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
334    }
335
336    fn compile_wmma_fragment_declaration(
337        f: &mut std::fmt::Formatter<'_>,
338        var: &crate::shared::Variable<HipDialect<Self>>,
339    ) -> std::fmt::Result {
340        wmma_api_base::compile_fragment_declaration(f, var)
341    }
342
343    fn compile_wmma_fragment(
344        f: &mut std::fmt::Formatter<'_>,
345        fragment: &Fragment<HipDialect<Self>>,
346    ) -> std::fmt::Result {
347        match fragment.ident {
348            FragmentIdent::A | FragmentIdent::B => match fragment.elem {
349                Elem::F16 => write!(f, "half16_t"),
350                Elem::BF16 => write!(f, "bhalf16_t"),
351                other => panic!("unsupported type {other} for {fragment}"),
352            },
353            FragmentIdent::Accumulator => match fragment.elem {
354                Elem::F16 => write!(f, "half16_t"),
355                Elem::BF16 => write!(f, "bhalf16_t"),
356                Elem::F32 => write!(f, "float8_t"),
357                other => panic!("unsupported type {other} for {fragment}"),
358            },
359            FragmentIdent::_Dialect(_) => Ok(()),
360        }
361    }
362
363    fn compile_wmma_instruction(
364        f: &mut std::fmt::Formatter<'_>,
365        instruction: &WmmaInstruction<HipDialect<Self>>,
366    ) -> std::fmt::Result {
367        match instruction {
368            WmmaInstruction::Fill { frag, value } => {
369                let extension = WmmaFill::new(match frag {
370                    Variable::WmmaFragment { frag, .. } => *frag,
371                    _ => panic!(),
372                });
373                let name = extension.fn_name();
374                writeln!(f, "{name}({frag}, {value});")
375            }
376            WmmaInstruction::Load {
377                frag,
378                value,
379                layout,
380                offset,
381                stride,
382            } => {
383                let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
384                let name = extension.fn_name();
385                let value_ptr = frag_as_ptr(f, value, offset);
386                writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
387            }
388            WmmaInstruction::Execute {
389                frag_a,
390                frag_b,
391                frag_c,
392                frag_d,
393                warp_size,
394            } => {
395                assert_eq!(*warp_size, 32, "Only warp size of 32 supported");
396
397                let extension = WmmaExecute::new(
398                    variable_to_frag(frag_a),
399                    variable_to_frag(frag_b),
400                    variable_to_frag(frag_c),
401                    variable_to_frag(frag_d),
402                );
403                let name = extension.fn_name();
404                writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
405            }
406            WmmaInstruction::ExecuteManual {
407                shape,
408                frag_a,
409                frag_b,
410                frag_c,
411                frag_d,
412            } => {
413                Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
414            }
415            WmmaInstruction::ExecuteScaled {
416                shape,
417                frag_a,
418                frag_b,
419                frag_c,
420                frag_d,
421                scales_a,
422                scales_b,
423                scales_factor,
424            } => Self::compile_scaled_mma(
425                f,
426                ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
427                *scales_a,
428                *scales_b,
429                *scales_factor,
430            ),
431            WmmaInstruction::Store {
432                output,
433                frag,
434                layout,
435                offset,
436                stride,
437            } => {
438                let extension = WmmaStore::new(variable_to_frag(frag), *layout);
439                let name = extension.fn_name();
440                let output_ptr = frag_as_ptr(f, output, offset);
441                writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
442            }
443            WmmaInstruction::Cast { input, output } => {
444                let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
445                let name = extension.fn_name();
446                writeln!(f, "{name}({input}, {output});")
447            }
448        }
449    }
450
451    fn compile_manual_mma(
452        f: &mut std::fmt::Formatter<'_>,
453        mma: ManualMma<HipDialect<Self>>,
454    ) -> std::fmt::Result {
455        compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
456    }
457
458    fn compile_scaled_mma(
459        _f: &mut std::fmt::Formatter<'_>,
460        _mma: ManualMma<HipDialect<Self>>,
461        _scales_a: Variable<HipDialect<Self>>,
462        _scales_b: Variable<HipDialect<Self>>,
463        _scales_factor: u32,
464    ) -> std::fmt::Result {
465        unimplemented!("Not supported in HIP")
466    }
467
468    fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
469        // Reference: https://gpuopen.com/learn/wmma_on_rdna3/
470        let mut result: SupportedMmaCombinations = vec![];
471        if arch.is_wmma_capable() {
472            // Types fully supported.
473            let types = vec![
474                (
475                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
476                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
477                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
478                ),
479                (
480                    gpu::ElemType::Float(gpu::FloatKind::F16),
481                    gpu::ElemType::Float(gpu::FloatKind::F16),
482                    gpu::ElemType::Float(gpu::FloatKind::F32),
483                ),
484                (
485                    gpu::ElemType::Float(gpu::FloatKind::BF16),
486                    gpu::ElemType::Float(gpu::FloatKind::BF16),
487                    gpu::ElemType::Float(gpu::FloatKind::F32),
488                ),
489            ];
490            let combinations: SupportedMmaCombinations = types
491                .into_iter()
492                .map(|(a, b, c)| MmaConfig {
493                    a_type: a.into(),
494                    b_type: b.into(),
495                    cd_type: c.into(),
496                    m: 16,
497                    n: 16,
498                    k: 16,
499                })
500                .collect();
501            result.extend(combinations);
502        }
503        result
504    }
505
506    fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
507        supported_mma_combinations(arch)
508    }
509}
510
511fn get_output_accumulator_index_step<D: Dialect>(
512    input_elem: &Elem<D>,
513    output: &Fragment<D>,
514) -> u32 {
515    // Each VGPR is 32 bit wide and there is 8 VGPR per lane, an accumulator can then be either:
516    // - a vector of 8 float
517    // - a vector of 16 half
518    // Depending on the precision used for the input, the whole 32 bits per register will be used or
519    // just only 16 bits. In such a case we always use the lower 16 bits (opsel set to false) which means
520    // that we only assign values to even indexes of the accumulator (0, 2, 4, ...)
521
522    assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
523
524    match input_elem {
525        Elem::F16 | Elem::BF16 | Elem::F32 => {
526            match output.elem {
527                // loading into accumulator of 16 half precision
528                Elem::F16 | Elem::BF16 => 2,
529                // loading into accumulator of 8 full precision
530                Elem::F32 => 1,
531                other => panic!("unsupported format {other} for {output}"),
532            }
533        }
534        other => panic!("unsupported format {other} for {input_elem}"),
535    }
536}
537
538pub(super) fn compile_manual_mma<D: Dialect>(
539    f: &mut std::fmt::Formatter<'_>,
540    shape: MmaShape<D>,
541    frag_a: &[Variable<D>],
542    frag_b: &[Variable<D>],
543    frag_c: &[Variable<D>],
544    frag_d: &Variable<D>,
545) -> std::fmt::Result {
546    let extension = WmmaExecute::from_manual(shape, frag_a[0].elem(), frag_c[0].elem());
547    let frag_d_len = frag_c.len();
548
549    let frag_a = comma_separated(
550        frag_a
551            .iter()
552            .flat_map(|it| (0..it.item().vectorization).map(|i| it.index(i)))
553            .map(|it| format!("{it}")),
554    );
555    let frag_b = comma_separated(
556        frag_b
557            .iter()
558            .flat_map(|it| (0..it.item().vectorization).map(|i| it.index(i)))
559            .map(|it| format!("{it}")),
560    );
561    let frag_c = comma_separated(
562        frag_c
563            .iter()
564            .flat_map(|it| (0..it.item().vectorization).map(|i| it.index(i)))
565            .map(|it| format!("{it}")),
566    );
567
568    // Item is irrelevant
569    let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
570
571    let name = extension.fn_name();
572    writeln!(f, "{ty} {frag_d_tmp} = {ty}{{}};", ty = extension.frag_d)?;
573    writeln!(
574        f,
575        "{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
576        extension.frag_a, extension.frag_b, extension.frag_c
577    )?;
578
579    for i in 0..frag_d_len {
580        writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i}];")?;
581    }
582
583    Ok(())
584}
585
586pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
587    // Reference: https://gpuopen.com/learn/wmma_on_rdna3/
588    // Feel free to add more if additional intrinsics are supported for execute
589    let mut result: SupportedMmaCombinations = vec![];
590    if arch.is_wmma_capable() {
591        // Types fully supported.
592        let types = vec![
593            (
594                gpu::ElemType::Float(gpu::FloatKind::F16),
595                gpu::ElemType::Float(gpu::FloatKind::F32),
596            ),
597            (
598                gpu::ElemType::Float(gpu::FloatKind::BF16),
599                gpu::ElemType::Float(gpu::FloatKind::F32),
600            ),
601        ];
602        let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
603            a_type: ab_elem.into(),
604            b_type: ab_elem.into(),
605            cd_type: cd_elem.into(),
606            m: 16,
607            n: 16,
608            k: 16,
609        });
610        result.extend(combinations);
611    }
612    result
613}
614
615// threads 0-15 and threads 16-31 of the wavefront hold the same fragments respectively
616// in other words fragments are duplicated
617// so lanes 0,16 / 1,17 / ... / 15, 31 are the same
618static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";