1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{
4 ir::{BarrierLevel, Processor},
5 post_processing::saturating::SaturatingArithmeticProcessor,
6};
7
8use crate::{
9 Dialect,
10 cuda::{
11 extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension, StMatrix},
12 processors::CudaMmaProcessor,
13 ptx::*,
14 },
15 shared::{
16 self, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
17 DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
18 DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, KernelArg,
19 ManualMma, Variable, WarpInstruction, unary,
20 },
21};
22
23use super::{Extension, arch::CudaArchitecture};
24
25#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
26pub struct CudaDialect<M> {
27 _wmma_compiler: PhantomData<M>,
28}
29
30impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
31 type Architecture = CudaArchitecture;
32}
33
34impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
35 type Extension = Extension<Self>;
36
37 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
38 f.write_str("#include <cuda_runtime.h>\n")?;
39 if flags.elem_fp4 {
40 f.write_str("#include <cuda_fp4.h>\n")?;
41 }
42 if flags.elem_fp6 {
43 f.write_str("#include <cuda_fp6.h>\n")?;
44 }
45 if flags.elem_fp8 {
46 f.write_str("#include <cuda_fp8.h>\n")?;
47 }
48 if flags.elem_bf16 {
49 f.write_str("#include <cuda_bf16.h>\n")?;
50 }
51 if flags.elem_f16 {
52 f.write_str("#include <cuda_fp16.h>\n")?;
53 }
54
55 if flags.inst_wmma || flags.elem_tf32 {
57 Self::compile_wmma_includes(f, flags)?;
58 }
59
60 if flags.op_pipeline {
61 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
62 f.write_str("#include <cuda/pipeline>\n")?;
63 }
64 if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
65 f.write_str("#include <cooperative_groups.h>\n")?;
66 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
67 f.write_str("#include <cuda/barrier>\n")?;
68 }
69 if flags.inst_ptx_wrappers {
70 f.write_str("#include <cuda/ptx>\n")?;
71 }
72 if flags.inst_tma {
73 f.write_str(
74 "typedef struct CUtensorMap_st {
75alignas(64) unsigned long long int opaque[16];
76} CUtensorMap;\n",
77 )?;
78 }
79 Ok(())
80 }
81
82 fn compile_extensions(
83 f: &mut std::fmt::Formatter<'_>,
84 extensions: &[Self::Extension],
85 ) -> std::fmt::Result {
86 for extension in extensions {
87 match extension {
88 Extension::NoExtension => {}
89 Extension::Mma(mma) => mma.format_extension(f)?,
90 }
91 }
92 Ok(())
93 }
94
95 fn register_instruction_extension(
96 _extensions: &mut Vec<Self::Extension>,
97 _instruction: &Instruction<Self>,
98 ) {
99 }
100
101 fn register_warp_instruction_extension(
102 _extensions: &mut Vec<Self::Extension>,
103 _instruction: &WarpInstruction<Self>,
104 ) {
105 }
106
107 fn register_wmma_instruction_extension(
108 extensions: &mut Vec<Self::Extension>,
109 instruction: &shared::WmmaInstruction<Self>,
110 ) {
111 match instruction {
112 shared::WmmaInstruction::ExecuteManual {
113 shape,
114 frag_a,
115 frag_b,
116 frag_c,
117 frag_d,
118 } => {
119 let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
120 *shape,
121 Fragment(frag_a.elem()),
122 Fragment(frag_b.elem()),
123 Fragment(frag_c.elem()),
124 Fragment(frag_d.elem()),
125 )));
126 if !extensions.contains(&ext) {
127 extensions.push(ext);
128 }
129 }
130 shared::WmmaInstruction::ExecuteScaled {
131 shape,
132 frag_a,
133 frag_b,
134 frag_c,
135 frag_d,
136 scales_a,
137 scales_factor,
138 ..
139 } => {
140 let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
141 *shape,
142 Fragment(frag_a.elem()),
143 Fragment(frag_b.elem()),
144 Fragment(frag_c.elem()),
145 Fragment(frag_d.elem()),
146 scales_a.elem(),
147 *scales_factor,
148 )));
149 if !extensions.contains(&ext) {
150 extensions.push(ext);
151 }
152 }
153 shared::WmmaInstruction::LdMatrix {
154 output,
155 factor,
156 transpose,
157 ..
158 } => {
159 let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
160 output.elem(),
161 *factor,
162 *transpose,
163 )));
164 if !extensions.contains(&ext) {
165 extensions.push(ext);
166 }
167 }
168 shared::WmmaInstruction::StMatrix {
169 registers,
170 factor,
171 transpose,
172 ..
173 } => {
174 let ext = Extension::Mma(MmaExtension::StMatrix(StMatrix::new(
175 registers.elem(),
176 *factor,
177 *transpose,
178 )));
179 if !extensions.contains(&ext) {
180 extensions.push(ext);
181 }
182 }
183 _ => {}
184 }
185 }
186}
187
188impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
191 fn item_can_be_optimized() -> bool {
192 true
193 }
194
195 fn compile_type_definitions(
196 f: &mut std::fmt::Formatter<'_>,
197 items: &HashSet<Item<Self>>,
198 scalars: &[(Elem<Self>, usize)],
199 info: &cubecl_core::Info,
200 flags: &Flags<Self>,
201 ) -> std::fmt::Result {
202 let mut items_deduplicated = HashSet::new();
204
205 for item in items {
206 let mut item = *item;
207 match item.elem() {
208 Elem::FP4(_) => {
209 item.elem = Elem::FP4(FP4Kind::E2M1);
210 }
211 Elem::FP4x2(_) => {
212 item.elem = Elem::FP4x2(FP4Kind::E2M1);
213 }
214 Elem::FP6(_) => {
215 item.elem = Elem::FP6(FP6Kind::E2M3);
216 }
217 Elem::FP6x2(_) => {
218 item.elem = Elem::FP6x2(FP6Kind::E2M3);
219 }
220 Elem::FP8(_) => {
221 item.elem = Elem::FP8(FP8Kind::E4M3);
222 }
223 Elem::FP8x2(_) => {
224 item.elem = Elem::FP8x2(FP8Kind::E4M3);
225 }
226 Elem::Atomic(inner) => {
227 item.elem = inner.as_elem();
228 }
229 _ => {}
230 }
231 items_deduplicated.insert(item);
232 }
233
234 shared::type_definitions::<Self>(f)?;
235 shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
236
237 shared::type_info_definition_sized(f, info, scalars, flags.address_type)?;
238
239 if flags.inst_wmma {
240 Self::compile_wmma_type_definitions(f, flags)?;
241 }
242
243 Ok(())
244 }
245
246 fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
247 if flags.inst_tma_im2col {
248 writeln!(f, "{TMA_LOAD_IM2COL}")?;
249 }
250 if flags.inst_async_copy {
251 writeln!(f, "{COPY_ASYNC}")?;
252 }
253 Ok(())
254 }
255
256 fn compile_elem(
257 f: &mut std::fmt::Formatter<'_>,
258 elem: &shared::Elem<Self>,
259 words: bool,
260 ) -> std::fmt::Result {
261 if words {
262 match elem {
263 shared::Elem::F32 => f.write_str("float"),
264 shared::Elem::F64 => f.write_str("double"),
265 shared::Elem::TF32 => f.write_str("float"),
266 shared::Elem::I8 => f.write_str("char"),
267 shared::Elem::I16 => f.write_str("short"),
268 shared::Elem::I32 => f.write_str("int"),
269 shared::Elem::I64 => f.write_str("long"),
270 shared::Elem::U8 => f.write_str("uchar"),
271 shared::Elem::U16 => f.write_str("ushort"),
272 shared::Elem::U32 => f.write_str("uint"),
273 shared::Elem::U64 => f.write_str("ulong"),
274 _ => Self::compile_elem(f, elem, false),
275 }
276 } else {
277 match elem {
278 shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
279 shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
280 shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
281 shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
282 shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
283 shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
284 shared::Elem::F16 => f.write_str("__half"),
285 shared::Elem::F16x2 => f.write_str("__half2"),
286 shared::Elem::F32 => f.write_str("float"),
287 shared::Elem::F64 => f.write_str("double"),
288 shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
289 shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
290 shared::Elem::TF32 => f.write_str("float"),
291 shared::Elem::I8 => f.write_str("int8"),
292 shared::Elem::I16 => f.write_str("int16"),
293 shared::Elem::I32 => f.write_str("int32"),
294 shared::Elem::I64 => f.write_str("int64"),
295 shared::Elem::U8 => f.write_str("uint8"),
296 shared::Elem::U16 => f.write_str("uint16"),
297 shared::Elem::U32 => f.write_str("uint32"),
298 shared::Elem::U64 => f.write_str("uint64"),
299 shared::Elem::Bool => f.write_str("bool"),
300 shared::Elem::Barrier(BarrierLevel::Unit) => {
301 f.write_str("cuda::barrier<cuda::thread_scope_thread>")
302 }
303 shared::Elem::Barrier(BarrierLevel::Cube) => {
304 f.write_str("cuda::barrier<cuda::thread_scope_block>")
305 }
306 shared::Elem::Atomic(inner) => write!(f, "{inner}"),
307 shared::Elem::_Dialect(_) => Ok(()),
308 }
309 }
310 }
311
312 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
313 if 1 == item.vectorization {
314 return write!(f, "{}", item.elem);
315 }
316 if item.native {
317 Self::compile_elem(f, &item.elem, true)?;
319 write!(f, "{}", item.vectorization)
320 } else {
321 write!(f, "{}_{}", item.elem, item.vectorization)
322 }
323 }
324
325 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 Ok(())
327 }
328}
329
330impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
333 fn compile_kernel_signature(
334 f: &mut std::fmt::Formatter<'_>,
335 kernel_name: &str,
336 tensor_maps: &[KernelArg<Self>],
337 buffers: &[KernelArg<Self>],
338 flags: &Flags<Self>,
339 ) -> std::fmt::Result {
340 write!(
341 f,
342 "
343
344extern \"C\" __global__ void __launch_bounds__({})",
345 flags.cube_dim.num_elems()
346 )?;
347 if let Some(cluster_dim) = flags.cluster_dim {
348 write!(
349 f,
350 "__cluster_dims__({}, {}, {}) ",
351 cluster_dim.x, cluster_dim.y, cluster_dim.z
352 )?;
353 }
354 writeln!(f, "{kernel_name} (")?;
355
356 shared::compile_bindings(f, tensor_maps, buffers, flags.has_info)?;
357 if flags.use_grid_constants {
358 shared::compile_info_static(f, flags)?;
359 } else {
360 shared::compile_info_dynamic(f, flags)?;
361 }
362 f.write_str("\n)")?;
363 Ok(())
365 }
366
367 fn compile_bindings_body(
368 f: &mut std::fmt::Formatter<'_>,
369 body: &shared::Body<Self>,
370 ) -> std::fmt::Result {
371 if !body.shared_memories.is_empty() {
372 let max_align = body
373 .shared_memories
374 .iter()
375 .map(|smem| smem.align())
376 .max()
377 .unwrap();
378 writeln!(
381 f,
382 "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
383 )?;
384 }
385 if body.info_by_ptr {
386 f.write_str("const info_st& info = *info_ptr;\n")?;
387 writeln!(
389 f,
390 "const {addr}* dynamic_meta = reinterpret_cast<const {addr}*>(
391 reinterpret_cast<const char*>(info_ptr) + sizeof(info_st)
392 );\n",
393 addr = body.address_type,
394 )?;
395 }
396 Ok(())
397 }
398}
399
400impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
401
402impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
405 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 write!(f, "cluster.block_rank()")
407 }
408
409 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410 write!(f, "cluster.block_index().x")
411 }
412
413 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414 write!(f, "cluster.block_index().y")
415 }
416
417 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418 write!(f, "cluster.block_index().z")
419 }
420}
421
422impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
425 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
427 writeln!(f, "__syncthreads();\n")
428 }
429
430 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431 writeln!(f, "__syncwarp();\n")
432 }
433
434 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435 writeln!(f, "__threadfence();")
436 }
437
438 fn compile_instruction_find_first_set<T: Component<Self>>(
440 f: &mut std::fmt::Formatter<'_>,
441 input: T,
442 out_elem: Elem<Self>,
443 ) -> std::fmt::Result {
444 write!(f, "{out_elem}(")?;
445 match input.elem() {
446 Elem::I32 => write!(f, "__ffs({input})"),
447 Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
448 Elem::I64 => write!(f, "__ffsll({input})"),
449 Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
450 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
451 }?;
452 write!(f, ")")
453 }
454
455 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
456 f: &mut std::fmt::Formatter<'_>,
457 input: T,
458 out_elem: Elem<Self>,
459 ) -> std::fmt::Result {
460 write!(f, "{out_elem}(")?;
461 match input.elem() {
462 Elem::I32 => write!(f, "__clz({input})"),
463 Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
464 Elem::I64 => write!(f, "__clzll({input})"),
465 Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
466 in_elem => write!(
467 f,
468 "{out_elem}(__clz({}) - {})",
469 unary::zero_extend(input),
470 (size_of::<u32>() - in_elem.size()) * 8
471 ),
472 }?;
473 write!(f, ")")
474 }
475
476 fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
477 f: &mut std::fmt::Formatter<'_>,
478 input: T,
479 out_elem: Elem<Self>,
480 ) -> std::fmt::Result {
481 write!(f, "{out_elem}(")?;
485 match input.elem() {
486 Elem::I32 | Elem::U32 => {
487 write!(f, "({input} == 0 ? 32 : __ffs({input}) - 1)")
488 }
489 Elem::I64 | Elem::U64 => {
490 write!(f, "({input} == 0 ? 64 : __ffsll({input}) - 1)")
491 }
492 in_elem => {
493 let bits = in_elem.size() * 8;
494 let extended = unary::zero_extend(input);
495 write!(f, "({extended} == 0 ? {bits} : __ffs({extended}) - 1)")
496 }
497 }?;
498 write!(f, ")")
499 }
500
501 fn compile_saturating_add(
502 f: &mut std::fmt::Formatter<'_>,
503 lhs: impl Display,
504 rhs: impl Display,
505 item: Item<Self>,
506 ) -> std::fmt::Result {
507 let elem = item.elem();
508 match elem {
509 Elem::I32 => {
510 write!(
511 f,
512 r#"[&]() -> {elem} {{
513 {elem} result;
514 asm("add.sat.s32 %0, %1, %2;"
515 : "=r"(result)
516 : "r"({lhs}), "r"({rhs}));
517 return result;
518 }}()"#
519 )
520 }
521 _ => unreachable!("Should be replaced by polyfill"),
522 }
523 }
524
525 fn compile_saturating_sub(
526 f: &mut std::fmt::Formatter<'_>,
527 lhs: impl Display,
528 rhs: impl Display,
529 item: Item<Self>,
530 ) -> std::fmt::Result {
531 let elem = item.elem();
532 match elem {
534 Elem::I32 => {
535 write!(
536 f,
537 r#"[&]() -> {elem} {{
538 {elem} result;
539 asm("sub.sat.s32 %0, %1, %2;"
540 : "=r"(result)
541 : "r"({lhs}), "r"({rhs}));
542 return result;
543 }}()"#
544 )
545 }
546 _ => unreachable!("Should be replaced by polyfill"),
547 }
548 }
549
550 fn compile_instruction_max_function_name(
552 f: &mut std::fmt::Formatter<'_>,
553 item: Item<Self>,
554 ) -> std::fmt::Result {
555 let max = match item.elem() {
556 Elem::F16 | Elem::BF16 => "__hmax",
557 Elem::F16x2 | Elem::BF16x2 => "__hmax2",
558 _ => "max",
559 };
560 write!(f, "{max}")
561 }
562
563 fn compile_instruction_min_function_name(
564 f: &mut std::fmt::Formatter<'_>,
565 item: Item<Self>,
566 ) -> std::fmt::Result {
567 let min = match item.elem() {
568 Elem::F16 | Elem::BF16 => "__hmin",
569 Elem::F16x2 | Elem::BF16x2 => "__hmin2",
570 _ => "min",
571 };
572 write!(f, "{min}")
573 }
574
575 fn compile_warp_shuffle(
577 f: &mut std::fmt::Formatter<'_>,
578 var: &str,
579 source: &str,
580 ) -> std::fmt::Result {
581 write!(f, "__shfl_sync(-1, {var}, {source})")
582 }
583 fn compile_warp_shuffle_xor(
584 f: &mut std::fmt::Formatter<'_>,
585 var: &str,
586 _elem: &Elem<Self>,
587 offset: &str,
588 ) -> std::fmt::Result {
589 write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
590 }
591 fn compile_warp_shuffle_up(
592 f: &mut std::fmt::Formatter<'_>,
593 var: &str,
594 offset: &str,
595 ) -> std::fmt::Result {
596 write!(f, "__shfl_up_sync(-1, {var}, {offset})")
597 }
598 fn compile_warp_shuffle_down(
599 f: &mut std::fmt::Formatter<'_>,
600 var: &str,
601 offset: &str,
602 ) -> std::fmt::Result {
603 write!(f, "__shfl_down_sync(-1, {var}, {offset})")
604 }
605 fn compile_warp_all<T: Component<Self>>(
606 f: &mut std::fmt::Formatter<'_>,
607 input: &T,
608 ) -> std::fmt::Result {
609 write!(f, "__all_sync(-1, {input})")
610 }
611 fn compile_warp_any<T: Component<Self>>(
612 f: &mut std::fmt::Formatter<'_>,
613 input: &T,
614 ) -> std::fmt::Result {
615 write!(f, "__any_sync(-1, {input})")
616 }
617
618 fn compile_warp_ballot(
619 f: &mut std::fmt::Formatter<'_>,
620 input: &Variable<Self>,
621 _out_elem: &Elem<Self>,
622 ) -> std::fmt::Result {
623 write!(f, "__ballot_sync(-1, {input})")
624 }
625
626 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
627 let elem = Elem::<Self>::Bool;
628 let uint32 = Elem::<Self>::U32;
629 writeln!(
633 f,
634 r#"{out} = {elem}([&]() -> {uint32} {{
635 {uint32} pred = 0;
636 asm volatile(
637 "{{\n"
638 " .reg .pred %%px;\n"
639 " elect.sync _|%%px, 0xffffffff;\n"
640 " selp.b32 %0, 1, 0, %%px;\n"
641 "}}\n"
642 : "+r"(pred));
643 return pred;
644 }}());"#
645 )
646 }
647
648 fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
649 write!(f, "__builtin_unreachable();")
650 }
651}
652
653impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
656 fn compile_wmma_includes(
657 f: &mut std::fmt::Formatter<'_>,
658 flags: &Flags<Self>,
659 ) -> std::fmt::Result {
660 M::compile_wmma_includes(f, flags)
661 }
662
663 fn compile_wmma_type_definitions(
664 f: &mut std::fmt::Formatter<'_>,
665 flags: &Flags<Self>,
666 ) -> std::fmt::Result {
667 M::compile_wmma_type_definitions(f, flags)
668 }
669
670 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
671 M::compile_wmma_local_variables(f)
672 }
673
674 fn compile_wmma_fragment_declaration(
675 f: &mut std::fmt::Formatter<'_>,
676 var: &Variable<Self>,
677 ) -> std::fmt::Result {
678 M::compile_wmma_fragment_declaration(f, var)
679 }
680
681 fn compile_wwma_fragment_ident(
682 f: &mut std::fmt::Formatter<'_>,
683 ident: &crate::shared::FragmentIdent<Self>,
684 ) -> std::fmt::Result {
685 M::compile_wwma_fragment_ident(f, ident)
686 }
687
688 fn compile_wmma_fragment_layout(
689 f: &mut std::fmt::Formatter<'_>,
690 layout: &crate::shared::FragmentLayout<Self>,
691 ) -> std::fmt::Result {
692 M::compile_wmma_fragment_layout(f, layout)
693 }
694
695 fn compile_wmma_fragment(
696 f: &mut std::fmt::Formatter<'_>,
697 fragment: &crate::shared::Fragment<Self>,
698 ) -> std::fmt::Result {
699 M::compile_wmma_fragment(f, fragment)
700 }
701
702 fn compile_wmma_instruction(
703 f: &mut std::fmt::Formatter<'_>,
704 instruction: &crate::shared::WmmaInstruction<Self>,
705 ) -> std::fmt::Result {
706 M::compile_wmma_instruction(f, instruction)
707 }
708
709 fn compile_manual_mma(
710 f: &mut std::fmt::Formatter<'_>,
711 mma: ManualMma<Self>,
712 ) -> std::fmt::Result {
713 M::compile_manual_mma(f, mma)
714 }
715
716 fn compile_scaled_mma(
717 f: &mut std::fmt::Formatter<'_>,
718 mma: ManualMma<Self>,
719 scales_a: Variable<Self>,
720 scales_b: Variable<Self>,
721 scales_factor: u32,
722 ) -> std::fmt::Result {
723 M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
724 }
725
726 fn supported_wmma_combinations(
727 arch: &CudaArchitecture,
728 ) -> crate::shared::SupportedMmaCombinations {
729 M::supported_wmma_combinations(arch)
730 }
731
732 fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
733 M::supported_mma_combinations(arch)
734 }
735
736 fn supported_scaled_mma_combinations(
737 arch: &CudaArchitecture,
738 ) -> shared::SupportedScaledMmaCombinations {
739 M::supported_scaled_mma_combinations(arch)
740 }
741}
742
743impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
744 fn processors() -> Vec<Box<dyn Processor>> {
745 vec![
746 Box::new(CudaMmaProcessor),
747 Box::new(SaturatingArithmeticProcessor::new(false)),
748 ]
749 }
750}