cubecl_cpp/cuda/
dialect.rs

1use std::{collections::HashSet, fmt::Display, marker::PhantomData};
2
3use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
4
5use crate::{
6    Dialect,
7    cuda::{
8        extension::{Fragment, MmaExecute, MmaExecuteScaled, MmaExtension},
9        processors::CudaMmaProcessor,
10        ptx::*,
11    },
12    shared::{
13        self, Binding, Component, DialectBindings, DialectCubeBuiltins, DialectIncludes,
14        DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler,
15        DialectWmmaCompiler, Elem, FP4Kind, FP6Kind, FP8Kind, Flags, Instruction, Item, ManualMma,
16        Variable, WarpInstruction, unary,
17    },
18};
19
20use super::{Extension, arch::CudaArchitecture};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaDialect<M> {
24    _wmma_compiler: PhantomData<M>,
25}
26
27impl<M: DialectWmmaCompiler<Self>> Dialect for CudaDialect<M> {
28    type Architecture = CudaArchitecture;
29}
30
31impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for CudaDialect<M> {
32    type Extension = Extension<Self>;
33
34    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
35        f.write_str("#include <cuda_runtime.h>\n")?;
36        if flags.elem_fp4 {
37            f.write_str("#include <cuda_fp4.h>\n")?;
38        }
39        if flags.elem_fp6 {
40            f.write_str("#include <cuda_fp6.h>\n")?;
41        }
42        if flags.elem_fp8 {
43            f.write_str("#include <cuda_fp8.h>\n")?;
44        }
45        if flags.elem_bf16 {
46            f.write_str("#include <cuda_bf16.h>\n")?;
47        }
48        if flags.elem_f16 {
49            f.write_str("#include <cuda_fp16.h>\n")?;
50        }
51
52        // tf32 conversion function is in mma header
53        if flags.inst_wmma || flags.elem_tf32 {
54            Self::compile_wmma_includes(f, flags)?;
55        }
56
57        if flags.op_pipeline {
58            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
59            f.write_str("#include <cuda/pipeline>\n")?;
60        }
61        if flags.op_barrier || flags.inst_tma || flags.indexes.cluster_pos {
62            f.write_str("#include <cooperative_groups.h>\n")?;
63            f.write_str("#include <cooperative_groups/memcpy_async.h>\n")?;
64            f.write_str("#include <cuda/barrier>\n")?;
65        }
66        if flags.inst_tma {
67            f.write_str(
68                "typedef struct CUtensorMap_st {
69alignas(64) unsigned long long int opaque[16];
70} CUtensorMap;\n",
71            )?;
72        }
73        Ok(())
74    }
75
76    fn compile_extensions(
77        f: &mut std::fmt::Formatter<'_>,
78        extensions: &[Self::Extension],
79    ) -> std::fmt::Result {
80        for extension in extensions {
81            match extension {
82                Extension::NoExtension => {}
83                Extension::Mma(mma) => mma.format_extension(f)?,
84            }
85        }
86        Ok(())
87    }
88
89    fn register_instruction_extension(
90        _extensions: &mut Vec<Self::Extension>,
91        _instruction: &Instruction<Self>,
92    ) {
93    }
94
95    fn register_warp_instruction_extension(
96        _extensions: &mut Vec<Self::Extension>,
97        _instruction: &WarpInstruction<Self>,
98    ) {
99    }
100
101    fn register_wmma_instruction_extension(
102        extensions: &mut Vec<Self::Extension>,
103        instruction: &shared::WmmaInstruction<Self>,
104    ) {
105        match instruction {
106            shared::WmmaInstruction::ExecuteManual {
107                shape,
108                frag_a,
109                frag_b,
110                frag_c,
111                frag_d,
112            } => {
113                let ext = Extension::Mma(MmaExtension::Execute(MmaExecute::new(
114                    *shape,
115                    vars_to_frag(frag_a),
116                    vars_to_frag(frag_b),
117                    vars_to_frag(frag_c),
118                    Fragment(frag_d.elem()),
119                )));
120                if !extensions.contains(&ext) {
121                    extensions.push(ext);
122                }
123            }
124            shared::WmmaInstruction::ExecuteScaled {
125                shape,
126                frag_a,
127                frag_b,
128                frag_c,
129                frag_d,
130                scales_a,
131                scales_factor,
132                ..
133            } => {
134                let ext = Extension::Mma(MmaExtension::ExecuteScaled(MmaExecuteScaled::new(
135                    *shape,
136                    vars_to_frag(frag_a),
137                    vars_to_frag(frag_b),
138                    vars_to_frag(frag_c),
139                    Fragment(frag_d.elem()),
140                    scales_a.elem(),
141                    *scales_factor,
142                )));
143                if !extensions.contains(&ext) {
144                    extensions.push(ext);
145                }
146            }
147            _ => {}
148        }
149    }
150}
151
152fn vars_to_frag<D: Dialect>(vars: &[Variable<D>]) -> Fragment<D> {
153    let elem = vars[0].elem();
154    Fragment(elem)
155}
156
157// Types
158
159impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for CudaDialect<M> {
160    fn item_can_be_optimized() -> bool {
161        true
162    }
163
164    fn compile_type_definitions(
165        f: &mut std::fmt::Formatter<'_>,
166        items: &HashSet<Item<Self>>,
167        scalars: &[(Elem<Self>, usize)],
168        flags: &Flags,
169    ) -> std::fmt::Result {
170        // All FP4/FP6/FP8 elems map to the same type, so we need to deduplicate them
171        let mut items_deduplicated = HashSet::new();
172
173        for item in items {
174            let mut item = *item;
175            match item.elem() {
176                Elem::FP4(_) => {
177                    item.elem = Elem::FP4(FP4Kind::E2M1);
178                }
179                Elem::FP4x2(_) => {
180                    item.elem = Elem::FP4x2(FP4Kind::E2M1);
181                }
182                Elem::FP6(_) => {
183                    item.elem = Elem::FP6(FP6Kind::E2M3);
184                }
185                Elem::FP6x2(_) => {
186                    item.elem = Elem::FP6x2(FP6Kind::E2M3);
187                }
188                Elem::FP8(_) => {
189                    item.elem = Elem::FP8(FP8Kind::E4M3);
190                }
191                Elem::FP8x2(_) => {
192                    item.elem = Elem::FP8x2(FP8Kind::E4M3);
193                }
194                _ => {}
195            }
196            items_deduplicated.insert(item);
197        }
198
199        shared::type_definitions::<Self>(f)?;
200        shared::type_vectorized_definitions::<Self>(f, &items_deduplicated)?;
201
202        if flags.use_grid_constants {
203            shared::type_scalar_definitions::<Self>(f, scalars)?;
204            shared::type_info_definition::<Self>(f, flags.static_meta_length)?;
205        }
206
207        if flags.inst_wmma {
208            Self::compile_wmma_type_definitions(f, flags)?;
209        }
210
211        Ok(())
212    }
213
214    fn compile_polyfills(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
215        if flags.inst_tma_im2col {
216            writeln!(f, "{TMA_LOAD_IM2COL}")?;
217        }
218        Ok(())
219    }
220
221    fn compile_elem(
222        f: &mut std::fmt::Formatter<'_>,
223        elem: &shared::Elem<Self>,
224        words: bool,
225    ) -> std::fmt::Result {
226        if words {
227            match elem {
228                shared::Elem::F32 => f.write_str("float"),
229                shared::Elem::F64 => f.write_str("double"),
230                shared::Elem::TF32 => f.write_str("float"),
231                shared::Elem::I8 => f.write_str("char"),
232                shared::Elem::I16 => f.write_str("short"),
233                shared::Elem::I32 => f.write_str("int"),
234                shared::Elem::I64 => f.write_str("long"),
235                shared::Elem::U8 => f.write_str("uchar"),
236                shared::Elem::U16 => f.write_str("ushort"),
237                shared::Elem::U32 => f.write_str("uint"),
238                shared::Elem::U64 => f.write_str("ulong"),
239                _ => Self::compile_elem(f, elem, false),
240            }
241        } else {
242            match elem {
243                shared::Elem::FP4(_) => write!(f, "__nv_fp4_storage_t"),
244                shared::Elem::FP4x2(_) => write!(f, "__nv_fp4x2_storage_t"),
245                shared::Elem::FP6(_) => write!(f, "__nv_fp6_storage_t"),
246                shared::Elem::FP6x2(_) => write!(f, "__nv_fp6x2_storage_t"),
247                shared::Elem::FP8(_) => write!(f, "__nv_fp8_storage_t"),
248                shared::Elem::FP8x2(_) => write!(f, "__nv_fp8x2_storage_t"),
249                shared::Elem::F16 => f.write_str("__half"),
250                shared::Elem::F16x2 => f.write_str("__half2"),
251                shared::Elem::F32 => f.write_str("float"),
252                shared::Elem::F64 => f.write_str("double"),
253                shared::Elem::BF16 => f.write_str("__nv_bfloat16"),
254                shared::Elem::BF16x2 => f.write_str("__nv_bfloat162"),
255                shared::Elem::TF32 => f.write_str("float"),
256                shared::Elem::I8 => f.write_str("int8"),
257                shared::Elem::I16 => f.write_str("int16"),
258                shared::Elem::I32 => f.write_str("int32"),
259                shared::Elem::I64 => f.write_str("int64"),
260                shared::Elem::U8 => f.write_str("uint8"),
261                shared::Elem::U16 => f.write_str("uint16"),
262                shared::Elem::U32 => f.write_str("uint32"),
263                shared::Elem::U64 => f.write_str("uint64"),
264                shared::Elem::Bool => f.write_str("bool"),
265                shared::Elem::Atomic(inner) => write!(f, "{inner}"),
266                shared::Elem::_Dialect(_) => Ok(()),
267            }
268        }
269    }
270
271    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
272        if 1 == item.vectorization {
273            return write!(f, "{}", item.elem);
274        }
275        if item.native {
276            // native types use the word form of types only
277            Self::compile_elem(f, &item.elem, true)?;
278            write!(f, "{}", item.vectorization)
279        } else {
280            write!(f, "{}_{}", item.elem, item.vectorization)
281        }
282    }
283
284    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        Ok(())
286    }
287}
288
289// Kernel argument bindings
290
291impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for CudaDialect<M> {
292    fn compile_kernel_signature(
293        f: &mut std::fmt::Formatter<'_>,
294        kernel_name: &str,
295        tensor_maps: &[Binding<Self>],
296        buffers: &[Binding<Self>],
297        scalars: &[(Elem<Self>, usize)],
298        flags: &Flags,
299    ) -> std::fmt::Result {
300        write!(
301            f,
302            "
303
304extern \"C\" __global__ void __launch_bounds__({})",
305            flags.cube_dim.num_elems()
306        )?;
307        if let Some(cluster_dim) = flags.cluster_dim {
308            write!(
309                f,
310                "__cluster_dims__({}, {}, {}) ",
311                cluster_dim.x, cluster_dim.y, cluster_dim.z
312            )?;
313        }
314        writeln!(f, "{kernel_name} (")?;
315        let has_scalars =
316            !scalars.is_empty() || (flags.use_grid_constants && flags.static_meta_length > 0);
317        shared::compile_bindings(f, tensor_maps, buffers, has_scalars, flags)?;
318        if flags.use_grid_constants {
319            shared::compile_scalars_static(f, scalars, flags)?;
320        } else {
321            shared::compile_scalars_dynamic(f, scalars)?;
322        }
323        f.write_str("\n)")?;
324        //
325        Ok(())
326    }
327
328    fn compile_bindings_body(
329        f: &mut std::fmt::Formatter<'_>,
330        body: &shared::Body<Self>,
331    ) -> std::fmt::Result {
332        if !body.shared_memories.is_empty() {
333            let max_align = body
334                .shared_memories
335                .iter()
336                .map(|smem| smem.align)
337                .max()
338                .unwrap();
339            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
340            // with `extern __shared__ alignas` and doesn't properly parse it.
341            writeln!(
342                f,
343                "extern __shared__ __align__({max_align}) uint8 dynamic_shared_mem[];"
344            )?;
345        }
346        Ok(())
347    }
348}
349
350impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for CudaDialect<M> {}
351
352// Cube builtins dialect
353
354impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for CudaDialect<M> {
355    fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        write!(f, "cluster.block_rank()")
357    }
358
359    fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        write!(f, "cluster.block_index().x")
361    }
362
363    fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        write!(f, "cluster.block_index().y")
365    }
366
367    fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        write!(f, "cluster.block_index().z")
369    }
370}
371
372// Instructions
373
374impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for CudaDialect<M> {
375    // sync
376    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377        writeln!(f, "__syncthreads();\n")
378    }
379
380    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381        writeln!(f, "__syncwarp();\n")
382    }
383
384    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        writeln!(f, "__threadfence();")
386    }
387
388    // unary
389    fn compile_instruction_find_first_set<T: Component<Self>>(
390        f: &mut std::fmt::Formatter<'_>,
391        input: T,
392        out_elem: Elem<Self>,
393    ) -> std::fmt::Result {
394        write!(f, "{out_elem}(")?;
395        match input.elem() {
396            Elem::I32 => write!(f, "__ffs({input})"),
397            Elem::U32 => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
398            Elem::I64 => write!(f, "__ffsll({input})"),
399            Elem::U64 => write!(f, "__ffsll({}({input}))", Elem::<Self>::I64),
400            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::I32),
401        }?;
402        write!(f, ")")
403    }
404
405    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
406        f: &mut std::fmt::Formatter<'_>,
407        input: T,
408        out_elem: Elem<Self>,
409    ) -> std::fmt::Result {
410        write!(f, "{out_elem}(")?;
411        match input.elem() {
412            Elem::I32 => write!(f, "__clz({input})"),
413            Elem::U32 => write!(f, "__clz({}({input}))", Elem::<Self>::I32),
414            Elem::I64 => write!(f, "__clzll({input})"),
415            Elem::U64 => write!(f, "__clzll({}({input}))", Elem::<Self>::I64),
416            in_elem => write!(
417                f,
418                "{out_elem}(__clz({}) - {})",
419                unary::zero_extend(input),
420                (size_of::<u32>() - in_elem.size()) * 8
421            ),
422        }?;
423        write!(f, ")")
424    }
425
426    fn compile_saturating_add(
427        f: &mut std::fmt::Formatter<'_>,
428        lhs: impl Display,
429        rhs: impl Display,
430        item: Item<Self>,
431    ) -> std::fmt::Result {
432        let elem = item.elem();
433        match elem {
434            Elem::I32 => {
435                write!(
436                    f,
437                    r#"[&]() -> {elem} {{
438    {elem} result;
439    asm("add.sat.s32 %0, %1, %2;"
440        : "=r"(result)
441        : "r"({lhs}), "r"({rhs}));
442    return result;
443        }}()"#
444                )
445            }
446            _ => unreachable!("Should be replaced by polyfill"),
447        }
448    }
449
450    fn compile_saturating_sub(
451        f: &mut std::fmt::Formatter<'_>,
452        lhs: impl Display,
453        rhs: impl Display,
454        item: Item<Self>,
455    ) -> std::fmt::Result {
456        let elem = item.elem();
457        // Native instruction only exists for signed int, unsigned should be removed in a preprocessor
458        match elem {
459            Elem::I32 => {
460                write!(
461                    f,
462                    r#"[&]() -> {elem} {{
463    {elem} result;
464    asm("sub.sat.s32 %0, %1, %2;"
465        : "=r"(result)
466        : "r"({lhs}), "r"({rhs}));
467    return result;
468        }}()"#
469                )
470            }
471            _ => unreachable!("Should be replaced by polyfill"),
472        }
473    }
474
475    // others
476    fn compile_instruction_max_function_name(
477        f: &mut std::fmt::Formatter<'_>,
478        item: Item<Self>,
479    ) -> std::fmt::Result {
480        let max = match item.elem() {
481            Elem::F16 | Elem::BF16 => "__hmax",
482            Elem::F16x2 | Elem::BF16x2 => "__hmax2",
483            _ => "max",
484        };
485        write!(f, "{max}")
486    }
487
488    fn compile_instruction_min_function_name(
489        f: &mut std::fmt::Formatter<'_>,
490        item: Item<Self>,
491    ) -> std::fmt::Result {
492        let min = match item.elem() {
493            Elem::F16 | Elem::BF16 => "__hmin",
494            Elem::F16x2 | Elem::BF16x2 => "__hmin2",
495            _ => "min",
496        };
497        write!(f, "{min}")
498    }
499
500    // warp
501    fn compile_warp_shuffle(
502        f: &mut std::fmt::Formatter<'_>,
503        var: &str,
504        source: &str,
505    ) -> std::fmt::Result {
506        write!(f, "__shfl_sync(-1, {var}, {source})")
507    }
508    fn compile_warp_shuffle_xor(
509        f: &mut std::fmt::Formatter<'_>,
510        var: &str,
511        _elem: &Elem<Self>,
512        offset: &str,
513    ) -> std::fmt::Result {
514        write!(f, "__shfl_xor_sync(-1, {var}, {offset})")
515    }
516    fn compile_warp_shuffle_up(
517        f: &mut std::fmt::Formatter<'_>,
518        var: &str,
519        offset: &str,
520    ) -> std::fmt::Result {
521        write!(f, "__shfl_up_sync(-1, {var}, {offset})")
522    }
523    fn compile_warp_shuffle_down(
524        f: &mut std::fmt::Formatter<'_>,
525        var: &str,
526        offset: &str,
527    ) -> std::fmt::Result {
528        write!(f, "__shfl_down_sync(-1, {var}, {offset})")
529    }
530    fn compile_warp_all<T: Component<Self>>(
531        f: &mut std::fmt::Formatter<'_>,
532        input: &T,
533    ) -> std::fmt::Result {
534        write!(f, "__all_sync(-1, {input})")
535    }
536    fn compile_warp_any<T: Component<Self>>(
537        f: &mut std::fmt::Formatter<'_>,
538        input: &T,
539    ) -> std::fmt::Result {
540        write!(f, "__any_sync(-1, {input})")
541    }
542
543    fn compile_warp_ballot(
544        f: &mut std::fmt::Formatter<'_>,
545        input: &Variable<Self>,
546        _out_elem: &Elem<Self>,
547    ) -> std::fmt::Result {
548        write!(f, "__ballot_sync(-1, {input})")
549    }
550}
551
552// Coop Matrices dialect
553
554impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for CudaDialect<M> {
555    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
556        M::compile_wmma_includes(f, flags)
557    }
558
559    fn compile_wmma_type_definitions(
560        f: &mut std::fmt::Formatter<'_>,
561        flags: &Flags,
562    ) -> std::fmt::Result {
563        M::compile_wmma_type_definitions(f, flags)
564    }
565
566    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
567        M::compile_wmma_local_variables(f)
568    }
569
570    fn compile_wmma_fragment_declaration(
571        f: &mut std::fmt::Formatter<'_>,
572        var: &Variable<Self>,
573    ) -> std::fmt::Result {
574        M::compile_wmma_fragment_declaration(f, var)
575    }
576
577    fn compile_wwma_fragment_ident(
578        f: &mut std::fmt::Formatter<'_>,
579        ident: &crate::shared::FragmentIdent<Self>,
580    ) -> std::fmt::Result {
581        M::compile_wwma_fragment_ident(f, ident)
582    }
583
584    fn compile_wmma_fragment_layout(
585        f: &mut std::fmt::Formatter<'_>,
586        layout: &crate::shared::FragmentLayout<Self>,
587    ) -> std::fmt::Result {
588        M::compile_wmma_fragment_layout(f, layout)
589    }
590
591    fn compile_wmma_fragment(
592        f: &mut std::fmt::Formatter<'_>,
593        fragment: &crate::shared::Fragment<Self>,
594    ) -> std::fmt::Result {
595        M::compile_wmma_fragment(f, fragment)
596    }
597
598    fn compile_wmma_instruction(
599        f: &mut std::fmt::Formatter<'_>,
600        instruction: &crate::shared::WmmaInstruction<Self>,
601    ) -> std::fmt::Result {
602        M::compile_wmma_instruction(f, instruction)
603    }
604
605    fn compile_manual_mma(
606        f: &mut std::fmt::Formatter<'_>,
607        mma: ManualMma<Self>,
608    ) -> std::fmt::Result {
609        M::compile_manual_mma(f, mma)
610    }
611
612    fn compile_scaled_mma(
613        f: &mut std::fmt::Formatter<'_>,
614        mma: ManualMma<Self>,
615        scales_a: Variable<Self>,
616        scales_b: Variable<Self>,
617        scales_factor: u32,
618    ) -> std::fmt::Result {
619        M::compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
620    }
621
622    fn supported_wmma_combinations(
623        arch: &CudaArchitecture,
624    ) -> crate::shared::SupportedMmaCombinations {
625        M::supported_wmma_combinations(arch)
626    }
627
628    fn supported_mma_combinations(arch: &CudaArchitecture) -> shared::SupportedMmaCombinations {
629        M::supported_mma_combinations(arch)
630    }
631
632    fn supported_scaled_mma_combinations(
633        arch: &CudaArchitecture,
634    ) -> shared::SupportedScaledMmaCombinations {
635        M::supported_scaled_mma_combinations(arch)
636    }
637}
638
639impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for CudaDialect<M> {
640    fn processors() -> Vec<Box<dyn Processor>> {
641        vec![
642            Box::new(CudaMmaProcessor),
643            Box::new(SaturatingArithmeticProcessor::new(false)),
644        ]
645    }
646}