1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
4
5use crate::{
6 Dialect,
7 cuda::{
8 extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension, StMatrix},
9 processors::CudaMmaProcessor,
10 ptx::*,
11 },
12 shared::{
13 self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
14 DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
15 DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
16 Variable, WarpInstruction, unary,
17 },
18};
19
20use super::{Extension, arch::CudaArchitecture};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaDialect<M> {
24 _wmma_compiler: PhantomData<M>,
25}
26
27impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
28 type Architecture = CudaArchitecture;
29}
30
31impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
32 type Extension = Extension<Self>;
33
34 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
35 f.write_str("#include <cuda_runtime.h>\n")?;
36 if flags.elem_fp4 {
37 f.write_str("#include <cuda_fp4.h>\n")?;
38 }
39 if flags.elem_fp6 {
40 f.write_str("#include <cuda_fp6.h>\n")?;
41 }
42 if flags.elem_fp8 {
43 f.write_str("#include <cuda_fp8.h>\n")?;
44 }
45 if flags.elem_bf16 {
46 f.write_str("#include <cuda_bf16.h>\n")?;
47 }
48 if flags.elem_f16 {
49 f.write_str("#include <cuda_fp16.h>\n")?;
50 }
51
52 if flags.inst_wmma || flags.elem_tf32 {
54 Self::compile_wmma_includes(f, flags)?;
55 }
56
57 if flags.op_pipeline {
58 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
59 f.write_str("#include <cuda/pipeline>\n")?;
60 }
61 if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
62 f.write_str("#include <cooperative_groups.h>\n")?;
63 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
64 f.write_str("#include <cuda/barrier>\n")?;
65 }
66 if flags.inst_ptx_wrappers {
67 f.write_str("#include <cuda/ptx>\n")?;
68 }
69 if flags.inst_tma {
70 f.write_str(
71 "typedef struct CUtensorMap_st {
72alignas(64) unsigned long long int opaque[16];
73} CUtensorMap;\n",
74 )?;
75 }
76 Ok(())
77 }
78
79 fn compile_extensions(
80 f: &mut std::fmt::Formatter<'_>,
81 extensions: &[Self::Extension],
82 ) -> std::fmt::Result {
83 for extension in extensions {
84 match extension {
85 Extension::NoExtension => {}
86 Extension::Mma(mma) => mma.format_extension(f)?,
87 }
88 }
89 Ok(())
90 }
91
92 fn register_instruction_extension(
93 _extensions: &mut Vec<Self::Extension>,
94 _instruction: &Instruction<Self>,
95 ) {
96 }
97
98 fn register_warp_instruction_extension(
99 _extensions: &mut Vec<Self::Extension>,
100 _instruction: &WarpInstruction<Self>,
101 ) {
102 }
103
104 fn register_wmma_instruction_extension(
105 extensions: &mut Vec<Self::Extension>,
106 instruction: &shared::WmmaInstruction<Self>,
107 ) {
108 match instruction {
109 shared::WmmaInstruction::ExecuteManual {
110 shape,
111 frag_a,
112 frag_b,
113 frag_c,
114 frag_d,
115 } => {
116 let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
117 *shape,
118 Fragment(frag_a.elem()),
119 Fragment(frag_b.elem()),
120 Fragment(frag_c.elem()),
121 Fragment(frag_d.elem()),
122 )));
123 if !extensions.contains(&ext) {
124 extensions.push(ext);
125 }
126 }
127 shared::WmmaInstruction::ExecuteScaled {
128 shape,
129 frag_a,
130 frag_b,
131 frag_c,
132 frag_d,
133 scales_a,
134 scales_factor,
135 ..
136 } => {
137 let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
138 *shape,
139 Fragment(frag_a.elem()),
140 Fragment(frag_b.elem()),
141 Fragment(frag_c.elem()),
142 Fragment(frag_d.elem()),
143 scales_a.elem(),
144 *scales_factor,
145 )));
146 if !extensions.contains(&ext) {
147 extensions.push(ext);
148 }
149 }
150 shared::WmmaInstruction::LdMatrix {
151 output,
152 factor,
153 transpose,
154 ..
155 } => {
156 let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
157 output.elem(),
158 *factor,
159 *transpose,
160 )));
161 if !extensions.contains(&ext) {
162 extensions.push(ext);
163 }
164 }
165 shared::WmmaInstruction::StMatrix {
166 registers,
167 factor,
168 transpose,
169 ..
170 } => {
171 let ext = Extension::Mma(MmaExtension::StMatrix(StMatrix::new(
172 registers.elem(),
173 *factor,
174 *transpose,
175 )));
176 if !extensions.contains(&ext) {
177 extensions.push(ext);
178 }
179 }
180 _ => {}
181 }
182 }
183}
184
185impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
188 fn item_can_be_optimized() -> bool {
189 true
190 }
191
192 fn compile_type_definitions(
193 f: &mut std::fmt::Formatter<'_>,
194 items: &HashSet<Item<Self>>,
195 scalars: &[(Elem<Self>, usize)],
196 flags: &Flags,
197 ) -> std::fmt::Result {
198 let mut items_deduplicated = HashSet::new();
200
201 for item in items {
202 let mut item = *item;
203 match item.elem() {
204 Elem::FP4(_) => {
205 item.elem = Elem::FP4(FP4Kind::E2M1);
206 }
207 Elem::FP4x2(_) => {
208 item.elem = Elem::FP4x2(FP4Kind::E2M1);
209 }
210 Elem::FP6(_) => {
211 item.elem = Elem::FP6(FP6Kind::E2M3);
212 }
213 Elem::FP6x2(_) => {
214 item.elem = Elem::FP6x2(FP6Kind::E2M3);
215 }
216 Elem::FP8(_) => {
217 item.elem = Elem::FP8(FP8Kind::E4M3);
218 }
219 Elem::FP8x2(_) => {
220 item.elem = Elem::FP8x2(FP8Kind::E4M3);
221 }
222 _ => {}
223 }
224 items_deduplicated.insert(item);
225 }
226
227 shared::type_definitions::<Self>(f)?;
228 shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
229
230 if flags.use_grid_constants {
231 shared::type_scalar_definitions::<Self>(f, scalars)?;
232 shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
233 }
234
235 if flags.inst_wmma {
236 Self::compile_wmma_type_definitions(f, flags)?;
237 }
238
239 Ok(())
240 }
241
242 fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
243 if flags.inst_tma_im2col {
244 writeln!(f, "{TMA_LOAD_IM2COL}")?;
245 }
246 Ok(())
247 }
248
249 fn compile_elem(
250 f: &mut std::fmt::Formatter<'_>,
251 elem: &shared::Elem<Self>,
252 words: bool,
253 ) -> std::fmt::Result {
254 if words {
255 match elem {
256 shared::Elem::F32 => f.write_str("float"),
257 shared::Elem::F64 => f.write_str("double"),
258 shared::Elem::TF32 => f.write_str("float"),
259 shared::Elem::I8 => f.write_str("char"),
260 shared::Elem::I16 => f.write_str("short"),
261 shared::Elem::I32 => f.write_str("int"),
262 shared::Elem::I64 => f.write_str("long"),
263 shared::Elem::U8 => f.write_str("uchar"),
264 shared::Elem::U16 => f.write_str("ushort"),
265 shared::Elem::U32 => f.write_str("uint"),
266 shared::Elem::U64 => f.write_str("ulong"),
267 _ => Self::compile_elem(f, elem, false),
268 }
269 } else {
270 match elem {
271 shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
272 shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
273 shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
274 shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
275 shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
276 shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
277 shared::Elem::F16 => f.write_str("__half"),
278 shared::Elem::F16x2 => f.write_str("__half2"),
279 shared::Elem::F32 => f.write_str("float"),
280 shared::Elem::F64 => f.write_str("double"),
281 shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
282 shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
283 shared::Elem::TF32 => f.write_str("float"),
284 shared::Elem::I8 => f.write_str("int8"),
285 shared::Elem::I16 => f.write_str("int16"),
286 shared::Elem::I32 => f.write_str("int32"),
287 shared::Elem::I64 => f.write_str("int64"),
288 shared::Elem::U8 => f.write_str("uint8"),
289 shared::Elem::U16 => f.write_str("uint16"),
290 shared::Elem::U32 => f.write_str("uint32"),
291 shared::Elem::U64 => f.write_str("uint64"),
292 shared::Elem::Bool => f.write_str("bool"),
293 shared::Elem::Atomic(inner) => write!(f, "{inner}"),
294 shared::Elem::_Dialect(_) => Ok(()),
295 }
296 }
297 }
298
299 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
300 if 1 == item.vectorization {
301 return write!(f, "{}", item.elem);
302 }
303 if item.native {
304 Self::compile_elem(f, &item.elem, true)?;
306 write!(f, "{}", item.vectorization)
307 } else {
308 write!(f, "{}_{}", item.elem, item.vectorization)
309 }
310 }
311
312 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 Ok(())
314 }
315}
316
317impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
320 fn compile_kernel_signature(
321 f: &mut std::fmt::Formatter<'_>,
322 kernel_name: &str,
323 tensor_maps: &[Binding<Self>],
324 buffers: &[Binding<Self>],
325 scalars: &[(Elem<Self>, usize)],
326 flags: &Flags,
327 ) -> std::fmt::Result {
328 write!(
329 f,
330 "
331
332extern \"C\" __global__ void __launch_bounds__({})",
333 flags.cube_dim.num_elems()
334 )?;
335 if let Some(cluster_dim) = flags.cluster_dim {
336 write!(
337 f,
338 "__cluster_dims__({}, {}, {}) ",
339 cluster_dim.x, cluster_dim.y, cluster_dim.z
340 )?;
341 }
342 writeln!(f, "{kernel_name} (")?;
343 let has_scalars =
344 !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
345 shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
346 if flags.use_grid_constants {
347 shared::compile_scalars_static(f, scalars, flags)?;
348 } else {
349 shared::compile_scalars_dynamic(f, scalars)?;
350 }
351 f.write_str("\n)")?;
352 Ok(())
354 }
355
356 fn compile_bindings_body(
357 f: &mut std::fmt::Formatter<'_>,
358 body: &shared::Body<Self>,
359 ) -> std::fmt::Result {
360 if !body.shared_memories.is_empty() {
361 let max_align = body
362 .shared_memories
363 .iter()
364 .map(|smem| smem.align)
365 .max()
366 .unwrap();
367 writeln!(
370 f,
371 "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
372 )?;
373 }
374 Ok(())
375 }
376}
377
378impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
379
380impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
383 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 write!(f, "cluster.block_rank()")
385 }
386
387 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 write!(f, "cluster.block_index().x")
389 }
390
391 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 write!(f, "cluster.block_index().y")
393 }
394
395 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 write!(f, "cluster.block_index().z")
397 }
398}
399
400impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
403 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 writeln!(f, "__syncthreads();\n")
406 }
407
408 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409 writeln!(f, "__syncwarp();\n")
410 }
411
412 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 writeln!(f, "__threadfence();")
414 }
415
416 fn compile_instruction_find_first_set<T: Component<Self>>(
418 f: &mut std::fmt::Formatter<'_>,
419 input: T,
420 out_elem: Elem<Self>,
421 ) -> std::fmt::Result {
422 write!(f, "{out_elem}(")?;
423 match input.elem() {
424 Elem::I32 => write!(f, "__ffs({input})"),
425 Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
426 Elem::I64 => write!(f, "__ffsll({input})"),
427 Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
428 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
429 }?;
430 write!(f, ")")
431 }
432
433 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
434 f: &mut std::fmt::Formatter<'_>,
435 input: T,
436 out_elem: Elem<Self>,
437 ) -> std::fmt::Result {
438 write!(f, "{out_elem}(")?;
439 match input.elem() {
440 Elem::I32 => write!(f, "__clz({input})"),
441 Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
442 Elem::I64 => write!(f, "__clzll({input})"),
443 Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
444 in_elem => write!(
445 f,
446 "{out_elem}(__clz({}) - {})",
447 unary::zero_extend(input),
448 (size_of::<u32>() - in_elem.size()) * 8
449 ),
450 }?;
451 write!(f, ")")
452 }
453
454 fn compile_saturating_add(
455 f: &mut std::fmt::Formatter<'_>,
456 lhs: impl Display,
457 rhs: impl Display,
458 item: Item<Self>,
459 ) -> std::fmt::Result {
460 let elem = item.elem();
461 match elem {
462 Elem::I32 => {
463 write!(
464 f,
465 r#"[&]() -> {elem} {{
466 {elem} result;
467 asm("add.sat.s32 %0, %1, %2;"
468 : "=r"(result)
469 : "r"({lhs}), "r"({rhs}));
470 return result;
471 }}()"#
472 )
473 }
474 _ => unreachable!("Should be replaced by polyfill"),
475 }
476 }
477
478 fn compile_saturating_sub(
479 f: &mut std::fmt::Formatter<'_>,
480 lhs: impl Display,
481 rhs: impl Display,
482 item: Item<Self>,
483 ) -> std::fmt::Result {
484 let elem = item.elem();
485 match elem {
487 Elem::I32 => {
488 write!(
489 f,
490 r#"[&]() -> {elem} {{
491 {elem} result;
492 asm("sub.sat.s32 %0, %1, %2;"
493 : "=r"(result)
494 : "r"({lhs}), "r"({rhs}));
495 return result;
496 }}()"#
497 )
498 }
499 _ => unreachable!("Should be replaced by polyfill"),
500 }
501 }
502
503 fn compile_instruction_max_function_name(
505 f: &mut std::fmt::Formatter<'_>,
506 item: Item<Self>,
507 ) -> std::fmt::Result {
508 let max = match item.elem() {
509 Elem::F16 | Elem::BF16 => "__hmax",
510 Elem::F16x2 | Elem::BF16x2 => "__hmax2",
511 _ => "max",
512 };
513 write!(f, "{max}")
514 }
515
516 fn compile_instruction_min_function_name(
517 f: &mut std::fmt::Formatter<'_>,
518 item: Item<Self>,
519 ) -> std::fmt::Result {
520 let min = match item.elem() {
521 Elem::F16 | Elem::BF16 => "__hmin",
522 Elem::F16x2 | Elem::BF16x2 => "__hmin2",
523 _ => "min",
524 };
525 write!(f, "{min}")
526 }
527
528 fn compile_warp_shuffle(
530 f: &mut std::fmt::Formatter<'_>,
531 var: &str,
532 source: &str,
533 ) -> std::fmt::Result {
534 write!(f, "__shfl_sync(-1, {var}, {source})")
535 }
536 fn compile_warp_shuffle_xor(
537 f: &mut std::fmt::Formatter<'_>,
538 var: &str,
539 _elem: &Elem<Self>,
540 offset: &str,
541 ) -> std::fmt::Result {
542 write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
543 }
544 fn compile_warp_shuffle_up(
545 f: &mut std::fmt::Formatter<'_>,
546 var: &str,
547 offset: &str,
548 ) -> std::fmt::Result {
549 write!(f, "__shfl_up_sync(-1, {var}, {offset})")
550 }
551 fn compile_warp_shuffle_down(
552 f: &mut std::fmt::Formatter<'_>,
553 var: &str,
554 offset: &str,
555 ) -> std::fmt::Result {
556 write!(f, "__shfl_down_sync(-1, {var}, {offset})")
557 }
558 fn compile_warp_all<T: Component<Self>>(
559 f: &mut std::fmt::Formatter<'_>,
560 input: &T,
561 ) -> std::fmt::Result {
562 write!(f, "__all_sync(-1, {input})")
563 }
564 fn compile_warp_any<T: Component<Self>>(
565 f: &mut std::fmt::Formatter<'_>,
566 input: &T,
567 ) -> std::fmt::Result {
568 write!(f, "__any_sync(-1, {input})")
569 }
570
571 fn compile_warp_ballot(
572 f: &mut std::fmt::Formatter<'_>,
573 input: &Variable<Self>,
574 _out_elem: &Elem<Self>,
575 ) -> std::fmt::Result {
576 write!(f, "__ballot_sync(-1, {input})")
577 }
578
579 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
580 let elem = Elem::<Self>::Bool;
581 let uint32 = Elem::<Self>::U32;
582 writeln!(
586 f,
587 r#"{out} = {elem}([&]() -> {uint32} {{
588 {uint32} pred = 0;
589 asm volatile(
590 "{{\n"
591 " .reg .pred %%px;\n"
592 " elect.sync _|%%px, 0xffffffff;\n"
593 " selp.b32 %0, 1, 0, %%px;\n"
594 "}}\n"
595 : "+r"(pred));
596 return pred;
597 }}());"#
598 )
599 }
600}
601
602impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
605 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
606 M::compile_wmma_includes(f, flags)
607 }
608
609 fn compile_wmma_type_definitions(
610 f: &mut std::fmt::Formatter<'_>,
611 flags: &Flags,
612 ) -> std::fmt::Result {
613 M::compile_wmma_type_definitions(f, flags)
614 }
615
616 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617 M::compile_wmma_local_variables(f)
618 }
619
620 fn compile_wmma_fragment_declaration(
621 f: &mut std::fmt::Formatter<'_>,
622 var: &Variable<Self>,
623 ) -> std::fmt::Result {
624 M::compile_wmma_fragment_declaration(f, var)
625 }
626
627 fn compile_wwma_fragment_ident(
628 f: &mut std::fmt::Formatter<'_>,
629 ident: &crate::shared::FragmentIdent<Self>,
630 ) -> std::fmt::Result {
631 M::compile_wwma_fragment_ident(f, ident)
632 }
633
634 fn compile_wmma_fragment_layout(
635 f: &mut std::fmt::Formatter<'_>,
636 layout: &crate::shared::FragmentLayout<Self>,
637 ) -> std::fmt::Result {
638 M::compile_wmma_fragment_layout(f, layout)
639 }
640
641 fn compile_wmma_fragment(
642 f: &mut std::fmt::Formatter<'_>,
643 fragment: &crate::shared::Fragment<Self>,
644 ) -> std::fmt::Result {
645 M::compile_wmma_fragment(f, fragment)
646 }
647
648 fn compile_wmma_instruction(
649 f: &mut std::fmt::Formatter<'_>,
650 instruction: &crate::shared::WmmaInstruction<Self>,
651 ) -> std::fmt::Result {
652 M::compile_wmma_instruction(f, instruction)
653 }
654
655 fn compile_manual_mma(
656 f: &mut std::fmt::Formatter<'_>,
657 mma: ManualMma<Self>,
658 ) -> std::fmt::Result {
659 M::compile_manual_mma(f, mma)
660 }
661
662 fn compile_scaled_mma(
663 f: &mut std::fmt::Formatter<'_>,
664 mma: ManualMma<Self>,
665 scales_a: Variable<Self>,
666 scales_b: Variable<Self>,
667 scales_factor: u32,
668 ) -> std::fmt::Result {
669 M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
670 }
671
672 fn supported_wmma_combinations(
673 arch: &CudaArchitecture,
674 ) -> crate::shared::SupportedMmaCombinations {
675 M::supported_wmma_combinations(arch)
676 }
677
678 fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
679 M::supported_mma_combinations(arch)
680 }
681
682 fn supported_scaled_mma_combinations(
683 arch: &CudaArchitecture,
684 ) -> shared::SupportedScaledMmaCombinations {
685 M::supported_scaled_mma_combinations(arch)
686 }
687}
688
689impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
690 fn processors() -> Vec<Box<dyn Processor>> {
691 vec![
692 Box::new(CudaMmaProcessor),
693 Box::new(SaturatingArithmeticProcessor::new(false)),
694 ]
695 }
696}