1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
4
5use crate::{
6 Dialect,
7 cuda::{
8 extension::{Fragment, MmaExecute, MmaExecuteScaled, MmaExtension},
9 processors::CudaMmaProcessor,
10 ptx::*,
11 },
12 shared::{
13 self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
14 DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
15 DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
16 Variable, WarpInstruction, unary,
17 },
18};
19
20use super::{Extension, arch::CudaArchitecture};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaDialect<M> {
24 _wmma_compiler: PhantomData<M>,
25}
26
27impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
28 type Architecture = CudaArchitecture;
29}
30
31impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
32 type Extension = Extension<Self>;
33
34 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
35 f.write_str("#include <cuda_runtime.h>\n")?;
36 if flags.elem_fp4 {
37 f.write_str("#include <cuda_fp4.h>\n")?;
38 }
39 if flags.elem_fp6 {
40 f.write_str("#include <cuda_fp6.h>\n")?;
41 }
42 if flags.elem_fp8 {
43 f.write_str("#include <cuda_fp8.h>\n")?;
44 }
45 if flags.elem_bf16 {
46 f.write_str("#include <cuda_bf16.h>\n")?;
47 }
48 if flags.elem_f16 {
49 f.write_str("#include <cuda_fp16.h>\n")?;
50 }
51
52 if flags.inst_wmma || flags.elem_tf32 {
54 Self::compile_wmma_includes(f, flags)?;
55 }
56
57 if flags.op_pipeline {
58 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
59 f.write_str("#include <cuda/pipeline>\n")?;
60 }
61 if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
62 f.write_str("#include <cooperative_groups.h>\n")?;
63 f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
64 f.write_str("#include <cuda/barrier>\n")?;
65 }
66 if flags.inst_tma {
67 f.write_str(
68 "typedef struct CUtensorMap_st {
69alignas(64) unsigned long long int opaque[16];
70} CUtensorMap;\n",
71 )?;
72 }
73 Ok(())
74 }
75
76 fn compile_extensions(
77 f: &mut std::fmt::Formatter<'_>,
78 extensions: &[Self::Extension],
79 ) -> std::fmt::Result {
80 for extension in extensions {
81 match extension {
82 Extension::NoExtension => {}
83 Extension::Mma(mma) => mma.format_extension(f)?,
84 }
85 }
86 Ok(())
87 }
88
89 fn register_instruction_extension(
90 _extensions: &mut Vec<Self::Extension>,
91 _instruction: &Instruction<Self>,
92 ) {
93 }
94
95 fn register_warp_instruction_extension(
96 _extensions: &mut Vec<Self::Extension>,
97 _instruction: &WarpInstruction<Self>,
98 ) {
99 }
100
101 fn register_wmma_instruction_extension(
102 extensions: &mut Vec<Self::Extension>,
103 instruction: &shared::WmmaInstruction<Self>,
104 ) {
105 match instruction {
106 shared::WmmaInstruction::ExecuteManual {
107 shape,
108 frag_a,
109 frag_b,
110 frag_c,
111 frag_d,
112 } => {
113 let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
114 *shape,
115 vars_to_frag(frag_a),
116 vars_to_frag(frag_b),
117 vars_to_frag(frag_c),
118 Fragment(frag_d.elem()),
119 )));
120 if !extensions.contains(&ext) {
121 extensions.push(ext);
122 }
123 }
124 shared::WmmaInstruction::ExecuteScaled {
125 shape,
126 frag_a,
127 frag_b,
128 frag_c,
129 frag_d,
130 scales_a,
131 scales_factor,
132 ..
133 } => {
134 let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
135 *shape,
136 vars_to_frag(frag_a),
137 vars_to_frag(frag_b),
138 vars_to_frag(frag_c),
139 Fragment(frag_d.elem()),
140 scales_a.elem(),
141 *scales_factor,
142 )));
143 if !extensions.contains(&ext) {
144 extensions.push(ext);
145 }
146 }
147 _ => {}
148 }
149 }
150}
151
152fn vars_to_frag<D: Dialect>(vars: &[Variable<D>]) -> Fragment<D> {
153 let elem = vars[0].elem();
154 Fragment(elem)
155}
156
157impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
160 fn item_can_be_optimized() -> bool {
161 true
162 }
163
164 fn compile_type_definitions(
165 f: &mut std::fmt::Formatter<'_>,
166 items: &HashSet<Item<Self>>,
167 scalars: &[(Elem<Self>, usize)],
168 flags: &Flags,
169 ) -> std::fmt::Result {
170 let mut items_deduplicated = HashSet::new();
172
173 for item in items {
174 let mut item = *item;
175 match item.elem() {
176 Elem::FP4(_) => {
177 item.elem = Elem::FP4(FP4Kind::E2M1);
178 }
179 Elem::FP4x2(_) => {
180 item.elem = Elem::FP4x2(FP4Kind::E2M1);
181 }
182 Elem::FP6(_) => {
183 item.elem = Elem::FP6(FP6Kind::E2M3);
184 }
185 Elem::FP6x2(_) => {
186 item.elem = Elem::FP6x2(FP6Kind::E2M3);
187 }
188 Elem::FP8(_) => {
189 item.elem = Elem::FP8(FP8Kind::E4M3);
190 }
191 Elem::FP8x2(_) => {
192 item.elem = Elem::FP8x2(FP8Kind::E4M3);
193 }
194 _ => {}
195 }
196 items_deduplicated.insert(item);
197 }
198
199 shared::type_definitions::<Self>(f)?;
200 shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
201
202 if flags.use_grid_constants {
203 shared::type_scalar_definitions::<Self>(f, scalars)?;
204 shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
205 }
206
207 if flags.inst_wmma {
208 Self::compile_wmma_type_definitions(f, flags)?;
209 }
210
211 Ok(())
212 }
213
214 fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
215 if flags.inst_tma_im2col {
216 writeln!(f, "{TMA_LOAD_IM2COL}")?;
217 }
218 Ok(())
219 }
220
221 fn compile_elem(
222 f: &mut std::fmt::Formatter<'_>,
223 elem: &shared::Elem<Self>,
224 words: bool,
225 ) -> std::fmt::Result {
226 if words {
227 match elem {
228 shared::Elem::F32 => f.write_str("float"),
229 shared::Elem::F64 => f.write_str("double"),
230 shared::Elem::TF32 => f.write_str("float"),
231 shared::Elem::I8 => f.write_str("char"),
232 shared::Elem::I16 => f.write_str("short"),
233 shared::Elem::I32 => f.write_str("int"),
234 shared::Elem::I64 => f.write_str("long"),
235 shared::Elem::U8 => f.write_str("uchar"),
236 shared::Elem::U16 => f.write_str("ushort"),
237 shared::Elem::U32 => f.write_str("uint"),
238 shared::Elem::U64 => f.write_str("ulong"),
239 _ => Self::compile_elem(f, elem, false),
240 }
241 } else {
242 match elem {
243 shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
244 shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
245 shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
246 shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
247 shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
248 shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
249 shared::Elem::F16 => f.write_str("__half"),
250 shared::Elem::F16x2 => f.write_str("__half2"),
251 shared::Elem::F32 => f.write_str("float"),
252 shared::Elem::F64 => f.write_str("double"),
253 shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
254 shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
255 shared::Elem::TF32 => f.write_str("float"),
256 shared::Elem::I8 => f.write_str("int8"),
257 shared::Elem::I16 => f.write_str("int16"),
258 shared::Elem::I32 => f.write_str("int32"),
259 shared::Elem::I64 => f.write_str("int64"),
260 shared::Elem::U8 => f.write_str("uint8"),
261 shared::Elem::U16 => f.write_str("uint16"),
262 shared::Elem::U32 => f.write_str("uint32"),
263 shared::Elem::U64 => f.write_str("uint64"),
264 shared::Elem::Bool => f.write_str("bool"),
265 shared::Elem::Atomic(inner) => write!(f, "{inner}"),
266 shared::Elem::_Dialect(_) => Ok(()),
267 }
268 }
269 }
270
271 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
272 if 1 == item.vectorization {
273 return write!(f, "{}", item.elem);
274 }
275 if item.native {
276 Self::compile_elem(f, &item.elem, true)?;
278 write!(f, "{}", item.vectorization)
279 } else {
280 write!(f, "{}_{}", item.elem, item.vectorization)
281 }
282 }
283
284 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 Ok(())
286 }
287}
288
289impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
292 fn compile_kernel_signature(
293 f: &mut std::fmt::Formatter<'_>,
294 kernel_name: &str,
295 tensor_maps: &[Binding<Self>],
296 buffers: &[Binding<Self>],
297 scalars: &[(Elem<Self>, usize)],
298 flags: &Flags,
299 ) -> std::fmt::Result {
300 write!(
301 f,
302 "
303
304extern \"C\" __global__ void __launch_bounds__({})",
305 flags.cube_dim.num_elems()
306 )?;
307 if let Some(cluster_dim) = flags.cluster_dim {
308 write!(
309 f,
310 "__cluster_dims__({}, {}, {}) ",
311 cluster_dim.x, cluster_dim.y, cluster_dim.z
312 )?;
313 }
314 writeln!(f, "{kernel_name} (")?;
315 let has_scalars =
316 !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
317 shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
318 if flags.use_grid_constants {
319 shared::compile_scalars_static(f, scalars, flags)?;
320 } else {
321 shared::compile_scalars_dynamic(f, scalars)?;
322 }
323 f.write_str("\n)")?;
324 Ok(())
326 }
327
328 fn compile_bindings_body(
329 f: &mut std::fmt::Formatter<'_>,
330 body: &shared::Body<Self>,
331 ) -> std::fmt::Result {
332 if !body.shared_memories.is_empty() {
333 let max_align = body
334 .shared_memories
335 .iter()
336 .map(|smem| smem.align)
337 .max()
338 .unwrap();
339 writeln!(
342 f,
343 "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
344 )?;
345 }
346 Ok(())
347 }
348}
349
350impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
351
352impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
355 fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 write!(f, "cluster.block_rank()")
357 }
358
359 fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 write!(f, "cluster.block_index().x")
361 }
362
363 fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 write!(f, "cluster.block_index().y")
365 }
366
367 fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 write!(f, "cluster.block_index().z")
369 }
370}
371
372impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
375 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377 writeln!(f, "__syncthreads();\n")
378 }
379
380 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 writeln!(f, "__syncwarp();\n")
382 }
383
384 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 writeln!(f, "__threadfence();")
386 }
387
388 fn compile_instruction_find_first_set<T: Component<Self>>(
390 f: &mut std::fmt::Formatter<'_>,
391 input: T,
392 out_elem: Elem<Self>,
393 ) -> std::fmt::Result {
394 write!(f, "{out_elem}(")?;
395 match input.elem() {
396 Elem::I32 => write!(f, "__ffs({input})"),
397 Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
398 Elem::I64 => write!(f, "__ffsll({input})"),
399 Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
400 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
401 }?;
402 write!(f, ")")
403 }
404
405 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
406 f: &mut std::fmt::Formatter<'_>,
407 input: T,
408 out_elem: Elem<Self>,
409 ) -> std::fmt::Result {
410 write!(f, "{out_elem}(")?;
411 match input.elem() {
412 Elem::I32 => write!(f, "__clz({input})"),
413 Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
414 Elem::I64 => write!(f, "__clzll({input})"),
415 Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
416 in_elem => write!(
417 f,
418 "{out_elem}(__clz({}) - {})",
419 unary::zero_extend(input),
420 (size_of::<u32>() - in_elem.size()) * 8
421 ),
422 }?;
423 write!(f, ")")
424 }
425
426 fn compile_saturating_add(
427 f: &mut std::fmt::Formatter<'_>,
428 lhs: impl Display,
429 rhs: impl Display,
430 item: Item<Self>,
431 ) -> std::fmt::Result {
432 let elem = item.elem();
433 match elem {
434 Elem::I32 => {
435 write!(
436 f,
437 r#"[&]() -> {elem} {{
438 {elem} result;
439 asm("add.sat.s32 %0, %1, %2;"
440 : "=r"(result)
441 : "r"({lhs}), "r"({rhs}));
442 return result;
443 }}()"#
444 )
445 }
446 _ => unreachable!("Should be replaced by polyfill"),
447 }
448 }
449
450 fn compile_saturating_sub(
451 f: &mut std::fmt::Formatter<'_>,
452 lhs: impl Display,
453 rhs: impl Display,
454 item: Item<Self>,
455 ) -> std::fmt::Result {
456 let elem = item.elem();
457 match elem {
459 Elem::I32 => {
460 write!(
461 f,
462 r#"[&]() -> {elem} {{
463 {elem} result;
464 asm("sub.sat.s32 %0, %1, %2;"
465 : "=r"(result)
466 : "r"({lhs}), "r"({rhs}));
467 return result;
468 }}()"#
469 )
470 }
471 _ => unreachable!("Should be replaced by polyfill"),
472 }
473 }
474
475 fn compile_instruction_max_function_name(
477 f: &mut std::fmt::Formatter<'_>,
478 item: Item<Self>,
479 ) -> std::fmt::Result {
480 let max = match item.elem() {
481 Elem::F16 | Elem::BF16 => "__hmax",
482 Elem::F16x2 | Elem::BF16x2 => "__hmax2",
483 _ => "max",
484 };
485 write!(f, "{max}")
486 }
487
488 fn compile_instruction_min_function_name(
489 f: &mut std::fmt::Formatter<'_>,
490 item: Item<Self>,
491 ) -> std::fmt::Result {
492 let min = match item.elem() {
493 Elem::F16 | Elem::BF16 => "__hmin",
494 Elem::F16x2 | Elem::BF16x2 => "__hmin2",
495 _ => "min",
496 };
497 write!(f, "{min}")
498 }
499
500 fn compile_warp_shuffle(
502 f: &mut std::fmt::Formatter<'_>,
503 var: &str,
504 source: &str,
505 ) -> std::fmt::Result {
506 write!(f, "__shfl_sync(-1, {var}, {source})")
507 }
508 fn compile_warp_shuffle_xor(
509 f: &mut std::fmt::Formatter<'_>,
510 var: &str,
511 _elem: &Elem<Self>,
512 offset: &str,
513 ) -> std::fmt::Result {
514 write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
515 }
516 fn compile_warp_shuffle_up(
517 f: &mut std::fmt::Formatter<'_>,
518 var: &str,
519 offset: &str,
520 ) -> std::fmt::Result {
521 write!(f, "__shfl_up_sync(-1, {var}, {offset})")
522 }
523 fn compile_warp_shuffle_down(
524 f: &mut std::fmt::Formatter<'_>,
525 var: &str,
526 offset: &str,
527 ) -> std::fmt::Result {
528 write!(f, "__shfl_down_sync(-1, {var}, {offset})")
529 }
530 fn compile_warp_all<T: Component<Self>>(
531 f: &mut std::fmt::Formatter<'_>,
532 input: &T,
533 ) -> std::fmt::Result {
534 write!(f, "__all_sync(-1, {input})")
535 }
536 fn compile_warp_any<T: Component<Self>>(
537 f: &mut std::fmt::Formatter<'_>,
538 input: &T,
539 ) -> std::fmt::Result {
540 write!(f, "__any_sync(-1, {input})")
541 }
542
543 fn compile_warp_ballot(
544 f: &mut std::fmt::Formatter<'_>,
545 input: &Variable<Self>,
546 _out_elem: &Elem<Self>,
547 ) -> std::fmt::Result {
548 write!(f, "__ballot_sync(-1, {input})")
549 }
550}
551
552impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
555 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
556 M::compile_wmma_includes(f, flags)
557 }
558
559 fn compile_wmma_type_definitions(
560 f: &mut std::fmt::Formatter<'_>,
561 flags: &Flags,
562 ) -> std::fmt::Result {
563 M::compile_wmma_type_definitions(f, flags)
564 }
565
566 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
567 M::compile_wmma_local_variables(f)
568 }
569
570 fn compile_wmma_fragment_declaration(
571 f: &mut std::fmt::Formatter<'_>,
572 var: &Variable<Self>,
573 ) -> std::fmt::Result {
574 M::compile_wmma_fragment_declaration(f, var)
575 }
576
577 fn compile_wwma_fragment_ident(
578 f: &mut std::fmt::Formatter<'_>,
579 ident: &crate::shared::FragmentIdent<Self>,
580 ) -> std::fmt::Result {
581 M::compile_wwma_fragment_ident(f, ident)
582 }
583
584 fn compile_wmma_fragment_layout(
585 f: &mut std::fmt::Formatter<'_>,
586 layout: &crate::shared::FragmentLayout<Self>,
587 ) -> std::fmt::Result {
588 M::compile_wmma_fragment_layout(f, layout)
589 }
590
591 fn compile_wmma_fragment(
592 f: &mut std::fmt::Formatter<'_>,
593 fragment: &crate::shared::Fragment<Self>,
594 ) -> std::fmt::Result {
595 M::compile_wmma_fragment(f, fragment)
596 }
597
598 fn compile_wmma_instruction(
599 f: &mut std::fmt::Formatter<'_>,
600 instruction: &crate::shared::WmmaInstruction<Self>,
601 ) -> std::fmt::Result {
602 M::compile_wmma_instruction(f, instruction)
603 }
604
605 fn compile_manual_mma(
606 f: &mut std::fmt::Formatter<'_>,
607 mma: ManualMma<Self>,
608 ) -> std::fmt::Result {
609 M::compile_manual_mma(f, mma)
610 }
611
612 fn compile_scaled_mma(
613 f: &mut std::fmt::Formatter<'_>,
614 mma: ManualMma<Self>,
615 scales_a: Variable<Self>,
616 scales_b: Variable<Self>,
617 scales_factor: u32,
618 ) -> std::fmt::Result {
619 M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
620 }
621
622 fn supported_wmma_combinations(
623 arch: &CudaArchitecture,
624 ) -> crate::shared::SupportedMmaCombinations {
625 M::supported_wmma_combinations(arch)
626 }
627
628 fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
629 M::supported_mma_combinations(arch)
630 }
631
632 fn supported_scaled_mma_combinations(
633 arch: &CudaArchitecture,
634 ) -> shared::SupportedScaledMmaCombinations {
635 M::supported_scaled_mma_combinations(arch)
636 }
637}
638
639impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
640 fn processors() -> Vec<Box<dyn Processor>> {
641 vec![
642 Box::new(CudaMmaProcessor),
643 Box::new(SaturatingArithmeticProcessor::new(false)),
644 ]
645 }
646}