1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
4
5use crate::{
6 Dialect,
7 cuda::{
8 extension::{Fragment, LdMatrix, MmaExecute, MmaExecuteScaled, MmaExtension},
9 processors::CudaMmaProcessor,
10 ptx::*,
11 },
12 shared::{
13 self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
14 DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
15 DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
16 Variable, WarpInstruction, unary,
17 },
18};
19
20use super::{Extension, arch::CudaArchitecture};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaDialect<M> {
24 _wmma_compiler: PhantomData<M>,
25}
26
27impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
28 type Architecture = CudaArchitecture;
29}
30
31impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
32 type Extension = Extension<Self>;
33
34 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
35 f.write_str("#include <cuda_runtime.h>\n")?;
36 if flags.elem_fp4 {
37 f.write_str("#include <cuda_fp4.h>\n")?;
38 }
39 if flags.elem_fp6 {
40 f.write_str("#include <cuda_fp6.h>\n")?;
41 }
42 if flags.elem_fp8 {
43 f.write_str("#include <cuda_fp8.h>\n")?;
44 }
45 if flags.elem_bf16 {
46 f.write_str("#include <cuda_bf16.h>\n")?;
47 }
48 if flags.elem_f16 {
49 f.write_str("#include <cuda_fp16.h>\n")?;
50 }
51
52 if flags.inst_wmma || flags.elem_tf32 {
54 Self::compile_wmma_includes(f, flags)?;
55 }
56
57 if flags.op_pipeline {
58 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
59 f.write_str("#include <cuda/pipeline>\n")?;
60 }
61 if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
62 f.write_str("#include <cooperative_groups.h>\n")?;
63 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
64 f.write_str("#include <cuda/barrier>\n")?;
65 }
66 if flags.inst_ptx_wrappers {
67 f.write_str("#include <cuda/ptx>\n")?;
68 }
69 if flags.inst_tma {
70 f.write_str(
71 "typedef struct CUtensorMap_st {
72alignas(64) unsigned long long int opaque[16];
73} CUtensorMap;\n",
74 )?;
75 }
76 Ok(())
77 }
78
79 fn compile_extensions(
80 f: &mut std::fmt::Formatter<'_>,
81 extensions: &[Self::Extension],
82 ) -> std::fmt::Result {
83 for extension in extensions {
84 match extension {
85 Extension::NoExtension => {}
86 Extension::Mma(mma) => mma.format_extension(f)?,
87 }
88 }
89 Ok(())
90 }
91
92 fn register_instruction_extension(
93 _extensions: &mut Vec<Self::Extension>,
94 _instruction: &Instruction<Self>,
95 ) {
96 }
97
98 fn register_warp_instruction_extension(
99 _extensions: &mut Vec<Self::Extension>,
100 _instruction: &WarpInstruction<Self>,
101 ) {
102 }
103
104 fn register_wmma_instruction_extension(
105 extensions: &mut Vec<Self::Extension>,
106 instruction: &shared::WmmaInstruction<Self>,
107 ) {
108 match instruction {
109 shared::WmmaInstruction::ExecuteManual {
110 shape,
111 frag_a,
112 frag_b,
113 frag_c,
114 frag_d,
115 } => {
116 let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
117 *shape,
118 Fragment(frag_a.elem()),
119 Fragment(frag_b.elem()),
120 Fragment(frag_c.elem()),
121 Fragment(frag_d.elem()),
122 )));
123 if !extensions.contains(&ext) {
124 extensions.push(ext);
125 }
126 }
127 shared::WmmaInstruction::ExecuteScaled {
128 shape,
129 frag_a,
130 frag_b,
131 frag_c,
132 frag_d,
133 scales_a,
134 scales_factor,
135 ..
136 } => {
137 let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
138 *shape,
139 Fragment(frag_a.elem()),
140 Fragment(frag_b.elem()),
141 Fragment(frag_c.elem()),
142 Fragment(frag_d.elem()),
143 scales_a.elem(),
144 *scales_factor,
145 )));
146 if !extensions.contains(&ext) {
147 extensions.push(ext);
148 }
149 }
150 shared::WmmaInstruction::LdMatrix {
151 output,
152 factor,
153 transpose,
154 ..
155 } => {
156 let ext = Extension::Mma(MmaExtension::LdMatrix(LdMatrix::new(
157 output.elem(),
158 *factor,
159 *transpose,
160 )));
161 if !extensions.contains(&ext) {
162 extensions.push(ext);
163 }
164 }
165 _ => {}
166 }
167 }
168}
169
170impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
173 fn item_can_be_optimized() -> bool {
174 true
175 }
176
177 fn compile_type_definitions(
178 f: &mut std::fmt::Formatter<'_>,
179 items: &HashSet<Item<Self>>,
180 scalars: &[(Elem<Self>, usize)],
181 flags: &Flags,
182 ) -> std::fmt::Result {
183 let mut items_deduplicated = HashSet::new();
185
186 for item in items {
187 let mut item = *item;
188 match item.elem() {
189 Elem::FP4(_) => {
190 item.elem = Elem::FP4(FP4Kind::E2M1);
191 }
192 Elem::FP4x2(_) => {
193 item.elem = Elem::FP4x2(FP4Kind::E2M1);
194 }
195 Elem::FP6(_) => {
196 item.elem = Elem::FP6(FP6Kind::E2M3);
197 }
198 Elem::FP6x2(_) => {
199 item.elem = Elem::FP6x2(FP6Kind::E2M3);
200 }
201 Elem::FP8(_) => {
202 item.elem = Elem::FP8(FP8Kind::E4M3);
203 }
204 Elem::FP8x2(_) => {
205 item.elem = Elem::FP8x2(FP8Kind::E4M3);
206 }
207 _ => {}
208 }
209 items_deduplicated.insert(item);
210 }
211
212 shared::type_definitions::<Self>(f)?;
213 shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
214
215 if flags.use_grid_constants {
216 shared::type_scalar_definitions::<Self>(f, scalars)?;
217 shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
218 }
219
220 if flags.inst_wmma {
221 Self::compile_wmma_type_definitions(f, flags)?;
222 }
223
224 Ok(())
225 }
226
227 fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
228 if flags.inst_tma_im2col {
229 writeln!(f, "{TMA_LOAD_IM2COL}")?;
230 }
231 Ok(())
232 }
233
234 fn compile_elem(
235 f: &mut std::fmt::Formatter<'_>,
236 elem: &shared::Elem<Self>,
237 words: bool,
238 ) -> std::fmt::Result {
239 if words {
240 match elem {
241 shared::Elem::F32 => f.write_str("float"),
242 shared::Elem::F64 => f.write_str("double"),
243 shared::Elem::TF32 => f.write_str("float"),
244 shared::Elem::I8 => f.write_str("char"),
245 shared::Elem::I16 => f.write_str("short"),
246 shared::Elem::I32 => f.write_str("int"),
247 shared::Elem::I64 => f.write_str("long"),
248 shared::Elem::U8 => f.write_str("uchar"),
249 shared::Elem::U16 => f.write_str("ushort"),
250 shared::Elem::U32 => f.write_str("uint"),
251 shared::Elem::U64 => f.write_str("ulong"),
252 _ => Self::compile_elem(f, elem, false),
253 }
254 } else {
255 match elem {
256 shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
257 shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
258 shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
259 shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
260 shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
261 shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
262 shared::Elem::F16 => f.write_str("__half"),
263 shared::Elem::F16x2 => f.write_str("__half2"),
264 shared::Elem::F32 => f.write_str("float"),
265 shared::Elem::F64 => f.write_str("double"),
266 shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
267 shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
268 shared::Elem::TF32 => f.write_str("float"),
269 shared::Elem::I8 => f.write_str("int8"),
270 shared::Elem::I16 => f.write_str("int16"),
271 shared::Elem::I32 => f.write_str("int32"),
272 shared::Elem::I64 => f.write_str("int64"),
273 shared::Elem::U8 => f.write_str("uint8"),
274 shared::Elem::U16 => f.write_str("uint16"),
275 shared::Elem::U32 => f.write_str("uint32"),
276 shared::Elem::U64 => f.write_str("uint64"),
277 shared::Elem::Bool => f.write_str("bool"),
278 shared::Elem::Atomic(inner) => write!(f, "{inner}"),
279 shared::Elem::_Dialect(_) => Ok(()),
280 }
281 }
282 }
283
284 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
285 if 1 == item.vectorization {
286 return write!(f, "{}", item.elem);
287 }
288 if item.native {
289 Self::compile_elem(f, &item.elem, true)?;
291 write!(f, "{}", item.vectorization)
292 } else {
293 write!(f, "{}_{}", item.elem, item.vectorization)
294 }
295 }
296
297 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 Ok(())
299 }
300}
301
302impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
305 fn compile_kernel_signature(
306 f: &mut std::fmt::Formatter<'_>,
307 kernel_name: &str,
308 tensor_maps: &[Binding<Self>],
309 buffers: &[Binding<Self>],
310 scalars: &[(Elem<Self>, usize)],
311 flags: &Flags,
312 ) -> std::fmt::Result {
313 write!(
314 f,
315 "
316
317extern \"C\" __global__ void __launch_bounds__({})",
318 flags.cube_dim.num_elems()
319 )?;
320 if let Some(cluster_dim) = flags.cluster_dim {
321 write!(
322 f,
323 "__cluster_dims__({}, {}, {}) ",
324 cluster_dim.x, cluster_dim.y, cluster_dim.z
325 )?;
326 }
327 writeln!(f, "{kernel_name} (")?;
328 let has_scalars =
329 !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
330 shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
331 if flags.use_grid_constants {
332 shared::compile_scalars_static(f, scalars, flags)?;
333 } else {
334 shared::compile_scalars_dynamic(f, scalars)?;
335 }
336 f.write_str("\n)")?;
337 Ok(())
339 }
340
341 fn compile_bindings_body(
342 f: &mut std::fmt::Formatter<'_>,
343 body: &shared::Body<Self>,
344 ) -> std::fmt::Result {
345 if !body.shared_memories.is_empty() {
346 let max_align = body
347 .shared_memories
348 .iter()
349 .map(|smem| smem.align)
350 .max()
351 .unwrap();
352 writeln!(
355 f,
356 "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
357 )?;
358 }
359 Ok(())
360 }
361}
362
363impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
364
365impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
368 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369 write!(f, "cluster.block_rank()")
370 }
371
372 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 write!(f, "cluster.block_index().x")
374 }
375
376 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377 write!(f, "cluster.block_index().y")
378 }
379
380 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 write!(f, "cluster.block_index().z")
382 }
383}
384
385impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
388 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 writeln!(f, "__syncthreads();\n")
391 }
392
393 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 writeln!(f, "__syncwarp();\n")
395 }
396
397 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398 writeln!(f, "__threadfence();")
399 }
400
401 fn compile_instruction_find_first_set<T: Component<Self>>(
403 f: &mut std::fmt::Formatter<'_>,
404 input: T,
405 out_elem: Elem<Self>,
406 ) -> std::fmt::Result {
407 write!(f, "{out_elem}(")?;
408 match input.elem() {
409 Elem::I32 => write!(f, "__ffs({input})"),
410 Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
411 Elem::I64 => write!(f, "__ffsll({input})"),
412 Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
413 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
414 }?;
415 write!(f, ")")
416 }
417
418 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
419 f: &mut std::fmt::Formatter<'_>,
420 input: T,
421 out_elem: Elem<Self>,
422 ) -> std::fmt::Result {
423 write!(f, "{out_elem}(")?;
424 match input.elem() {
425 Elem::I32 => write!(f, "__clz({input})"),
426 Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
427 Elem::I64 => write!(f, "__clzll({input})"),
428 Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
429 in_elem => write!(
430 f,
431 "{out_elem}(__clz({}) - {})",
432 unary::zero_extend(input),
433 (size_of::<u32>() - in_elem.size()) * 8
434 ),
435 }?;
436 write!(f, ")")
437 }
438
439 fn compile_saturating_add(
440 f: &mut std::fmt::Formatter<'_>,
441 lhs: impl Display,
442 rhs: impl Display,
443 item: Item<Self>,
444 ) -> std::fmt::Result {
445 let elem = item.elem();
446 match elem {
447 Elem::I32 => {
448 write!(
449 f,
450 r#"[&]() -> {elem} {{
451 {elem} result;
452 asm("add.sat.s32 %0, %1, %2;"
453 : "=r"(result)
454 : "r"({lhs}), "r"({rhs}));
455 return result;
456 }}()"#
457 )
458 }
459 _ => unreachable!("Should be replaced by polyfill"),
460 }
461 }
462
463 fn compile_saturating_sub(
464 f: &mut std::fmt::Formatter<'_>,
465 lhs: impl Display,
466 rhs: impl Display,
467 item: Item<Self>,
468 ) -> std::fmt::Result {
469 let elem = item.elem();
470 match elem {
472 Elem::I32 => {
473 write!(
474 f,
475 r#"[&]() -> {elem} {{
476 {elem} result;
477 asm("sub.sat.s32 %0, %1, %2;"
478 : "=r"(result)
479 : "r"({lhs}), "r"({rhs}));
480 return result;
481 }}()"#
482 )
483 }
484 _ => unreachable!("Should be replaced by polyfill"),
485 }
486 }
487
488 fn compile_instruction_max_function_name(
490 f: &mut std::fmt::Formatter<'_>,
491 item: Item<Self>,
492 ) -> std::fmt::Result {
493 let max = match item.elem() {
494 Elem::F16 | Elem::BF16 => "__hmax",
495 Elem::F16x2 | Elem::BF16x2 => "__hmax2",
496 _ => "max",
497 };
498 write!(f, "{max}")
499 }
500
501 fn compile_instruction_min_function_name(
502 f: &mut std::fmt::Formatter<'_>,
503 item: Item<Self>,
504 ) -> std::fmt::Result {
505 let min = match item.elem() {
506 Elem::F16 | Elem::BF16 => "__hmin",
507 Elem::F16x2 | Elem::BF16x2 => "__hmin2",
508 _ => "min",
509 };
510 write!(f, "{min}")
511 }
512
513 fn compile_warp_shuffle(
515 f: &mut std::fmt::Formatter<'_>,
516 var: &str,
517 source: &str,
518 ) -> std::fmt::Result {
519 write!(f, "__shfl_sync(-1, {var}, {source})")
520 }
521 fn compile_warp_shuffle_xor(
522 f: &mut std::fmt::Formatter<'_>,
523 var: &str,
524 _elem: &Elem<Self>,
525 offset: &str,
526 ) -> std::fmt::Result {
527 write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
528 }
529 fn compile_warp_shuffle_up(
530 f: &mut std::fmt::Formatter<'_>,
531 var: &str,
532 offset: &str,
533 ) -> std::fmt::Result {
534 write!(f, "__shfl_up_sync(-1, {var}, {offset})")
535 }
536 fn compile_warp_shuffle_down(
537 f: &mut std::fmt::Formatter<'_>,
538 var: &str,
539 offset: &str,
540 ) -> std::fmt::Result {
541 write!(f, "__shfl_down_sync(-1, {var}, {offset})")
542 }
543 fn compile_warp_all<T: Component<Self>>(
544 f: &mut std::fmt::Formatter<'_>,
545 input: &T,
546 ) -> std::fmt::Result {
547 write!(f, "__all_sync(-1, {input})")
548 }
549 fn compile_warp_any<T: Component<Self>>(
550 f: &mut std::fmt::Formatter<'_>,
551 input: &T,
552 ) -> std::fmt::Result {
553 write!(f, "__any_sync(-1, {input})")
554 }
555
556 fn compile_warp_ballot(
557 f: &mut std::fmt::Formatter<'_>,
558 input: &Variable<Self>,
559 _out_elem: &Elem<Self>,
560 ) -> std::fmt::Result {
561 write!(f, "__ballot_sync(-1, {input})")
562 }
563
564 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
565 let elem = Elem::<Self>::Bool;
566 let uint32 = Elem::<Self>::U32;
567 writeln!(
571 f,
572 r#"{out} = {elem}([&]() -> {uint32} {{
573 {uint32} pred = 0;
574 asm volatile(
575 "{{\n"
576 " .reg .pred %%px;\n"
577 " elect.sync _|%%px, 0xffffffff;\n"
578 " selp.b32 %0, 1, 0, %%px;\n"
579 "}}\n"
580 : "+r"(pred));
581 return pred;
582 }}());"#
583 )
584 }
585}
586
587impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
590 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
591 M::compile_wmma_includes(f, flags)
592 }
593
594 fn compile_wmma_type_definitions(
595 f: &mut std::fmt::Formatter<'_>,
596 flags: &Flags,
597 ) -> std::fmt::Result {
598 M::compile_wmma_type_definitions(f, flags)
599 }
600
601 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
602 M::compile_wmma_local_variables(f)
603 }
604
605 fn compile_wmma_fragment_declaration(
606 f: &mut std::fmt::Formatter<'_>,
607 var: &Variable<Self>,
608 ) -> std::fmt::Result {
609 M::compile_wmma_fragment_declaration(f, var)
610 }
611
612 fn compile_wwma_fragment_ident(
613 f: &mut std::fmt::Formatter<'_>,
614 ident: &crate::shared::FragmentIdent<Self>,
615 ) -> std::fmt::Result {
616 M::compile_wwma_fragment_ident(f, ident)
617 }
618
619 fn compile_wmma_fragment_layout(
620 f: &mut std::fmt::Formatter<'_>,
621 layout: &crate::shared::FragmentLayout<Self>,
622 ) -> std::fmt::Result {
623 M::compile_wmma_fragment_layout(f, layout)
624 }
625
626 fn compile_wmma_fragment(
627 f: &mut std::fmt::Formatter<'_>,
628 fragment: &crate::shared::Fragment<Self>,
629 ) -> std::fmt::Result {
630 M::compile_wmma_fragment(f, fragment)
631 }
632
633 fn compile_wmma_instruction(
634 f: &mut std::fmt::Formatter<'_>,
635 instruction: &crate::shared::WmmaInstruction<Self>,
636 ) -> std::fmt::Result {
637 M::compile_wmma_instruction(f, instruction)
638 }
639
640 fn compile_manual_mma(
641 f: &mut std::fmt::Formatter<'_>,
642 mma: ManualMma<Self>,
643 ) -> std::fmt::Result {
644 M::compile_manual_mma(f, mma)
645 }
646
647 fn compile_scaled_mma(
648 f: &mut std::fmt::Formatter<'_>,
649 mma: ManualMma<Self>,
650 scales_a: Variable<Self>,
651 scales_b: Variable<Self>,
652 scales_factor: u32,
653 ) -> std::fmt::Result {
654 M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
655 }
656
657 fn supported_wmma_combinations(
658 arch: &CudaArchitecture,
659 ) -> crate::shared::SupportedMmaCombinations {
660 M::supported_wmma_combinations(arch)
661 }
662
663 fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
664 M::supported_mma_combinations(arch)
665 }
666
667 fn supported_scaled_mma_combinations(
668 arch: &CudaArchitecture,
669 ) -> shared::SupportedScaledMmaCombinations {
670 M::supported_scaled_mma_combinations(arch)
671 }
672}
673
674impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
675 fn processors() -> Vec<Box<dyn Processor>> {
676 vec![
677 Box::new(CudaMmaProcessor),
678 Box::new(SaturatingArithmeticProcessor::new(false)),
679 ]
680 }
681}