cubecl_cpp/cuda/
dialect.rs

1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
4
5use crate::{
6    Dialect,
7    cuda::{
8        extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension},
9        processors::CudaMmaProcessor,
10        ptx::*,
11    },
12    shared::{
13        self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
14        DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
15        DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
16        Variable, WarpInstruction, unary,
17    },
18};
19
20use super::{Extension, arch::CudaArchitecture};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaDialect<M> {
24    _wmma_compiler: PhantomData<M>,
25}
26
27impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
28    type Architecture = CudaArchitecture;
29}
30
31impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
32    type Extension = Extension<Self>;
33
34    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
35        f.write_str("#include <cuda_runtime.h>\n")?;
36        if flags.elem_fp4 {
37            f.write_str("#include <cuda_fp4.h>\n")?;
38        }
39        if flags.elem_fp6 {
40            f.write_str("#include <cuda_fp6.h>\n")?;
41        }
42        if flags.elem_fp8 {
43            f.write_str("#include <cuda_fp8.h>\n")?;
44        }
45        if flags.elem_bf16 {
46            f.write_str("#include <cuda_bf16.h>\n")?;
47        }
48        if flags.elem_f16 {
49            f.write_str("#include <cuda_fp16.h>\n")?;
50        }
51
52        // tf32 conversion function is in mma header
53        if flags.inst_wmma || flags.elem_tf32 {
54            Self::compile_wmma_includes(f, flags)?;
55        }
56
57        if flags.op_pipeline {
58            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
59            f.write_str("#include <cuda/pipeline>\n")?;
60        }
61        if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
62            f.write_str("#include <cooperative_groups.h>\n")?;
63            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
64            f.write_str("#include <cuda/barrier>\n")?;
65        }
66        if flags.inst_ptx_wrappers {
67            f.write_str("#include <cuda/ptx>\n")?;
68        }
69        if flags.inst_tma {
70            f.write_str(
71                "typedef struct CUtensorMap_st {
72alignas(64) unsigned long long int opaque[16];
73} CUtensorMap;\n",
74            )?;
75        }
76        Ok(())
77    }
78
79    fn compile_extensions(
80        f: &mut std::fmt::Formatter<'_>,
81        extensions: &[Self::Extension],
82    ) -> std::fmt::Result {
83        for extension in extensions {
84            match extension {
85                Extension::NoExtension => {}
86                Extension::Mma(mma) => mma.format_extension(f)?,
87            }
88        }
89        Ok(())
90    }
91
92    fn register_instruction_extension(
93        _extensions: &mut Vec<Self::Extension>,
94        _instruction: &Instruction<Self>,
95    ) {
96    }
97
98    fn register_warp_instruction_extension(
99        _extensions: &mut Vec<Self::Extension>,
100        _instruction: &WarpInstruction<Self>,
101    ) {
102    }
103
104    fn register_wmma_instruction_extension(
105        extensions: &mut Vec<Self::Extension>,
106        instruction: &shared::WmmaInstruction<Self>,
107    ) {
108        match instruction {
109            shared::WmmaInstruction::ExecuteManual {
110                shape,
111                frag_a,
112                frag_b,
113                frag_c,
114                frag_d,
115            } => {
116                let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
117                    *shape,
118                    Fragment(frag_a.elem()),
119                    Fragment(frag_b.elem()),
120                    Fragment(frag_c.elem()),
121                    Fragment(frag_d.elem()),
122                )));
123                if !extensions.contains(&ext) {
124                    extensions.push(ext);
125                }
126            }
127            shared::WmmaInstruction::ExecuteScaled {
128                shape,
129                frag_a,
130                frag_b,
131                frag_c,
132                frag_d,
133                scales_a,
134                scales_factor,
135                ..
136            } => {
137                let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
138                    *shape,
139                    Fragment(frag_a.elem()),
140                    Fragment(frag_b.elem()),
141                    Fragment(frag_c.elem()),
142                    Fragment(frag_d.elem()),
143                    scales_a.elem(),
144                    *scales_factor,
145                )));
146                if !extensions.contains(&ext) {
147                    extensions.push(ext);
148                }
149            }
150            shared::WmmaInstruction::LdMatrix {
151                output,
152                factor,
153                transpose,
154                ..
155            } => {
156                let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
157                    output.elem(),
158                    *factor,
159                    *transpose,
160                )));
161                if !extensions.contains(&ext) {
162                    extensions.push(ext);
163                }
164            }
165            _ => {}
166        }
167    }
168}
169
170// Types
171
172impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
173    fn item_can_be_optimized() -> bool {
174        true
175    }
176
177    fn compile_type_definitions(
178        f: &mut std::fmt::Formatter<'_>,
179        items: &HashSet<Item<Self>>,
180        scalars: &[(Elem<Self>, usize)],
181        flags: &Flags,
182    ) -> std::fmt::Result {
183        // All FP4/FP6/FP8 elems map to the same type, so we need to deduplicate them
184        let mut items_deduplicated = HashSet::new();
185
186        for item in items {
187            let mut item = *item;
188            match item.elem() {
189                Elem::FP4(_) => {
190                    item.elem = Elem::FP4(FP4Kind::E2M1);
191                }
192                Elem::FP4x2(_) => {
193                    item.elem = Elem::FP4x2(FP4Kind::E2M1);
194                }
195                Elem::FP6(_) => {
196                    item.elem = Elem::FP6(FP6Kind::E2M3);
197                }
198                Elem::FP6x2(_) => {
199                    item.elem = Elem::FP6x2(FP6Kind::E2M3);
200                }
201                Elem::FP8(_) => {
202                    item.elem = Elem::FP8(FP8Kind::E4M3);
203                }
204                Elem::FP8x2(_) => {
205                    item.elem = Elem::FP8x2(FP8Kind::E4M3);
206                }
207                _ => {}
208            }
209            items_deduplicated.insert(item);
210        }
211
212        shared::type_definitions::<Self>(f)?;
213        shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
214
215        if flags.use_grid_constants {
216            shared::type_scalar_definitions::<Self>(f, scalars)?;
217            shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
218        }
219
220        if flags.inst_wmma {
221            Self::compile_wmma_type_definitions(f, flags)?;
222        }
223
224        Ok(())
225    }
226
227    fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
228        if flags.inst_tma_im2col {
229            writeln!(f, "{TMA_LOAD_IM2COL}")?;
230        }
231        Ok(())
232    }
233
234    fn compile_elem(
235        f: &mut std::fmt::Formatter<'_>,
236        elem: &shared::Elem<Self>,
237        words: bool,
238    ) -> std::fmt::Result {
239        if words {
240            match elem {
241                shared::Elem::F32 => f.write_str("float"),
242                shared::Elem::F64 => f.write_str("double"),
243                shared::Elem::TF32 => f.write_str("float"),
244                shared::Elem::I8 => f.write_str("char"),
245                shared::Elem::I16 => f.write_str("short"),
246                shared::Elem::I32 => f.write_str("int"),
247                shared::Elem::I64 => f.write_str("long"),
248                shared::Elem::U8 => f.write_str("uchar"),
249                shared::Elem::U16 => f.write_str("ushort"),
250                shared::Elem::U32 => f.write_str("uint"),
251                shared::Elem::U64 => f.write_str("ulong"),
252                _ => Self::compile_elem(f, elem, false),
253            }
254        } else {
255            match elem {
256                shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
257                shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
258                shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
259                shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
260                shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
261                shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
262                shared::Elem::F16 => f.write_str("__half"),
263                shared::Elem::F16x2 => f.write_str("__half2"),
264                shared::Elem::F32 => f.write_str("float"),
265                shared::Elem::F64 => f.write_str("double"),
266                shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
267                shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
268                shared::Elem::TF32 => f.write_str("float"),
269                shared::Elem::I8 => f.write_str("int8"),
270                shared::Elem::I16 => f.write_str("int16"),
271                shared::Elem::I32 => f.write_str("int32"),
272                shared::Elem::I64 => f.write_str("int64"),
273                shared::Elem::U8 => f.write_str("uint8"),
274                shared::Elem::U16 => f.write_str("uint16"),
275                shared::Elem::U32 => f.write_str("uint32"),
276                shared::Elem::U64 => f.write_str("uint64"),
277                shared::Elem::Bool => f.write_str("bool"),
278                shared::Elem::Atomic(inner) => write!(f, "{inner}"),
279                shared::Elem::_Dialect(_) => Ok(()),
280            }
281        }
282    }
283
284    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
285        if 1 == item.vectorization {
286            return write!(f, "{}", item.elem);
287        }
288        if item.native {
289            // native types use the word form of types only
290            Self::compile_elem(f, &item.elem, true)?;
291            write!(f, "{}", item.vectorization)
292        } else {
293            write!(f, "{}_{}", item.elem, item.vectorization)
294        }
295    }
296
297    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298        Ok(())
299    }
300}
301
302// Kernel argument bindings
303
304impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
305    fn compile_kernel_signature(
306        f: &mut std::fmt::Formatter<'_>,
307        kernel_name: &str,
308        tensor_maps: &[Binding<Self>],
309        buffers: &[Binding<Self>],
310        scalars: &[(Elem<Self>, usize)],
311        flags: &Flags,
312    ) -> std::fmt::Result {
313        write!(
314            f,
315            "
316
317extern \"C\" __global__ void __launch_bounds__({})",
318            flags.cube_dim.num_elems()
319        )?;
320        if let Some(cluster_dim) = flags.cluster_dim {
321            write!(
322                f,
323                "__cluster_dims__({}, {}, {}) ",
324                cluster_dim.x, cluster_dim.y, cluster_dim.z
325            )?;
326        }
327        writeln!(f, "{kernel_name} (")?;
328        let has_scalars =
329            !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
330        shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
331        if flags.use_grid_constants {
332            shared::compile_scalars_static(f, scalars, flags)?;
333        } else {
334            shared::compile_scalars_dynamic(f, scalars)?;
335        }
336        f.write_str("\n)")?;
337        //
338        Ok(())
339    }
340
341    fn compile_bindings_body(
342        f: &mut std::fmt::Formatter<'_>,
343        body: &shared::Body<Self>,
344    ) -> std::fmt::Result {
345        if !body.shared_memories.is_empty() {
346            let max_align = body
347                .shared_memories
348                .iter()
349                .map(|smem| smem.align)
350                .max()
351                .unwrap();
352            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
353            // with `extern __shared__ alignas` and doesn't properly parse it.
354            writeln!(
355                f,
356                "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
357            )?;
358        }
359        Ok(())
360    }
361}
362
363impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
364
365// Cube builtins dialect
366
367impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
368    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369        write!(f, "cluster.block_rank()")
370    }
371
372    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        write!(f, "cluster.block_index().x")
374    }
375
376    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377        write!(f, "cluster.block_index().y")
378    }
379
380    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        write!(f, "cluster.block_index().z")
382    }
383}
384
385// Instructions
386
387impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
388    // sync
389    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        writeln!(f, "__syncthreads();\n")
391    }
392
393    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        writeln!(f, "__syncwarp();\n")
395    }
396
397    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398        writeln!(f, "__threadfence();")
399    }
400
401    // unary
402    fn compile_instruction_find_first_set<T: Component<Self>>(
403        f: &mut std::fmt::Formatter<'_>,
404        input: T,
405        out_elem: Elem<Self>,
406    ) -> std::fmt::Result {
407        write!(f, "{out_elem}(")?;
408        match input.elem() {
409            Elem::I32 => write!(f, "__ffs({input})"),
410            Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
411            Elem::I64 => write!(f, "__ffsll({input})"),
412            Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
413            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
414        }?;
415        write!(f, ")")
416    }
417
418    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
419        f: &mut std::fmt::Formatter<'_>,
420        input: T,
421        out_elem: Elem<Self>,
422    ) -> std::fmt::Result {
423        write!(f, "{out_elem}(")?;
424        match input.elem() {
425            Elem::I32 => write!(f, "__clz({input})"),
426            Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
427            Elem::I64 => write!(f, "__clzll({input})"),
428            Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
429            in_elem => write!(
430                f,
431                "{out_elem}(__clz({}) - {})",
432                unary::zero_extend(input),
433                (size_of::<u32>() - in_elem.size()) * 8
434            ),
435        }?;
436        write!(f, ")")
437    }
438
439    fn compile_saturating_add(
440        f: &mut std::fmt::Formatter<'_>,
441        lhs: impl Display,
442        rhs: impl Display,
443        item: Item<Self>,
444    ) -> std::fmt::Result {
445        let elem = item.elem();
446        match elem {
447            Elem::I32 => {
448                write!(
449                    f,
450                    r#"[&]() -> {elem} {{
451    {elem} result;
452    asm("add.sat.s32 %0, %1, %2;"
453        : "=r"(result)
454        : "r"({lhs}), "r"({rhs}));
455    return result;
456        }}()"#
457                )
458            }
459            _ => unreachable!("Should be replaced by polyfill"),
460        }
461    }
462
463    fn compile_saturating_sub(
464        f: &mut std::fmt::Formatter<'_>,
465        lhs: impl Display,
466        rhs: impl Display,
467        item: Item<Self>,
468    ) -> std::fmt::Result {
469        let elem = item.elem();
470        // Native instruction only exists for signed int, unsigned should be removed in a preprocessor
471        match elem {
472            Elem::I32 => {
473                write!(
474                    f,
475                    r#"[&]() -> {elem} {{
476    {elem} result;
477    asm("sub.sat.s32 %0, %1, %2;"
478        : "=r"(result)
479        : "r"({lhs}), "r"({rhs}));
480    return result;
481        }}()"#
482                )
483            }
484            _ => unreachable!("Should be replaced by polyfill"),
485        }
486    }
487
488    // others
489    fn compile_instruction_max_function_name(
490        f: &mut std::fmt::Formatter<'_>,
491        item: Item<Self>,
492    ) -> std::fmt::Result {
493        let max = match item.elem() {
494            Elem::F16 | Elem::BF16 => "__hmax",
495            Elem::F16x2 | Elem::BF16x2 => "__hmax2",
496            _ => "max",
497        };
498        write!(f, "{max}")
499    }
500
501    fn compile_instruction_min_function_name(
502        f: &mut std::fmt::Formatter<'_>,
503        item: Item<Self>,
504    ) -> std::fmt::Result {
505        let min = match item.elem() {
506            Elem::F16 | Elem::BF16 => "__hmin",
507            Elem::F16x2 | Elem::BF16x2 => "__hmin2",
508            _ => "min",
509        };
510        write!(f, "{min}")
511    }
512
513    // warp
514    fn compile_warp_shuffle(
515        f: &mut std::fmt::Formatter<'_>,
516        var: &str,
517        source: &str,
518    ) -> std::fmt::Result {
519        write!(f, "__shfl_sync(-1, {var}, {source})")
520    }
521    fn compile_warp_shuffle_xor(
522        f: &mut std::fmt::Formatter<'_>,
523        var: &str,
524        _elem: &Elem<Self>,
525        offset: &str,
526    ) -> std::fmt::Result {
527        write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
528    }
529    fn compile_warp_shuffle_up(
530        f: &mut std::fmt::Formatter<'_>,
531        var: &str,
532        offset: &str,
533    ) -> std::fmt::Result {
534        write!(f, "__shfl_up_sync(-1, {var}, {offset})")
535    }
536    fn compile_warp_shuffle_down(
537        f: &mut std::fmt::Formatter<'_>,
538        var: &str,
539        offset: &str,
540    ) -> std::fmt::Result {
541        write!(f, "__shfl_down_sync(-1, {var}, {offset})")
542    }
543    fn compile_warp_all<T: Component<Self>>(
544        f: &mut std::fmt::Formatter<'_>,
545        input: &T,
546    ) -> std::fmt::Result {
547        write!(f, "__all_sync(-1, {input})")
548    }
549    fn compile_warp_any<T: Component<Self>>(
550        f: &mut std::fmt::Formatter<'_>,
551        input: &T,
552    ) -> std::fmt::Result {
553        write!(f, "__any_sync(-1, {input})")
554    }
555
556    fn compile_warp_ballot(
557        f: &mut std::fmt::Formatter<'_>,
558        input: &Variable<Self>,
559        _out_elem: &Elem<Self>,
560    ) -> std::fmt::Result {
561        write!(f, "__ballot_sync(-1, {input})")
562    }
563
564    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
565        let elem = Elem::<Self>::Bool;
566        let uint32 = Elem::<Self>::U32;
567        // Used to have a wrapper but it has been removed in newer version due to being
568        // "incomplete". We only need the predicate and have a fixed mask, so it's trivial to
569        // implement.
570        writeln!(
571            f,
572            r#"{out} = {elem}([&]() -> {uint32} {{
573    {uint32} pred = 0;
574    asm volatile(
575        "{{\n"
576        "     .reg .pred %%px;\n"
577        "     elect.sync _|%%px, 0xffffffff;\n"
578        "     selp.b32 %0, 1, 0, %%px;\n"
579        "}}\n"
580        : "+r"(pred));
581    return pred;
582        }}());"#
583        )
584    }
585}
586
587// Coop Matrices dialect
588
589impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
590    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
591        M::compile_wmma_includes(f, flags)
592    }
593
594    fn compile_wmma_type_definitions(
595        f: &mut std::fmt::Formatter<'_>,
596        flags: &Flags,
597    ) -> std::fmt::Result {
598        M::compile_wmma_type_definitions(f, flags)
599    }
600
601    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602        M::compile_wmma_local_variables(f)
603    }
604
605    fn compile_wmma_fragment_declaration(
606        f: &mut std::fmt::Formatter<'_>,
607        var: &Variable<Self>,
608    ) -> std::fmt::Result {
609        M::compile_wmma_fragment_declaration(f, var)
610    }
611
612    fn compile_wwma_fragment_ident(
613        f: &mut std::fmt::Formatter<'_>,
614        ident: &crate::shared::FragmentIdent<Self>,
615    ) -> std::fmt::Result {
616        M::compile_wwma_fragment_ident(f, ident)
617    }
618
619    fn compile_wmma_fragment_layout(
620        f: &mut std::fmt::Formatter<'_>,
621        layout: &crate::shared::FragmentLayout<Self>,
622    ) -> std::fmt::Result {
623        M::compile_wmma_fragment_layout(f, layout)
624    }
625
626    fn compile_wmma_fragment(
627        f: &mut std::fmt::Formatter<'_>,
628        fragment: &crate::shared::Fragment<Self>,
629    ) -> std::fmt::Result {
630        M::compile_wmma_fragment(f, fragment)
631    }
632
633    fn compile_wmma_instruction(
634        f: &mut std::fmt::Formatter<'_>,
635        instruction: &crate::shared::WmmaInstruction<Self>,
636    ) -> std::fmt::Result {
637        M::compile_wmma_instruction(f, instruction)
638    }
639
640    fn compile_manual_mma(
641        f: &mut std::fmt::Formatter<'_>,
642        mma: ManualMma<Self>,
643    ) -> std::fmt::Result {
644        M::compile_manual_mma(f, mma)
645    }
646
647    fn compile_scaled_mma(
648        f: &mut std::fmt::Formatter<'_>,
649        mma: ManualMma<Self>,
650        scales_a: Variable<Self>,
651        scales_b: Variable<Self>,
652        scales_factor: u32,
653    ) -> std::fmt::Result {
654        M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
655    }
656
657    fn supported_wmma_combinations(
658        arch: &CudaArchitecture,
659    ) -> crate::shared::SupportedMmaCombinations {
660        M::supported_wmma_combinations(arch)
661    }
662
663    fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
664        M::supported_mma_combinations(arch)
665    }
666
667    fn supported_scaled_mma_combinations(
668        arch: &CudaArchitecture,
669    ) -> shared::SupportedScaledMmaCombinations {
670        M::supported_scaled_mma_combinations(arch)
671    }
672}
673
674impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
675    fn processors() -> Vec<Box<dyn Processor>> {
676        vec![
677            Box::new(CudaMmaProcessor),
678            Box::new(SaturatingArithmeticProcessor::new(false)),
679        ]
680    }
681}