1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
4
5use crate::{
6 Dialect,
7 cuda::{
8 extension::{Fragment, MmaExecute, MmaExecuteScaled, MmaExtension},
9 processors::CudaMmaProcessor,
10 ptx::*,
11 },
12 shared::{
13 self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
14 DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
15 DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
16 Variable, WarpInstruction, unary,
17 },
18};
19
20use super::{Extension, arch::CudaArchitecture};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaDialect<M> {
24 _wmma_compiler: PhantomData<M>,
25}
26
27impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
28 type Architecture = CudaArchitecture;
29}
30
31impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
32 type Extension = Extension<Self>;
33
34 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
35 f.write_str("#include <cuda_runtime.h>\n")?;
36 if flags.elem_fp4 {
37 f.write_str("#include <cuda_fp4.h>\n")?;
38 }
39 if flags.elem_fp6 {
40 f.write_str("#include <cuda_fp6.h>\n")?;
41 }
42 if flags.elem_fp8 {
43 f.write_str("#include <cuda_fp8.h>\n")?;
44 }
45 if flags.elem_bf16 {
46 f.write_str("#include <cuda_bf16.h>\n")?;
47 }
48 if flags.elem_f16 {
49 f.write_str("#include <cuda_fp16.h>\n")?;
50 }
51
52 if flags.inst_wmma || flags.elem_tf32 {
54 Self::compile_wmma_includes(f, flags)?;
55 }
56
57 if flags.op_pipeline {
58 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
59 f.write_str("#include <cuda/pipeline>\n")?;
60 }
61 if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
62 f.write_str("#include <cooperative_groups.h>\n")?;
63 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
64 f.write_str("#include <cuda/barrier>\n")?;
65 }
66 if flags.inst_ptx_wrappers {
67 f.write_str("#include <cuda/ptx>\n")?;
68 }
69 if flags.inst_tma {
70 f.write_str(
71 "typedef struct CUtensorMap_st {
72alignas(64) unsigned long long int opaque[16];
73} CUtensorMap;\n",
74 )?;
75 }
76 Ok(())
77 }
78
79 fn compile_extensions(
80 f: &mut std::fmt::Formatter<'_>,
81 extensions: &[Self::Extension],
82 ) -> std::fmt::Result {
83 for extension in extensions {
84 match extension {
85 Extension::NoExtension => {}
86 Extension::Mma(mma) => mma.format_extension(f)?,
87 }
88 }
89 Ok(())
90 }
91
92 fn register_instruction_extension(
93 _extensions: &mut Vec<Self::Extension>,
94 _instruction: &Instruction<Self>,
95 ) {
96 }
97
98 fn register_warp_instruction_extension(
99 _extensions: &mut Vec<Self::Extension>,
100 _instruction: &WarpInstruction<Self>,
101 ) {
102 }
103
104 fn register_wmma_instruction_extension(
105 extensions: &mut Vec<Self::Extension>,
106 instruction: &shared::WmmaInstruction<Self>,
107 ) {
108 match instruction {
109 shared::WmmaInstruction::ExecuteManual {
110 shape,
111 frag_a,
112 frag_b,
113 frag_c,
114 frag_d,
115 } => {
116 let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
117 *shape,
118 vars_to_frag(frag_a),
119 vars_to_frag(frag_b),
120 vars_to_frag(frag_c),
121 Fragment(frag_d.elem()),
122 )));
123 if !extensions.contains(&ext) {
124 extensions.push(ext);
125 }
126 }
127 shared::WmmaInstruction::ExecuteScaled {
128 shape,
129 frag_a,
130 frag_b,
131 frag_c,
132 frag_d,
133 scales_a,
134 scales_factor,
135 ..
136 } => {
137 let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
138 *shape,
139 vars_to_frag(frag_a),
140 vars_to_frag(frag_b),
141 vars_to_frag(frag_c),
142 Fragment(frag_d.elem()),
143 scales_a.elem(),
144 *scales_factor,
145 )));
146 if !extensions.contains(&ext) {
147 extensions.push(ext);
148 }
149 }
150 _ => {}
151 }
152 }
153}
154
155fn vars_to_frag<D: Dialect>(vars: &[Variable<D>]) -> Fragment<D> {
156 let elem = vars[0].elem();
157 Fragment(elem)
158}
159
160impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
163 fn item_can_be_optimized() -> bool {
164 true
165 }
166
167 fn compile_type_definitions(
168 f: &mut std::fmt::Formatter<'_>,
169 items: &HashSet<Item<Self>>,
170 scalars: &[(Elem<Self>, usize)],
171 flags: &Flags,
172 ) -> std::fmt::Result {
173 let mut items_deduplicated = HashSet::new();
175
176 for item in items {
177 let mut item = *item;
178 match item.elem() {
179 Elem::FP4(_) => {
180 item.elem = Elem::FP4(FP4Kind::E2M1);
181 }
182 Elem::FP4x2(_) => {
183 item.elem = Elem::FP4x2(FP4Kind::E2M1);
184 }
185 Elem::FP6(_) => {
186 item.elem = Elem::FP6(FP6Kind::E2M3);
187 }
188 Elem::FP6x2(_) => {
189 item.elem = Elem::FP6x2(FP6Kind::E2M3);
190 }
191 Elem::FP8(_) => {
192 item.elem = Elem::FP8(FP8Kind::E4M3);
193 }
194 Elem::FP8x2(_) => {
195 item.elem = Elem::FP8x2(FP8Kind::E4M3);
196 }
197 _ => {}
198 }
199 items_deduplicated.insert(item);
200 }
201
202 shared::type_definitions::<Self>(f)?;
203 shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
204
205 if flags.use_grid_constants {
206 shared::type_scalar_definitions::<Self>(f, scalars)?;
207 shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
208 }
209
210 if flags.inst_wmma {
211 Self::compile_wmma_type_definitions(f, flags)?;
212 }
213
214 Ok(())
215 }
216
217 fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
218 if flags.inst_tma_im2col {
219 writeln!(f, "{TMA_LOAD_IM2COL}")?;
220 }
221 Ok(())
222 }
223
224 fn compile_elem(
225 f: &mut std::fmt::Formatter<'_>,
226 elem: &shared::Elem<Self>,
227 words: bool,
228 ) -> std::fmt::Result {
229 if words {
230 match elem {
231 shared::Elem::F32 => f.write_str("float"),
232 shared::Elem::F64 => f.write_str("double"),
233 shared::Elem::TF32 => f.write_str("float"),
234 shared::Elem::I8 => f.write_str("char"),
235 shared::Elem::I16 => f.write_str("short"),
236 shared::Elem::I32 => f.write_str("int"),
237 shared::Elem::I64 => f.write_str("long"),
238 shared::Elem::U8 => f.write_str("uchar"),
239 shared::Elem::U16 => f.write_str("ushort"),
240 shared::Elem::U32 => f.write_str("uint"),
241 shared::Elem::U64 => f.write_str("ulong"),
242 _ => Self::compile_elem(f, elem, false),
243 }
244 } else {
245 match elem {
246 shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
247 shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
248 shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
249 shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
250 shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
251 shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
252 shared::Elem::F16 => f.write_str("__half"),
253 shared::Elem::F16x2 => f.write_str("__half2"),
254 shared::Elem::F32 => f.write_str("float"),
255 shared::Elem::F64 => f.write_str("double"),
256 shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
257 shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
258 shared::Elem::TF32 => f.write_str("float"),
259 shared::Elem::I8 => f.write_str("int8"),
260 shared::Elem::I16 => f.write_str("int16"),
261 shared::Elem::I32 => f.write_str("int32"),
262 shared::Elem::I64 => f.write_str("int64"),
263 shared::Elem::U8 => f.write_str("uint8"),
264 shared::Elem::U16 => f.write_str("uint16"),
265 shared::Elem::U32 => f.write_str("uint32"),
266 shared::Elem::U64 => f.write_str("uint64"),
267 shared::Elem::Bool => f.write_str("bool"),
268 shared::Elem::Atomic(inner) => write!(f, "{inner}"),
269 shared::Elem::_Dialect(_) => Ok(()),
270 }
271 }
272 }
273
274 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
275 if 1 == item.vectorization {
276 return write!(f, "{}", item.elem);
277 }
278 if item.native {
279 Self::compile_elem(f, &item.elem, true)?;
281 write!(f, "{}", item.vectorization)
282 } else {
283 write!(f, "{}_{}", item.elem, item.vectorization)
284 }
285 }
286
287 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 Ok(())
289 }
290}
291
292impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
295 fn compile_kernel_signature(
296 f: &mut std::fmt::Formatter<'_>,
297 kernel_name: &str,
298 tensor_maps: &[Binding<Self>],
299 buffers: &[Binding<Self>],
300 scalars: &[(Elem<Self>, usize)],
301 flags: &Flags,
302 ) -> std::fmt::Result {
303 write!(
304 f,
305 "
306
307extern \"C\" __global__ void __launch_bounds__({})",
308 flags.cube_dim.num_elems()
309 )?;
310 if let Some(cluster_dim) = flags.cluster_dim {
311 write!(
312 f,
313 "__cluster_dims__({}, {}, {}) ",
314 cluster_dim.x, cluster_dim.y, cluster_dim.z
315 )?;
316 }
317 writeln!(f, "{kernel_name} (")?;
318 let has_scalars =
319 !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
320 shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
321 if flags.use_grid_constants {
322 shared::compile_scalars_static(f, scalars, flags)?;
323 } else {
324 shared::compile_scalars_dynamic(f, scalars)?;
325 }
326 f.write_str("\n)")?;
327 Ok(())
329 }
330
331 fn compile_bindings_body(
332 f: &mut std::fmt::Formatter<'_>,
333 body: &shared::Body<Self>,
334 ) -> std::fmt::Result {
335 if !body.shared_memories.is_empty() {
336 let max_align = body
337 .shared_memories
338 .iter()
339 .map(|smem| smem.align)
340 .max()
341 .unwrap();
342 writeln!(
345 f,
346 "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
347 )?;
348 }
349 Ok(())
350 }
351}
352
353impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
354
355impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
358 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 write!(f, "cluster.block_rank()")
360 }
361
362 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 write!(f, "cluster.block_index().x")
364 }
365
366 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 write!(f, "cluster.block_index().y")
368 }
369
370 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 write!(f, "cluster.block_index().z")
372 }
373}
374
375impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
378 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 writeln!(f, "__syncthreads();\n")
381 }
382
383 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384 writeln!(f, "__syncwarp();\n")
385 }
386
387 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 writeln!(f, "__threadfence();")
389 }
390
391 fn compile_instruction_find_first_set<T: Component<Self>>(
393 f: &mut std::fmt::Formatter<'_>,
394 input: T,
395 out_elem: Elem<Self>,
396 ) -> std::fmt::Result {
397 write!(f, "{out_elem}(")?;
398 match input.elem() {
399 Elem::I32 => write!(f, "__ffs({input})"),
400 Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
401 Elem::I64 => write!(f, "__ffsll({input})"),
402 Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
403 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
404 }?;
405 write!(f, ")")
406 }
407
408 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
409 f: &mut std::fmt::Formatter<'_>,
410 input: T,
411 out_elem: Elem<Self>,
412 ) -> std::fmt::Result {
413 write!(f, "{out_elem}(")?;
414 match input.elem() {
415 Elem::I32 => write!(f, "__clz({input})"),
416 Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
417 Elem::I64 => write!(f, "__clzll({input})"),
418 Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
419 in_elem => write!(
420 f,
421 "{out_elem}(__clz({}) - {})",
422 unary::zero_extend(input),
423 (size_of::<u32>() - in_elem.size()) * 8
424 ),
425 }?;
426 write!(f, ")")
427 }
428
429 fn compile_saturating_add(
430 f: &mut std::fmt::Formatter<'_>,
431 lhs: impl Display,
432 rhs: impl Display,
433 item: Item<Self>,
434 ) -> std::fmt::Result {
435 let elem = item.elem();
436 match elem {
437 Elem::I32 => {
438 write!(
439 f,
440 r#"[&]() -> {elem} {{
441 {elem} result;
442 asm("add.sat.s32 %0, %1, %2;"
443 : "=r"(result)
444 : "r"({lhs}), "r"({rhs}));
445 return result;
446 }}()"#
447 )
448 }
449 _ => unreachable!("Should be replaced by polyfill"),
450 }
451 }
452
453 fn compile_saturating_sub(
454 f: &mut std::fmt::Formatter<'_>,
455 lhs: impl Display,
456 rhs: impl Display,
457 item: Item<Self>,
458 ) -> std::fmt::Result {
459 let elem = item.elem();
460 match elem {
462 Elem::I32 => {
463 write!(
464 f,
465 r#"[&]() -> {elem} {{
466 {elem} result;
467 asm("sub.sat.s32 %0, %1, %2;"
468 : "=r"(result)
469 : "r"({lhs}), "r"({rhs}));
470 return result;
471 }}()"#
472 )
473 }
474 _ => unreachable!("Should be replaced by polyfill"),
475 }
476 }
477
478 fn compile_instruction_max_function_name(
480 f: &mut std::fmt::Formatter<'_>,
481 item: Item<Self>,
482 ) -> std::fmt::Result {
483 let max = match item.elem() {
484 Elem::F16 | Elem::BF16 => "__hmax",
485 Elem::F16x2 | Elem::BF16x2 => "__hmax2",
486 _ => "max",
487 };
488 write!(f, "{max}")
489 }
490
491 fn compile_instruction_min_function_name(
492 f: &mut std::fmt::Formatter<'_>,
493 item: Item<Self>,
494 ) -> std::fmt::Result {
495 let min = match item.elem() {
496 Elem::F16 | Elem::BF16 => "__hmin",
497 Elem::F16x2 | Elem::BF16x2 => "__hmin2",
498 _ => "min",
499 };
500 write!(f, "{min}")
501 }
502
503 fn compile_warp_shuffle(
505 f: &mut std::fmt::Formatter<'_>,
506 var: &str,
507 source: &str,
508 ) -> std::fmt::Result {
509 write!(f, "__shfl_sync(-1, {var}, {source})")
510 }
511 fn compile_warp_shuffle_xor(
512 f: &mut std::fmt::Formatter<'_>,
513 var: &str,
514 _elem: &Elem<Self>,
515 offset: &str,
516 ) -> std::fmt::Result {
517 write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
518 }
519 fn compile_warp_shuffle_up(
520 f: &mut std::fmt::Formatter<'_>,
521 var: &str,
522 offset: &str,
523 ) -> std::fmt::Result {
524 write!(f, "__shfl_up_sync(-1, {var}, {offset})")
525 }
526 fn compile_warp_shuffle_down(
527 f: &mut std::fmt::Formatter<'_>,
528 var: &str,
529 offset: &str,
530 ) -> std::fmt::Result {
531 write!(f, "__shfl_down_sync(-1, {var}, {offset})")
532 }
533 fn compile_warp_all<T: Component<Self>>(
534 f: &mut std::fmt::Formatter<'_>,
535 input: &T,
536 ) -> std::fmt::Result {
537 write!(f, "__all_sync(-1, {input})")
538 }
539 fn compile_warp_any<T: Component<Self>>(
540 f: &mut std::fmt::Formatter<'_>,
541 input: &T,
542 ) -> std::fmt::Result {
543 write!(f, "__any_sync(-1, {input})")
544 }
545
546 fn compile_warp_ballot(
547 f: &mut std::fmt::Formatter<'_>,
548 input: &Variable<Self>,
549 _out_elem: &Elem<Self>,
550 ) -> std::fmt::Result {
551 write!(f, "__ballot_sync(-1, {input})")
552 }
553
554 fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
555 let elem = Elem::<Self>::Bool;
556 let uint32 = Elem::<Self>::U32;
557 writeln!(
561 f,
562 r#"{out} = {elem}([&]() -> {uint32} {{
563 {uint32} pred = 0;
564 asm volatile(
565 "{{\n"
566 " .reg .pred %%px;\n"
567 " elect.sync _|%%px, 0xffffffff;\n"
568 " selp.b32 %0, 1, 0, %%px;\n"
569 "}}\n"
570 : "+r"(pred));
571 return pred;
572 }}());"#
573 )
574 }
575}
576
577impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
580 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
581 M::compile_wmma_includes(f, flags)
582 }
583
584 fn compile_wmma_type_definitions(
585 f: &mut std::fmt::Formatter<'_>,
586 flags: &Flags,
587 ) -> std::fmt::Result {
588 M::compile_wmma_type_definitions(f, flags)
589 }
590
591 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592 M::compile_wmma_local_variables(f)
593 }
594
595 fn compile_wmma_fragment_declaration(
596 f: &mut std::fmt::Formatter<'_>,
597 var: &Variable<Self>,
598 ) -> std::fmt::Result {
599 M::compile_wmma_fragment_declaration(f, var)
600 }
601
602 fn compile_wwma_fragment_ident(
603 f: &mut std::fmt::Formatter<'_>,
604 ident: &crate::shared::FragmentIdent<Self>,
605 ) -> std::fmt::Result {
606 M::compile_wwma_fragment_ident(f, ident)
607 }
608
609 fn compile_wmma_fragment_layout(
610 f: &mut std::fmt::Formatter<'_>,
611 layout: &crate::shared::FragmentLayout<Self>,
612 ) -> std::fmt::Result {
613 M::compile_wmma_fragment_layout(f, layout)
614 }
615
616 fn compile_wmma_fragment(
617 f: &mut std::fmt::Formatter<'_>,
618 fragment: &crate::shared::Fragment<Self>,
619 ) -> std::fmt::Result {
620 M::compile_wmma_fragment(f, fragment)
621 }
622
623 fn compile_wmma_instruction(
624 f: &mut std::fmt::Formatter<'_>,
625 instruction: &crate::shared::WmmaInstruction<Self>,
626 ) -> std::fmt::Result {
627 M::compile_wmma_instruction(f, instruction)
628 }
629
630 fn compile_manual_mma(
631 f: &mut std::fmt::Formatter<'_>,
632 mma: ManualMma<Self>,
633 ) -> std::fmt::Result {
634 M::compile_manual_mma(f, mma)
635 }
636
637 fn compile_scaled_mma(
638 f: &mut std::fmt::Formatter<'_>,
639 mma: ManualMma<Self>,
640 scales_a: Variable<Self>,
641 scales_b: Variable<Self>,
642 scales_factor: u32,
643 ) -> std::fmt::Result {
644 M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
645 }
646
647 fn supported_wmma_combinations(
648 arch: &CudaArchitecture,
649 ) -> crate::shared::SupportedMmaCombinations {
650 M::supported_wmma_combinations(arch)
651 }
652
653 fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
654 M::supported_mma_combinations(arch)
655 }
656
657 fn supported_scaled_mma_combinations(
658 arch: &CudaArchitecture,
659 ) -> shared::SupportedScaledMmaCombinations {
660 M::supported_scaled_mma_combinations(arch)
661 }
662}
663
664impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
665 fn processors() -> Vec<Box<dyn Processor>> {
666 vec![
667 Box::new(CudaMmaProcessor),
668 Box::new(SaturatingArithmeticProcessor::new(false)),
669 ]
670 }
671}