1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{
4 ir::{BarrierLevel, Processor},
5 post_processing::saturating::SaturatingArithmeticProcessor,
6};
7
8use crate::{
9 Dialect,
10 cuda::{
11 extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension, StMatrix},
12 processors::CudaMmaProcessor,
13 ptx::*,
14 },
15 shared::{
16 self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
17 DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
18 DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
19 Variable, WarpInstruction, unary,
20 },
21};
22
23use super::{Extension, arch::CudaArchitecture};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct CudaDialect<M> {
27 _wmma_compiler: PhantomData<M>,
28}
29
30impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
31 type Architecture = CudaArchitecture;
32}
33
34impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
35 type Extension = Extension<Self>;
36
37 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
38 f.write_str("#include <cuda_runtime.h>\n")?;
39 if flags.elem_fp4 {
40 f.write_str("#include <cuda_fp4.h>\n")?;
41 }
42 if flags.elem_fp6 {
43 f.write_str("#include <cuda_fp6.h>\n")?;
44 }
45 if flags.elem_fp8 {
46 f.write_str("#include <cuda_fp8.h>\n")?;
47 }
48 if flags.elem_bf16 {
49 f.write_str("#include <cuda_bf16.h>\n")?;
50 }
51 if flags.elem_f16 {
52 f.write_str("#include <cuda_fp16.h>\n")?;
53 }
54
55 if flags.inst_wmma || flags.elem_tf32 {
57 Self::compile_wmma_includes(f, flags)?;
58 }
59
60 if flags.op_pipeline {
61 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
62 f.write_str("#include <cuda/pipeline>\n")?;
63 }
64 if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
65 f.write_str("#include <cooperative_groups.h>\n")?;
66 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
67 f.write_str("#include <cuda/barrier>\n")?;
68 }
69 if flags.inst_ptx_wrappers {
70 f.write_str("#include <cuda/ptx>\n")?;
71 }
72 if flags.inst_tma {
73 f.write_str(
74 "typedef struct CUtensorMap_st {
75alignas(64) unsigned long long int opaque[16];
76} CUtensorMap;\n",
77 )?;
78 }
79 Ok(())
80 }
81
82 fn compile_extensions(
83 f: &mut std::fmt::Formatter<'_>,
84 extensions: &[Self::Extension],
85 ) -> std::fmt::Result {
86 for extension in extensions {
87 match extension {
88 Extension::NoExtension => {}
89 Extension::Mma(mma) => mma.format_extension(f)?,
90 }
91 }
92 Ok(())
93 }
94
95 fn register_instruction_extension(
96 _extensions: &mut Vec<Self::Extension>,
97 _instruction: &Instruction<Self>,
98 ) {
99 }
100
101 fn register_warp_instruction_extension(
102 _extensions: &mut Vec<Self::Extension>,
103 _instruction: &WarpInstruction<Self>,
104 ) {
105 }
106
107 fn register_wmma_instruction_extension(
108 extensions: &mut Vec<Self::Extension>,
109 instruction: &shared::WmmaInstruction<Self>,
110 ) {
111 match instruction {
112 shared::WmmaInstruction::ExecuteManual {
113 shape,
114 frag_a,
115 frag_b,
116 frag_c,
117 frag_d,
118 } => {
119 let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
120 *shape,
121 Fragment(frag_a.elem()),
122 Fragment(frag_b.elem()),
123 Fragment(frag_c.elem()),
124 Fragment(frag_d.elem()),
125 )));
126 if !extensions.contains(&ext) {
127 extensions.push(ext);
128 }
129 }
130 shared::WmmaInstruction::ExecuteScaled {
131 shape,
132 frag_a,
133 frag_b,
134 frag_c,
135 frag_d,
136 scales_a,
137 scales_factor,
138 ..
139 } => {
140 let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
141 *shape,
142 Fragment(frag_a.elem()),
143 Fragment(frag_b.elem()),
144 Fragment(frag_c.elem()),
145 Fragment(frag_d.elem()),
146 scales_a.elem(),
147 *scales_factor,
148 )));
149 if !extensions.contains(&ext) {
150 extensions.push(ext);
151 }
152 }
153 shared::WmmaInstruction::LdMatrix {
154 output,
155 factor,
156 transpose,
157 ..
158 } => {
159 let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
160 output.elem(),
161 *factor,
162 *transpose,
163 )));
164 if !extensions.contains(&ext) {
165 extensions.push(ext);
166 }
167 }
168 shared::WmmaInstruction::StMatrix {
169 registers,
170 factor,
171 transpose,
172 ..
173 } => {
174 let ext = Extension::Mma(MmaExtension::StMatrix(StMatrix::new(
175 registers.elem(),
176 *factor,
177 *transpose,
178 )));
179 if !extensions.contains(&ext) {
180 extensions.push(ext);
181 }
182 }
183 _ => {}
184 }
185 }
186}
187
188impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
191 fn item_can_be_optimized() -> bool {
192 true
193 }
194
195 fn compile_type_definitions(
196 f: &mut std::fmt::Formatter<'_>,
197 items: &HashSet<Item<Self>>,
198 scalars: &[(Elem<Self>, usize)],
199 flags: &Flags<Self>,
200 ) -> std::fmt::Result {
201 let mut items_deduplicated = HashSet::new();
203
204 for item in items {
205 let mut item = *item;
206 match item.elem() {
207 Elem::FP4(_) => {
208 item.elem = Elem::FP4(FP4Kind::E2M1);
209 }
210 Elem::FP4x2(_) => {
211 item.elem = Elem::FP4x2(FP4Kind::E2M1);
212 }
213 Elem::FP6(_) => {
214 item.elem = Elem::FP6(FP6Kind::E2M3);
215 }
216 Elem::FP6x2(_) => {
217 item.elem = Elem::FP6x2(FP6Kind::E2M3);
218 }
219 Elem::FP8(_) => {
220 item.elem = Elem::FP8(FP8Kind::E4M3);
221 }
222 Elem::FP8x2(_) => {
223 item.elem = Elem::FP8x2(FP8Kind::E4M3);
224 }
225 _ => {}
226 }
227 items_deduplicated.insert(item);
228 }
229
230 shared::type_definitions::<Self>(f)?;
231 shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
232
233 if flags.use_grid_constants {
234 shared::type_scalar_definitions::<Self>(f, scalars)?;
235 shared::type_info_definition::<Self>(f, flags.static_meta_length, flags.address_type)?;
236 }
237
238 if flags.inst_wmma {
239 Self::compile_wmma_type_definitions(f, flags)?;
240 }
241
242 Ok(())
243 }
244
245 fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
246 if flags.inst_tma_im2col {
247 writeln!(f, "{TMA_LOAD_IM2COL}")?;
248 }
249 if flags.inst_async_copy {
250 writeln!(f, "{COPY_ASYNC}")?;
251 }
252 Ok(())
253 }
254
255 fn compile_elem(
256 f: &mut std::fmt::Formatter<'_>,
257 elem: &shared::Elem<Self>,
258 words: bool,
259 ) -> std::fmt::Result {
260 if words {
261 match elem {
262 shared::Elem::F32 => f.write_str("float"),
263 shared::Elem::F64 => f.write_str("double"),
264 shared::Elem::TF32 => f.write_str("float"),
265 shared::Elem::I8 => f.write_str("char"),
266 shared::Elem::I16 => f.write_str("short"),
267 shared::Elem::I32 => f.write_str("int"),
268 shared::Elem::I64 => f.write_str("long"),
269 shared::Elem::U8 => f.write_str("uchar"),
270 shared::Elem::U16 => f.write_str("ushort"),
271 shared::Elem::U32 => f.write_str("uint"),
272 shared::Elem::U64 => f.write_str("ulong"),
273 _ => Self::compile_elem(f, elem, false),
274 }
275 } else {
276 match elem {
277 shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
278 shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
279 shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
280 shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
281 shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
282 shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
283 shared::Elem::F16 => f.write_str("__half"),
284 shared::Elem::F16x2 => f.write_str("__half2"),
285 shared::Elem::F32 => f.write_str("float"),
286 shared::Elem::F64 => f.write_str("double"),
287 shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
288 shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
289 shared::Elem::TF32 => f.write_str("float"),
290 shared::Elem::I8 => f.write_str("int8"),
291 shared::Elem::I16 => f.write_str("int16"),
292 shared::Elem::I32 => f.write_str("int32"),
293 shared::Elem::I64 => f.write_str("int64"),
294 shared::Elem::U8 => f.write_str("uint8"),
295 shared::Elem::U16 => f.write_str("uint16"),
296 shared::Elem::U32 => f.write_str("uint32"),
297 shared::Elem::U64 => f.write_str("uint64"),
298 shared::Elem::Bool => f.write_str("bool"),
299 shared::Elem::Barrier(BarrierLevel::Unit) => {
300 f.write_str("cuda::barrier<cuda::thread_scope_thread>")
301 }
302 shared::Elem::Barrier(BarrierLevel::Cube) => {
303 f.write_str("cuda::barrier<cuda::thread_scope_block>")
304 }
305 shared::Elem::Atomic(inner) => write!(f, "{inner}"),
306 shared::Elem::_Dialect(_) => Ok(()),
307 }
308 }
309 }
310
311 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
312 if 1 == item.vectorization {
313 return write!(f, "{}", item.elem);
314 }
315 if item.native {
316 Self::compile_elem(f, &item.elem, true)?;
318 write!(f, "{}", item.vectorization)
319 } else {
320 write!(f, "{}_{}", item.elem, item.vectorization)
321 }
322 }
323
324 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 Ok(())
326 }
327}
328
329impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
332 fn compile_kernel_signature(
333 f: &mut std::fmt::Formatter<'_>,
334 kernel_name: &str,
335 tensor_maps: &[Binding<Self>],
336 buffers: &[Binding<Self>],
337 scalars: &[(Elem<Self>, usize)],
338 flags: &Flags<Self>,
339 ) -> std::fmt::Result {
340 write!(
341 f,
342 "
343
344extern \"C\" __global__ void __launch_bounds__({})",
345 flags.cube_dim.num_elems()
346 )?;
347 if let Some(cluster_dim) = flags.cluster_dim {
348 write!(
349 f,
350 "__cluster_dims__({}, {}, {}) ",
351 cluster_dim.x, cluster_dim.y, cluster_dim.z
352 )?;
353 }
354 writeln!(f, "{kernel_name} (")?;
355 let has_scalars =
356 !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
357 shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
358 if flags.use_grid_constants {
359 shared::compile_scalars_static(f, scalars, flags)?;
360 } else {
361 shared::compile_scalars_dynamic(f, scalars)?;
362 }
363 f.write_str("\n)")?;
364 Ok(())
366 }
367
368 fn compile_bindings_body(
369 f: &mut std::fmt::Formatter<'_>,
370 body: &shared::Body<Self>,
371 ) -> std::fmt::Result {
372 if !body.shared_memories.is_empty() {
373 let max_align = body
374 .shared_memories
375 .iter()
376 .map(|smem| smem.align())
377 .max()
378 .unwrap();
379 writeln!(
382 f,
383 "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
384 )?;
385 }
386 Ok(())
387 }
388}
389
390impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
391
392impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
395 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 write!(f, "cluster.block_rank()")
397 }
398
399 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400 write!(f, "cluster.block_index().x")
401 }
402
403 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 write!(f, "cluster.block_index().y")
405 }
406
407 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 write!(f, "cluster.block_index().z")
409 }
410}
411
412impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
415 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 writeln!(f, "__syncthreads();\n")
418 }
419
420 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
421 writeln!(f, "__syncwarp();\n")
422 }
423
424 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425 writeln!(f, "__threadfence();")
426 }
427
428 fn compile_instruction_find_first_set<T: Component<Self>>(
430 f: &mut std::fmt::Formatter<'_>,
431 input: T,
432 out_elem: Elem<Self>,
433 ) -> std::fmt::Result {
434 write!(f, "{out_elem}(")?;
435 match input.elem() {
436 Elem::I32 => write!(f, "__ffs({input})"),
437 Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
438 Elem::I64 => write!(f, "__ffsll({input})"),
439 Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
440 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
441 }?;
442 write!(f, ")")
443 }
444
445 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
446 f: &mut std::fmt::Formatter<'_>,
447 input: T,
448 out_elem: Elem<Self>,
449 ) -> std::fmt::Result {
450 write!(f, "{out_elem}(")?;
451 match input.elem() {
452 Elem::I32 => write!(f, "__clz({input})"),
453 Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
454 Elem::I64 => write!(f, "__clzll({input})"),
455 Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
456 in_elem => write!(
457 f,
458 "{out_elem}(__clz({}) - {})",
459 unary::zero_extend(input),
460 (size_of::<u32>() - in_elem.size()) * 8
461 ),
462 }?;
463 write!(f, ")")
464 }
465
466 fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
467 f: &mut std::fmt::Formatter<'_>,
468 input: T,
469 out_elem: Elem<Self>,
470 ) -> std::fmt::Result {
471 write!(f, "{out_elem}(")?;
475 match input.elem() {
476 Elem::I32 | Elem::U32 => {
477 write!(f, "({input} == 0 ? 32 : __ffs({input}) - 1)")
478 }
479 Elem::I64 | Elem::U64 => {
480 write!(f, "({input} == 0 ? 64 : __ffsll({input}) - 1)")
481 }
482 in_elem => {
483 let bits = in_elem.size() * 8;
484 let extended = unary::zero_extend(input);
485 write!(f, "({extended} == 0 ? {bits} : __ffs({extended}) - 1)")
486 }
487 }?;
488 write!(f, ")")
489 }
490
491 fn compile_saturating_add(
492 f: &mut std::fmt::Formatter<'_>,
493 lhs: impl Display,
494 rhs: impl Display,
495 item: Item<Self>,
496 ) -> std::fmt::Result {
497 let elem = item.elem();
498 match elem {
499 Elem::I32 => {
500 write!(
501 f,
502 r#"[&]() -> {elem} {{
503 {elem} result;
504 asm("add.sat.s32 %0, %1, %2;"
505 : "=r"(result)
506 : "r"({lhs}), "r"({rhs}));
507 return result;
508 }}()"#
509 )
510 }
511 _ => unreachable!("Should be replaced by polyfill"),
512 }
513 }
514
515 fn compile_saturating_sub(
516 f: &mut std::fmt::Formatter<'_>,
517 lhs: impl Display,
518 rhs: impl Display,
519 item: Item<Self>,
520 ) -> std::fmt::Result {
521 let elem = item.elem();
522 match elem {
524 Elem::I32 => {
525 write!(
526 f,
527 r#"[&]() -> {elem} {{
528 {elem} result;
529 asm("sub.sat.s32 %0, %1, %2;"
530 : "=r"(result)
531 : "r"({lhs}), "r"({rhs}));
532 return result;
533 }}()"#
534 )
535 }
536 _ => unreachable!("Should be replaced by polyfill"),
537 }
538 }
539
540 fn compile_instruction_max_function_name(
542 f: &mut std::fmt::Formatter<'_>,
543 item: Item<Self>,
544 ) -> std::fmt::Result {
545 let max = match item.elem() {
546 Elem::F16 | Elem::BF16 => "__hmax",
547 Elem::F16x2 | Elem::BF16x2 => "__hmax2",
548 _ => "max",
549 };
550 write!(f, "{max}")
551 }
552
553 fn compile_instruction_min_function_name(
554 f: &mut std::fmt::Formatter<'_>,
555 item: Item<Self>,
556 ) -> std::fmt::Result {
557 let min = match item.elem() {
558 Elem::F16 | Elem::BF16 => "__hmin",
559 Elem::F16x2 | Elem::BF16x2 => "__hmin2",
560 _ => "min",
561 };
562 write!(f, "{min}")
563 }
564
565 fn compile_warp_shuffle(
567 f: &mut std::fmt::Formatter<'_>,
568 var: &str,
569 source: &str,
570 ) -> std::fmt::Result {
571 write!(f, "__shfl_sync(-1, {var}, {source})")
572 }
573 fn compile_warp_shuffle_xor(
574 f: &mut std::fmt::Formatter<'_>,
575 var: &str,
576 _elem: &Elem<Self>,
577 offset: &str,
578 ) -> std::fmt::Result {
579 write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
580 }
581 fn compile_warp_shuffle_up(
582 f: &mut std::fmt::Formatter<'_>,
583 var: &str,
584 offset: &str,
585 ) -> std::fmt::Result {
586 write!(f, "__shfl_up_sync(-1, {var}, {offset})")
587 }
588 fn compile_warp_shuffle_down(
589 f: &mut std::fmt::Formatter<'_>,
590 var: &str,
591 offset: &str,
592 ) -> std::fmt::Result {
593 write!(f, "__shfl_down_sync(-1, {var}, {offset})")
594 }
595 fn compile_warp_all<T: Component<Self>>(
596 f: &mut std::fmt::Formatter<'_>,
597 input: &T,
598 ) -> std::fmt::Result {
599 write!(f, "__all_sync(-1, {input})")
600 }
601 fn compile_warp_any<T: Component<Self>>(
602 f: &mut std::fmt::Formatter<'_>,
603 input: &T,
604 ) -> std::fmt::Result {
605 write!(f, "__any_sync(-1, {input})")
606 }
607
608 fn compile_warp_ballot(
609 f: &mut std::fmt::Formatter<'_>,
610 input: &Variable<Self>,
611 _out_elem: &Elem<Self>,
612 ) -> std::fmt::Result {
613 write!(f, "__ballot_sync(-1, {input})")
614 }
615
616 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
617 let elem = Elem::<Self>::Bool;
618 let uint32 = Elem::<Self>::U32;
619 writeln!(
623 f,
624 r#"{out} = {elem}([&]() -> {uint32} {{
625 {uint32} pred = 0;
626 asm volatile(
627 "{{\n"
628 " .reg .pred %%px;\n"
629 " elect.sync _|%%px, 0xffffffff;\n"
630 " selp.b32 %0, 1, 0, %%px;\n"
631 "}}\n"
632 : "+r"(pred));
633 return pred;
634 }}());"#
635 )
636 }
637}
638
639impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
642 fn compile_wmma_includes(
643 f: &mut std::fmt::Formatter<'_>,
644 flags: &Flags<Self>,
645 ) -> std::fmt::Result {
646 M::compile_wmma_includes(f, flags)
647 }
648
649 fn compile_wmma_type_definitions(
650 f: &mut std::fmt::Formatter<'_>,
651 flags: &Flags<Self>,
652 ) -> std::fmt::Result {
653 M::compile_wmma_type_definitions(f, flags)
654 }
655
656 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
657 M::compile_wmma_local_variables(f)
658 }
659
660 fn compile_wmma_fragment_declaration(
661 f: &mut std::fmt::Formatter<'_>,
662 var: &Variable<Self>,
663 ) -> std::fmt::Result {
664 M::compile_wmma_fragment_declaration(f, var)
665 }
666
667 fn compile_wwma_fragment_ident(
668 f: &mut std::fmt::Formatter<'_>,
669 ident: &crate::shared::FragmentIdent<Self>,
670 ) -> std::fmt::Result {
671 M::compile_wwma_fragment_ident(f, ident)
672 }
673
674 fn compile_wmma_fragment_layout(
675 f: &mut std::fmt::Formatter<'_>,
676 layout: &crate::shared::FragmentLayout<Self>,
677 ) -> std::fmt::Result {
678 M::compile_wmma_fragment_layout(f, layout)
679 }
680
681 fn compile_wmma_fragment(
682 f: &mut std::fmt::Formatter<'_>,
683 fragment: &crate::shared::Fragment<Self>,
684 ) -> std::fmt::Result {
685 M::compile_wmma_fragment(f, fragment)
686 }
687
688 fn compile_wmma_instruction(
689 f: &mut std::fmt::Formatter<'_>,
690 instruction: &crate::shared::WmmaInstruction<Self>,
691 ) -> std::fmt::Result {
692 M::compile_wmma_instruction(f, instruction)
693 }
694
695 fn compile_manual_mma(
696 f: &mut std::fmt::Formatter<'_>,
697 mma: ManualMma<Self>,
698 ) -> std::fmt::Result {
699 M::compile_manual_mma(f, mma)
700 }
701
702 fn compile_scaled_mma(
703 f: &mut std::fmt::Formatter<'_>,
704 mma: ManualMma<Self>,
705 scales_a: Variable<Self>,
706 scales_b: Variable<Self>,
707 scales_factor: u32,
708 ) -> std::fmt::Result {
709 M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
710 }
711
712 fn supported_wmma_combinations(
713 arch: &CudaArchitecture,
714 ) -> crate::shared::SupportedMmaCombinations {
715 M::supported_wmma_combinations(arch)
716 }
717
718 fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
719 M::supported_mma_combinations(arch)
720 }
721
722 fn supported_scaled_mma_combinations(
723 arch: &CudaArchitecture,
724 ) -> shared::SupportedScaledMmaCombinations {
725 M::supported_scaled_mma_combinations(arch)
726 }
727}
728
729impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
730 fn processors() -> Vec<Box<dyn Processor>> {
731 vec![
732 Box::new(CudaMmaProcessor),
733 Box::new(SaturatingArithmeticProcessor::new(false)),
734 ]
735 }
736}