cubecl_cpp/cuda/
dialect.rs

1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
4
5use crate::{
6    Dialect,
7    cuda::{
8        extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension, StMatrix},
9        processors::CudaMmaProcessor,
10        ptx::*,
11    },
12    shared::{
13        self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
14        DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
15        DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
16        Variable, WarpInstruction, unary,
17    },
18};
19
20use super::{Extension, arch::CudaArchitecture};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaDialect<M> {
24    _wmma_compiler: PhantomData<M>,
25}
26
27impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
28    type Architecture = CudaArchitecture;
29}
30
31impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
32    type Extension = Extension<Self>;
33
34    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
35        f.write_str("#include <cuda_runtime.h>\n")?;
36        if flags.elem_fp4 {
37            f.write_str("#include <cuda_fp4.h>\n")?;
38        }
39        if flags.elem_fp6 {
40            f.write_str("#include <cuda_fp6.h>\n")?;
41        }
42        if flags.elem_fp8 {
43            f.write_str("#include <cuda_fp8.h>\n")?;
44        }
45        if flags.elem_bf16 {
46            f.write_str("#include <cuda_bf16.h>\n")?;
47        }
48        if flags.elem_f16 {
49            f.write_str("#include <cuda_fp16.h>\n")?;
50        }
51
52        // tf32 conversion function is in mma header
53        if flags.inst_wmma || flags.elem_tf32 {
54            Self::compile_wmma_includes(f, flags)?;
55        }
56
57        if flags.op_pipeline {
58            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
59            f.write_str("#include <cuda/pipeline>\n")?;
60        }
61        if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
62            f.write_str("#include <cooperative_groups.h>\n")?;
63            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
64            f.write_str("#include <cuda/barrier>\n")?;
65        }
66        if flags.inst_ptx_wrappers {
67            f.write_str("#include <cuda/ptx>\n")?;
68        }
69        if flags.inst_tma {
70            f.write_str(
71                "typedef struct CUtensorMap_st {
72alignas(64) unsigned long long int opaque[16];
73} CUtensorMap;\n",
74            )?;
75        }
76        Ok(())
77    }
78
79    fn compile_extensions(
80        f: &mut std::fmt::Formatter<'_>,
81        extensions: &[Self::Extension],
82    ) -> std::fmt::Result {
83        for extension in extensions {
84            match extension {
85                Extension::NoExtension => {}
86                Extension::Mma(mma) => mma.format_extension(f)?,
87            }
88        }
89        Ok(())
90    }
91
92    fn register_instruction_extension(
93        _extensions: &mut Vec<Self::Extension>,
94        _instruction: &Instruction<Self>,
95    ) {
96    }
97
98    fn register_warp_instruction_extension(
99        _extensions: &mut Vec<Self::Extension>,
100        _instruction: &WarpInstruction<Self>,
101    ) {
102    }
103
104    fn register_wmma_instruction_extension(
105        extensions: &mut Vec<Self::Extension>,
106        instruction: &shared::WmmaInstruction<Self>,
107    ) {
108        match instruction {
109            shared::WmmaInstruction::ExecuteManual {
110                shape,
111                frag_a,
112                frag_b,
113                frag_c,
114                frag_d,
115            } => {
116                let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
117                    *shape,
118                    Fragment(frag_a.elem()),
119                    Fragment(frag_b.elem()),
120                    Fragment(frag_c.elem()),
121                    Fragment(frag_d.elem()),
122                )));
123                if !extensions.contains(&ext) {
124                    extensions.push(ext);
125                }
126            }
127            shared::WmmaInstruction::ExecuteScaled {
128                shape,
129                frag_a,
130                frag_b,
131                frag_c,
132                frag_d,
133                scales_a,
134                scales_factor,
135                ..
136            } => {
137                let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
138                    *shape,
139                    Fragment(frag_a.elem()),
140                    Fragment(frag_b.elem()),
141                    Fragment(frag_c.elem()),
142                    Fragment(frag_d.elem()),
143                    scales_a.elem(),
144                    *scales_factor,
145                )));
146                if !extensions.contains(&ext) {
147                    extensions.push(ext);
148                }
149            }
150            shared::WmmaInstruction::LdMatrix {
151                output,
152                factor,
153                transpose,
154                ..
155            } => {
156                let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
157                    output.elem(),
158                    *factor,
159                    *transpose,
160                )));
161                if !extensions.contains(&ext) {
162                    extensions.push(ext);
163                }
164            }
165            shared::WmmaInstruction::StMatrix {
166                registers,
167                factor,
168                transpose,
169                ..
170            } => {
171                let ext = Extension::Mma(MmaExtension::StMatrix(StMatrix::new(
172                    registers.elem(),
173                    *factor,
174                    *transpose,
175                )));
176                if !extensions.contains(&ext) {
177                    extensions.push(ext);
178                }
179            }
180            _ => {}
181        }
182    }
183}
184
185// Types
186
187impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
188    fn item_can_be_optimized() -> bool {
189        true
190    }
191
192    fn compile_type_definitions(
193        f: &mut std::fmt::Formatter<'_>,
194        items: &HashSet<Item<Self>>,
195        scalars: &[(Elem<Self>, usize)],
196        flags: &Flags,
197    ) -> std::fmt::Result {
198        // All FP4/FP6/FP8 elems map to the same type, so we need to deduplicate them
199        let mut items_deduplicated = HashSet::new();
200
201        for item in items {
202            let mut item = *item;
203            match item.elem() {
204                Elem::FP4(_) => {
205                    item.elem = Elem::FP4(FP4Kind::E2M1);
206                }
207                Elem::FP4x2(_) => {
208                    item.elem = Elem::FP4x2(FP4Kind::E2M1);
209                }
210                Elem::FP6(_) => {
211                    item.elem = Elem::FP6(FP6Kind::E2M3);
212                }
213                Elem::FP6x2(_) => {
214                    item.elem = Elem::FP6x2(FP6Kind::E2M3);
215                }
216                Elem::FP8(_) => {
217                    item.elem = Elem::FP8(FP8Kind::E4M3);
218                }
219                Elem::FP8x2(_) => {
220                    item.elem = Elem::FP8x2(FP8Kind::E4M3);
221                }
222                _ => {}
223            }
224            items_deduplicated.insert(item);
225        }
226
227        shared::type_definitions::<Self>(f)?;
228        shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
229
230        if flags.use_grid_constants {
231            shared::type_scalar_definitions::<Self>(f, scalars)?;
232            shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
233        }
234
235        if flags.inst_wmma {
236            Self::compile_wmma_type_definitions(f, flags)?;
237        }
238
239        Ok(())
240    }
241
242    fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
243        if flags.inst_tma_im2col {
244            writeln!(f, "{TMA_LOAD_IM2COL}")?;
245        }
246        Ok(())
247    }
248
249    fn compile_elem(
250        f: &mut std::fmt::Formatter<'_>,
251        elem: &shared::Elem<Self>,
252        words: bool,
253    ) -> std::fmt::Result {
254        if words {
255            match elem {
256                shared::Elem::F32 => f.write_str("float"),
257                shared::Elem::F64 => f.write_str("double"),
258                shared::Elem::TF32 => f.write_str("float"),
259                shared::Elem::I8 => f.write_str("char"),
260                shared::Elem::I16 => f.write_str("short"),
261                shared::Elem::I32 => f.write_str("int"),
262                shared::Elem::I64 => f.write_str("long"),
263                shared::Elem::U8 => f.write_str("uchar"),
264                shared::Elem::U16 => f.write_str("ushort"),
265                shared::Elem::U32 => f.write_str("uint"),
266                shared::Elem::U64 => f.write_str("ulong"),
267                _ => Self::compile_elem(f, elem, false),
268            }
269        } else {
270            match elem {
271                shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
272                shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
273                shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
274                shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
275                shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
276                shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
277                shared::Elem::F16 => f.write_str("__half"),
278                shared::Elem::F16x2 => f.write_str("__half2"),
279                shared::Elem::F32 => f.write_str("float"),
280                shared::Elem::F64 => f.write_str("double"),
281                shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
282                shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
283                shared::Elem::TF32 => f.write_str("float"),
284                shared::Elem::I8 => f.write_str("int8"),
285                shared::Elem::I16 => f.write_str("int16"),
286                shared::Elem::I32 => f.write_str("int32"),
287                shared::Elem::I64 => f.write_str("int64"),
288                shared::Elem::U8 => f.write_str("uint8"),
289                shared::Elem::U16 => f.write_str("uint16"),
290                shared::Elem::U32 => f.write_str("uint32"),
291                shared::Elem::U64 => f.write_str("uint64"),
292                shared::Elem::Bool => f.write_str("bool"),
293                shared::Elem::Atomic(inner) => write!(f, "{inner}"),
294                shared::Elem::_Dialect(_) => Ok(()),
295            }
296        }
297    }
298
299    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
300        if 1 == item.vectorization {
301            return write!(f, "{}", item.elem);
302        }
303        if item.native {
304            // native types use the word form of types only
305            Self::compile_elem(f, &item.elem, true)?;
306            write!(f, "{}", item.vectorization)
307        } else {
308            write!(f, "{}_{}", item.elem, item.vectorization)
309        }
310    }
311
312    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        Ok(())
314    }
315}
316
317// Kernel argument bindings
318
319impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
320    fn compile_kernel_signature(
321        f: &mut std::fmt::Formatter<'_>,
322        kernel_name: &str,
323        tensor_maps: &[Binding<Self>],
324        buffers: &[Binding<Self>],
325        scalars: &[(Elem<Self>, usize)],
326        flags: &Flags,
327    ) -> std::fmt::Result {
328        write!(
329            f,
330            "
331
332extern \"C\" __global__ void __launch_bounds__({})",
333            flags.cube_dim.num_elems()
334        )?;
335        if let Some(cluster_dim) = flags.cluster_dim {
336            write!(
337                f,
338                "__cluster_dims__({}, {}, {}) ",
339                cluster_dim.x, cluster_dim.y, cluster_dim.z
340            )?;
341        }
342        writeln!(f, "{kernel_name} (")?;
343        let has_scalars =
344            !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
345        shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
346        if flags.use_grid_constants {
347            shared::compile_scalars_static(f, scalars, flags)?;
348        } else {
349            shared::compile_scalars_dynamic(f, scalars)?;
350        }
351        f.write_str("\n)")?;
352        //
353        Ok(())
354    }
355
356    fn compile_bindings_body(
357        f: &mut std::fmt::Formatter<'_>,
358        body: &shared::Body<Self>,
359    ) -> std::fmt::Result {
360        if !body.shared_memories.is_empty() {
361            let max_align = body
362                .shared_memories
363                .iter()
364                .map(|smem| smem.align)
365                .max()
366                .unwrap();
367            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
368            // with `extern __shared__ alignas` and doesn't properly parse it.
369            writeln!(
370                f,
371                "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
372            )?;
373        }
374        Ok(())
375    }
376}
377
378impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
379
380// Cube builtins dialect
381
382impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
383    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        write!(f, "cluster.block_rank()")
385    }
386
387    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        write!(f, "cluster.block_index().x")
389    }
390
391    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        write!(f, "cluster.block_index().y")
393    }
394
395    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        write!(f, "cluster.block_index().z")
397    }
398}
399
400// Instructions
401
402impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
403    // sync
404    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405        writeln!(f, "__syncthreads();\n")
406    }
407
408    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409        writeln!(f, "__syncwarp();\n")
410    }
411
412    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        writeln!(f, "__threadfence();")
414    }
415
416    // unary
417    fn compile_instruction_find_first_set<T: Component<Self>>(
418        f: &mut std::fmt::Formatter<'_>,
419        input: T,
420        out_elem: Elem<Self>,
421    ) -> std::fmt::Result {
422        write!(f, "{out_elem}(")?;
423        match input.elem() {
424            Elem::I32 => write!(f, "__ffs({input})"),
425            Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
426            Elem::I64 => write!(f, "__ffsll({input})"),
427            Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
428            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
429        }?;
430        write!(f, ")")
431    }
432
433    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
434        f: &mut std::fmt::Formatter<'_>,
435        input: T,
436        out_elem: Elem<Self>,
437    ) -> std::fmt::Result {
438        write!(f, "{out_elem}(")?;
439        match input.elem() {
440            Elem::I32 => write!(f, "__clz({input})"),
441            Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
442            Elem::I64 => write!(f, "__clzll({input})"),
443            Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
444            in_elem => write!(
445                f,
446                "{out_elem}(__clz({}) - {})",
447                unary::zero_extend(input),
448                (size_of::<u32>() - in_elem.size()) * 8
449            ),
450        }?;
451        write!(f, ")")
452    }
453
454    fn compile_saturating_add(
455        f: &mut std::fmt::Formatter<'_>,
456        lhs: impl Display,
457        rhs: impl Display,
458        item: Item<Self>,
459    ) -> std::fmt::Result {
460        let elem = item.elem();
461        match elem {
462            Elem::I32 => {
463                write!(
464                    f,
465                    r#"[&]() -> {elem} {{
466    {elem} result;
467    asm("add.sat.s32 %0, %1, %2;"
468        : "=r"(result)
469        : "r"({lhs}), "r"({rhs}));
470    return result;
471        }}()"#
472                )
473            }
474            _ => unreachable!("Should be replaced by polyfill"),
475        }
476    }
477
478    fn compile_saturating_sub(
479        f: &mut std::fmt::Formatter<'_>,
480        lhs: impl Display,
481        rhs: impl Display,
482        item: Item<Self>,
483    ) -> std::fmt::Result {
484        let elem = item.elem();
485        // Native instruction only exists for signed int, unsigned should be removed in a preprocessor
486        match elem {
487            Elem::I32 => {
488                write!(
489                    f,
490                    r#"[&]() -> {elem} {{
491    {elem} result;
492    asm("sub.sat.s32 %0, %1, %2;"
493        : "=r"(result)
494        : "r"({lhs}), "r"({rhs}));
495    return result;
496        }}()"#
497                )
498            }
499            _ => unreachable!("Should be replaced by polyfill"),
500        }
501    }
502
503    // others
504    fn compile_instruction_max_function_name(
505        f: &mut std::fmt::Formatter<'_>,
506        item: Item<Self>,
507    ) -> std::fmt::Result {
508        let max = match item.elem() {
509            Elem::F16 | Elem::BF16 => "__hmax",
510            Elem::F16x2 | Elem::BF16x2 => "__hmax2",
511            _ => "max",
512        };
513        write!(f, "{max}")
514    }
515
516    fn compile_instruction_min_function_name(
517        f: &mut std::fmt::Formatter<'_>,
518        item: Item<Self>,
519    ) -> std::fmt::Result {
520        let min = match item.elem() {
521            Elem::F16 | Elem::BF16 => "__hmin",
522            Elem::F16x2 | Elem::BF16x2 => "__hmin2",
523            _ => "min",
524        };
525        write!(f, "{min}")
526    }
527
528    // warp
529    fn compile_warp_shuffle(
530        f: &mut std::fmt::Formatter<'_>,
531        var: &str,
532        source: &str,
533    ) -> std::fmt::Result {
534        write!(f, "__shfl_sync(-1, {var}, {source})")
535    }
536    fn compile_warp_shuffle_xor(
537        f: &mut std::fmt::Formatter<'_>,
538        var: &str,
539        _elem: &Elem<Self>,
540        offset: &str,
541    ) -> std::fmt::Result {
542        write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
543    }
544    fn compile_warp_shuffle_up(
545        f: &mut std::fmt::Formatter<'_>,
546        var: &str,
547        offset: &str,
548    ) -> std::fmt::Result {
549        write!(f, "__shfl_up_sync(-1, {var}, {offset})")
550    }
551    fn compile_warp_shuffle_down(
552        f: &mut std::fmt::Formatter<'_>,
553        var: &str,
554        offset: &str,
555    ) -> std::fmt::Result {
556        write!(f, "__shfl_down_sync(-1, {var}, {offset})")
557    }
558    fn compile_warp_all<T: Component<Self>>(
559        f: &mut std::fmt::Formatter<'_>,
560        input: &T,
561    ) -> std::fmt::Result {
562        write!(f, "__all_sync(-1, {input})")
563    }
564    fn compile_warp_any<T: Component<Self>>(
565        f: &mut std::fmt::Formatter<'_>,
566        input: &T,
567    ) -> std::fmt::Result {
568        write!(f, "__any_sync(-1, {input})")
569    }
570
571    fn compile_warp_ballot(
572        f: &mut std::fmt::Formatter<'_>,
573        input: &Variable<Self>,
574        _out_elem: &Elem<Self>,
575    ) -> std::fmt::Result {
576        write!(f, "__ballot_sync(-1, {input})")
577    }
578
579    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
580        let elem = Elem::<Self>::Bool;
581        let uint32 = Elem::<Self>::U32;
582        // Used to have a wrapper but it has been removed in newer version due to being
583        // "incomplete". We only need the predicate and have a fixed mask, so it's trivial to
584        // implement.
585        writeln!(
586            f,
587            r#"{out} = {elem}([&]() -> {uint32} {{
588    {uint32} pred = 0;
589    asm volatile(
590        "{{\n"
591        "     .reg .pred %%px;\n"
592        "     elect.sync _|%%px, 0xffffffff;\n"
593        "     selp.b32 %0, 1, 0, %%px;\n"
594        "}}\n"
595        : "+r"(pred));
596    return pred;
597        }}());"#
598        )
599    }
600}
601
602// Coop Matrices dialect
603
604impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
605    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
606        M::compile_wmma_includes(f, flags)
607    }
608
609    fn compile_wmma_type_definitions(
610        f: &mut std::fmt::Formatter<'_>,
611        flags: &Flags,
612    ) -> std::fmt::Result {
613        M::compile_wmma_type_definitions(f, flags)
614    }
615
616    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617        M::compile_wmma_local_variables(f)
618    }
619
620    fn compile_wmma_fragment_declaration(
621        f: &mut std::fmt::Formatter<'_>,
622        var: &Variable<Self>,
623    ) -> std::fmt::Result {
624        M::compile_wmma_fragment_declaration(f, var)
625    }
626
627    fn compile_wwma_fragment_ident(
628        f: &mut std::fmt::Formatter<'_>,
629        ident: &crate::shared::FragmentIdent<Self>,
630    ) -> std::fmt::Result {
631        M::compile_wwma_fragment_ident(f, ident)
632    }
633
634    fn compile_wmma_fragment_layout(
635        f: &mut std::fmt::Formatter<'_>,
636        layout: &crate::shared::FragmentLayout<Self>,
637    ) -> std::fmt::Result {
638        M::compile_wmma_fragment_layout(f, layout)
639    }
640
641    fn compile_wmma_fragment(
642        f: &mut std::fmt::Formatter<'_>,
643        fragment: &crate::shared::Fragment<Self>,
644    ) -> std::fmt::Result {
645        M::compile_wmma_fragment(f, fragment)
646    }
647
648    fn compile_wmma_instruction(
649        f: &mut std::fmt::Formatter<'_>,
650        instruction: &crate::shared::WmmaInstruction<Self>,
651    ) -> std::fmt::Result {
652        M::compile_wmma_instruction(f, instruction)
653    }
654
655    fn compile_manual_mma(
656        f: &mut std::fmt::Formatter<'_>,
657        mma: ManualMma<Self>,
658    ) -> std::fmt::Result {
659        M::compile_manual_mma(f, mma)
660    }
661
662    fn compile_scaled_mma(
663        f: &mut std::fmt::Formatter<'_>,
664        mma: ManualMma<Self>,
665        scales_a: Variable<Self>,
666        scales_b: Variable<Self>,
667        scales_factor: u32,
668    ) -> std::fmt::Result {
669        M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
670    }
671
672    fn supported_wmma_combinations(
673        arch: &CudaArchitecture,
674    ) -> crate::shared::SupportedMmaCombinations {
675        M::supported_wmma_combinations(arch)
676    }
677
678    fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
679        M::supported_mma_combinations(arch)
680    }
681
682    fn supported_scaled_mma_combinations(
683        arch: &CudaArchitecture,
684    ) -> shared::SupportedScaledMmaCombinations {
685        M::supported_scaled_mma_combinations(arch)
686    }
687}
688
689impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
690    fn processors() -> Vec<Box<dyn Processor>> {
691        vec![
692            Box::new(CudaMmaProcessor),
693            Box::new(SaturatingArithmeticProcessor::new(false)),
694        ]
695    }
696}