cubecl_cpp/cuda/
dialect.rs

1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
4
5use crate::{
6    Dialect,
7    cuda::{
8        extension::{Fragment, MmaExecute, MmaExecuteScaled, MmaExtension},
9        processors::CudaMmaProcessor,
10        ptx::*,
11    },
12    shared::{
13        self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
14        DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
15        DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
16        Variable, WarpInstruction, unary,
17    },
18};
19
20use super::{Extension, arch::CudaArchitecture};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaDialect<M> {
24    _wmma_compiler: PhantomData<M>,
25}
26
27impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
28    type Architecture = CudaArchitecture;
29}
30
31impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
32    type Extension = Extension<Self>;
33
34    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
35        f.write_str("#include <cuda_runtime.h>\n")?;
36        if flags.elem_fp4 {
37            f.write_str("#include <cuda_fp4.h>\n")?;
38        }
39        if flags.elem_fp6 {
40            f.write_str("#include <cuda_fp6.h>\n")?;
41        }
42        if flags.elem_fp8 {
43            f.write_str("#include <cuda_fp8.h>\n")?;
44        }
45        if flags.elem_bf16 {
46            f.write_str("#include <cuda_bf16.h>\n")?;
47        }
48        if flags.elem_f16 {
49            f.write_str("#include <cuda_fp16.h>\n")?;
50        }
51
52        // tf32 conversion function is in mma header
53        if flags.inst_wmma || flags.elem_tf32 {
54            Self::compile_wmma_includes(f, flags)?;
55        }
56
57        if flags.op_pipeline {
58            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
59            f.write_str("#include <cuda/pipeline>\n")?;
60        }
61        if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
62            f.write_str("#include <cooperative_groups.h>\n")?;
63            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
64            f.write_str("#include <cuda/barrier>\n")?;
65        }
66        if flags.inst_ptx_wrappers {
67            f.write_str("#include <cuda/ptx>\n")?;
68        }
69        if flags.inst_tma {
70            f.write_str(
71                "typedef struct CUtensorMap_st {
72alignas(64) unsigned long long int opaque[16];
73} CUtensorMap;\n",
74            )?;
75        }
76        Ok(())
77    }
78
79    fn compile_extensions(
80        f: &mut std::fmt::Formatter<'_>,
81        extensions: &[Self::Extension],
82    ) -> std::fmt::Result {
83        for extension in extensions {
84            match extension {
85                Extension::NoExtension => {}
86                Extension::Mma(mma) => mma.format_extension(f)?,
87            }
88        }
89        Ok(())
90    }
91
92    fn register_instruction_extension(
93        _extensions: &mut Vec<Self::Extension>,
94        _instruction: &Instruction<Self>,
95    ) {
96    }
97
98    fn register_warp_instruction_extension(
99        _extensions: &mut Vec<Self::Extension>,
100        _instruction: &WarpInstruction<Self>,
101    ) {
102    }
103
104    fn register_wmma_instruction_extension(
105        extensions: &mut Vec<Self::Extension>,
106        instruction: &shared::WmmaInstruction<Self>,
107    ) {
108        match instruction {
109            shared::WmmaInstruction::ExecuteManual {
110                shape,
111                frag_a,
112                frag_b,
113                frag_c,
114                frag_d,
115            } => {
116                let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
117                    *shape,
118                    vars_to_frag(frag_a),
119                    vars_to_frag(frag_b),
120                    vars_to_frag(frag_c),
121                    Fragment(frag_d.elem()),
122                )));
123                if !extensions.contains(&ext) {
124                    extensions.push(ext);
125                }
126            }
127            shared::WmmaInstruction::ExecuteScaled {
128                shape,
129                frag_a,
130                frag_b,
131                frag_c,
132                frag_d,
133                scales_a,
134                scales_factor,
135                ..
136            } => {
137                let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
138                    *shape,
139                    vars_to_frag(frag_a),
140                    vars_to_frag(frag_b),
141                    vars_to_frag(frag_c),
142                    Fragment(frag_d.elem()),
143                    scales_a.elem(),
144                    *scales_factor,
145                )));
146                if !extensions.contains(&ext) {
147                    extensions.push(ext);
148                }
149            }
150            _ => {}
151        }
152    }
153}
154
155fn vars_to_frag<D: Dialect>(vars: &[Variable<D>]) -> Fragment<D> {
156    let elem = vars[0].elem();
157    Fragment(elem)
158}
159
160// Types
161
162impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
163    fn item_can_be_optimized() -> bool {
164        true
165    }
166
167    fn compile_type_definitions(
168        f: &mut std::fmt::Formatter<'_>,
169        items: &HashSet<Item<Self>>,
170        scalars: &[(Elem<Self>, usize)],
171        flags: &Flags,
172    ) -> std::fmt::Result {
173        // All FP4/FP6/FP8 elems map to the same type, so we need to deduplicate them
174        let mut items_deduplicated = HashSet::new();
175
176        for item in items {
177            let mut item = *item;
178            match item.elem() {
179                Elem::FP4(_) => {
180                    item.elem = Elem::FP4(FP4Kind::E2M1);
181                }
182                Elem::FP4x2(_) => {
183                    item.elem = Elem::FP4x2(FP4Kind::E2M1);
184                }
185                Elem::FP6(_) => {
186                    item.elem = Elem::FP6(FP6Kind::E2M3);
187                }
188                Elem::FP6x2(_) => {
189                    item.elem = Elem::FP6x2(FP6Kind::E2M3);
190                }
191                Elem::FP8(_) => {
192                    item.elem = Elem::FP8(FP8Kind::E4M3);
193                }
194                Elem::FP8x2(_) => {
195                    item.elem = Elem::FP8x2(FP8Kind::E4M3);
196                }
197                _ => {}
198            }
199            items_deduplicated.insert(item);
200        }
201
202        shared::type_definitions::<Self>(f)?;
203        shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
204
205        if flags.use_grid_constants {
206            shared::type_scalar_definitions::<Self>(f, scalars)?;
207            shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
208        }
209
210        if flags.inst_wmma {
211            Self::compile_wmma_type_definitions(f, flags)?;
212        }
213
214        Ok(())
215    }
216
217    fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
218        if flags.inst_tma_im2col {
219            writeln!(f, "{TMA_LOAD_IM2COL}")?;
220        }
221        Ok(())
222    }
223
224    fn compile_elem(
225        f: &mut std::fmt::Formatter<'_>,
226        elem: &shared::Elem<Self>,
227        words: bool,
228    ) -> std::fmt::Result {
229        if words {
230            match elem {
231                shared::Elem::F32 => f.write_str("float"),
232                shared::Elem::F64 => f.write_str("double"),
233                shared::Elem::TF32 => f.write_str("float"),
234                shared::Elem::I8 => f.write_str("char"),
235                shared::Elem::I16 => f.write_str("short"),
236                shared::Elem::I32 => f.write_str("int"),
237                shared::Elem::I64 => f.write_str("long"),
238                shared::Elem::U8 => f.write_str("uchar"),
239                shared::Elem::U16 => f.write_str("ushort"),
240                shared::Elem::U32 => f.write_str("uint"),
241                shared::Elem::U64 => f.write_str("ulong"),
242                _ => Self::compile_elem(f, elem, false),
243            }
244        } else {
245            match elem {
246                shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
247                shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
248                shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
249                shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
250                shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
251                shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
252                shared::Elem::F16 => f.write_str("__half"),
253                shared::Elem::F16x2 => f.write_str("__half2"),
254                shared::Elem::F32 => f.write_str("float"),
255                shared::Elem::F64 => f.write_str("double"),
256                shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
257                shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
258                shared::Elem::TF32 => f.write_str("float"),
259                shared::Elem::I8 => f.write_str("int8"),
260                shared::Elem::I16 => f.write_str("int16"),
261                shared::Elem::I32 => f.write_str("int32"),
262                shared::Elem::I64 => f.write_str("int64"),
263                shared::Elem::U8 => f.write_str("uint8"),
264                shared::Elem::U16 => f.write_str("uint16"),
265                shared::Elem::U32 => f.write_str("uint32"),
266                shared::Elem::U64 => f.write_str("uint64"),
267                shared::Elem::Bool => f.write_str("bool"),
268                shared::Elem::Atomic(inner) => write!(f, "{inner}"),
269                shared::Elem::_Dialect(_) => Ok(()),
270            }
271        }
272    }
273
274    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
275        if 1 == item.vectorization {
276            return write!(f, "{}", item.elem);
277        }
278        if item.native {
279            // native types use the word form of types only
280            Self::compile_elem(f, &item.elem, true)?;
281            write!(f, "{}", item.vectorization)
282        } else {
283            write!(f, "{}_{}", item.elem, item.vectorization)
284        }
285    }
286
287    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        Ok(())
289    }
290}
291
292// Kernel argument bindings
293
294impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
295    fn compile_kernel_signature(
296        f: &mut std::fmt::Formatter<'_>,
297        kernel_name: &str,
298        tensor_maps: &[Binding<Self>],
299        buffers: &[Binding<Self>],
300        scalars: &[(Elem<Self>, usize)],
301        flags: &Flags,
302    ) -> std::fmt::Result {
303        write!(
304            f,
305            "
306
307extern \"C\" __global__ void __launch_bounds__({})",
308            flags.cube_dim.num_elems()
309        )?;
310        if let Some(cluster_dim) = flags.cluster_dim {
311            write!(
312                f,
313                "__cluster_dims__({}, {}, {}) ",
314                cluster_dim.x, cluster_dim.y, cluster_dim.z
315            )?;
316        }
317        writeln!(f, "{kernel_name} (")?;
318        let has_scalars =
319            !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
320        shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
321        if flags.use_grid_constants {
322            shared::compile_scalars_static(f, scalars, flags)?;
323        } else {
324            shared::compile_scalars_dynamic(f, scalars)?;
325        }
326        f.write_str("\n)")?;
327        //
328        Ok(())
329    }
330
331    fn compile_bindings_body(
332        f: &mut std::fmt::Formatter<'_>,
333        body: &shared::Body<Self>,
334    ) -> std::fmt::Result {
335        if !body.shared_memories.is_empty() {
336            let max_align = body
337                .shared_memories
338                .iter()
339                .map(|smem| smem.align)
340                .max()
341                .unwrap();
342            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
343            // with `extern __shared__ alignas` and doesn't properly parse it.
344            writeln!(
345                f,
346                "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
347            )?;
348        }
349        Ok(())
350    }
351}
352
353impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
354
355// Cube builtins dialect
356
357impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
358    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        write!(f, "cluster.block_rank()")
360    }
361
362    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363        write!(f, "cluster.block_index().x")
364    }
365
366    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        write!(f, "cluster.block_index().y")
368    }
369
370    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        write!(f, "cluster.block_index().z")
372    }
373}
374
375// Instructions
376
377impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
378    // sync
379    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        writeln!(f, "__syncthreads();\n")
381    }
382
383    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        writeln!(f, "__syncwarp();\n")
385    }
386
387    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        writeln!(f, "__threadfence();")
389    }
390
391    // unary
392    fn compile_instruction_find_first_set<T: Component<Self>>(
393        f: &mut std::fmt::Formatter<'_>,
394        input: T,
395        out_elem: Elem<Self>,
396    ) -> std::fmt::Result {
397        write!(f, "{out_elem}(")?;
398        match input.elem() {
399            Elem::I32 => write!(f, "__ffs({input})"),
400            Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
401            Elem::I64 => write!(f, "__ffsll({input})"),
402            Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
403            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
404        }?;
405        write!(f, ")")
406    }
407
408    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
409        f: &mut std::fmt::Formatter<'_>,
410        input: T,
411        out_elem: Elem<Self>,
412    ) -> std::fmt::Result {
413        write!(f, "{out_elem}(")?;
414        match input.elem() {
415            Elem::I32 => write!(f, "__clz({input})"),
416            Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
417            Elem::I64 => write!(f, "__clzll({input})"),
418            Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
419            in_elem => write!(
420                f,
421                "{out_elem}(__clz({}) - {})",
422                unary::zero_extend(input),
423                (size_of::<u32>() - in_elem.size()) * 8
424            ),
425        }?;
426        write!(f, ")")
427    }
428
429    fn compile_saturating_add(
430        f: &mut std::fmt::Formatter<'_>,
431        lhs: impl Display,
432        rhs: impl Display,
433        item: Item<Self>,
434    ) -> std::fmt::Result {
435        let elem = item.elem();
436        match elem {
437            Elem::I32 => {
438                write!(
439                    f,
440                    r#"[&]() -> {elem} {{
441    {elem} result;
442    asm("add.sat.s32 %0, %1, %2;"
443        : "=r"(result)
444        : "r"({lhs}), "r"({rhs}));
445    return result;
446        }}()"#
447                )
448            }
449            _ => unreachable!("Should be replaced by polyfill"),
450        }
451    }
452
453    fn compile_saturating_sub(
454        f: &mut std::fmt::Formatter<'_>,
455        lhs: impl Display,
456        rhs: impl Display,
457        item: Item<Self>,
458    ) -> std::fmt::Result {
459        let elem = item.elem();
460        // Native instruction only exists for signed int, unsigned should be removed in a preprocessor
461        match elem {
462            Elem::I32 => {
463                write!(
464                    f,
465                    r#"[&]() -> {elem} {{
466    {elem} result;
467    asm("sub.sat.s32 %0, %1, %2;"
468        : "=r"(result)
469        : "r"({lhs}), "r"({rhs}));
470    return result;
471        }}()"#
472                )
473            }
474            _ => unreachable!("Should be replaced by polyfill"),
475        }
476    }
477
478    // others
479    fn compile_instruction_max_function_name(
480        f: &mut std::fmt::Formatter<'_>,
481        item: Item<Self>,
482    ) -> std::fmt::Result {
483        let max = match item.elem() {
484            Elem::F16 | Elem::BF16 => "__hmax",
485            Elem::F16x2 | Elem::BF16x2 => "__hmax2",
486            _ => "max",
487        };
488        write!(f, "{max}")
489    }
490
491    fn compile_instruction_min_function_name(
492        f: &mut std::fmt::Formatter<'_>,
493        item: Item<Self>,
494    ) -> std::fmt::Result {
495        let min = match item.elem() {
496            Elem::F16 | Elem::BF16 => "__hmin",
497            Elem::F16x2 | Elem::BF16x2 => "__hmin2",
498            _ => "min",
499        };
500        write!(f, "{min}")
501    }
502
503    // warp
504    fn compile_warp_shuffle(
505        f: &mut std::fmt::Formatter<'_>,
506        var: &str,
507        source: &str,
508    ) -> std::fmt::Result {
509        write!(f, "__shfl_sync(-1, {var}, {source})")
510    }
511    fn compile_warp_shuffle_xor(
512        f: &mut std::fmt::Formatter<'_>,
513        var: &str,
514        _elem: &Elem<Self>,
515        offset: &str,
516    ) -> std::fmt::Result {
517        write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
518    }
519    fn compile_warp_shuffle_up(
520        f: &mut std::fmt::Formatter<'_>,
521        var: &str,
522        offset: &str,
523    ) -> std::fmt::Result {
524        write!(f, "__shfl_up_sync(-1, {var}, {offset})")
525    }
526    fn compile_warp_shuffle_down(
527        f: &mut std::fmt::Formatter<'_>,
528        var: &str,
529        offset: &str,
530    ) -> std::fmt::Result {
531        write!(f, "__shfl_down_sync(-1, {var}, {offset})")
532    }
533    fn compile_warp_all<T: Component<Self>>(
534        f: &mut std::fmt::Formatter<'_>,
535        input: &T,
536    ) -> std::fmt::Result {
537        write!(f, "__all_sync(-1, {input})")
538    }
539    fn compile_warp_any<T: Component<Self>>(
540        f: &mut std::fmt::Formatter<'_>,
541        input: &T,
542    ) -> std::fmt::Result {
543        write!(f, "__any_sync(-1, {input})")
544    }
545
546    fn compile_warp_ballot(
547        f: &mut std::fmt::Formatter<'_>,
548        input: &Variable<Self>,
549        _out_elem: &Elem<Self>,
550    ) -> std::fmt::Result {
551        write!(f, "__ballot_sync(-1, {input})")
552    }
553
554    fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
555        let elem = Elem::<Self>::Bool;
556        let uint32 = Elem::<Self>::U32;
557        // Used to have a wrapper but it has been removed in newer version due to being
558        // "incomplete". We only need the predicate and have a fixed mask, so it's trivial to
559        // implement.
560        writeln!(
561            f,
562            r#"{out} = {elem}([&]() -> {uint32} {{
563    {uint32} pred = 0;
564    asm volatile(
565        "{{\n"
566        "     .reg .pred %%px;\n"
567        "     elect.sync _|%%px, 0xffffffff;\n"
568        "     selp.b32 %0, 1, 0, %%px;\n"
569        "}}\n"
570        : "+r"(pred));
571    return pred;
572        }}());"#
573        )
574    }
575}
576
577// Coop Matrices dialect
578
579impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
580    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
581        M::compile_wmma_includes(f, flags)
582    }
583
584    fn compile_wmma_type_definitions(
585        f: &mut std::fmt::Formatter<'_>,
586        flags: &Flags,
587    ) -> std::fmt::Result {
588        M::compile_wmma_type_definitions(f, flags)
589    }
590
591    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592        M::compile_wmma_local_variables(f)
593    }
594
595    fn compile_wmma_fragment_declaration(
596        f: &mut std::fmt::Formatter<'_>,
597        var: &Variable<Self>,
598    ) -> std::fmt::Result {
599        M::compile_wmma_fragment_declaration(f, var)
600    }
601
602    fn compile_wwma_fragment_ident(
603        f: &mut std::fmt::Formatter<'_>,
604        ident: &crate::shared::FragmentIdent<Self>,
605    ) -> std::fmt::Result {
606        M::compile_wwma_fragment_ident(f, ident)
607    }
608
609    fn compile_wmma_fragment_layout(
610        f: &mut std::fmt::Formatter<'_>,
611        layout: &crate::shared::FragmentLayout<Self>,
612    ) -> std::fmt::Result {
613        M::compile_wmma_fragment_layout(f, layout)
614    }
615
616    fn compile_wmma_fragment(
617        f: &mut std::fmt::Formatter<'_>,
618        fragment: &crate::shared::Fragment<Self>,
619    ) -> std::fmt::Result {
620        M::compile_wmma_fragment(f, fragment)
621    }
622
623    fn compile_wmma_instruction(
624        f: &mut std::fmt::Formatter<'_>,
625        instruction: &crate::shared::WmmaInstruction<Self>,
626    ) -> std::fmt::Result {
627        M::compile_wmma_instruction(f, instruction)
628    }
629
630    fn compile_manual_mma(
631        f: &mut std::fmt::Formatter<'_>,
632        mma: ManualMma<Self>,
633    ) -> std::fmt::Result {
634        M::compile_manual_mma(f, mma)
635    }
636
637    fn compile_scaled_mma(
638        f: &mut std::fmt::Formatter<'_>,
639        mma: ManualMma<Self>,
640        scales_a: Variable<Self>,
641        scales_b: Variable<Self>,
642        scales_factor: u32,
643    ) -> std::fmt::Result {
644        M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
645    }
646
647    fn supported_wmma_combinations(
648        arch: &CudaArchitecture,
649    ) -> crate::shared::SupportedMmaCombinations {
650        M::supported_wmma_combinations(arch)
651    }
652
653    fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
654        M::supported_mma_combinations(arch)
655    }
656
657    fn supported_scaled_mma_combinations(
658        arch: &CudaArchitecture,
659    ) -> shared::SupportedScaledMmaCombinations {
660        M::supported_scaled_mma_combinations(arch)
661    }
662}
663
664impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
665    fn processors() -> Vec<Box<dyn Processor>> {
666        vec![
667            Box::new(CudaMmaProcessor),
668            Box::new(SaturatingArithmeticProcessor::new(false)),
669        ]
670    }
671}