cubecl_cpp/hip/
dialect.rs

1use core::any::TypeId;
2use std::fmt::Display;
3use std::{collections::HashSet, marker::PhantomData};
4
5use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
6
7use crate::shared::DialectWarpReduceCompiler;
8use crate::{
9    Dialect,
10    shared::{
11        self, Binding, DialectBindings, DialectCubeBuiltins, DialectIncludes, DialectTypes,
12        DialectWmmaCompiler, Flags, Item, ManualMma,
13    },
14};
15use crate::{
16    hip::processors::HipMmaProcessor,
17    shared::{
18        Component, DialectInstructions, DialectProcessors, Elem, Instruction, Variable, unary,
19        variable_to_frag,
20    },
21};
22
23use super::Extension;
24use super::arch::AMDArchitecture;
25use super::extension::{WmmaExtension, format_f162bf16, format_max, format_min};
26use super::mma::{WmmaCast, WmmaExecute, WmmaFill, WmmaIntrinsicCompiler, WmmaLoad, WmmaStore};
27
28#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
29pub struct HipDialect<M> {
30    _wmma_compiler: PhantomData<M>,
31}
32
33// Base dialect
34
35impl<M: DialectWmmaCompiler<Self>> Dialect for HipDialect<M> {
36    type Architecture = AMDArchitecture;
37}
38
39impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for HipDialect<M> {}
40
41// Includes
42
43impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for HipDialect<M> {
44    type Extension = Extension<Self>;
45
46    fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
47        f.write_str("#include <hip/hip_runtime.h>\n")?;
48        if flags.elem_bf16 {
49            f.write_str("#include <hip/hip_bf16.h>\n")?;
50        }
51        if flags.elem_f16 {
52            f.write_str("#include <hip/hip_fp16.h>\n")?;
53        }
54        if flags.inst_wmma {
55            Self::compile_wmma_includes(f, flags)?;
56        }
57        Ok(())
58    }
59
60    fn compile_extensions(
61        f: &mut std::fmt::Formatter<'_>,
62        extensions: &[Self::Extension],
63    ) -> std::fmt::Result {
64        for extension in extensions {
65            match extension {
66                Extension::F162BF16 => format_f162bf16(f)?,
67                Extension::Max(var) => format_max::<Self>(f, var)?,
68                Extension::Min(var) => format_min::<Self>(f, var)?,
69                Extension::NoExtension => {}
70                Extension::Wmma(inst) => inst.format_wmma(f)?,
71            }
72        }
73        Ok(())
74    }
75
76    fn register_instruction_extension(
77        extensions: &mut Vec<Self::Extension>,
78        instruction: &Instruction<Self>,
79    ) {
80        let mut register_extension = |extension: Self::Extension| {
81            if !extensions.contains(&extension) {
82                extensions.push(extension);
83            }
84        };
85        #[allow(clippy::single_match)]
86        match instruction {
87            shared::Instruction::<Self>::Max(op) => {
88                register_extension(Extension::Max(*op.lhs.item().elem()));
89            }
90            shared::Instruction::<Self>::Min(op) => {
91                register_extension(Extension::Min(*op.lhs.item().elem()));
92            }
93            _ => {}
94        }
95    }
96
97    fn register_warp_instruction_extension(
98        extensions: &mut Vec<Self::Extension>,
99        instruction: &shared::WarpInstruction<Self>,
100    ) {
101        let mut register_extension = |extension: Self::Extension| {
102            if !extensions.contains(&extension) {
103                extensions.push(extension);
104            }
105        };
106
107        #[allow(clippy::single_match)]
108        match instruction {
109            shared::WarpInstruction::<Self>::ReduceMax { input, .. } => {
110                let input_item = input.item();
111                let input_elem = input_item.elem();
112                if *input_elem == Elem::<Self>::BF16 {
113                    register_extension(Extension::F162BF16);
114                }
115                register_extension(Extension::Max(*input_elem));
116            }
117            shared::WarpInstruction::<Self>::ReduceMin { input, .. } => {
118                let input_item = input.item();
119                let input_elem = input_item.elem();
120                if *input_elem == Elem::<Self>::BF16 {
121                    register_extension(Extension::F162BF16);
122                }
123                register_extension(Extension::Min(*input_elem));
124            }
125            shared::WarpInstruction::<Self>::ReduceProd { input, .. } => {
126                let input_item = input.item();
127                let input_elem = input_item.elem();
128                if *input_elem == Elem::<Self>::BF16 {
129                    register_extension(Extension::F162BF16);
130                }
131            }
132            shared::WarpInstruction::<Self>::ReduceSum { input, .. } => {
133                let input_item = input.item();
134                let input_elem = input_item.elem();
135                if *input_elem == Elem::<Self>::BF16 {
136                    register_extension(Extension::F162BF16);
137                }
138            }
139            _ => {}
140        }
141    }
142
143    fn register_wmma_instruction_extension(
144        extensions: &mut Vec<Self::Extension>,
145        instruction: &shared::WmmaInstruction<Self>,
146    ) {
147        if TypeId::of::<M>() == TypeId::of::<WmmaIntrinsicCompiler>() {
148            let extension = match instruction {
149                shared::WmmaInstruction::Fill { frag, .. } => {
150                    Extension::Wmma(WmmaExtension::Fill(WmmaFill::new(variable_to_frag(frag))))
151                }
152                shared::WmmaInstruction::Load { frag, layout, .. } => Extension::Wmma(
153                    WmmaExtension::Load(WmmaLoad::new(variable_to_frag(frag), *layout)),
154                ),
155                shared::WmmaInstruction::Execute {
156                    frag_a,
157                    frag_b,
158                    frag_c,
159                    frag_d,
160                    warp_size: _,
161                } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::new(
162                    variable_to_frag(frag_a),
163                    variable_to_frag(frag_b),
164                    variable_to_frag(frag_c),
165                    variable_to_frag(frag_d),
166                ))),
167                shared::WmmaInstruction::ExecuteManual {
168                    shape,
169                    frag_a,
170                    frag_c,
171                    ..
172                } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
173                    *shape,
174                    frag_a[0].elem(),
175                    frag_c[0].elem(),
176                ))),
177                shared::WmmaInstruction::ExecuteScaled { .. } => {
178                    unimplemented!("Not supported in HIP")
179                }
180                shared::WmmaInstruction::Store { frag, layout, .. } => Extension::Wmma(
181                    WmmaExtension::Store(WmmaStore::new(variable_to_frag(frag), *layout)),
182                ),
183                shared::WmmaInstruction::Cast { input, output } => {
184                    Extension::Wmma(WmmaExtension::Cast(WmmaCast::new(
185                        variable_to_frag(input),
186                        variable_to_frag(output),
187                    )))
188                }
189            };
190
191            if !extensions.contains(&extension) {
192                extensions.push(extension);
193            }
194        } else if let shared::WmmaInstruction::ExecuteManual {
195            shape,
196            frag_a,
197            frag_c,
198            ..
199        } = instruction
200        {
201            let extension = Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
202                *shape,
203                frag_a[0].elem(),
204                frag_c[0].elem(),
205            )));
206
207            if !extensions.contains(&extension) {
208                extensions.push(extension);
209            }
210        }
211    }
212}
213
214// Types
215
216impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for HipDialect<M> {
217    fn item_can_be_optimized() -> bool {
218        // for now deactivate support for half2 and bfloat162 because the HIP API lack support for it.
219        false
220    }
221
222    fn compile_type_definitions(
223        f: &mut std::fmt::Formatter<'_>,
224        items: &HashSet<Item<Self>>,
225        _scalars: &[(Elem<Self>, usize)],
226        flags: &Flags,
227    ) -> std::fmt::Result {
228        shared::type_definitions::<Self>(f)?;
229        shared::type_vectorized_definitions::<Self>(f, items)?;
230
231        if flags.inst_wmma {
232            Self::compile_wmma_type_definitions(f, flags)?;
233        }
234
235        Ok(())
236    }
237
238    fn compile_elem(
239        f: &mut std::fmt::Formatter<'_>,
240        elem: &shared::Elem<Self>,
241        words: bool,
242    ) -> std::fmt::Result {
243        if words {
244            match elem {
245                shared::Elem::F32 => f.write_str("float"),
246                shared::Elem::F64 => f.write_str("double"),
247                shared::Elem::TF32 => f.write_str("float"),
248                shared::Elem::I8 => f.write_str("char"),
249                shared::Elem::I16 => f.write_str("short"),
250                shared::Elem::I32 => f.write_str("int"),
251                shared::Elem::I64 => f.write_str("long"),
252                shared::Elem::U8 => f.write_str("uchar"),
253                shared::Elem::U16 => f.write_str("ushort"),
254                shared::Elem::U32 => f.write_str("uint"),
255                shared::Elem::U64 => f.write_str("ulong"),
256                _ => Self::compile_elem(f, elem, false),
257            }
258        } else {
259            match elem {
260                shared::Elem::FP4(_)
261                | shared::Elem::FP4x2(_)
262                | shared::Elem::FP6(_)
263                | shared::Elem::FP6x2(_)
264                | shared::Elem::FP8(_)
265                | shared::Elem::FP8x2(_) => unimplemented!("FP4/FP6/FP8 not supported in HIP"),
266                shared::Elem::F16 => f.write_str("__half"),
267                shared::Elem::F16x2 => f.write_str("__half2"),
268                shared::Elem::F32 => f.write_str("float"),
269                shared::Elem::F64 => f.write_str("double"),
270                shared::Elem::BF16 => f.write_str("__bf16"),
271                shared::Elem::BF16x2 => f.write_str("__bf162"),
272                shared::Elem::TF32 => f.write_str("float"),
273                shared::Elem::I8 => f.write_str("int8"),
274                shared::Elem::I16 => f.write_str("int16"),
275                shared::Elem::I32 => f.write_str("int32"),
276                shared::Elem::I64 => f.write_str("int64"),
277                shared::Elem::U8 => f.write_str("uint8"),
278                shared::Elem::U16 => f.write_str("uint16"),
279                shared::Elem::U32 => f.write_str("uint32"),
280                shared::Elem::U64 => f.write_str("uint64"),
281                shared::Elem::Bool => f.write_str("bool"),
282                shared::Elem::Atomic(inner) => inner.fmt(f),
283                shared::Elem::_Dialect(_) => Ok(()),
284            }
285        }
286    }
287
288    fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
289        if 1 == item.vectorization {
290            return write!(f, "{}", item.elem);
291        }
292        if item.native {
293            // native types use the word form of types only
294            Self::compile_elem(f, &item.elem, true)?;
295            write!(f, "{}", item.vectorization)
296        } else {
297            write!(f, "{}_{}", item.elem, item.vectorization)
298        }
299    }
300
301    fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        Ok(())
303    }
304}
305
306// Kernel argument bindings
307
308impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for HipDialect<M> {
309    fn compile_kernel_signature(
310        f: &mut std::fmt::Formatter<'_>,
311        kernel_name: &str,
312        tensor_maps: &[Binding<Self>],
313        buffers: &[Binding<Self>],
314        scalars: &[(Elem<Self>, usize)],
315        flags: &Flags,
316    ) -> std::fmt::Result {
317        write!(
318            f,
319            "
320
321extern \"C\" __global__ void __launch_bounds__({}) {kernel_name}(
322",
323            flags.cube_dim.num_elems()
324        )?;
325        shared::compile_bindings::<Self>(f, tensor_maps, buffers, !scalars.is_empty(), flags)?;
326        shared::compile_scalars_dynamic::<Self>(f, scalars)?;
327        f.write_str("\n)")?;
328
329        Ok(())
330    }
331
332    fn compile_bindings_body(
333        f: &mut std::fmt::Formatter<'_>,
334        body: &shared::Body<Self>,
335    ) -> std::fmt::Result {
336        if !body.shared_memories.is_empty() {
337            let max_align = body
338                .shared_memories
339                .iter()
340                .map(|smem| smem.align)
341                .max()
342                .unwrap();
343            // The `__align__` instead of `alignas` is on purpose - the compiler is currently bugged
344            // with `extern __shared__ alignas` and doesn't properly parse it.
345            writeln!(
346                f,
347                "extern __shared__ __align__({max_align}) uchar dynamic_shared_mem[];"
348            )?;
349        }
350        Ok(())
351    }
352}
353
354// Cube builtins dialect
355
356impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for HipDialect<M> {}
357
358// Instructions
359
360impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for HipDialect<M> {
361    fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        writeln!(f, "__syncthreads();\n")
363    }
364
365    fn compile_instruction_sync_warp(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        panic!("Sync warp is unimplemented on hip")
367    }
368
369    fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        writeln!(f, "__threadfence();")
371    }
372
373    // unary
374    fn compile_instruction_find_first_set<T: Component<Self>>(
375        f: &mut std::fmt::Formatter<'_>,
376        input: T,
377        out_elem: Elem<Self>,
378    ) -> std::fmt::Result {
379        write!(f, "{out_elem}(")?;
380        match input.elem() {
381            Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
382            Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
383            _ => write!(f, "__ffs({}({input}))", Elem::<Self>::U32),
384        }?;
385        write!(f, ")")
386    }
387
388    fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
389        f: &mut std::fmt::Formatter<'_>,
390        input: T,
391        out_elem: Elem<Self>,
392    ) -> std::fmt::Result {
393        write!(f, "{out_elem}(")?;
394        match input.elem() {
395            Elem::I32 | Elem::U32 => write!(f, "__clz({input})"),
396            Elem::I64 | Elem::U64 => write!(f, "__clzll({input})"),
397            in_elem => write!(
398                f,
399                "__clz({}) - {}",
400                unary::zero_extend(input),
401                (size_of::<u32>() - in_elem.size()) * 8
402            ),
403        }?;
404        write!(f, ")")
405    }
406
407    fn compile_saturating_add(
408        _f: &mut std::fmt::Formatter<'_>,
409        _lhs: impl Display,
410        _rhs: impl Display,
411        _item: Item<Self>,
412    ) -> std::fmt::Result {
413        unimplemented!("No native instruction exists, Should be replaced in a preprocessor");
414    }
415
416    fn compile_saturating_sub(
417        _f: &mut std::fmt::Formatter<'_>,
418        _lhs: impl Display,
419        _rhs: impl Display,
420        _item: Item<Self>,
421    ) -> std::fmt::Result {
422        unimplemented!("No native instruction exists, Should be replaced in a preprocessor");
423    }
424
425    // others
426    fn compile_instruction_max_function_name(
427        f: &mut std::fmt::Formatter<'_>,
428        item: Item<Self>,
429    ) -> std::fmt::Result {
430        let max = match item.elem() {
431            Elem::F16 => "__hmax",
432            Elem::BF16 => "max_bfloat16",
433            _ => "max",
434        };
435        write!(f, "{max}")
436    }
437
438    fn compile_instruction_min_function_name(
439        f: &mut std::fmt::Formatter<'_>,
440        item: Item<Self>,
441    ) -> std::fmt::Result {
442        let min = match item.elem() {
443            Elem::F16 => "__hmin",
444            Elem::BF16 => "min_bfloat16",
445            _ => "min",
446        };
447        write!(f, "{min}")
448    }
449
450    // Warp
451    fn compile_warp_shuffle(
452        f: &mut std::fmt::Formatter<'_>,
453        var: &str,
454        source: &str,
455    ) -> std::fmt::Result {
456        write!(f, "__shfl({var}, {source})")
457    }
458    fn compile_warp_shuffle_xor(
459        f: &mut std::fmt::Formatter<'_>,
460        var: &str,
461        elem: &Elem<Self>,
462        offset: &str,
463    ) -> std::fmt::Result {
464        match elem {
465            Elem::BF16 => write!(
466                f,
467                "half_to_bfloat16(__shfl_xor(reinterpret_cast<__half&>({var}), {offset}))"
468            ),
469            _ => write!(f, "__shfl_xor({var}, {offset})"),
470        }
471    }
472    fn compile_warp_shuffle_up(
473        f: &mut std::fmt::Formatter<'_>,
474        var: &str,
475        offset: &str,
476    ) -> std::fmt::Result {
477        write!(f, "__shfl_up({var}, {offset})")
478    }
479    fn compile_warp_shuffle_down(
480        f: &mut std::fmt::Formatter<'_>,
481        var: &str,
482        offset: &str,
483    ) -> std::fmt::Result {
484        write!(f, "__shfl_down({var}, {offset})")
485    }
486    fn compile_warp_all<T: Component<Self>>(
487        f: &mut std::fmt::Formatter<'_>,
488        input: &T,
489    ) -> std::fmt::Result {
490        let item = input.item();
491        let elem = item.elem;
492        write!(f, "static_cast<{elem}>(__all({input}))")
493    }
494    fn compile_warp_any<T: Component<Self>>(
495        f: &mut std::fmt::Formatter<'_>,
496        input: &T,
497    ) -> std::fmt::Result {
498        let item = input.item();
499        let elem = item.elem;
500        write!(f, "static_cast<{elem}>(__any({input}))")
501    }
502    fn compile_warp_ballot(
503        f: &mut std::fmt::Formatter<'_>,
504        input: &Variable<Self>,
505        out_elem: &Elem<Self>,
506    ) -> std::fmt::Result {
507        write!(f, "{out_elem}(__ballot({input}))")
508    }
509}
510
511// Coop Matrices dialect
512
513impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for HipDialect<M> {
514    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
515        M::compile_wmma_includes(f, flags)
516    }
517
518    fn compile_wmma_type_definitions(
519        f: &mut std::fmt::Formatter<'_>,
520        flags: &Flags,
521    ) -> std::fmt::Result {
522        M::compile_wmma_type_definitions(f, flags)
523    }
524
525    fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        M::compile_wmma_local_variables(f)
527    }
528
529    fn compile_wmma_fragment_declaration(
530        f: &mut std::fmt::Formatter<'_>,
531        var: &Variable<Self>,
532    ) -> std::fmt::Result {
533        M::compile_wmma_fragment_declaration(f, var)
534    }
535
536    fn compile_wwma_fragment_ident(
537        f: &mut std::fmt::Formatter<'_>,
538        ident: &crate::shared::FragmentIdent<Self>,
539    ) -> std::fmt::Result {
540        M::compile_wwma_fragment_ident(f, ident)
541    }
542
543    fn compile_wmma_fragment_layout(
544        f: &mut std::fmt::Formatter<'_>,
545        layout: &crate::shared::FragmentLayout<Self>,
546    ) -> std::fmt::Result {
547        M::compile_wmma_fragment_layout(f, layout)
548    }
549
550    fn compile_wmma_fragment(
551        f: &mut std::fmt::Formatter<'_>,
552        fragment: &crate::shared::Fragment<Self>,
553    ) -> std::fmt::Result {
554        M::compile_wmma_fragment(f, fragment)
555    }
556
557    fn compile_wmma_instruction(
558        f: &mut std::fmt::Formatter<'_>,
559        instruction: &crate::shared::WmmaInstruction<Self>,
560    ) -> std::fmt::Result {
561        M::compile_wmma_instruction(f, instruction)
562    }
563
564    fn compile_manual_mma(
565        f: &mut std::fmt::Formatter<'_>,
566        mma: ManualMma<Self>,
567    ) -> std::fmt::Result {
568        M::compile_manual_mma(f, mma)
569    }
570
571    fn supported_wmma_combinations(
572        arch: &AMDArchitecture,
573    ) -> crate::shared::SupportedMmaCombinations {
574        M::supported_wmma_combinations(arch)
575    }
576
577    fn supported_mma_combinations(arch: &AMDArchitecture) -> shared::SupportedMmaCombinations {
578        M::supported_mma_combinations(arch)
579    }
580
581    fn compile_scaled_mma(
582        _f: &mut std::fmt::Formatter<'_>,
583        _mma: ManualMma<Self>,
584        _scales_a: Variable<Self>,
585        _scales_b: Variable<Self>,
586        _scales_factor: u32,
587    ) -> std::fmt::Result {
588        panic!("Scaled MMA not supporter in HIP")
589    }
590}
591
592impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for HipDialect<M> {
593    fn processors() -> Vec<Box<dyn Processor>> {
594        vec![
595            Box::new(HipMmaProcessor),
596            Box::new(SaturatingArithmeticProcessor::new(true)),
597        ]
598    }
599}