1use core::any::TypeId;
2use std::fmt::Display;
3use std::{collections::HashSet, marker::PhantomData};
4
5use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
6
7use crate::shared::DialectWarpReduceCompiler;
8use crate::{
9 Dialect,
10 shared::{
11 self, DialectBindings, DialectCubeBuiltins, DialectIncludes, DialectTypes,
12 DialectWmmaCompiler, Flags, Item, KernelArg, ManualMma,
13 },
14};
15use crate::{
16 hip::processors::HipMmaProcessor,
17 shared::{
18 Component, DialectInstructions, DialectProcessors, Elem, Instruction, Variable, unary,
19 variable_to_frag,
20 },
21};
22
23use super::Extension;
24use super::arch::AMDArchitecture;
25use super::extension::{WmmaExtension, format_f162bf16, format_max, format_min};
26use super::mma::{WmmaCast, WmmaExecute, WmmaFill, WmmaIntrinsicCompiler, WmmaLoad, WmmaStore};
27
28#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
29pub struct HipDialect<M> {
30 _wmma_compiler: PhantomData<M>,
31}
32
33impl<M: DialectWmmaCompiler<Self>> Dialect for HipDialect<M> {
36 type Architecture = AMDArchitecture;
37}
38
39impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for HipDialect<M> {}
40
41impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for HipDialect<M> {
44 type Extension = Extension<Self>;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<Self>) -> std::fmt::Result {
47 f.write_str("#include <hip/hip_runtime.h>\n")?;
48 if flags.elem_bf16 {
49 f.write_str("#include <hip/hip_bf16.h>\n")?;
50 }
51 if flags.elem_f16 {
52 f.write_str("#include <hip/hip_fp16.h>\n")?;
53 }
54 if flags.inst_wmma {
55 Self::compile_wmma_includes(f, flags)?;
56 }
57 Ok(())
58 }
59
60 fn compile_extensions(
61 f: &mut std::fmt::Formatter<'_>,
62 extensions: &[Self::Extension],
63 ) -> std::fmt::Result {
64 for extension in extensions {
65 match extension {
66 Extension::F162BF16 => format_f162bf16(f)?,
67 Extension::Max(var) => format_max::<Self>(f, var)?,
68 Extension::Min(var) => format_min::<Self>(f, var)?,
69 Extension::NoExtension => {}
70 Extension::Wmma(inst) => inst.format_wmma(f)?,
71 }
72 }
73 Ok(())
74 }
75
76 fn register_instruction_extension(
77 extensions: &mut Vec<Self::Extension>,
78 instruction: &Instruction<Self>,
79 ) {
80 let mut register_extension = |extension: Self::Extension| {
81 if !extensions.contains(&extension) {
82 extensions.push(extension);
83 }
84 };
85 #[allow(clippy::single_match)]
86 match instruction {
87 shared::Instruction::<Self>::Max(op) => {
88 register_extension(Extension::Max(*op.lhs.item().elem()));
89 }
90 shared::Instruction::<Self>::Min(op) => {
91 register_extension(Extension::Min(*op.lhs.item().elem()));
92 }
93 _ => {}
94 }
95 }
96
97 fn register_warp_instruction_extension(
98 extensions: &mut Vec<Self::Extension>,
99 instruction: &shared::WarpInstruction<Self>,
100 ) {
101 let mut register_extension = |extension: Self::Extension| {
102 if !extensions.contains(&extension) {
103 extensions.push(extension);
104 }
105 };
106
107 #[allow(clippy::single_match)]
108 match instruction {
109 shared::WarpInstruction::<Self>::ReduceMax { input, .. } => {
110 let input_item = input.item();
111 let input_elem = input_item.elem();
112 if *input_elem == Elem::<Self>::BF16 {
113 register_extension(Extension::F162BF16);
114 }
115 register_extension(Extension::Max(*input_elem));
116 }
117 shared::WarpInstruction::<Self>::ReduceMin { input, .. } => {
118 let input_item = input.item();
119 let input_elem = input_item.elem();
120 if *input_elem == Elem::<Self>::BF16 {
121 register_extension(Extension::F162BF16);
122 }
123 register_extension(Extension::Min(*input_elem));
124 }
125 shared::WarpInstruction::<Self>::ReduceProd { input, .. } => {
126 let input_item = input.item();
127 let input_elem = input_item.elem();
128 if *input_elem == Elem::<Self>::BF16 {
129 register_extension(Extension::F162BF16);
130 }
131 }
132 shared::WarpInstruction::<Self>::ReduceSum { input, .. } => {
133 let input_item = input.item();
134 let input_elem = input_item.elem();
135 if *input_elem == Elem::<Self>::BF16 {
136 register_extension(Extension::F162BF16);
137 }
138 }
139 _ => {}
140 }
141 }
142
143 fn register_wmma_instruction_extension(
144 extensions: &mut Vec<Self::Extension>,
145 instruction: &shared::WmmaInstruction<Self>,
146 ) {
147 if TypeId::of::<M>() == TypeId::of::<WmmaIntrinsicCompiler>() {
148 let extension = match instruction {
149 shared::WmmaInstruction::Fill { frag, .. } => {
150 Extension::Wmma(WmmaExtension::Fill(WmmaFill::new(variable_to_frag(frag))))
151 }
152 shared::WmmaInstruction::Load { frag, layout, .. } => Extension::Wmma(
153 WmmaExtension::Load(WmmaLoad::new(variable_to_frag(frag), *layout)),
154 ),
155 shared::WmmaInstruction::LdMatrix { .. }
156 | shared::WmmaInstruction::StMatrix { .. } => {
157 panic!("Invalid extension: StMatrix & LdMatrix not supported for HIP");
158 }
159 shared::WmmaInstruction::Execute {
160 frag_a,
161 frag_b,
162 frag_c,
163 frag_d,
164 warp_size: _,
165 } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::new(
166 variable_to_frag(frag_a),
167 variable_to_frag(frag_b),
168 variable_to_frag(frag_c),
169 variable_to_frag(frag_d),
170 ))),
171 shared::WmmaInstruction::ExecuteManual {
172 shape,
173 frag_a,
174 frag_c,
175 ..
176 } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
177 *shape,
178 frag_a.elem(),
179 frag_c.elem(),
180 ))),
181 shared::WmmaInstruction::ExecuteScaled { .. } => {
182 panic!("Invalid extension: ExecuteScaled not supported for HIP");
183 }
184 shared::WmmaInstruction::Store { frag, layout, .. } => Extension::Wmma(
185 WmmaExtension::Store(WmmaStore::new(variable_to_frag(frag), *layout)),
186 ),
187 shared::WmmaInstruction::Cast { input, output } => {
188 Extension::Wmma(WmmaExtension::Cast(WmmaCast::new(
189 variable_to_frag(input),
190 variable_to_frag(output),
191 )))
192 }
193 };
194
195 if !extensions.contains(&extension) {
196 extensions.push(extension);
197 }
198 } else if let shared::WmmaInstruction::ExecuteManual {
199 shape,
200 frag_a,
201 frag_c,
202 ..
203 } = instruction
204 {
205 let extension = Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
206 *shape,
207 frag_a.elem(),
208 frag_c.elem(),
209 )));
210
211 if !extensions.contains(&extension) {
212 extensions.push(extension);
213 }
214 }
215 }
216}
217
218impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for HipDialect<M> {
221 fn item_can_be_optimized() -> bool {
222 false
224 }
225
226 fn compile_type_definitions(
227 f: &mut std::fmt::Formatter<'_>,
228 items: &HashSet<Item<Self>>,
229 scalars: &[(Elem<Self>, usize)],
230 info: &cubecl_core::Info,
231 flags: &Flags<Self>,
232 ) -> std::fmt::Result {
233 shared::type_definitions::<Self>(f)?;
234 shared::type_vectorized_definitions::<Self>(f, items)?;
235
236 shared::type_info_definition_sized(f, info, scalars, flags.address_type)?;
237
238 if flags.inst_wmma {
239 Self::compile_wmma_type_definitions(f, flags)?;
240 }
241
242 Ok(())
243 }
244
245 fn compile_elem(
246 f: &mut std::fmt::Formatter<'_>,
247 elem: &shared::Elem<Self>,
248 words: bool,
249 ) -> std::fmt::Result {
250 if words {
251 match elem {
252 shared::Elem::F32 => f.write_str("float"),
253 shared::Elem::F64 => f.write_str("double"),
254 shared::Elem::TF32 => f.write_str("float"),
255 shared::Elem::I8 => f.write_str("char"),
256 shared::Elem::I16 => f.write_str("short"),
257 shared::Elem::I32 => f.write_str("int"),
258 shared::Elem::I64 => f.write_str("long"),
259 shared::Elem::U8 => f.write_str("uchar"),
260 shared::Elem::U16 => f.write_str("ushort"),
261 shared::Elem::U32 => f.write_str("uint"),
262 shared::Elem::U64 => f.write_str("ulong"),
263 _ => Self::compile_elem(f, elem, false),
264 }
265 } else {
266 match elem {
267 shared::Elem::FP4(_)
268 | shared::Elem::FP4x2(_)
269 | shared::Elem::FP6(_)
270 | shared::Elem::FP6x2(_)
271 | shared::Elem::FP8(_)
272 | shared::Elem::FP8x2(_) => {
273 f.write_str("#error FP4/FP6/FP8 not supported in HIP\n")
274 }
275 shared::Elem::F16 => f.write_str("__half"),
276 shared::Elem::F16x2 => f.write_str("__half2"),
277 shared::Elem::F32 => f.write_str("float"),
278 shared::Elem::F64 => f.write_str("double"),
279 shared::Elem::BF16 => f.write_str("__hip_bfloat16"),
280 shared::Elem::BF16x2 => f.write_str("__hip_bfloat162"),
281 shared::Elem::TF32 => f.write_str("float"),
282 shared::Elem::I8 => f.write_str("int8"),
283 shared::Elem::I16 => f.write_str("int16"),
284 shared::Elem::I32 => f.write_str("int32"),
285 shared::Elem::I64 => f.write_str("int64"),
286 shared::Elem::U8 => f.write_str("uint8"),
287 shared::Elem::U16 => f.write_str("uint16"),
288 shared::Elem::U32 => f.write_str("uint32"),
289 shared::Elem::U64 => f.write_str("uint64"),
290 shared::Elem::Bool => f.write_str("bool"),
291 shared::Elem::Barrier(_) => panic!("Barrier object not supported in HIP"),
292 shared::Elem::Atomic(inner) => inner.fmt(f),
293 shared::Elem::_Dialect(_) => Ok(()),
294 }
295 }
296 }
297
298 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
299 if 1 == item.vectorization {
300 return write!(f, "{}", item.elem);
301 }
302 if item.native {
303 Self::compile_elem(f, &item.elem, true)?;
305 write!(f, "{}", item.vectorization)
306 } else {
307 write!(f, "{}_{}", item.elem, item.vectorization)
308 }
309 }
310
311 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312 Ok(())
313 }
314}
315
316impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for HipDialect<M> {
319 fn compile_kernel_signature(
320 f: &mut std::fmt::Formatter<'_>,
321 kernel_name: &str,
322 tensor_maps: &[KernelArg<Self>],
323 buffers: &[KernelArg<Self>],
324 flags: &Flags<Self>,
325 ) -> std::fmt::Result {
326 write!(
327 f,
328 "
329
330extern \"C\" __global__ void __launch_bounds__({}) {kernel_name}(
331",
332 flags.cube_dim.num_elems()
333 )?;
334 shared::compile_bindings::<Self>(f, tensor_maps, buffers, flags.has_info)?;
335 shared::compile_info_dynamic::<Self>(f, flags)?;
336 f.write_str("\n)")?;
337
338 Ok(())
339 }
340
341 fn compile_bindings_body(
342 f: &mut std::fmt::Formatter<'_>,
343 body: &shared::Body<Self>,
344 ) -> std::fmt::Result {
345 if !body.shared_memories.is_empty() {
346 let max_align = body
347 .shared_memories
348 .iter()
349 .map(|smem| smem.align())
350 .max()
351 .unwrap();
352 writeln!(
355 f,
356 "extern __shared__ __align__({max_align}) uchar dynamic_shared_mem[];"
357 )?;
358 }
359 if body.info_by_ptr {
360 f.write_str("const info_st& info = *info_ptr;\n")?;
361 writeln!(
363 f,
364 "const {addr}* dynamic_meta = reinterpret_cast<const {addr}*>(
365 reinterpret_cast<const char*>(info_ptr) + sizeof(info_st)
366 );\n",
367 addr = body.address_type,
368 )?;
369 }
370 Ok(())
371 }
372}
373
374impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for HipDialect<M> {}
377
378impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for HipDialect<M> {
381 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
382 writeln!(f, "__syncthreads();\n")
383 }
384
385 fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 writeln!(f, "#error Sync warp is unimplemented on hip\n")
387 }
388
389 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 writeln!(f, "__threadfence();")
391 }
392
393 fn compile_instruction_find_first_set<T: Component<Self>>(
395 f: &mut std::fmt::Formatter<'_>,
396 input: T,
397 out_elem: Elem<Self>,
398 ) -> std::fmt::Result {
399 write!(f, "{out_elem}(")?;
400 match input.elem() {
401 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
402 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
403 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::U32),
404 }?;
405 write!(f, ")")
406 }
407
408 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
409 f: &mut std::fmt::Formatter<'_>,
410 input: T,
411 out_elem: Elem<Self>,
412 ) -> std::fmt::Result {
413 write!(f, "{out_elem}(")?;
414 match input.elem() {
415 Elem::I32 | Elem::U32 => write!(f, "__clz({input})"),
416 Elem::I64 | Elem::U64 => write!(f, "__clzll({input})"),
417 in_elem => write!(
418 f,
419 "__clz({}) - {}",
420 unary::zero_extend(input),
421 (size_of::<u32>() - in_elem.size()) * 8
422 ),
423 }?;
424 write!(f, ")")
425 }
426
427 fn compile_instruction_trailing_zeros_scalar<T: Component<Self>>(
428 f: &mut std::fmt::Formatter<'_>,
429 input: T,
430 out_elem: Elem<Self>,
431 ) -> std::fmt::Result {
432 write!(f, "{out_elem}(")?;
435 match input.elem() {
436 Elem::I32 | Elem::U32 => {
437 write!(f, "({input} == 0 ? 32 : __ffs({input}) - 1)")
438 }
439 Elem::I64 | Elem::U64 => {
440 write!(f, "({input} == 0 ? 64 : __ffsll({input}) - 1)")
441 }
442 in_elem => {
443 let bits = in_elem.size() * 8;
444 let extended = unary::zero_extend(input);
445 write!(f, "({extended} == 0 ? {bits} : __ffs({extended}) - 1)")
446 }
447 }?;
448 write!(f, ")")
449 }
450
451 fn compile_saturating_add(
452 f: &mut std::fmt::Formatter<'_>,
453 _lhs: impl Display,
454 _rhs: impl Display,
455 _item: Item<Self>,
456 ) -> std::fmt::Result {
457 f.write_str(
458 "#error No native saturating add exists, TODO: Should be replaced in a preprocessor\n",
459 )
460 }
461
462 fn compile_saturating_sub(
463 f: &mut std::fmt::Formatter<'_>,
464 _lhs: impl Display,
465 _rhs: impl Display,
466 _item: Item<Self>,
467 ) -> std::fmt::Result {
468 f.write_str(
469 "#error No native saturating sub exists, TODO: Should be replaced in a preprocessor\n",
470 )
471 }
472
473 fn compile_instruction_max_function_name(
475 f: &mut std::fmt::Formatter<'_>,
476 item: Item<Self>,
477 ) -> std::fmt::Result {
478 let max = match item.elem() {
479 Elem::F16 => "__hmax",
480 Elem::BF16 => "__hmax",
481 _ => "max",
482 };
483 write!(f, "{max}")
484 }
485
486 fn compile_instruction_min_function_name(
487 f: &mut std::fmt::Formatter<'_>,
488 item: Item<Self>,
489 ) -> std::fmt::Result {
490 let min = match item.elem() {
491 Elem::F16 => "__hmin",
492 Elem::BF16 => "__hmin",
493 _ => "min",
494 };
495 write!(f, "{min}")
496 }
497
498 fn compile_warp_shuffle(
500 f: &mut std::fmt::Formatter<'_>,
501 var: &str,
502 source: &str,
503 ) -> std::fmt::Result {
504 write!(f, "__shfl({var}, {source})")
505 }
506 fn compile_warp_shuffle_xor(
507 f: &mut std::fmt::Formatter<'_>,
508 var: &str,
509 elem: &Elem<Self>,
510 offset: &str,
511 ) -> std::fmt::Result {
512 match elem {
513 Elem::BF16 => write!(
514 f,
515 "half_to_bfloat16(__shfl_xor(reinterpret_cast<__half&>({var}), {offset}))"
516 ),
517 _ => write!(f, "__shfl_xor({var}, {offset})"),
518 }
519 }
520 fn compile_warp_shuffle_up(
521 f: &mut std::fmt::Formatter<'_>,
522 var: &str,
523 offset: &str,
524 ) -> std::fmt::Result {
525 write!(f, "__shfl_up({var}, {offset})")
526 }
527 fn compile_warp_shuffle_down(
528 f: &mut std::fmt::Formatter<'_>,
529 var: &str,
530 offset: &str,
531 ) -> std::fmt::Result {
532 write!(f, "__shfl_down({var}, {offset})")
533 }
534 fn compile_warp_all<T: Component<Self>>(
535 f: &mut std::fmt::Formatter<'_>,
536 input: &T,
537 ) -> std::fmt::Result {
538 let item = input.item();
539 let elem = item.elem;
540 write!(f, "static_cast<{elem}>(__all({input}))")
541 }
542 fn compile_warp_any<T: Component<Self>>(
543 f: &mut std::fmt::Formatter<'_>,
544 input: &T,
545 ) -> std::fmt::Result {
546 let item = input.item();
547 let elem = item.elem;
548 write!(f, "static_cast<{elem}>(__any({input}))")
549 }
550 fn compile_warp_ballot(
551 f: &mut std::fmt::Formatter<'_>,
552 input: &Variable<Self>,
553 out_elem: &Elem<Self>,
554 ) -> std::fmt::Result {
555 write!(f, "{out_elem}(__ballot({input}))")
556 }
557
558 fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
559 write!(f, "__builtin_unreachable();")
560 }
561}
562
563impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for HipDialect<M> {
566 fn compile_wmma_includes(
567 f: &mut std::fmt::Formatter<'_>,
568 flags: &Flags<Self>,
569 ) -> std::fmt::Result {
570 M::compile_wmma_includes(f, flags)
571 }
572
573 fn compile_wmma_type_definitions(
574 f: &mut std::fmt::Formatter<'_>,
575 flags: &Flags<Self>,
576 ) -> std::fmt::Result {
577 M::compile_wmma_type_definitions(f, flags)
578 }
579
580 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 M::compile_wmma_local_variables(f)
582 }
583
584 fn compile_wmma_fragment_declaration(
585 f: &mut std::fmt::Formatter<'_>,
586 var: &Variable<Self>,
587 ) -> std::fmt::Result {
588 M::compile_wmma_fragment_declaration(f, var)
589 }
590
591 fn compile_wwma_fragment_ident(
592 f: &mut std::fmt::Formatter<'_>,
593 ident: &crate::shared::FragmentIdent<Self>,
594 ) -> std::fmt::Result {
595 M::compile_wwma_fragment_ident(f, ident)
596 }
597
598 fn compile_wmma_fragment_layout(
599 f: &mut std::fmt::Formatter<'_>,
600 layout: &crate::shared::FragmentLayout<Self>,
601 ) -> std::fmt::Result {
602 M::compile_wmma_fragment_layout(f, layout)
603 }
604
605 fn compile_wmma_fragment(
606 f: &mut std::fmt::Formatter<'_>,
607 fragment: &crate::shared::Fragment<Self>,
608 ) -> std::fmt::Result {
609 M::compile_wmma_fragment(f, fragment)
610 }
611
612 fn compile_wmma_instruction(
613 f: &mut std::fmt::Formatter<'_>,
614 instruction: &crate::shared::WmmaInstruction<Self>,
615 ) -> std::fmt::Result {
616 M::compile_wmma_instruction(f, instruction)
617 }
618
619 fn compile_manual_mma(
620 f: &mut std::fmt::Formatter<'_>,
621 mma: ManualMma<Self>,
622 ) -> std::fmt::Result {
623 M::compile_manual_mma(f, mma)
624 }
625
626 fn supported_wmma_combinations(
627 arch: &AMDArchitecture,
628 ) -> crate::shared::SupportedMmaCombinations {
629 M::supported_wmma_combinations(arch)
630 }
631
632 fn supported_mma_combinations(arch: &AMDArchitecture) -> shared::SupportedMmaCombinations {
633 M::supported_mma_combinations(arch)
634 }
635
636 fn compile_scaled_mma(
637 _f: &mut std::fmt::Formatter<'_>,
638 _mma: ManualMma<Self>,
639 _scales_a: Variable<Self>,
640 _scales_b: Variable<Self>,
641 _scales_factor: u32,
642 ) -> std::fmt::Result {
643 panic!("Scaled MMA not supporter in HIP")
644 }
645}
646
647impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for HipDialect<M> {
648 fn processors() -> Vec<Box<dyn Processor>> {
649 vec![
650 Box::new(HipMmaProcessor),
651 Box::new(SaturatingArithmeticProcessor::new(true)),
652 ]
653 }
654}