1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{
4 ir::{BarrierLevel, Processor},
5 post_processing::saturating::SaturatingArithmeticProcessor,
6};
7
8use crate::{
9 Dialect,
10 cuda::{
11 extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension, StMatrix},
12 processors::CudaMmaProcessor,
13 ptx::*,
14 },
15 shared::{
16 self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
17 DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
18 DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
19 Variable, WarpInstruction, unary,
20 },
21};
22
23use super::{Extension, arch::CudaArchitecture};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct CudaDialect<M> {
27 _wmma_compiler: PhantomData<M>,
28}
29
30impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
31 type Architecture = CudaArchitecture;
32}
33
34impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
35 type Extension = Extension<Self>;
36
37 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
38 f.write_str("#include <cuda_runtime.h>\n")?;
39 if flags.elem_fp4 {
40 f.write_str("#include <cuda_fp4.h>\n")?;
41 }
42 if flags.elem_fp6 {
43 f.write_str("#include <cuda_fp6.h>\n")?;
44 }
45 if flags.elem_fp8 {
46 f.write_str("#include <cuda_fp8.h>\n")?;
47 }
48 if flags.elem_bf16 {
49 f.write_str("#include <cuda_bf16.h>\n")?;
50 }
51 if flags.elem_f16 {
52 f.write_str("#include <cuda_fp16.h>\n")?;
53 }
54
55 if flags.inst_wmma || flags.elem_tf32 {
57 Self::compile_wmma_includes(f, flags)?;
58 }
59
60 if flags.op_pipeline {
61 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
62 f.write_str("#include <cuda/pipeline>\n")?;
63 }
64 if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
65 f.write_str("#include <cooperative_groups.h>\n")?;
66 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
67 f.write_str("#include <cuda/barrier>\n")?;
68 }
69 if flags.inst_ptx_wrappers {
70 f.write_str("#include <cuda/ptx>\n")?;
71 }
72 if flags.inst_tma {
73 f.write_str(
74 "typedef struct CUtensorMap_st {
75alignas(64) unsigned long long int opaque[16];
76} CUtensorMap;\n",
77 )?;
78 }
79 Ok(())
80 }
81
82 fn compile_extensions(
83 f: &mut std::fmt::Formatter<'_>,
84 extensions: &[Self::Extension],
85 ) -> std::fmt::Result {
86 for extension in extensions {
87 match extension {
88 Extension::NoExtension => {}
89 Extension::Mma(mma) => mma.format_extension(f)?,
90 }
91 }
92 Ok(())
93 }
94
95 fn register_instruction_extension(
96 _extensions: &mut Vec<Self::Extension>,
97 _instruction: &Instruction<Self>,
98 ) {
99 }
100
101 fn register_warp_instruction_extension(
102 _extensions: &mut Vec<Self::Extension>,
103 _instruction: &WarpInstruction<Self>,
104 ) {
105 }
106
107 fn register_wmma_instruction_extension(
108 extensions: &mut Vec<Self::Extension>,
109 instruction: &shared::WmmaInstruction<Self>,
110 ) {
111 match instruction {
112 shared::WmmaInstruction::ExecuteManual {
113 shape,
114 frag_a,
115 frag_b,
116 frag_c,
117 frag_d,
118 } => {
119 let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
120 *shape,
121 Fragment(frag_a.elem()),
122 Fragment(frag_b.elem()),
123 Fragment(frag_c.elem()),
124 Fragment(frag_d.elem()),
125 )));
126 if !extensions.contains(&ext) {
127 extensions.push(ext);
128 }
129 }
130 shared::WmmaInstruction::ExecuteScaled {
131 shape,
132 frag_a,
133 frag_b,
134 frag_c,
135 frag_d,
136 scales_a,
137 scales_factor,
138 ..
139 } => {
140 let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
141 *shape,
142 Fragment(frag_a.elem()),
143 Fragment(frag_b.elem()),
144 Fragment(frag_c.elem()),
145 Fragment(frag_d.elem()),
146 scales_a.elem(),
147 *scales_factor,
148 )));
149 if !extensions.contains(&ext) {
150 extensions.push(ext);
151 }
152 }
153 shared::WmmaInstruction::LdMatrix {
154 output,
155 factor,
156 transpose,
157 ..
158 } => {
159 let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
160 output.elem(),
161 *factor,
162 *transpose,
163 )));
164 if !extensions.contains(&ext) {
165 extensions.push(ext);
166 }
167 }
168 shared::WmmaInstruction::StMatrix {
169 registers,
170 factor,
171 transpose,
172 ..
173 } => {
174 let ext = Extension::Mma(MmaExtension::StMatrix(StMatrix::new(
175 registers.elem(),
176 *factor,
177 *transpose,
178 )));
179 if !extensions.contains(&ext) {
180 extensions.push(ext);
181 }
182 }
183 _ => {}
184 }
185 }
186}
187
188impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
191 fn item_can_be_optimized() -> bool {
192 true
193 }
194
195 fn compile_type_definitions(
196 f: &mut std::fmt::Formatter<'_>,
197 items: &HashSet<Item<Self>>,
198 scalars: &[(Elem<Self>, usize)],
199 flags: &Flags,
200 ) -> std::fmt::Result {
201 let mut items_deduplicated = HashSet::new();
203
204 for item in items {
205 let mut item = *item;
206 match item.elem() {
207 Elem::FP4(_) => {
208 item.elem = Elem::FP4(FP4Kind::E2M1);
209 }
210 Elem::FP4x2(_) => {
211 item.elem = Elem::FP4x2(FP4Kind::E2M1);
212 }
213 Elem::FP6(_) => {
214 item.elem = Elem::FP6(FP6Kind::E2M3);
215 }
216 Elem::FP6x2(_) => {
217 item.elem = Elem::FP6x2(FP6Kind::E2M3);
218 }
219 Elem::FP8(_) => {
220 item.elem = Elem::FP8(FP8Kind::E4M3);
221 }
222 Elem::FP8x2(_) => {
223 item.elem = Elem::FP8x2(FP8Kind::E4M3);
224 }
225 _ => {}
226 }
227 items_deduplicated.insert(item);
228 }
229
230 shared::type_definitions::<Self>(f)?;
231 shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
232
233 if flags.use_grid_constants {
234 shared::type_scalar_definitions::<Self>(f, scalars)?;
235 shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
236 }
237
238 if flags.inst_wmma {
239 Self::compile_wmma_type_definitions(f, flags)?;
240 }
241
242 Ok(())
243 }
244
245 fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
246 if flags.inst_tma_im2col {
247 writeln!(f, "{TMA_LOAD_IM2COL}")?;
248 }
249 if flags.inst_async_copy {
250 writeln!(f, "{COPY_ASYNC}")?;
251 }
252 Ok(())
253 }
254
255 fn compile_elem(
256 f: &mut std::fmt::Formatter<'_>,
257 elem: &shared::Elem<Self>,
258 words: bool,
259 ) -> std::fmt::Result {
260 if words {
261 match elem {
262 shared::Elem::F32 => f.write_str("float"),
263 shared::Elem::F64 => f.write_str("double"),
264 shared::Elem::TF32 => f.write_str("float"),
265 shared::Elem::I8 => f.write_str("char"),
266 shared::Elem::I16 => f.write_str("short"),
267 shared::Elem::I32 => f.write_str("int"),
268 shared::Elem::I64 => f.write_str("long"),
269 shared::Elem::U8 => f.write_str("uchar"),
270 shared::Elem::U16 => f.write_str("ushort"),
271 shared::Elem::U32 => f.write_str("uint"),
272 shared::Elem::U64 => f.write_str("ulong"),
273 _ => Self::compile_elem(f, elem, false),
274 }
275 } else {
276 match elem {
277 shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
278 shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
279 shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
280 shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
281 shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
282 shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
283 shared::Elem::F16 => f.write_str("__half"),
284 shared::Elem::F16x2 => f.write_str("__half2"),
285 shared::Elem::F32 => f.write_str("float"),
286 shared::Elem::F64 => f.write_str("double"),
287 shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
288 shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
289 shared::Elem::TF32 => f.write_str("float"),
290 shared::Elem::I8 => f.write_str("int8"),
291 shared::Elem::I16 => f.write_str("int16"),
292 shared::Elem::I32 => f.write_str("int32"),
293 shared::Elem::I64 => f.write_str("int64"),
294 shared::Elem::U8 => f.write_str("uint8"),
295 shared::Elem::U16 => f.write_str("uint16"),
296 shared::Elem::U32 => f.write_str("uint32"),
297 shared::Elem::U64 => f.write_str("uint64"),
298 shared::Elem::Bool => f.write_str("bool"),
299 shared::Elem::Barrier(BarrierLevel::Unit) => {
300 f.write_str("cuda::barrier<cuda::thread_scope_thread>")
301 }
302 shared::Elem::Barrier(BarrierLevel::Cube) => {
303 f.write_str("cuda::barrier<cuda::thread_scope_block>")
304 }
305 shared::Elem::Atomic(inner) => write!(f, "{inner}"),
306 shared::Elem::_Dialect(_) => Ok(()),
307 }
308 }
309 }
310
311 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
312 if 1 == item.vectorization {
313 return write!(f, "{}", item.elem);
314 }
315 if item.native {
316 Self::compile_elem(f, &item.elem, true)?;
318 write!(f, "{}", item.vectorization)
319 } else {
320 write!(f, "{}_{}", item.elem, item.vectorization)
321 }
322 }
323
324 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 Ok(())
326 }
327}
328
329impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
332 fn compile_kernel_signature(
333 f: &mut std::fmt::Formatter<'_>,
334 kernel_name: &str,
335 tensor_maps: &[Binding<Self>],
336 buffers: &[Binding<Self>],
337 scalars: &[(Elem<Self>, usize)],
338 flags: &Flags,
339 ) -> std::fmt::Result {
340 write!(
341 f,
342 "
343
344extern \"C\" __global__ void __launch_bounds__({})",
345 flags.cube_dim.num_elems()
346 )?;
347 if let Some(cluster_dim) = flags.cluster_dim {
348 write!(
349 f,
350 "__cluster_dims__({}, {}, {}) ",
351 cluster_dim.x, cluster_dim.y, cluster_dim.z
352 )?;
353 }
354 writeln!(f, "{kernel_name} (")?;
355 let has_scalars =
356 !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
357 shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
358 if flags.use_grid_constants {
359 shared::compile_scalars_static(f, scalars, flags)?;
360 } else {
361 shared::compile_scalars_dynamic(f, scalars)?;
362 }
363 f.write_str("\n)")?;
364 Ok(())
366 }
367
368 fn compile_bindings_body(
369 f: &mut std::fmt::Formatter<'_>,
370 body: &shared::Body<Self>,
371 ) -> std::fmt::Result {
372 if !body.shared_memories.is_empty() {
373 let max_align = body
374 .shared_memories
375 .iter()
376 .map(|smem| smem.align())
377 .max()
378 .unwrap();
379 writeln!(
382 f,
383 "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
384 )?;
385 }
386 Ok(())
387 }
388}
389
390impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
391
392impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
395 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 write!(f, "cluster.block_rank()")
397 }
398
399 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400 write!(f, "cluster.block_index().x")
401 }
402
403 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 write!(f, "cluster.block_index().y")
405 }
406
407 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 write!(f, "cluster.block_index().z")
409 }
410}
411
412impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
415 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 writeln!(f, "__syncthreads();\n")
418 }
419
420 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421 writeln!(f, "__syncwarp();\n")
422 }
423
424 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425 writeln!(f, "__threadfence();")
426 }
427
428 fn compile_instruction_find_first_set<T: Component<Self>>(
430 f: &mut std::fmt::Formatter<'_>,
431 input: T,
432 out_elem: Elem<Self>,
433 ) -> std::fmt::Result {
434 write!(f, "{out_elem}(")?;
435 match input.elem() {
436 Elem::I32 => write!(f, "__ffs({input})"),
437 Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
438 Elem::I64 => write!(f, "__ffsll({input})"),
439 Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
440 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
441 }?;
442 write!(f, ")")
443 }
444
445 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
446 f: &mut std::fmt::Formatter<'_>,
447 input: T,
448 out_elem: Elem<Self>,
449 ) -> std::fmt::Result {
450 write!(f, "{out_elem}(")?;
451 match input.elem() {
452 Elem::I32 => write!(f, "__clz({input})"),
453 Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
454 Elem::I64 => write!(f, "__clzll({input})"),
455 Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
456 in_elem => write!(
457 f,
458 "{out_elem}(__clz({}) - {})",
459 unary::zero_extend(input),
460 (size_of::<u32>() - in_elem.size()) * 8
461 ),
462 }?;
463 write!(f, ")")
464 }
465
466 fn compile_saturating_add(
467 f: &mut std::fmt::Formatter<'_>,
468 lhs: impl Display,
469 rhs: impl Display,
470 item: Item<Self>,
471 ) -> std::fmt::Result {
472 let elem = item.elem();
473 match elem {
474 Elem::I32 => {
475 write!(
476 f,
477 r#"[&]() -> {elem} {{
478 {elem} result;
479 asm("add.sat.s32 %0, %1, %2;"
480 : "=r"(result)
481 : "r"({lhs}), "r"({rhs}));
482 return result;
483 }}()"#
484 )
485 }
486 _ => unreachable!("Should be replaced by polyfill"),
487 }
488 }
489
490 fn compile_saturating_sub(
491 f: &mut std::fmt::Formatter<'_>,
492 lhs: impl Display,
493 rhs: impl Display,
494 item: Item<Self>,
495 ) -> std::fmt::Result {
496 let elem = item.elem();
497 match elem {
499 Elem::I32 => {
500 write!(
501 f,
502 r#"[&]() -> {elem} {{
503 {elem} result;
504 asm("sub.sat.s32 %0, %1, %2;"
505 : "=r"(result)
506 : "r"({lhs}), "r"({rhs}));
507 return result;
508 }}()"#
509 )
510 }
511 _ => unreachable!("Should be replaced by polyfill"),
512 }
513 }
514
515 fn compile_instruction_max_function_name(
517 f: &mut std::fmt::Formatter<'_>,
518 item: Item<Self>,
519 ) -> std::fmt::Result {
520 let max = match item.elem() {
521 Elem::F16 | Elem::BF16 => "__hmax",
522 Elem::F16x2 | Elem::BF16x2 => "__hmax2",
523 _ => "max",
524 };
525 write!(f, "{max}")
526 }
527
528 fn compile_instruction_min_function_name(
529 f: &mut std::fmt::Formatter<'_>,
530 item: Item<Self>,
531 ) -> std::fmt::Result {
532 let min = match item.elem() {
533 Elem::F16 | Elem::BF16 => "__hmin",
534 Elem::F16x2 | Elem::BF16x2 => "__hmin2",
535 _ => "min",
536 };
537 write!(f, "{min}")
538 }
539
540 fn compile_warp_shuffle(
542 f: &mut std::fmt::Formatter<'_>,
543 var: &str,
544 source: &str,
545 ) -> std::fmt::Result {
546 write!(f, "__shfl_sync(-1, {var}, {source})")
547 }
548 fn compile_warp_shuffle_xor(
549 f: &mut std::fmt::Formatter<'_>,
550 var: &str,
551 _elem: &Elem<Self>,
552 offset: &str,
553 ) -> std::fmt::Result {
554 write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
555 }
556 fn compile_warp_shuffle_up(
557 f: &mut std::fmt::Formatter<'_>,
558 var: &str,
559 offset: &str,
560 ) -> std::fmt::Result {
561 write!(f, "__shfl_up_sync(-1, {var}, {offset})")
562 }
563 fn compile_warp_shuffle_down(
564 f: &mut std::fmt::Formatter<'_>,
565 var: &str,
566 offset: &str,
567 ) -> std::fmt::Result {
568 write!(f, "__shfl_down_sync(-1, {var}, {offset})")
569 }
570 fn compile_warp_all<T: Component<Self>>(
571 f: &mut std::fmt::Formatter<'_>,
572 input: &T,
573 ) -> std::fmt::Result {
574 write!(f, "__all_sync(-1, {input})")
575 }
576 fn compile_warp_any<T: Component<Self>>(
577 f: &mut std::fmt::Formatter<'_>,
578 input: &T,
579 ) -> std::fmt::Result {
580 write!(f, "__any_sync(-1, {input})")
581 }
582
583 fn compile_warp_ballot(
584 f: &mut std::fmt::Formatter<'_>,
585 input: &Variable<Self>,
586 _out_elem: &Elem<Self>,
587 ) -> std::fmt::Result {
588 write!(f, "__ballot_sync(-1, {input})")
589 }
590
591 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
592 let elem = Elem::<Self>::Bool;
593 let uint32 = Elem::<Self>::U32;
594 writeln!(
598 f,
599 r#"{out} = {elem}([&]() -> {uint32} {{
600 {uint32} pred = 0;
601 asm volatile(
602 "{{\n"
603 " .reg .pred %%px;\n"
604 " elect.sync _|%%px, 0xffffffff;\n"
605 " selp.b32 %0, 1, 0, %%px;\n"
606 "}}\n"
607 : "+r"(pred));
608 return pred;
609 }}());"#
610 )
611 }
612}
613
614impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
617 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
618 M::compile_wmma_includes(f, flags)
619 }
620
621 fn compile_wmma_type_definitions(
622 f: &mut std::fmt::Formatter<'_>,
623 flags: &Flags,
624 ) -> std::fmt::Result {
625 M::compile_wmma_type_definitions(f, flags)
626 }
627
628 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629 M::compile_wmma_local_variables(f)
630 }
631
632 fn compile_wmma_fragment_declaration(
633 f: &mut std::fmt::Formatter<'_>,
634 var: &Variable<Self>,
635 ) -> std::fmt::Result {
636 M::compile_wmma_fragment_declaration(f, var)
637 }
638
639 fn compile_wwma_fragment_ident(
640 f: &mut std::fmt::Formatter<'_>,
641 ident: &crate::shared::FragmentIdent<Self>,
642 ) -> std::fmt::Result {
643 M::compile_wwma_fragment_ident(f, ident)
644 }
645
646 fn compile_wmma_fragment_layout(
647 f: &mut std::fmt::Formatter<'_>,
648 layout: &crate::shared::FragmentLayout<Self>,
649 ) -> std::fmt::Result {
650 M::compile_wmma_fragment_layout(f, layout)
651 }
652
653 fn compile_wmma_fragment(
654 f: &mut std::fmt::Formatter<'_>,
655 fragment: &crate::shared::Fragment<Self>,
656 ) -> std::fmt::Result {
657 M::compile_wmma_fragment(f, fragment)
658 }
659
660 fn compile_wmma_instruction(
661 f: &mut std::fmt::Formatter<'_>,
662 instruction: &crate::shared::WmmaInstruction<Self>,
663 ) -> std::fmt::Result {
664 M::compile_wmma_instruction(f, instruction)
665 }
666
667 fn compile_manual_mma(
668 f: &mut std::fmt::Formatter<'_>,
669 mma: ManualMma<Self>,
670 ) -> std::fmt::Result {
671 M::compile_manual_mma(f, mma)
672 }
673
674 fn compile_scaled_mma(
675 f: &mut std::fmt::Formatter<'_>,
676 mma: ManualMma<Self>,
677 scales_a: Variable<Self>,
678 scales_b: Variable<Self>,
679 scales_factor: u32,
680 ) -> std::fmt::Result {
681 M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
682 }
683
684 fn supported_wmma_combinations(
685 arch: &CudaArchitecture,
686 ) -> crate::shared::SupportedMmaCombinations {
687 M::supported_wmma_combinations(arch)
688 }
689
690 fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
691 M::supported_mma_combinations(arch)
692 }
693
694 fn supported_scaled_mma_combinations(
695 arch: &CudaArchitecture,
696 ) -> shared::SupportedScaledMmaCombinations {
697 M::supported_scaled_mma_combinations(arch)
698 }
699}
700
701impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
702 fn processors() -> Vec<Box<dyn Processor>> {
703 vec![
704 Box::new(CudaMmaProcessor),
705 Box::new(SaturatingArithmeticProcessor::new(false)),
706 ]
707 }
708}