Skip to main content

cubecl_cpp/cuda/
dialect.rs

1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{
4    ir::{BarrierLevel, Processor},
5    post_processing::saturating::SaturatingArithmeticProcessor,
6};
7
8use crate::{
9    Dialect,
10    cuda::{
11        extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension, StMatrix},
12        processors::CudaMmaProcessor,
13        ptx::*,
14    },
15    shared::{
16        self, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
17        DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
18        DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, KernelArg,
19        ManualMma, Variable, WarpInstruction, unary,
20    },
21};
22
23use super::{Extension, arch::CudaArchitecture};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct CudaDialect<M> {
27    _wmma_compiler: PhantomData<M>,
28}
29
30impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
31    type Architecture = CudaArchitecture;
32}
33
34impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
35    type Extension = Extension<Self>;
36
37    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
38        f.write_str("#include <cuda_runtime.h>\n")?;
39        if flags.elem_fp4 {
40            f.write_str("#include <cuda_fp4.h>\n")?;
41        }
42        if flags.elem_fp6 {
43            f.write_str("#include <cuda_fp6.h>\n")?;
44        }
45        if flags.elem_fp8 {
46            f.write_str("#include <cuda_fp8.h>\n")?;
47        }
48        if flags.elem_bf16 {
49            f.write_str("#include <cuda_bf16.h>\n")?;
50        }
51        if flags.elem_f16 {
52            f.write_str("#include <cuda_fp16.h>\n")?;
53        }
54
55        // tf32 conversion function is in mma header
56        if flags.inst_wmma || flags.elem_tf32 {
57            Self::compile_wmma_includes(f, flags)?;
58        }
59
60        if flags.op_pipeline {
61            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
62            f.write_str("#include <cuda/pipeline>\n")?;
63        }
64        if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
65            f.write_str("#include <cooperative_groups.h>\n")?;
66            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
67            f.write_str("#include <cuda/barrier>\n")?;
68        }
69        if flags.inst_ptx_wrappers {
70            f.write_str("#include <cuda/ptx>\n")?;
71        }
72        if flags.inst_tma {
73            f.write_str(
74                "typedef struct CUtensorMap_st {
75alignas(64) unsigned long long int opaque[16];
76} CUtensorMap;\n",
77            )?;
78        }
79        Ok(())
80    }
81
82    fn compile_extensions(
83        f: &mut std::fmt::Formatter<'_>,
84        extensions: &[Self::Extension],
85    ) -> std::fmt::Result {
86        for extension in extensions {
87            match extension {
88                Extension::NoExtension => {}
89                Extension::Mma(mma) => mma.format_extension(f)?,
90            }
91        }
92        Ok(())
93    }
94
95    fn register_instruction_extension(
96        _extensions: &mut Vec<Self::Extension>,
97        _instruction: &Instruction<Self>,
98    ) {
99    }
100
101    fn register_warp_instruction_extension(
102        _extensions: &mut Vec<Self::Extension>,
103        _instruction: &WarpInstruction<Self>,
104    ) {
105    }
106
107    fn register_wmma_instruction_extension(
108        extensions: &mut Vec<Self::Extension>,
109        instruction: &shared::WmmaInstruction<Self>,
110    ) {
111        match instruction {
112            shared::WmmaInstruction::ExecuteManual {
113                shape,
114                frag_a,
115                frag_b,
116                frag_c,
117                frag_d,
118            } => {
119                let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
120                    *shape,
121                    Fragment(frag_a.elem()),
122                    Fragment(frag_b.elem()),
123                    Fragment(frag_c.elem()),
124                    Fragment(frag_d.elem()),
125                )));
126                if !extensions.contains(&ext) {
127                    extensions.push(ext);
128                }
129            }
130            shared::WmmaInstruction::ExecuteScaled {
131                shape,
132                frag_a,
133                frag_b,
134                frag_c,
135                frag_d,
136                scales_a,
137                scales_factor,
138                ..
139            } => {
140                let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
141                    *shape,
142                    Fragment(frag_a.elem()),
143                    Fragment(frag_b.elem()),
144                    Fragment(frag_c.elem()),
145                    Fragment(frag_d.elem()),
146                    scales_a.elem(),
147                    *scales_factor,
148                )));
149                if !extensions.contains(&ext) {
150                    extensions.push(ext);
151                }
152            }
153            shared::WmmaInstruction::LdMatrix {
154                output,
155                factor,
156                transpose,
157                ..
158            } => {
159                let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
160                    output.elem(),
161                    *factor,
162                    *transpose,
163                )));
164                if !extensions.contains(&ext) {
165                    extensions.push(ext);
166                }
167            }
168            shared::WmmaInstruction::StMatrix {
169                registers,
170                factor,
171                transpose,
172                ..
173            } => {
174                let ext = Extension::Mma(MmaExtension::StMatrix(StMatrix::new(
175                    registers.elem(),
176                    *factor,
177                    *transpose,
178                )));
179                if !extensions.contains(&ext) {
180                    extensions.push(ext);
181                }
182            }
183            _ => {}
184        }
185    }
186}
187
188// Types
189
190impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
191    fn item_can_be_optimized() -> bool {
192        true
193    }
194
195    fn compile_type_definitions(
196        f: &mut std::fmt::Formatter<'_>,
197        items: &HashSet<Item<Self>>,
198        scalars: &[(Elem<Self>, usize)],
199        info: &cubecl_core::Info,
200        flags: &Flags<Self>,
201    ) -> std::fmt::Result {
202        // All FP4/FP6/FP8 elems map to the same type, so we need to deduplicate them
203        let mut items_deduplicated = HashSet::new();
204
205        for item in items {
206            let mut item = *item;
207            match item.elem() {
208                Elem::FP4(_) => {
209                    item.elem = Elem::FP4(FP4Kind::E2M1);
210                }
211                Elem::FP4x2(_) => {
212                    item.elem = Elem::FP4x2(FP4Kind::E2M1);
213                }
214                Elem::FP6(_) => {
215                    item.elem = Elem::FP6(FP6Kind::E2M3);
216                }
217                Elem::FP6x2(_) => {
218                    item.elem = Elem::FP6x2(FP6Kind::E2M3);
219                }
220                Elem::FP8(_) => {
221                    item.elem = Elem::FP8(FP8Kind::E4M3);
222                }
223                Elem::FP8x2(_) => {
224                    item.elem = Elem::FP8x2(FP8Kind::E4M3);
225                }
226                Elem::Atomic(inner) => {
227                    item.elem = inner.as_elem();
228                }
229                _ => {}
230            }
231            items_deduplicated.insert(item);
232        }
233
234        shared::type_definitions::<Self>(f)?;
235        shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
236
237        shared::type_info_definition_sized(f, info, scalars, flags.address_type)?;
238
239        if flags.inst_wmma {
240            Self::compile_wmma_type_definitions(f, flags)?;
241        }
242
243        Ok(())
244    }
245
246    fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
247        if flags.inst_tma_im2col {
248            writeln!(f, "{TMA_LOAD_IM2COL}")?;
249        }
250        if flags.inst_async_copy {
251            writeln!(f, "{COPY_ASYNC}")?;
252        }
253        Ok(())
254    }
255
256    fn compile_elem(
257        f: &mut std::fmt::Formatter<'_>,
258        elem: &shared::Elem<Self>,
259        words: bool,
260    ) -> std::fmt::Result {
261        if words {
262            match elem {
263                shared::Elem::F32 => f.write_str("float"),
264                shared::Elem::F64 => f.write_str("double"),
265                shared::Elem::TF32 => f.write_str("float"),
266                shared::Elem::I8 => f.write_str("char"),
267                shared::Elem::I16 => f.write_str("short"),
268                shared::Elem::I32 => f.write_str("int"),
269                shared::Elem::I64 => f.write_str("long"),
270                shared::Elem::U8 => f.write_str("uchar"),
271                shared::Elem::U16 => f.write_str("ushort"),
272                shared::Elem::U32 => f.write_str("uint"),
273                shared::Elem::U64 => f.write_str("ulong"),
274                _ => Self::compile_elem(f, elem, false),
275            }
276        } else {
277            match elem {
278                shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
279                shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
280                shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
281                shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
282                shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
283                shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
284                shared::Elem::F16 => f.write_str("__half"),
285                shared::Elem::F16x2 => f.write_str("__half2"),
286                shared::Elem::F32 => f.write_str("float"),
287                shared::Elem::F64 => f.write_str("double"),
288                shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
289                shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
290                shared::Elem::TF32 => f.write_str("float"),
291                shared::Elem::I8 => f.write_str("int8"),
292                shared::Elem::I16 => f.write_str("int16"),
293                shared::Elem::I32 => f.write_str("int32"),
294                shared::Elem::I64 => f.write_str("int64"),
295                shared::Elem::U8 => f.write_str("uint8"),
296                shared::Elem::U16 => f.write_str("uint16"),
297                shared::Elem::U32 => f.write_str("uint32"),
298                shared::Elem::U64 => f.write_str("uint64"),
299                shared::Elem::Bool => f.write_str("bool"),
300                shared::Elem::Barrier(BarrierLevel::Unit) => {
301                    f.write_str("cuda::barrier<cuda::thread_scope_thread>")
302                }
303                shared::Elem::Barrier(BarrierLevel::Cube) => {
304                    f.write_str("cuda::barrier<cuda::thread_scope_block>")
305                }
306                shared::Elem::Atomic(inner) => write!(f, "{inner}"),
307                shared::Elem::_Dialect(_) => Ok(()),
308            }
309        }
310    }
311
312    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
313        if 1 == item.vectorization {
314            return write!(f, "{}", item.elem);
315        }
316        if item.native {
317            // native types use the word form of types only
318            Self::compile_elem(f, &item.elem, true)?;
319            write!(f, "{}", item.vectorization)
320        } else {
321            write!(f, "{}_{}", item.elem, item.vectorization)
322        }
323    }
324
325    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        Ok(())
327    }
328}
329
330// Kernel argument bindings
331
332impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
333    fn compile_kernel_signature(
334        f: &mut std::fmt::Formatter<'_>,
335        kernel_name: &str,
336        tensor_maps: &[KernelArg<Self>],
337        buffers: &[KernelArg<Self>],
338        flags: &Flags<Self>,
339    ) -> std::fmt::Result {
340        write!(
341            f,
342            "
343
344extern \"C\" __global__ void __launch_bounds__({})",
345            flags.cube_dim.num_elems()
346        )?;
347        if let Some(cluster_dim) = flags.cluster_dim {
348            write!(
349                f,
350                "__cluster_dims__({}, {}, {}) ",
351                cluster_dim.x, cluster_dim.y, cluster_dim.z
352            )?;
353        }
354        writeln!(f, "{kernel_name} (")?;
355
356        shared::compile_bindings(f, tensor_maps, buffers, flags.has_info)?;
357        if flags.use_grid_constants {
358            shared::compile_info_static(f, flags)?;
359        } else {
360            shared::compile_info_dynamic(f, flags)?;
361        }
362        f.write_str("\n)")?;
363        //
364        Ok(())
365    }
366
367    fn compile_bindings_body(
368        f: &mut std::fmt::Formatter<'_>,
369        body: &shared::Body<Self>,
370    ) -> std::fmt::Result {
371        if !body.shared_memories.is_empty() {
372            let max_align = body
373                .shared_memories
374                .iter()
375                .map(|smem| smem.align())
376                .max()
377                .unwrap();
378            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
379            // with `extern __shared__ alignas` and doesn't properly parse it.
380            writeln!(
381                f,
382                "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
383            )?;
384        }
385        if body.info_by_ptr {
386            f.write_str("const info_st& info = *info_ptr;\n")?;
387            // Could use `info_ptr + 1` but that seems dirty, so use manual `sizeof` instead
388            writeln!(
389                f,
390                "const {addr}* dynamic_meta = reinterpret_cast<const {addr}*>(
391                    reinterpret_cast<const char*>(info_ptr) + sizeof(info_st)
392                );\n",
393                addr = body.address_type,
394            )?;
395        }
396        Ok(())
397    }
398}
399
400impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
401
402// Cube builtins dialect
403
404impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
405    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        write!(f, "cluster.block_rank()")
407    }
408
409    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410        write!(f, "cluster.block_index().x")
411    }
412
413    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        write!(f, "cluster.block_index().y")
415    }
416
417    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418        write!(f, "cluster.block_index().z")
419    }
420}
421
422// Instructions
423
424impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
425    // sync
426    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
427        writeln!(f, "__syncthreads();\n")
428    }
429
430    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431        writeln!(f, "__syncwarp();\n")
432    }
433
434    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        writeln!(f, "__threadfence();")
436    }
437
438    // unary
439    fn compile_instruction_find_first_set<T: Component<Self>>(
440        f: &mut std::fmt::Formatter<'_>,
441        input: T,
442        out_elem: Elem<Self>,
443    ) -> std::fmt::Result {
444        write!(f, "{out_elem}(")?;
445        match input.elem() {
446            Elem::I32 => write!(f, "__ffs({input})"),
447            Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
448            Elem::I64 => write!(f, "__ffsll({input})"),
449            Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
450            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
451        }?;
452        write!(f, ")")
453    }
454
455    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
456        f: &mut std::fmt::Formatter<'_>,
457        input: T,
458        out_elem: Elem<Self>,
459    ) -> std::fmt::Result {
460        write!(f, "{out_elem}(")?;
461        match input.elem() {
462            Elem::I32 => write!(f, "__clz({input})"),
463            Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
464            Elem::I64 => write!(f, "__clzll({input})"),
465            Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
466            in_elem => write!(
467                f,
468                "{out_elem}(__clz({}) - {})",
469                unary::zero_extend(input),
470                (size_of::<u32>() - in_elem.size()) * 8
471            ),
472        }?;
473        write!(f, ")")
474    }
475
476    fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
477        f: &mut std::fmt::Formatter<'_>,
478        input: T,
479        out_elem: Elem<Self>,
480    ) -> std::fmt::Result {
481        // CUDA doesn't have a direct ctz intrinsic, but __ffs returns 1-indexed position
482        // of the first set bit from LSB (0 if no bit set).
483        // trailing_zeros(x) = x == 0 ? bitwidth : __ffs(x) - 1
484        write!(f, "{out_elem}(")?;
485        match input.elem() {
486            Elem::I32 | Elem::U32 => {
487                write!(f, "({input} == 0 ? 32 : __ffs({input}) - 1)")
488            }
489            Elem::I64 | Elem::U64 => {
490                write!(f, "({input} == 0 ? 64 : __ffsll({input}) - 1)")
491            }
492            in_elem => {
493                let bits = in_elem.size() * 8;
494                let extended = unary::zero_extend(input);
495                write!(f, "({extended} == 0 ? {bits} : __ffs({extended}) - 1)")
496            }
497        }?;
498        write!(f, ")")
499    }
500
501    fn compile_saturating_add(
502        f: &mut std::fmt::Formatter<'_>,
503        lhs: impl Display,
504        rhs: impl Display,
505        item: Item<Self>,
506    ) -> std::fmt::Result {
507        let elem = item.elem();
508        match elem {
509            Elem::I32 => {
510                write!(
511                    f,
512                    r#"[&]() -> {elem} {{
513    {elem} result;
514    asm("add.sat.s32 %0, %1, %2;"
515        : "=r"(result)
516        : "r"({lhs}), "r"({rhs}));
517    return result;
518        }}()"#
519                )
520            }
521            _ => unreachable!("Should be replaced by polyfill"),
522        }
523    }
524
525    fn compile_saturating_sub(
526        f: &mut std::fmt::Formatter<'_>,
527        lhs: impl Display,
528        rhs: impl Display,
529        item: Item<Self>,
530    ) -> std::fmt::Result {
531        let elem = item.elem();
532        // Native instruction only exists for signed int, unsigned should be removed in a preprocessor
533        match elem {
534            Elem::I32 => {
535                write!(
536                    f,
537                    r#"[&]() -> {elem} {{
538    {elem} result;
539    asm("sub.sat.s32 %0, %1, %2;"
540        : "=r"(result)
541        : "r"({lhs}), "r"({rhs}));
542    return result;
543        }}()"#
544                )
545            }
546            _ => unreachable!("Should be replaced by polyfill"),
547        }
548    }
549
550    // others
551    fn compile_instruction_max_function_name(
552        f: &mut std::fmt::Formatter<'_>,
553        item: Item<Self>,
554    ) -> std::fmt::Result {
555        let max = match item.elem() {
556            Elem::F16 | Elem::BF16 => "__hmax",
557            Elem::F16x2 | Elem::BF16x2 => "__hmax2",
558            _ => "max",
559        };
560        write!(f, "{max}")
561    }
562
563    fn compile_instruction_min_function_name(
564        f: &mut std::fmt::Formatter<'_>,
565        item: Item<Self>,
566    ) -> std::fmt::Result {
567        let min = match item.elem() {
568            Elem::F16 | Elem::BF16 => "__hmin",
569            Elem::F16x2 | Elem::BF16x2 => "__hmin2",
570            _ => "min",
571        };
572        write!(f, "{min}")
573    }
574
575    // warp
576    fn compile_warp_shuffle(
577        f: &mut std::fmt::Formatter<'_>,
578        var: &str,
579        source: &str,
580    ) -> std::fmt::Result {
581        write!(f, "__shfl_sync(-1, {var}, {source})")
582    }
583    fn compile_warp_shuffle_xor(
584        f: &mut std::fmt::Formatter<'_>,
585        var: &str,
586        _elem: &Elem<Self>,
587        offset: &str,
588    ) -> std::fmt::Result {
589        write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
590    }
591    fn compile_warp_shuffle_up(
592        f: &mut std::fmt::Formatter<'_>,
593        var: &str,
594        offset: &str,
595    ) -> std::fmt::Result {
596        write!(f, "__shfl_up_sync(-1, {var}, {offset})")
597    }
598    fn compile_warp_shuffle_down(
599        f: &mut std::fmt::Formatter<'_>,
600        var: &str,
601        offset: &str,
602    ) -> std::fmt::Result {
603        write!(f, "__shfl_down_sync(-1, {var}, {offset})")
604    }
605    fn compile_warp_all<T: Component<Self>>(
606        f: &mut std::fmt::Formatter<'_>,
607        input: &T,
608    ) -> std::fmt::Result {
609        write!(f, "__all_sync(-1, {input})")
610    }
611    fn compile_warp_any<T: Component<Self>>(
612        f: &mut std::fmt::Formatter<'_>,
613        input: &T,
614    ) -> std::fmt::Result {
615        write!(f, "__any_sync(-1, {input})")
616    }
617
618    fn compile_warp_ballot(
619        f: &mut std::fmt::Formatter<'_>,
620        input: &Variable<Self>,
621        _out_elem: &Elem<Self>,
622    ) -> std::fmt::Result {
623        write!(f, "__ballot_sync(-1, {input})")
624    }
625
626    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
627        let elem = Elem::<Self>::Bool;
628        let uint32 = Elem::<Self>::U32;
629        // Used to have a wrapper but it has been removed in newer version due to being
630        // "incomplete". We only need the predicate and have a fixed mask, so it's trivial to
631        // implement.
632        writeln!(
633            f,
634            r#"{out} = {elem}([&]() -> {uint32} {{
635    {uint32} pred = 0;
636    asm volatile(
637        "{{\n"
638        "     .reg .pred %%px;\n"
639        "     elect.sync _|%%px, 0xffffffff;\n"
640        "     selp.b32 %0, 1, 0, %%px;\n"
641        "}}\n"
642        : "+r"(pred));
643    return pred;
644        }}());"#
645        )
646    }
647
648    fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
649        write!(f, "__builtin_unreachable();")
650    }
651}
652
653// Coop Matrices dialect
654
655impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
656    fn compile_wmma_includes(
657        f: &mut std::fmt::Formatter<'_>,
658        flags: &Flags<Self>,
659    ) -> std::fmt::Result {
660        M::compile_wmma_includes(f, flags)
661    }
662
663    fn compile_wmma_type_definitions(
664        f: &mut std::fmt::Formatter<'_>,
665        flags: &Flags<Self>,
666    ) -> std::fmt::Result {
667        M::compile_wmma_type_definitions(f, flags)
668    }
669
670    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
671        M::compile_wmma_local_variables(f)
672    }
673
674    fn compile_wmma_fragment_declaration(
675        f: &mut std::fmt::Formatter<'_>,
676        var: &Variable<Self>,
677    ) -> std::fmt::Result {
678        M::compile_wmma_fragment_declaration(f, var)
679    }
680
681    fn compile_wwma_fragment_ident(
682        f: &mut std::fmt::Formatter<'_>,
683        ident: &crate::shared::FragmentIdent<Self>,
684    ) -> std::fmt::Result {
685        M::compile_wwma_fragment_ident(f, ident)
686    }
687
688    fn compile_wmma_fragment_layout(
689        f: &mut std::fmt::Formatter<'_>,
690        layout: &crate::shared::FragmentLayout<Self>,
691    ) -> std::fmt::Result {
692        M::compile_wmma_fragment_layout(f, layout)
693    }
694
695    fn compile_wmma_fragment(
696        f: &mut std::fmt::Formatter<'_>,
697        fragment: &crate::shared::Fragment<Self>,
698    ) -> std::fmt::Result {
699        M::compile_wmma_fragment(f, fragment)
700    }
701
702    fn compile_wmma_instruction(
703        f: &mut std::fmt::Formatter<'_>,
704        instruction: &crate::shared::WmmaInstruction<Self>,
705    ) -> std::fmt::Result {
706        M::compile_wmma_instruction(f, instruction)
707    }
708
709    fn compile_manual_mma(
710        f: &mut std::fmt::Formatter<'_>,
711        mma: ManualMma<Self>,
712    ) -> std::fmt::Result {
713        M::compile_manual_mma(f, mma)
714    }
715
716    fn compile_scaled_mma(
717        f: &mut std::fmt::Formatter<'_>,
718        mma: ManualMma<Self>,
719        scales_a: Variable<Self>,
720        scales_b: Variable<Self>,
721        scales_factor: u32,
722    ) -> std::fmt::Result {
723        M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
724    }
725
726    fn supported_wmma_combinations(
727        arch: &CudaArchitecture,
728    ) -> crate::shared::SupportedMmaCombinations {
729        M::supported_wmma_combinations(arch)
730    }
731
732    fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
733        M::supported_mma_combinations(arch)
734    }
735
736    fn supported_scaled_mma_combinations(
737        arch: &CudaArchitecture,
738    ) -> shared::SupportedScaledMmaCombinations {
739        M::supported_scaled_mma_combinations(arch)
740    }
741}
742
743impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
744    fn processors() -> Vec<Box<dyn Processor>> {
745        vec![
746            Box::new(CudaMmaProcessor),
747            Box::new(SaturatingArithmeticProcessor::new(false)),
748        ]
749    }
750}