Skip to main content

cubecl_cpp/hip/
dialect.rs

1use core::any::TypeId;
2use std::fmt::Display;
3use std::{collections::HashSet, marker::PhantomData};
4
5use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
6
7use crate::shared::DialectWarpReduceCompiler;
8use crate::{
9    Dialect,
10    shared::{
11        self, DialectBindings, DialectCubeBuiltins, DialectIncludes, DialectTypes,
12        DialectWmmaCompiler, Flags, Item, KernelArg, ManualMma,
13    },
14};
15use crate::{
16    hip::processors::HipMmaProcessor,
17    shared::{
18        Component, DialectInstructions, DialectProcessors, Elem, Instruction, Variable, unary,
19        variable_to_frag,
20    },
21};
22
23use super::Extension;
24use super::arch::AMDArchitecture;
25use super::extension::{WmmaExtension, format_f162bf16, format_max, format_min};
26use super::mma::{WmmaCast, WmmaExecute, WmmaFill, WmmaIntrinsicCompiler, WmmaLoad, WmmaStore};
27
28#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
29pub struct HipDialect<M> {
30    _wmma_compiler: PhantomData<M>,
31}
32
33// Base dialect
34
35impl<M: DialectWmmaCompiler<Self>> Dialect for HipDialect<M> {
36    type Architecture = AMDArchitecture;
37}
38
39impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for HipDialect<M> {}
40
41// Includes
42
43impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for HipDialect<M> {
44    type Extension = Extension<Self>;
45
46    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
47        f.write_str("#include <hip/hip_runtime.h>\n")?;
48        if flags.elem_bf16 {
49            f.write_str("#include <hip/hip_bf16.h>\n")?;
50        }
51        if flags.elem_f16 {
52            f.write_str("#include <hip/hip_fp16.h>\n")?;
53        }
54        if flags.inst_wmma {
55            Self::compile_wmma_includes(f, flags)?;
56        }
57        Ok(())
58    }
59
60    fn compile_extensions(
61        f: &mut std::fmt::Formatter<'_>,
62        extensions: &[Self::Extension],
63    ) -> std::fmt::Result {
64        for extension in extensions {
65            match extension {
66                Extension::F162BF16 => format_f162bf16(f)?,
67                Extension::Max(var) => format_max::<Self>(f, var)?,
68                Extension::Min(var) => format_min::<Self>(f, var)?,
69                Extension::NoExtension => {}
70                Extension::Wmma(inst) => inst.format_wmma(f)?,
71            }
72        }
73        Ok(())
74    }
75
76    fn register_instruction_extension(
77        extensions: &mut Vec<Self::Extension>,
78        instruction: &Instruction<Self>,
79    ) {
80        let mut register_extension = |extension: Self::Extension| {
81            if !extensions.contains(&extension) {
82                extensions.push(extension);
83            }
84        };
85        #[allow(clippy::single_match)]
86        match instruction {
87            shared::Instruction::<Self>::Max(op) => {
88                register_extension(Extension::Max(*op.lhs.item().elem()));
89            }
90            shared::Instruction::<Self>::Min(op) => {
91                register_extension(Extension::Min(*op.lhs.item().elem()));
92            }
93            _ => {}
94        }
95    }
96
97    fn register_warp_instruction_extension(
98        extensions: &mut Vec<Self::Extension>,
99        instruction: &shared::WarpInstruction<Self>,
100    ) {
101        let mut register_extension = |extension: Self::Extension| {
102            if !extensions.contains(&extension) {
103                extensions.push(extension);
104            }
105        };
106
107        #[allow(clippy::single_match)]
108        match instruction {
109            shared::WarpInstruction::<Self>::ReduceMax { input, .. } => {
110                let input_item = input.item();
111                let input_elem = input_item.elem();
112                if *input_elem == Elem::<Self>::BF16 {
113                    register_extension(Extension::F162BF16);
114                }
115                register_extension(Extension::Max(*input_elem));
116            }
117            shared::WarpInstruction::<Self>::ReduceMin { input, .. } => {
118                let input_item = input.item();
119                let input_elem = input_item.elem();
120                if *input_elem == Elem::<Self>::BF16 {
121                    register_extension(Extension::F162BF16);
122                }
123                register_extension(Extension::Min(*input_elem));
124            }
125            shared::WarpInstruction::<Self>::ReduceProd { input, .. } => {
126                let input_item = input.item();
127                let input_elem = input_item.elem();
128                if *input_elem == Elem::<Self>::BF16 {
129                    register_extension(Extension::F162BF16);
130                }
131            }
132            shared::WarpInstruction::<Self>::ReduceSum { input, .. } => {
133                let input_item = input.item();
134                let input_elem = input_item.elem();
135                if *input_elem == Elem::<Self>::BF16 {
136                    register_extension(Extension::F162BF16);
137                }
138            }
139            _ => {}
140        }
141    }
142
143    fn register_wmma_instruction_extension(
144        extensions: &mut Vec<Self::Extension>,
145        instruction: &shared::WmmaInstruction<Self>,
146    ) {
147        if TypeId::of::<M>() == TypeId::of::<WmmaIntrinsicCompiler>() {
148            let extension = match instruction {
149                shared::WmmaInstruction::Fill { frag, .. } => {
150                    Extension::Wmma(WmmaExtension::Fill(WmmaFill::new(variable_to_frag(frag))))
151                }
152                shared::WmmaInstruction::Load { frag, layout, .. } => Extension::Wmma(
153                    WmmaExtension::Load(WmmaLoad::new(variable_to_frag(frag), *layout)),
154                ),
155                shared::WmmaInstruction::LdMatrix { .. }
156                | shared::WmmaInstruction::StMatrix { .. } => {
157                    panic!("Invalid extension: StMatrix & LdMatrix not supported for HIP");
158                }
159                shared::WmmaInstruction::Execute {
160                    frag_a,
161                    frag_b,
162                    frag_c,
163                    frag_d,
164                    warp_size: _,
165                } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::new(
166                    variable_to_frag(frag_a),
167                    variable_to_frag(frag_b),
168                    variable_to_frag(frag_c),
169                    variable_to_frag(frag_d),
170                ))),
171                shared::WmmaInstruction::ExecuteManual {
172                    shape,
173                    frag_a,
174                    frag_c,
175                    ..
176                } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
177                    *shape,
178                    frag_a.elem(),
179                    frag_c.elem(),
180                ))),
181                shared::WmmaInstruction::ExecuteScaled { .. } => {
182                    panic!("Invalid extension: ExecuteScaled not supported for HIP");
183                }
184                shared::WmmaInstruction::Store { frag, layout, .. } => Extension::Wmma(
185                    WmmaExtension::Store(WmmaStore::new(variable_to_frag(frag), *layout)),
186                ),
187                shared::WmmaInstruction::Cast { input, output } => {
188                    Extension::Wmma(WmmaExtension::Cast(WmmaCast::new(
189                        variable_to_frag(input),
190                        variable_to_frag(output),
191                    )))
192                }
193            };
194
195            if !extensions.contains(&extension) {
196                extensions.push(extension);
197            }
198        } else if let shared::WmmaInstruction::ExecuteManual {
199            shape,
200            frag_a,
201            frag_c,
202            ..
203        } = instruction
204        {
205            let extension = Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
206                *shape,
207                frag_a.elem(),
208                frag_c.elem(),
209            )));
210
211            if !extensions.contains(&extension) {
212                extensions.push(extension);
213            }
214        }
215    }
216}
217
218// Types
219
220impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for HipDialect<M> {
221    fn item_can_be_optimized() -> bool {
222        // for now deactivate support for half2 and bfloat162 because the HIP API lack support for it.
223        false
224    }
225
226    fn compile_type_definitions(
227        f: &mut std::fmt::Formatter<'_>,
228        items: &HashSet<Item<Self>>,
229        scalars: &[(Elem<Self>, usize)],
230        info: &cubecl_core::Info,
231        flags: &Flags<Self>,
232    ) -> std::fmt::Result {
233        shared::type_definitions::<Self>(f)?;
234        shared::type_vectorized_definitions::<Self>(f, items)?;
235
236        shared::type_info_definition_sized(f, info, scalars, flags.address_type)?;
237
238        if flags.inst_wmma {
239            Self::compile_wmma_type_definitions(f, flags)?;
240        }
241
242        Ok(())
243    }
244
245    fn compile_elem(
246        f: &mut std::fmt::Formatter<'_>,
247        elem: &shared::Elem<Self>,
248        words: bool,
249    ) -> std::fmt::Result {
250        if words {
251            match elem {
252                shared::Elem::F32 => f.write_str("float"),
253                shared::Elem::F64 => f.write_str("double"),
254                shared::Elem::TF32 => f.write_str("float"),
255                shared::Elem::I8 => f.write_str("char"),
256                shared::Elem::I16 => f.write_str("short"),
257                shared::Elem::I32 => f.write_str("int"),
258                shared::Elem::I64 => f.write_str("long"),
259                shared::Elem::U8 => f.write_str("uchar"),
260                shared::Elem::U16 => f.write_str("ushort"),
261                shared::Elem::U32 => f.write_str("uint"),
262                shared::Elem::U64 => f.write_str("ulong"),
263                _ => Self::compile_elem(f, elem, false),
264            }
265        } else {
266            match elem {
267                shared::Elem::FP4(_)
268                | shared::Elem::FP4x2(_)
269                | shared::Elem::FP6(_)
270                | shared::Elem::FP6x2(_)
271                | shared::Elem::FP8(_)
272                | shared::Elem::FP8x2(_) => {
273                    f.write_str("#error FP4/FP6/FP8 not supported in HIP\n")
274                }
275                shared::Elem::F16 => f.write_str("__half"),
276                shared::Elem::F16x2 => f.write_str("__half2"),
277                shared::Elem::F32 => f.write_str("float"),
278                shared::Elem::F64 => f.write_str("double"),
279                shared::Elem::BF16 => f.write_str("__hip_bfloat16"),
280                shared::Elem::BF16x2 => f.write_str("__hip_bfloat162"),
281                shared::Elem::TF32 => f.write_str("float"),
282                shared::Elem::I8 => f.write_str("int8"),
283                shared::Elem::I16 => f.write_str("int16"),
284                shared::Elem::I32 => f.write_str("int32"),
285                shared::Elem::I64 => f.write_str("int64"),
286                shared::Elem::U8 => f.write_str("uint8"),
287                shared::Elem::U16 => f.write_str("uint16"),
288                shared::Elem::U32 => f.write_str("uint32"),
289                shared::Elem::U64 => f.write_str("uint64"),
290                shared::Elem::Bool => f.write_str("bool"),
291                shared::Elem::Barrier(_) => panic!("Barrier object not supported in HIP"),
292                shared::Elem::Atomic(inner) => inner.fmt(f),
293                shared::Elem::_Dialect(_) => Ok(()),
294            }
295        }
296    }
297
298    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
299        if 1 == item.vectorization {
300            return write!(f, "{}", item.elem);
301        }
302        if item.native {
303            // native types use the word form of types only
304            Self::compile_elem(f, &item.elem, true)?;
305            write!(f, "{}", item.vectorization)
306        } else {
307            write!(f, "{}_{}", item.elem, item.vectorization)
308        }
309    }
310
311    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312        Ok(())
313    }
314}
315
316// Kernel argument bindings
317
318impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for HipDialect<M> {
319    fn compile_kernel_signature(
320        f: &mut std::fmt::Formatter<'_>,
321        kernel_name: &str,
322        tensor_maps: &[KernelArg<Self>],
323        buffers: &[KernelArg<Self>],
324        flags: &Flags<Self>,
325    ) -> std::fmt::Result {
326        write!(
327            f,
328            "
329
330extern \"C\" __global__ void __launch_bounds__({}) {kernel_name}(
331",
332            flags.cube_dim.num_elems()
333        )?;
334        shared::compile_bindings::<Self>(f, tensor_maps, buffers, flags.has_info)?;
335        shared::compile_info_dynamic::<Self>(f, flags)?;
336        f.write_str("\n)")?;
337
338        Ok(())
339    }
340
341    fn compile_bindings_body(
342        f: &mut std::fmt::Formatter<'_>,
343        body: &shared::Body<Self>,
344    ) -> std::fmt::Result {
345        if !body.shared_memories.is_empty() {
346            let max_align = body
347                .shared_memories
348                .iter()
349                .map(|smem| smem.align())
350                .max()
351                .unwrap();
352            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
353            // with `extern __shared__ alignas` and doesn't properly parse it.
354            writeln!(
355                f,
356                "extern __shared__ __align__({max_align}) uchar dynamic_shared_mem[];"
357            )?;
358        }
359        if body.info_by_ptr {
360            f.write_str("const info_st& info = *info_ptr;\n")?;
361            // Could use `info_ptr + 1` but that seems dirty, so use manual `sizeof` instead
362            writeln!(
363                f,
364                "const {addr}* dynamic_meta = reinterpret_cast<const {addr}*>(
365                    reinterpret_cast<const char*>(info_ptr) + sizeof(info_st)
366                );\n",
367                addr = body.address_type,
368            )?;
369        }
370        Ok(())
371    }
372}
373
374// Cube builtins dialect
375
376impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for HipDialect<M> {}
377
378// Instructions
379
380impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for HipDialect<M> {
381    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382        writeln!(f, "__syncthreads();\n")
383    }
384
385    fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        writeln!(f, "#error Sync warp is unimplemented on hip\n")
387    }
388
389    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        writeln!(f, "__threadfence();")
391    }
392
393    // unary
394    fn compile_instruction_find_first_set<T: Component<Self>>(
395        f: &mut std::fmt::Formatter<'_>,
396        input: T,
397        out_elem: Elem<Self>,
398    ) -> std::fmt::Result {
399        write!(f, "{out_elem}(")?;
400        match input.elem() {
401            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
402            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
403            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::U32),
404        }?;
405        write!(f, ")")
406    }
407
408    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
409        f: &mut std::fmt::Formatter<'_>,
410        input: T,
411        out_elem: Elem<Self>,
412    ) -> std::fmt::Result {
413        write!(f, "{out_elem}(")?;
414        match input.elem() {
415            Elem::I32 | Elem::U32 => write!(f, "__clz({input})"),
416            Elem::I64 | Elem::U64 => write!(f, "__clzll({input})"),
417            in_elem => write!(
418                f,
419                "__clz({}) - {}",
420                unary::zero_extend(input),
421                (size_of::<u32>() - in_elem.size()) * 8
422            ),
423        }?;
424        write!(f, ")")
425    }
426
427    fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
428        f: &mut std::fmt::Formatter<'_>,
429        input: T,
430        out_elem: Elem<Self>,
431    ) -> std::fmt::Result {
432        // trailing_zeros = ffs - 1 for non-zero, or bit_width for zero
433        // __ffs returns 1-based index of least significant set bit, or 0 if input is 0
434        write!(f, "{out_elem}(")?;
435        match input.elem() {
436            Elem::I32 | Elem::U32 => {
437                write!(f, "({input} == 0 ? 32 : __ffs({input}) - 1)")
438            }
439            Elem::I64 | Elem::U64 => {
440                write!(f, "({input} == 0 ? 64 : __ffsll({input}) - 1)")
441            }
442            in_elem => {
443                let bits = in_elem.size() * 8;
444                let extended = unary::zero_extend(input);
445                write!(f, "({extended} == 0 ? {bits} : __ffs({extended}) - 1)")
446            }
447        }?;
448        write!(f, ")")
449    }
450
451    fn compile_saturating_add(
452        f: &mut std::fmt::Formatter<'_>,
453        _lhs: impl Display,
454        _rhs: impl Display,
455        _item: Item<Self>,
456    ) -> std::fmt::Result {
457        f.write_str(
458            "#error No native saturating add exists, TODO: Should be replaced in a preprocessor\n",
459        )
460    }
461
462    fn compile_saturating_sub(
463        f: &mut std::fmt::Formatter<'_>,
464        _lhs: impl Display,
465        _rhs: impl Display,
466        _item: Item<Self>,
467    ) -> std::fmt::Result {
468        f.write_str(
469            "#error No native saturating sub exists, TODO: Should be replaced in a preprocessor\n",
470        )
471    }
472
473    // others
474    fn compile_instruction_max_function_name(
475        f: &mut std::fmt::Formatter<'_>,
476        item: Item<Self>,
477    ) -> std::fmt::Result {
478        let max = match item.elem() {
479            Elem::F16 => "__hmax",
480            Elem::BF16 => "__hmax",
481            _ => "max",
482        };
483        write!(f, "{max}")
484    }
485
486    fn compile_instruction_min_function_name(
487        f: &mut std::fmt::Formatter<'_>,
488        item: Item<Self>,
489    ) -> std::fmt::Result {
490        let min = match item.elem() {
491            Elem::F16 => "__hmin",
492            Elem::BF16 => "__hmin",
493            _ => "min",
494        };
495        write!(f, "{min}")
496    }
497
498    // Warp
499    fn compile_warp_shuffle(
500        f: &mut std::fmt::Formatter<'_>,
501        var: &str,
502        source: &str,
503    ) -> std::fmt::Result {
504        write!(f, "__shfl({var}, {source})")
505    }
506    fn compile_warp_shuffle_xor(
507        f: &mut std::fmt::Formatter<'_>,
508        var: &str,
509        elem: &Elem<Self>,
510        offset: &str,
511    ) -> std::fmt::Result {
512        match elem {
513            Elem::BF16 => write!(
514                f,
515                "half_to_bfloat16(__shfl_xor(reinterpret_cast<__half&>({var}), {offset}))"
516            ),
517            _ => write!(f, "__shfl_xor({var}, {offset})"),
518        }
519    }
520    fn compile_warp_shuffle_up(
521        f: &mut std::fmt::Formatter<'_>,
522        var: &str,
523        offset: &str,
524    ) -> std::fmt::Result {
525        write!(f, "__shfl_up({var}, {offset})")
526    }
527    fn compile_warp_shuffle_down(
528        f: &mut std::fmt::Formatter<'_>,
529        var: &str,
530        offset: &str,
531    ) -> std::fmt::Result {
532        write!(f, "__shfl_down({var}, {offset})")
533    }
534    fn compile_warp_all<T: Component<Self>>(
535        f: &mut std::fmt::Formatter<'_>,
536        input: &T,
537    ) -> std::fmt::Result {
538        let item = input.item();
539        let elem = item.elem;
540        write!(f, "static_cast<{elem}>(__all({input}))")
541    }
542    fn compile_warp_any<T: Component<Self>>(
543        f: &mut std::fmt::Formatter<'_>,
544        input: &T,
545    ) -> std::fmt::Result {
546        let item = input.item();
547        let elem = item.elem;
548        write!(f, "static_cast<{elem}>(__any({input}))")
549    }
550    fn compile_warp_ballot(
551        f: &mut std::fmt::Formatter<'_>,
552        input: &Variable<Self>,
553        out_elem: &Elem<Self>,
554    ) -> std::fmt::Result {
555        write!(f, "{out_elem}(__ballot({input}))")
556    }
557
558    fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559        write!(f, "__builtin_unreachable();")
560    }
561}
562
563// Coop Matrices dialect
564
565impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for HipDialect<M> {
566    fn compile_wmma_includes(
567        f: &mut std::fmt::Formatter<'_>,
568        flags: &Flags<Self>,
569    ) -> std::fmt::Result {
570        M::compile_wmma_includes(f, flags)
571    }
572
573    fn compile_wmma_type_definitions(
574        f: &mut std::fmt::Formatter<'_>,
575        flags: &Flags<Self>,
576    ) -> std::fmt::Result {
577        M::compile_wmma_type_definitions(f, flags)
578    }
579
580    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581        M::compile_wmma_local_variables(f)
582    }
583
584    fn compile_wmma_fragment_declaration(
585        f: &mut std::fmt::Formatter<'_>,
586        var: &Variable<Self>,
587    ) -> std::fmt::Result {
588        M::compile_wmma_fragment_declaration(f, var)
589    }
590
591    fn compile_wwma_fragment_ident(
592        f: &mut std::fmt::Formatter<'_>,
593        ident: &crate::shared::FragmentIdent<Self>,
594    ) -> std::fmt::Result {
595        M::compile_wwma_fragment_ident(f, ident)
596    }
597
598    fn compile_wmma_fragment_layout(
599        f: &mut std::fmt::Formatter<'_>,
600        layout: &crate::shared::FragmentLayout<Self>,
601    ) -> std::fmt::Result {
602        M::compile_wmma_fragment_layout(f, layout)
603    }
604
605    fn compile_wmma_fragment(
606        f: &mut std::fmt::Formatter<'_>,
607        fragment: &crate::shared::Fragment<Self>,
608    ) -> std::fmt::Result {
609        M::compile_wmma_fragment(f, fragment)
610    }
611
612    fn compile_wmma_instruction(
613        f: &mut std::fmt::Formatter<'_>,
614        instruction: &crate::shared::WmmaInstruction<Self>,
615    ) -> std::fmt::Result {
616        M::compile_wmma_instruction(f, instruction)
617    }
618
619    fn compile_manual_mma(
620        f: &mut std::fmt::Formatter<'_>,
621        mma: ManualMma<Self>,
622    ) -> std::fmt::Result {
623        M::compile_manual_mma(f, mma)
624    }
625
626    fn supported_wmma_combinations(
627        arch: &AMDArchitecture,
628    ) -> crate::shared::SupportedMmaCombinations {
629        M::supported_wmma_combinations(arch)
630    }
631
632    fn supported_mma_combinations(arch: &AMDArchitecture) -> shared::SupportedMmaCombinations {
633        M::supported_mma_combinations(arch)
634    }
635
636    fn compile_scaled_mma(
637        _f: &mut std::fmt::Formatter<'_>,
638        _mma: ManualMma<Self>,
639        _scales_a: Variable<Self>,
640        _scales_b: Variable<Self>,
641        _scales_factor: u32,
642    ) -> std::fmt::Result {
643        panic!("Scaled MMA not supporter in HIP")
644    }
645}
646
647impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for HipDialect<M> {
648    fn processors() -> Vec<Box<dyn Processor>> {
649        vec![
650            Box::new(HipMmaProcessor),
651            Box::new(SaturatingArithmeticProcessor::new(true)),
652        ]
653    }
654}