cubecl_cpp/cuda/
dialect.rs

1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{
4    ir::{BarrierLevel, Processor},
5    post_processing::saturating::SaturatingArithmeticProcessor,
6};
7
8use crate::{
9    Dialect,
10    cuda::{
11        extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension, StMatrix},
12        processors::CudaMmaProcessor,
13        ptx::*,
14    },
15    shared::{
16        self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
17        DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
18        DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
19        Variable, WarpInstruction, unary,
20    },
21};
22
23use super::{Extension, arch::CudaArchitecture};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct CudaDialect<M> {
27    _wmma_compiler: PhantomData<M>,
28}
29
30impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
31    type Architecture = CudaArchitecture;
32}
33
34impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
35    type Extension = Extension<Self>;
36
37    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
38        f.write_str("#include <cuda_runtime.h>\n")?;
39        if flags.elem_fp4 {
40            f.write_str("#include <cuda_fp4.h>\n")?;
41        }
42        if flags.elem_fp6 {
43            f.write_str("#include <cuda_fp6.h>\n")?;
44        }
45        if flags.elem_fp8 {
46            f.write_str("#include <cuda_fp8.h>\n")?;
47        }
48        if flags.elem_bf16 {
49            f.write_str("#include <cuda_bf16.h>\n")?;
50        }
51        if flags.elem_f16 {
52            f.write_str("#include <cuda_fp16.h>\n")?;
53        }
54
55        // tf32 conversion function is in mma header
56        if flags.inst_wmma || flags.elem_tf32 {
57            Self::compile_wmma_includes(f, flags)?;
58        }
59
60        if flags.op_pipeline {
61            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
62            f.write_str("#include <cuda/pipeline>\n")?;
63        }
64        if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
65            f.write_str("#include <cooperative_groups.h>\n")?;
66            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
67            f.write_str("#include <cuda/barrier>\n")?;
68        }
69        if flags.inst_ptx_wrappers {
70            f.write_str("#include <cuda/ptx>\n")?;
71        }
72        if flags.inst_tma {
73            f.write_str(
74                "typedef struct CUtensorMap_st {
75alignas(64) unsigned long long int opaque[16];
76} CUtensorMap;\n",
77            )?;
78        }
79        Ok(())
80    }
81
82    fn compile_extensions(
83        f: &mut std::fmt::Formatter<'_>,
84        extensions: &[Self::Extension],
85    ) -> std::fmt::Result {
86        for extension in extensions {
87            match extension {
88                Extension::NoExtension => {}
89                Extension::Mma(mma) => mma.format_extension(f)?,
90            }
91        }
92        Ok(())
93    }
94
95    fn register_instruction_extension(
96        _extensions: &mut Vec<Self::Extension>,
97        _instruction: &Instruction<Self>,
98    ) {
99    }
100
101    fn register_warp_instruction_extension(
102        _extensions: &mut Vec<Self::Extension>,
103        _instruction: &WarpInstruction<Self>,
104    ) {
105    }
106
107    fn register_wmma_instruction_extension(
108        extensions: &mut Vec<Self::Extension>,
109        instruction: &shared::WmmaInstruction<Self>,
110    ) {
111        match instruction {
112            shared::WmmaInstruction::ExecuteManual {
113                shape,
114                frag_a,
115                frag_b,
116                frag_c,
117                frag_d,
118            } => {
119                let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
120                    *shape,
121                    Fragment(frag_a.elem()),
122                    Fragment(frag_b.elem()),
123                    Fragment(frag_c.elem()),
124                    Fragment(frag_d.elem()),
125                )));
126                if !extensions.contains(&ext) {
127                    extensions.push(ext);
128                }
129            }
130            shared::WmmaInstruction::ExecuteScaled {
131                shape,
132                frag_a,
133                frag_b,
134                frag_c,
135                frag_d,
136                scales_a,
137                scales_factor,
138                ..
139            } => {
140                let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
141                    *shape,
142                    Fragment(frag_a.elem()),
143                    Fragment(frag_b.elem()),
144                    Fragment(frag_c.elem()),
145                    Fragment(frag_d.elem()),
146                    scales_a.elem(),
147                    *scales_factor,
148                )));
149                if !extensions.contains(&ext) {
150                    extensions.push(ext);
151                }
152            }
153            shared::WmmaInstruction::LdMatrix {
154                output,
155                factor,
156                transpose,
157                ..
158            } => {
159                let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
160                    output.elem(),
161                    *factor,
162                    *transpose,
163                )));
164                if !extensions.contains(&ext) {
165                    extensions.push(ext);
166                }
167            }
168            shared::WmmaInstruction::StMatrix {
169                registers,
170                factor,
171                transpose,
172                ..
173            } => {
174                let ext = Extension::Mma(MmaExtension::StMatrix(StMatrix::new(
175                    registers.elem(),
176                    *factor,
177                    *transpose,
178                )));
179                if !extensions.contains(&ext) {
180                    extensions.push(ext);
181                }
182            }
183            _ => {}
184        }
185    }
186}
187
188// Types
189
190impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
191    fn item_can_be_optimized() -> bool {
192        true
193    }
194
195    fn compile_type_definitions(
196        f: &mut std::fmt::Formatter<'_>,
197        items: &HashSet<Item<Self>>,
198        scalars: &[(Elem<Self>, usize)],
199        flags: &Flags,
200    ) -> std::fmt::Result {
201        // All FP4/FP6/FP8 elems map to the same type, so we need to deduplicate them
202        let mut items_deduplicated = HashSet::new();
203
204        for item in items {
205            let mut item = *item;
206            match item.elem() {
207                Elem::FP4(_) => {
208                    item.elem = Elem::FP4(FP4Kind::E2M1);
209                }
210                Elem::FP4x2(_) => {
211                    item.elem = Elem::FP4x2(FP4Kind::E2M1);
212                }
213                Elem::FP6(_) => {
214                    item.elem = Elem::FP6(FP6Kind::E2M3);
215                }
216                Elem::FP6x2(_) => {
217                    item.elem = Elem::FP6x2(FP6Kind::E2M3);
218                }
219                Elem::FP8(_) => {
220                    item.elem = Elem::FP8(FP8Kind::E4M3);
221                }
222                Elem::FP8x2(_) => {
223                    item.elem = Elem::FP8x2(FP8Kind::E4M3);
224                }
225                _ => {}
226            }
227            items_deduplicated.insert(item);
228        }
229
230        shared::type_definitions::<Self>(f)?;
231        shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
232
233        if flags.use_grid_constants {
234            shared::type_scalar_definitions::<Self>(f, scalars)?;
235            shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
236        }
237
238        if flags.inst_wmma {
239            Self::compile_wmma_type_definitions(f, flags)?;
240        }
241
242        Ok(())
243    }
244
245    fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
246        if flags.inst_tma_im2col {
247            writeln!(f, "{TMA_LOAD_IM2COL}")?;
248        }
249        if flags.inst_async_copy {
250            writeln!(f, "{COPY_ASYNC}")?;
251        }
252        Ok(())
253    }
254
255    fn compile_elem(
256        f: &mut std::fmt::Formatter<'_>,
257        elem: &shared::Elem<Self>,
258        words: bool,
259    ) -> std::fmt::Result {
260        if words {
261            match elem {
262                shared::Elem::F32 => f.write_str("float"),
263                shared::Elem::F64 => f.write_str("double"),
264                shared::Elem::TF32 => f.write_str("float"),
265                shared::Elem::I8 => f.write_str("char"),
266                shared::Elem::I16 => f.write_str("short"),
267                shared::Elem::I32 => f.write_str("int"),
268                shared::Elem::I64 => f.write_str("long"),
269                shared::Elem::U8 => f.write_str("uchar"),
270                shared::Elem::U16 => f.write_str("ushort"),
271                shared::Elem::U32 => f.write_str("uint"),
272                shared::Elem::U64 => f.write_str("ulong"),
273                _ => Self::compile_elem(f, elem, false),
274            }
275        } else {
276            match elem {
277                shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
278                shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
279                shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
280                shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
281                shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
282                shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
283                shared::Elem::F16 => f.write_str("__half"),
284                shared::Elem::F16x2 => f.write_str("__half2"),
285                shared::Elem::F32 => f.write_str("float"),
286                shared::Elem::F64 => f.write_str("double"),
287                shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
288                shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
289                shared::Elem::TF32 => f.write_str("float"),
290                shared::Elem::I8 => f.write_str("int8"),
291                shared::Elem::I16 => f.write_str("int16"),
292                shared::Elem::I32 => f.write_str("int32"),
293                shared::Elem::I64 => f.write_str("int64"),
294                shared::Elem::U8 => f.write_str("uint8"),
295                shared::Elem::U16 => f.write_str("uint16"),
296                shared::Elem::U32 => f.write_str("uint32"),
297                shared::Elem::U64 => f.write_str("uint64"),
298                shared::Elem::Bool => f.write_str("bool"),
299                shared::Elem::Barrier(BarrierLevel::Unit) => {
300                    f.write_str("cuda::barrier<cuda::thread_scope_thread>")
301                }
302                shared::Elem::Barrier(BarrierLevel::Cube) => {
303                    f.write_str("cuda::barrier<cuda::thread_scope_block>")
304                }
305                shared::Elem::Atomic(inner) => write!(f, "{inner}"),
306                shared::Elem::_Dialect(_) => Ok(()),
307            }
308        }
309    }
310
311    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
312        if 1 == item.vectorization {
313            return write!(f, "{}", item.elem);
314        }
315        if item.native {
316            // native types use the word form of types only
317            Self::compile_elem(f, &item.elem, true)?;
318            write!(f, "{}", item.vectorization)
319        } else {
320            write!(f, "{}_{}", item.elem, item.vectorization)
321        }
322    }
323
324    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        Ok(())
326    }
327}
328
329// Kernel argument bindings
330
331impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
332    fn compile_kernel_signature(
333        f: &mut std::fmt::Formatter<'_>,
334        kernel_name: &str,
335        tensor_maps: &[Binding<Self>],
336        buffers: &[Binding<Self>],
337        scalars: &[(Elem<Self>, usize)],
338        flags: &Flags,
339    ) -> std::fmt::Result {
340        write!(
341            f,
342            "
343
344extern \"C\" __global__ void __launch_bounds__({})",
345            flags.cube_dim.num_elems()
346        )?;
347        if let Some(cluster_dim) = flags.cluster_dim {
348            write!(
349                f,
350                "__cluster_dims__({}, {}, {}) ",
351                cluster_dim.x, cluster_dim.y, cluster_dim.z
352            )?;
353        }
354        writeln!(f, "{kernel_name} (")?;
355        let has_scalars =
356            !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
357        shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
358        if flags.use_grid_constants {
359            shared::compile_scalars_static(f, scalars, flags)?;
360        } else {
361            shared::compile_scalars_dynamic(f, scalars)?;
362        }
363        f.write_str("\n)")?;
364        //
365        Ok(())
366    }
367
368    fn compile_bindings_body(
369        f: &mut std::fmt::Formatter<'_>,
370        body: &shared::Body<Self>,
371    ) -> std::fmt::Result {
372        if !body.shared_memories.is_empty() {
373            let max_align = body
374                .shared_memories
375                .iter()
376                .map(|smem| smem.align())
377                .max()
378                .unwrap();
379            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
380            // with `extern __shared__ alignas` and doesn't properly parse it.
381            writeln!(
382                f,
383                "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
384            )?;
385        }
386        Ok(())
387    }
388}
389
390impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
391
392// Cube builtins dialect
393
394impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
395    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        write!(f, "cluster.block_rank()")
397    }
398
399    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        write!(f, "cluster.block_index().x")
401    }
402
403    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        write!(f, "cluster.block_index().y")
405    }
406
407    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        write!(f, "cluster.block_index().z")
409    }
410}
411
412// Instructions
413
414impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
415    // sync
416    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        writeln!(f, "__syncthreads();\n")
418    }
419
420    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421        writeln!(f, "__syncwarp();\n")
422    }
423
424    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        writeln!(f, "__threadfence();")
426    }
427
428    // unary
429    fn compile_instruction_find_first_set<T: Component<Self>>(
430        f: &mut std::fmt::Formatter<'_>,
431        input: T,
432        out_elem: Elem<Self>,
433    ) -> std::fmt::Result {
434        write!(f, "{out_elem}(")?;
435        match input.elem() {
436            Elem::I32 => write!(f, "__ffs({input})"),
437            Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
438            Elem::I64 => write!(f, "__ffsll({input})"),
439            Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
440            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
441        }?;
442        write!(f, ")")
443    }
444
445    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
446        f: &mut std::fmt::Formatter<'_>,
447        input: T,
448        out_elem: Elem<Self>,
449    ) -> std::fmt::Result {
450        write!(f, "{out_elem}(")?;
451        match input.elem() {
452            Elem::I32 => write!(f, "__clz({input})"),
453            Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
454            Elem::I64 => write!(f, "__clzll({input})"),
455            Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
456            in_elem => write!(
457                f,
458                "{out_elem}(__clz({}) - {})",
459                unary::zero_extend(input),
460                (size_of::<u32>() - in_elem.size()) * 8
461            ),
462        }?;
463        write!(f, ")")
464    }
465
466    fn compile_saturating_add(
467        f: &mut std::fmt::Formatter<'_>,
468        lhs: impl Display,
469        rhs: impl Display,
470        item: Item<Self>,
471    ) -> std::fmt::Result {
472        let elem = item.elem();
473        match elem {
474            Elem::I32 => {
475                write!(
476                    f,
477                    r#"[&]() -> {elem} {{
478    {elem} result;
479    asm("add.sat.s32 %0, %1, %2;"
480        : "=r"(result)
481        : "r"({lhs}), "r"({rhs}));
482    return result;
483        }}()"#
484                )
485            }
486            _ => unreachable!("Should be replaced by polyfill"),
487        }
488    }
489
490    fn compile_saturating_sub(
491        f: &mut std::fmt::Formatter<'_>,
492        lhs: impl Display,
493        rhs: impl Display,
494        item: Item<Self>,
495    ) -> std::fmt::Result {
496        let elem = item.elem();
497        // Native instruction only exists for signed int, unsigned should be removed in a preprocessor
498        match elem {
499            Elem::I32 => {
500                write!(
501                    f,
502                    r#"[&]() -> {elem} {{
503    {elem} result;
504    asm("sub.sat.s32 %0, %1, %2;"
505        : "=r"(result)
506        : "r"({lhs}), "r"({rhs}));
507    return result;
508        }}()"#
509                )
510            }
511            _ => unreachable!("Should be replaced by polyfill"),
512        }
513    }
514
515    // others
516    fn compile_instruction_max_function_name(
517        f: &mut std::fmt::Formatter<'_>,
518        item: Item<Self>,
519    ) -> std::fmt::Result {
520        let max = match item.elem() {
521            Elem::F16 | Elem::BF16 => "__hmax",
522            Elem::F16x2 | Elem::BF16x2 => "__hmax2",
523            _ => "max",
524        };
525        write!(f, "{max}")
526    }
527
528    fn compile_instruction_min_function_name(
529        f: &mut std::fmt::Formatter<'_>,
530        item: Item<Self>,
531    ) -> std::fmt::Result {
532        let min = match item.elem() {
533            Elem::F16 | Elem::BF16 => "__hmin",
534            Elem::F16x2 | Elem::BF16x2 => "__hmin2",
535            _ => "min",
536        };
537        write!(f, "{min}")
538    }
539
540    // warp
541    fn compile_warp_shuffle(
542        f: &mut std::fmt::Formatter<'_>,
543        var: &str,
544        source: &str,
545    ) -> std::fmt::Result {
546        write!(f, "__shfl_sync(-1, {var}, {source})")
547    }
548    fn compile_warp_shuffle_xor(
549        f: &mut std::fmt::Formatter<'_>,
550        var: &str,
551        _elem: &Elem<Self>,
552        offset: &str,
553    ) -> std::fmt::Result {
554        write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
555    }
556    fn compile_warp_shuffle_up(
557        f: &mut std::fmt::Formatter<'_>,
558        var: &str,
559        offset: &str,
560    ) -> std::fmt::Result {
561        write!(f, "__shfl_up_sync(-1, {var}, {offset})")
562    }
563    fn compile_warp_shuffle_down(
564        f: &mut std::fmt::Formatter<'_>,
565        var: &str,
566        offset: &str,
567    ) -> std::fmt::Result {
568        write!(f, "__shfl_down_sync(-1, {var}, {offset})")
569    }
570    fn compile_warp_all<T: Component<Self>>(
571        f: &mut std::fmt::Formatter<'_>,
572        input: &T,
573    ) -> std::fmt::Result {
574        write!(f, "__all_sync(-1, {input})")
575    }
576    fn compile_warp_any<T: Component<Self>>(
577        f: &mut std::fmt::Formatter<'_>,
578        input: &T,
579    ) -> std::fmt::Result {
580        write!(f, "__any_sync(-1, {input})")
581    }
582
583    fn compile_warp_ballot(
584        f: &mut std::fmt::Formatter<'_>,
585        input: &Variable<Self>,
586        _out_elem: &Elem<Self>,
587    ) -> std::fmt::Result {
588        write!(f, "__ballot_sync(-1, {input})")
589    }
590
591    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
592        let elem = Elem::<Self>::Bool;
593        let uint32 = Elem::<Self>::U32;
594        // Used to have a wrapper but it has been removed in newer version due to being
595        // "incomplete". We only need the predicate and have a fixed mask, so it's trivial to
596        // implement.
597        writeln!(
598            f,
599            r#"{out} = {elem}([&]() -> {uint32} {{
600    {uint32} pred = 0;
601    asm volatile(
602        "{{\n"
603        "     .reg .pred %%px;\n"
604        "     elect.sync _|%%px, 0xffffffff;\n"
605        "     selp.b32 %0, 1, 0, %%px;\n"
606        "}}\n"
607        : "+r"(pred));
608    return pred;
609        }}());"#
610        )
611    }
612}
613
614// Coop Matrices dialect
615
616impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
617    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
618        M::compile_wmma_includes(f, flags)
619    }
620
621    fn compile_wmma_type_definitions(
622        f: &mut std::fmt::Formatter<'_>,
623        flags: &Flags,
624    ) -> std::fmt::Result {
625        M::compile_wmma_type_definitions(f, flags)
626    }
627
628    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629        M::compile_wmma_local_variables(f)
630    }
631
632    fn compile_wmma_fragment_declaration(
633        f: &mut std::fmt::Formatter<'_>,
634        var: &Variable<Self>,
635    ) -> std::fmt::Result {
636        M::compile_wmma_fragment_declaration(f, var)
637    }
638
639    fn compile_wwma_fragment_ident(
640        f: &mut std::fmt::Formatter<'_>,
641        ident: &crate::shared::FragmentIdent<Self>,
642    ) -> std::fmt::Result {
643        M::compile_wwma_fragment_ident(f, ident)
644    }
645
646    fn compile_wmma_fragment_layout(
647        f: &mut std::fmt::Formatter<'_>,
648        layout: &crate::shared::FragmentLayout<Self>,
649    ) -> std::fmt::Result {
650        M::compile_wmma_fragment_layout(f, layout)
651    }
652
653    fn compile_wmma_fragment(
654        f: &mut std::fmt::Formatter<'_>,
655        fragment: &crate::shared::Fragment<Self>,
656    ) -> std::fmt::Result {
657        M::compile_wmma_fragment(f, fragment)
658    }
659
660    fn compile_wmma_instruction(
661        f: &mut std::fmt::Formatter<'_>,
662        instruction: &crate::shared::WmmaInstruction<Self>,
663    ) -> std::fmt::Result {
664        M::compile_wmma_instruction(f, instruction)
665    }
666
667    fn compile_manual_mma(
668        f: &mut std::fmt::Formatter<'_>,
669        mma: ManualMma<Self>,
670    ) -> std::fmt::Result {
671        M::compile_manual_mma(f, mma)
672    }
673
674    fn compile_scaled_mma(
675        f: &mut std::fmt::Formatter<'_>,
676        mma: ManualMma<Self>,
677        scales_a: Variable<Self>,
678        scales_b: Variable<Self>,
679        scales_factor: u32,
680    ) -> std::fmt::Result {
681        M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
682    }
683
684    fn supported_wmma_combinations(
685        arch: &CudaArchitecture,
686    ) -> crate::shared::SupportedMmaCombinations {
687        M::supported_wmma_combinations(arch)
688    }
689
690    fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
691        M::supported_mma_combinations(arch)
692    }
693
694    fn supported_scaled_mma_combinations(
695        arch: &CudaArchitecture,
696    ) -> shared::SupportedScaledMmaCombinations {
697        M::supported_scaled_mma_combinations(arch)
698    }
699}
700
701impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
702    fn processors() -> Vec<Box<dyn Processor>> {
703        vec![
704            Box::new(CudaMmaProcessor),
705            Box::new(SaturatingArithmeticProcessor::new(false)),
706        ]
707    }
708}