Skip to main content

cubecl_cpp/cuda/
dialect.rs

1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{
4    ir::{BarrierLevel, Processor},
5    post_processing::saturating::SaturatingArithmeticProcessor,
6};
7
8use crate::{
9    Dialect,
10    cuda::{
11        extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension, StMatrix},
12        processors::CudaMmaProcessor,
13        ptx::*,
14    },
15    shared::{
16        self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
17        DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
18        DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
19        Variable, WarpInstruction, unary,
20    },
21};
22
23use super::{Extension, arch::CudaArchitecture};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct CudaDialect<M> {
27    _wmma_compiler: PhantomData<M>,
28}
29
30impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
31    type Architecture = CudaArchitecture;
32}
33
34impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
35    type Extension = Extension<Self>;
36
37    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
38        f.write_str("#include <cuda_runtime.h>\n")?;
39        if flags.elem_fp4 {
40            f.write_str("#include <cuda_fp4.h>\n")?;
41        }
42        if flags.elem_fp6 {
43            f.write_str("#include <cuda_fp6.h>\n")?;
44        }
45        if flags.elem_fp8 {
46            f.write_str("#include <cuda_fp8.h>\n")?;
47        }
48        if flags.elem_bf16 {
49            f.write_str("#include <cuda_bf16.h>\n")?;
50        }
51        if flags.elem_f16 {
52            f.write_str("#include <cuda_fp16.h>\n")?;
53        }
54
55        // tf32 conversion function is in mma header
56        if flags.inst_wmma || flags.elem_tf32 {
57            Self::compile_wmma_includes(f, flags)?;
58        }
59
60        if flags.op_pipeline {
61            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
62            f.write_str("#include <cuda/pipeline>\n")?;
63        }
64        if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
65            f.write_str("#include <cooperative_groups.h>\n")?;
66            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
67            f.write_str("#include <cuda/barrier>\n")?;
68        }
69        if flags.inst_ptx_wrappers {
70            f.write_str("#include <cuda/ptx>\n")?;
71        }
72        if flags.inst_tma {
73            f.write_str(
74                "typedef struct CUtensorMap_st {
75alignas(64) unsigned long long int opaque[16];
76} CUtensorMap;\n",
77            )?;
78        }
79        Ok(())
80    }
81
82    fn compile_extensions(
83        f: &mut std::fmt::Formatter<'_>,
84        extensions: &[Self::Extension],
85    ) -> std::fmt::Result {
86        for extension in extensions {
87            match extension {
88                Extension::NoExtension => {}
89                Extension::Mma(mma) => mma.format_extension(f)?,
90            }
91        }
92        Ok(())
93    }
94
95    fn register_instruction_extension(
96        _extensions: &mut Vec<Self::Extension>,
97        _instruction: &Instruction<Self>,
98    ) {
99    }
100
101    fn register_warp_instruction_extension(
102        _extensions: &mut Vec<Self::Extension>,
103        _instruction: &WarpInstruction<Self>,
104    ) {
105    }
106
107    fn register_wmma_instruction_extension(
108        extensions: &mut Vec<Self::Extension>,
109        instruction: &shared::WmmaInstruction<Self>,
110    ) {
111        match instruction {
112            shared::WmmaInstruction::ExecuteManual {
113                shape,
114                frag_a,
115                frag_b,
116                frag_c,
117                frag_d,
118            } => {
119                let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
120                    *shape,
121                    Fragment(frag_a.elem()),
122                    Fragment(frag_b.elem()),
123                    Fragment(frag_c.elem()),
124                    Fragment(frag_d.elem()),
125                )));
126                if !extensions.contains(&ext) {
127                    extensions.push(ext);
128                }
129            }
130            shared::WmmaInstruction::ExecuteScaled {
131                shape,
132                frag_a,
133                frag_b,
134                frag_c,
135                frag_d,
136                scales_a,
137                scales_factor,
138                ..
139            } => {
140                let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
141                    *shape,
142                    Fragment(frag_a.elem()),
143                    Fragment(frag_b.elem()),
144                    Fragment(frag_c.elem()),
145                    Fragment(frag_d.elem()),
146                    scales_a.elem(),
147                    *scales_factor,
148                )));
149                if !extensions.contains(&ext) {
150                    extensions.push(ext);
151                }
152            }
153            shared::WmmaInstruction::LdMatrix {
154                output,
155                factor,
156                transpose,
157                ..
158            } => {
159                let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
160                    output.elem(),
161                    *factor,
162                    *transpose,
163                )));
164                if !extensions.contains(&ext) {
165                    extensions.push(ext);
166                }
167            }
168            shared::WmmaInstruction::StMatrix {
169                registers,
170                factor,
171                transpose,
172                ..
173            } => {
174                let ext = Extension::Mma(MmaExtension::StMatrix(StMatrix::new(
175                    registers.elem(),
176                    *factor,
177                    *transpose,
178                )));
179                if !extensions.contains(&ext) {
180                    extensions.push(ext);
181                }
182            }
183            _ => {}
184        }
185    }
186}
187
188// Types
189
190impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
191    fn item_can_be_optimized() -> bool {
192        true
193    }
194
195    fn compile_type_definitions(
196        f: &mut std::fmt::Formatter<'_>,
197        items: &HashSet<Item<Self>>,
198        scalars: &[(Elem<Self>, usize)],
199        flags: &Flags<Self>,
200    ) -> std::fmt::Result {
201        // All FP4/FP6/FP8 elems map to the same type, so we need to deduplicate them
202        let mut items_deduplicated = HashSet::new();
203
204        for item in items {
205            let mut item = *item;
206            match item.elem() {
207                Elem::FP4(_) => {
208                    item.elem = Elem::FP4(FP4Kind::E2M1);
209                }
210                Elem::FP4x2(_) => {
211                    item.elem = Elem::FP4x2(FP4Kind::E2M1);
212                }
213                Elem::FP6(_) => {
214                    item.elem = Elem::FP6(FP6Kind::E2M3);
215                }
216                Elem::FP6x2(_) => {
217                    item.elem = Elem::FP6x2(FP6Kind::E2M3);
218                }
219                Elem::FP8(_) => {
220                    item.elem = Elem::FP8(FP8Kind::E4M3);
221                }
222                Elem::FP8x2(_) => {
223                    item.elem = Elem::FP8x2(FP8Kind::E4M3);
224                }
225                _ => {}
226            }
227            items_deduplicated.insert(item);
228        }
229
230        shared::type_definitions::<Self>(f)?;
231        shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
232
233        if flags.use_grid_constants {
234            shared::type_scalar_definitions::<Self>(f, scalars)?;
235            shared::type_info_definition::<Self>(f, flags.static_meta_length, flags.address_type)?;
236        }
237
238        if flags.inst_wmma {
239            Self::compile_wmma_type_definitions(f, flags)?;
240        }
241
242        Ok(())
243    }
244
245    fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
246        if flags.inst_tma_im2col {
247            writeln!(f, "{TMA_LOAD_IM2COL}")?;
248        }
249        if flags.inst_async_copy {
250            writeln!(f, "{COPY_ASYNC}")?;
251        }
252        Ok(())
253    }
254
255    fn compile_elem(
256        f: &mut std::fmt::Formatter<'_>,
257        elem: &shared::Elem<Self>,
258        words: bool,
259    ) -> std::fmt::Result {
260        if words {
261            match elem {
262                shared::Elem::F32 => f.write_str("float"),
263                shared::Elem::F64 => f.write_str("double"),
264                shared::Elem::TF32 => f.write_str("float"),
265                shared::Elem::I8 => f.write_str("char"),
266                shared::Elem::I16 => f.write_str("short"),
267                shared::Elem::I32 => f.write_str("int"),
268                shared::Elem::I64 => f.write_str("long"),
269                shared::Elem::U8 => f.write_str("uchar"),
270                shared::Elem::U16 => f.write_str("ushort"),
271                shared::Elem::U32 => f.write_str("uint"),
272                shared::Elem::U64 => f.write_str("ulong"),
273                _ => Self::compile_elem(f, elem, false),
274            }
275        } else {
276            match elem {
277                shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
278                shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
279                shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
280                shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
281                shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
282                shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
283                shared::Elem::F16 => f.write_str("__half"),
284                shared::Elem::F16x2 => f.write_str("__half2"),
285                shared::Elem::F32 => f.write_str("float"),
286                shared::Elem::F64 => f.write_str("double"),
287                shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
288                shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
289                shared::Elem::TF32 => f.write_str("float"),
290                shared::Elem::I8 => f.write_str("int8"),
291                shared::Elem::I16 => f.write_str("int16"),
292                shared::Elem::I32 => f.write_str("int32"),
293                shared::Elem::I64 => f.write_str("int64"),
294                shared::Elem::U8 => f.write_str("uint8"),
295                shared::Elem::U16 => f.write_str("uint16"),
296                shared::Elem::U32 => f.write_str("uint32"),
297                shared::Elem::U64 => f.write_str("uint64"),
298                shared::Elem::Bool => f.write_str("bool"),
299                shared::Elem::Barrier(BarrierLevel::Unit) => {
300                    f.write_str("cuda::barrier<cuda::thread_scope_thread>")
301                }
302                shared::Elem::Barrier(BarrierLevel::Cube) => {
303                    f.write_str("cuda::barrier<cuda::thread_scope_block>")
304                }
305                shared::Elem::Atomic(inner) => write!(f, "{inner}"),
306                shared::Elem::_Dialect(_) => Ok(()),
307            }
308        }
309    }
310
311    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
312        if 1 == item.vectorization {
313            return write!(f, "{}", item.elem);
314        }
315        if item.native {
316            // native types use the word form of types only
317            Self::compile_elem(f, &item.elem, true)?;
318            write!(f, "{}", item.vectorization)
319        } else {
320            write!(f, "{}_{}", item.elem, item.vectorization)
321        }
322    }
323
324    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        Ok(())
326    }
327}
328
329// Kernel argument bindings
330
331impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
332    fn compile_kernel_signature(
333        f: &mut std::fmt::Formatter<'_>,
334        kernel_name: &str,
335        tensor_maps: &[Binding<Self>],
336        buffers: &[Binding<Self>],
337        scalars: &[(Elem<Self>, usize)],
338        flags: &Flags<Self>,
339    ) -> std::fmt::Result {
340        write!(
341            f,
342            "
343
344extern \"C\" __global__ void __launch_bounds__({})",
345            flags.cube_dim.num_elems()
346        )?;
347        if let Some(cluster_dim) = flags.cluster_dim {
348            write!(
349                f,
350                "__cluster_dims__({}, {}, {}) ",
351                cluster_dim.x, cluster_dim.y, cluster_dim.z
352            )?;
353        }
354        writeln!(f, "{kernel_name} (")?;
355        let has_scalars =
356            !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
357        shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
358        if flags.use_grid_constants {
359            shared::compile_scalars_static(f, scalars, flags)?;
360        } else {
361            shared::compile_scalars_dynamic(f, scalars)?;
362        }
363        f.write_str("\n)")?;
364        //
365        Ok(())
366    }
367
368    fn compile_bindings_body(
369        f: &mut std::fmt::Formatter<'_>,
370        body: &shared::Body<Self>,
371    ) -> std::fmt::Result {
372        if !body.shared_memories.is_empty() {
373            let max_align = body
374                .shared_memories
375                .iter()
376                .map(|smem| smem.align())
377                .max()
378                .unwrap();
379            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
380            // with `extern __shared__ alignas` and doesn't properly parse it.
381            writeln!(
382                f,
383                "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
384            )?;
385        }
386        Ok(())
387    }
388}
389
390impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
391
392// Cube builtins dialect
393
394impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
395    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        write!(f, "cluster.block_rank()")
397    }
398
399    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        write!(f, "cluster.block_index().x")
401    }
402
403    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        write!(f, "cluster.block_index().y")
405    }
406
407    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        write!(f, "cluster.block_index().z")
409    }
410}
411
412// Instructions
413
414impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
415    // sync
416    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        writeln!(f, "__syncthreads();\n")
418    }
419
420    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421        writeln!(f, "__syncwarp();\n")
422    }
423
424    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425        writeln!(f, "__threadfence();")
426    }
427
428    // unary
429    fn compile_instruction_find_first_set<T: Component<Self>>(
430        f: &mut std::fmt::Formatter<'_>,
431        input: T,
432        out_elem: Elem<Self>,
433    ) -> std::fmt::Result {
434        write!(f, "{out_elem}(")?;
435        match input.elem() {
436            Elem::I32 => write!(f, "__ffs({input})"),
437            Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
438            Elem::I64 => write!(f, "__ffsll({input})"),
439            Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
440            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
441        }?;
442        write!(f, ")")
443    }
444
445    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
446        f: &mut std::fmt::Formatter<'_>,
447        input: T,
448        out_elem: Elem<Self>,
449    ) -> std::fmt::Result {
450        write!(f, "{out_elem}(")?;
451        match input.elem() {
452            Elem::I32 => write!(f, "__clz({input})"),
453            Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
454            Elem::I64 => write!(f, "__clzll({input})"),
455            Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
456            in_elem => write!(
457                f,
458                "{out_elem}(__clz({}) - {})",
459                unary::zero_extend(input),
460                (size_of::<u32>() - in_elem.size()) * 8
461            ),
462        }?;
463        write!(f, ")")
464    }
465
466    fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
467        f: &mut std::fmt::Formatter<'_>,
468        input: T,
469        out_elem: Elem<Self>,
470    ) -> std::fmt::Result {
471        // CUDA doesn't have a direct ctz intrinsic, but __ffs returns 1-indexed position
472        // of the first set bit from LSB (0 if no bit set).
473        // trailing_zeros(x) = x == 0 ? bitwidth : __ffs(x) - 1
474        write!(f, "{out_elem}(")?;
475        match input.elem() {
476            Elem::I32 | Elem::U32 => {
477                write!(f, "({input} == 0 ? 32 : __ffs({input}) - 1)")
478            }
479            Elem::I64 | Elem::U64 => {
480                write!(f, "({input} == 0 ? 64 : __ffsll({input}) - 1)")
481            }
482            in_elem => {
483                let bits = in_elem.size() * 8;
484                let extended = unary::zero_extend(input);
485                write!(f, "({extended} == 0 ? {bits} : __ffs({extended}) - 1)")
486            }
487        }?;
488        write!(f, ")")
489    }
490
491    fn compile_saturating_add(
492        f: &mut std::fmt::Formatter<'_>,
493        lhs: impl Display,
494        rhs: impl Display,
495        item: Item<Self>,
496    ) -> std::fmt::Result {
497        let elem = item.elem();
498        match elem {
499            Elem::I32 => {
500                write!(
501                    f,
502                    r#"[&]() -> {elem} {{
503    {elem} result;
504    asm("add.sat.s32 %0, %1, %2;"
505        : "=r"(result)
506        : "r"({lhs}), "r"({rhs}));
507    return result;
508        }}()"#
509                )
510            }
511            _ => unreachable!("Should be replaced by polyfill"),
512        }
513    }
514
515    fn compile_saturating_sub(
516        f: &mut std::fmt::Formatter<'_>,
517        lhs: impl Display,
518        rhs: impl Display,
519        item: Item<Self>,
520    ) -> std::fmt::Result {
521        let elem = item.elem();
522        // Native instruction only exists for signed int, unsigned should be removed in a preprocessor
523        match elem {
524            Elem::I32 => {
525                write!(
526                    f,
527                    r#"[&]() -> {elem} {{
528    {elem} result;
529    asm("sub.sat.s32 %0, %1, %2;"
530        : "=r"(result)
531        : "r"({lhs}), "r"({rhs}));
532    return result;
533        }}()"#
534                )
535            }
536            _ => unreachable!("Should be replaced by polyfill"),
537        }
538    }
539
540    // others
541    fn compile_instruction_max_function_name(
542        f: &mut std::fmt::Formatter<'_>,
543        item: Item<Self>,
544    ) -> std::fmt::Result {
545        let max = match item.elem() {
546            Elem::F16 | Elem::BF16 => "__hmax",
547            Elem::F16x2 | Elem::BF16x2 => "__hmax2",
548            _ => "max",
549        };
550        write!(f, "{max}")
551    }
552
553    fn compile_instruction_min_function_name(
554        f: &mut std::fmt::Formatter<'_>,
555        item: Item<Self>,
556    ) -> std::fmt::Result {
557        let min = match item.elem() {
558            Elem::F16 | Elem::BF16 => "__hmin",
559            Elem::F16x2 | Elem::BF16x2 => "__hmin2",
560            _ => "min",
561        };
562        write!(f, "{min}")
563    }
564
565    // warp
566    fn compile_warp_shuffle(
567        f: &mut std::fmt::Formatter<'_>,
568        var: &str,
569        source: &str,
570    ) -> std::fmt::Result {
571        write!(f, "__shfl_sync(-1, {var}, {source})")
572    }
573    fn compile_warp_shuffle_xor(
574        f: &mut std::fmt::Formatter<'_>,
575        var: &str,
576        _elem: &Elem<Self>,
577        offset: &str,
578    ) -> std::fmt::Result {
579        write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
580    }
581    fn compile_warp_shuffle_up(
582        f: &mut std::fmt::Formatter<'_>,
583        var: &str,
584        offset: &str,
585    ) -> std::fmt::Result {
586        write!(f, "__shfl_up_sync(-1, {var}, {offset})")
587    }
588    fn compile_warp_shuffle_down(
589        f: &mut std::fmt::Formatter<'_>,
590        var: &str,
591        offset: &str,
592    ) -> std::fmt::Result {
593        write!(f, "__shfl_down_sync(-1, {var}, {offset})")
594    }
595    fn compile_warp_all<T: Component<Self>>(
596        f: &mut std::fmt::Formatter<'_>,
597        input: &T,
598    ) -> std::fmt::Result {
599        write!(f, "__all_sync(-1, {input})")
600    }
601    fn compile_warp_any<T: Component<Self>>(
602        f: &mut std::fmt::Formatter<'_>,
603        input: &T,
604    ) -> std::fmt::Result {
605        write!(f, "__any_sync(-1, {input})")
606    }
607
608    fn compile_warp_ballot(
609        f: &mut std::fmt::Formatter<'_>,
610        input: &Variable<Self>,
611        _out_elem: &Elem<Self>,
612    ) -> std::fmt::Result {
613        write!(f, "__ballot_sync(-1, {input})")
614    }
615
616    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
617        let elem = Elem::<Self>::Bool;
618        let uint32 = Elem::<Self>::U32;
619        // Used to have a wrapper but it has been removed in newer version due to being
620        // "incomplete". We only need the predicate and have a fixed mask, so it's trivial to
621        // implement.
622        writeln!(
623            f,
624            r#"{out} = {elem}([&]() -> {uint32} {{
625    {uint32} pred = 0;
626    asm volatile(
627        "{{\n"
628        "     .reg .pred %%px;\n"
629        "     elect.sync _|%%px, 0xffffffff;\n"
630        "     selp.b32 %0, 1, 0, %%px;\n"
631        "}}\n"
632        : "+r"(pred));
633    return pred;
634        }}());"#
635        )
636    }
637}
638
639// Coop Matrices dialect
640
641impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
642    fn compile_wmma_includes(
643        f: &mut std::fmt::Formatter<'_>,
644        flags: &Flags<Self>,
645    ) -> std::fmt::Result {
646        M::compile_wmma_includes(f, flags)
647    }
648
649    fn compile_wmma_type_definitions(
650        f: &mut std::fmt::Formatter<'_>,
651        flags: &Flags<Self>,
652    ) -> std::fmt::Result {
653        M::compile_wmma_type_definitions(f, flags)
654    }
655
656    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
657        M::compile_wmma_local_variables(f)
658    }
659
660    fn compile_wmma_fragment_declaration(
661        f: &mut std::fmt::Formatter<'_>,
662        var: &Variable<Self>,
663    ) -> std::fmt::Result {
664        M::compile_wmma_fragment_declaration(f, var)
665    }
666
667    fn compile_wwma_fragment_ident(
668        f: &mut std::fmt::Formatter<'_>,
669        ident: &crate::shared::FragmentIdent<Self>,
670    ) -> std::fmt::Result {
671        M::compile_wwma_fragment_ident(f, ident)
672    }
673
674    fn compile_wmma_fragment_layout(
675        f: &mut std::fmt::Formatter<'_>,
676        layout: &crate::shared::FragmentLayout<Self>,
677    ) -> std::fmt::Result {
678        M::compile_wmma_fragment_layout(f, layout)
679    }
680
681    fn compile_wmma_fragment(
682        f: &mut std::fmt::Formatter<'_>,
683        fragment: &crate::shared::Fragment<Self>,
684    ) -> std::fmt::Result {
685        M::compile_wmma_fragment(f, fragment)
686    }
687
688    fn compile_wmma_instruction(
689        f: &mut std::fmt::Formatter<'_>,
690        instruction: &crate::shared::WmmaInstruction<Self>,
691    ) -> std::fmt::Result {
692        M::compile_wmma_instruction(f, instruction)
693    }
694
695    fn compile_manual_mma(
696        f: &mut std::fmt::Formatter<'_>,
697        mma: ManualMma<Self>,
698    ) -> std::fmt::Result {
699        M::compile_manual_mma(f, mma)
700    }
701
702    fn compile_scaled_mma(
703        f: &mut std::fmt::Formatter<'_>,
704        mma: ManualMma<Self>,
705        scales_a: Variable<Self>,
706        scales_b: Variable<Self>,
707        scales_factor: u32,
708    ) -> std::fmt::Result {
709        M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
710    }
711
712    fn supported_wmma_combinations(
713        arch: &CudaArchitecture,
714    ) -> crate::shared::SupportedMmaCombinations {
715        M::supported_wmma_combinations(arch)
716    }
717
718    fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
719        M::supported_mma_combinations(arch)
720    }
721
722    fn supported_scaled_mma_combinations(
723        arch: &CudaArchitecture,
724    ) -> shared::SupportedScaledMmaCombinations {
725        M::supported_scaled_mma_combinations(arch)
726    }
727}
728
729impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
730    fn processors() -> Vec<Box<dyn Processor>> {
731        vec![
732            Box::new(CudaMmaProcessor),
733            Box::new(SaturatingArithmeticProcessor::new(false)),
734        ]
735    }
736}