1use core::any::TypeId;
2use std::fmt::Display;
3use std::{collections::HashSet, marker::PhantomData};
4
5use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
6
7use crate::shared::DialectWarpReduceCompiler;
8use crate::{
9 Dialect,
10 shared::{
11 self, Binding, DialectBindings, DialectCubeBuiltins, DialectIncludes, DialectTypes,
12 DialectWmmaCompiler, Flags, Item, ManualMma,
13 },
14};
15use crate::{
16 hip::processors::HipMmaProcessor,
17 shared::{
18 Component, DialectInstructions, DialectProcessors, Elem, Instruction, Variable, unary,
19 variable_to_frag,
20 },
21};
22
23use super::Extension;
24use super::arch::AMDArchitecture;
25use super::extension::{WmmaExtension, format_f162bf16, format_max, format_min};
26use super::mma::{WmmaCast, WmmaExecute, WmmaFill, WmmaIntrinsicCompiler, WmmaLoad, WmmaStore};
27
28#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
29pub struct HipDialect<M> {
30 _wmma_compiler: PhantomData<M>,
31}
32
33impl<M: DialectWmmaCompiler<Self>> Dialect for HipDialect<M> {
36 type Architecture = AMDArchitecture;
37}
38
39impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for HipDialect<M> {}
40
41impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for HipDialect<M> {
44 type Extension = Extension<Self>;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
47 f.write_str("#include <hip/hip_runtime.h>\n")?;
48 if flags.elem_bf16 {
49 f.write_str("#include <hip/hip_bf16.h>\n")?;
50 }
51 if flags.elem_f16 {
52 f.write_str("#include <hip/hip_fp16.h>\n")?;
53 }
54 if flags.inst_wmma {
55 Self::compile_wmma_includes(f, flags)?;
56 }
57 Ok(())
58 }
59
60 fn compile_extensions(
61 f: &mut std::fmt::Formatter<'_>,
62 extensions: &[Self::Extension],
63 ) -> std::fmt::Result {
64 for extension in extensions {
65 match extension {
66 Extension::F162BF16 => format_f162bf16(f)?,
67 Extension::Max(var) => format_max::<Self>(f, var)?,
68 Extension::Min(var) => format_min::<Self>(f, var)?,
69 Extension::NoExtension => {}
70 Extension::Wmma(inst) => inst.format_wmma(f)?,
71 }
72 }
73 Ok(())
74 }
75
76 fn register_instruction_extension(
77 extensions: &mut Vec<Self::Extension>,
78 instruction: &Instruction<Self>,
79 ) {
80 let mut register_extension = |extension: Self::Extension| {
81 if !extensions.contains(&extension) {
82 extensions.push(extension);
83 }
84 };
85 #[allow(clippy::single_match)]
86 match instruction {
87 shared::Instruction::<Self>::Max(op) => {
88 register_extension(Extension::Max(*op.lhs.item().elem()));
89 }
90 shared::Instruction::<Self>::Min(op) => {
91 register_extension(Extension::Min(*op.lhs.item().elem()));
92 }
93 _ => {}
94 }
95 }
96
97 fn register_warp_instruction_extension(
98 extensions: &mut Vec<Self::Extension>,
99 instruction: &shared::WarpInstruction<Self>,
100 ) {
101 let mut register_extension = |extension: Self::Extension| {
102 if !extensions.contains(&extension) {
103 extensions.push(extension);
104 }
105 };
106
107 #[allow(clippy::single_match)]
108 match instruction {
109 shared::WarpInstruction::<Self>::ReduceMax { input, .. } => {
110 let input_item = input.item();
111 let input_elem = input_item.elem();
112 if *input_elem == Elem::<Self>::BF16 {
113 register_extension(Extension::F162BF16);
114 }
115 register_extension(Extension::Max(*input_elem));
116 }
117 shared::WarpInstruction::<Self>::ReduceMin { input, .. } => {
118 let input_item = input.item();
119 let input_elem = input_item.elem();
120 if *input_elem == Elem::<Self>::BF16 {
121 register_extension(Extension::F162BF16);
122 }
123 register_extension(Extension::Min(*input_elem));
124 }
125 shared::WarpInstruction::<Self>::ReduceProd { input, .. } => {
126 let input_item = input.item();
127 let input_elem = input_item.elem();
128 if *input_elem == Elem::<Self>::BF16 {
129 register_extension(Extension::F162BF16);
130 }
131 }
132 shared::WarpInstruction::<Self>::ReduceSum { input, .. } => {
133 let input_item = input.item();
134 let input_elem = input_item.elem();
135 if *input_elem == Elem::<Self>::BF16 {
136 register_extension(Extension::F162BF16);
137 }
138 }
139 _ => {}
140 }
141 }
142
143 fn register_wmma_instruction_extension(
144 extensions: &mut Vec<Self::Extension>,
145 instruction: &shared::WmmaInstruction<Self>,
146 ) {
147 if TypeId::of::<M>() == TypeId::of::<WmmaIntrinsicCompiler>() {
148 let extension = match instruction {
149 shared::WmmaInstruction::Fill { frag, .. } => {
150 Extension::Wmma(WmmaExtension::Fill(WmmaFill::new(variable_to_frag(frag))))
151 }
152 shared::WmmaInstruction::Load { frag, layout, .. } => Extension::Wmma(
153 WmmaExtension::Load(WmmaLoad::new(variable_to_frag(frag), *layout)),
154 ),
155 shared::WmmaInstruction::LdMatrix { .. } => {
156 unimplemented!("Not supported for HIP");
157 }
158 shared::WmmaInstruction::Execute {
159 frag_a,
160 frag_b,
161 frag_c,
162 frag_d,
163 warp_size: _,
164 } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::new(
165 variable_to_frag(frag_a),
166 variable_to_frag(frag_b),
167 variable_to_frag(frag_c),
168 variable_to_frag(frag_d),
169 ))),
170 shared::WmmaInstruction::ExecuteManual {
171 shape,
172 frag_a,
173 frag_c,
174 ..
175 } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
176 *shape,
177 frag_a.elem(),
178 frag_c.elem(),
179 ))),
180 shared::WmmaInstruction::ExecuteScaled { .. } => {
181 unimplemented!("Not supported in HIP")
182 }
183 shared::WmmaInstruction::Store { frag, layout, .. } => Extension::Wmma(
184 WmmaExtension::Store(WmmaStore::new(variable_to_frag(frag), *layout)),
185 ),
186 shared::WmmaInstruction::Cast { input, output } => {
187 Extension::Wmma(WmmaExtension::Cast(WmmaCast::new(
188 variable_to_frag(input),
189 variable_to_frag(output),
190 )))
191 }
192 };
193
194 if !extensions.contains(&extension) {
195 extensions.push(extension);
196 }
197 } else if let shared::WmmaInstruction::ExecuteManual {
198 shape,
199 frag_a,
200 frag_c,
201 ..
202 } = instruction
203 {
204 let extension = Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
205 *shape,
206 frag_a.elem(),
207 frag_c.elem(),
208 )));
209
210 if !extensions.contains(&extension) {
211 extensions.push(extension);
212 }
213 }
214 }
215}
216
217impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for HipDialect<M> {
220 fn item_can_be_optimized() -> bool {
221 false
223 }
224
225 fn compile_type_definitions(
226 f: &mut std::fmt::Formatter<'_>,
227 items: &HashSet<Item<Self>>,
228 _scalars: &[(Elem<Self>, usize)],
229 flags: &Flags,
230 ) -> std::fmt::Result {
231 shared::type_definitions::<Self>(f)?;
232 shared::type_vectorized_definitions::<Self>(f, items)?;
233
234 if flags.inst_wmma {
235 Self::compile_wmma_type_definitions(f, flags)?;
236 }
237
238 Ok(())
239 }
240
241 fn compile_elem(
242 f: &mut std::fmt::Formatter<'_>,
243 elem: &shared::Elem<Self>,
244 words: bool,
245 ) -> std::fmt::Result {
246 if words {
247 match elem {
248 shared::Elem::F32 => f.write_str("float"),
249 shared::Elem::F64 => f.write_str("double"),
250 shared::Elem::TF32 => f.write_str("float"),
251 shared::Elem::I8 => f.write_str("char"),
252 shared::Elem::I16 => f.write_str("short"),
253 shared::Elem::I32 => f.write_str("int"),
254 shared::Elem::I64 => f.write_str("long"),
255 shared::Elem::U8 => f.write_str("uchar"),
256 shared::Elem::U16 => f.write_str("ushort"),
257 shared::Elem::U32 => f.write_str("uint"),
258 shared::Elem::U64 => f.write_str("ulong"),
259 _ => Self::compile_elem(f, elem, false),
260 }
261 } else {
262 match elem {
263 shared::Elem::FP4(_)
264 | shared::Elem::FP4x2(_)
265 | shared::Elem::FP6(_)
266 | shared::Elem::FP6x2(_)
267 | shared::Elem::FP8(_)
268 | shared::Elem::FP8x2(_) => unimplemented!("FP4/FP6/FP8 not supported in HIP"),
269 shared::Elem::F16 => f.write_str("__half"),
270 shared::Elem::F16x2 => f.write_str("__half2"),
271 shared::Elem::F32 => f.write_str("float"),
272 shared::Elem::F64 => f.write_str("double"),
273 shared::Elem::BF16 => f.write_str("__bf16"),
274 shared::Elem::BF16x2 => f.write_str("__bf162"),
275 shared::Elem::TF32 => f.write_str("float"),
276 shared::Elem::I8 => f.write_str("int8"),
277 shared::Elem::I16 => f.write_str("int16"),
278 shared::Elem::I32 => f.write_str("int32"),
279 shared::Elem::I64 => f.write_str("int64"),
280 shared::Elem::U8 => f.write_str("uint8"),
281 shared::Elem::U16 => f.write_str("uint16"),
282 shared::Elem::U32 => f.write_str("uint32"),
283 shared::Elem::U64 => f.write_str("uint64"),
284 shared::Elem::Bool => f.write_str("bool"),
285 shared::Elem::Atomic(inner) => inner.fmt(f),
286 shared::Elem::_Dialect(_) => Ok(()),
287 }
288 }
289 }
290
291 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
292 if 1 == item.vectorization {
293 return write!(f, "{}", item.elem);
294 }
295 if item.native {
296 Self::compile_elem(f, &item.elem, true)?;
298 write!(f, "{}", item.vectorization)
299 } else {
300 write!(f, "{}_{}", item.elem, item.vectorization)
301 }
302 }
303
304 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 Ok(())
306 }
307}
308
309impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for HipDialect<M> {
312 fn compile_kernel_signature(
313 f: &mut std::fmt::Formatter<'_>,
314 kernel_name: &str,
315 tensor_maps: &[Binding<Self>],
316 buffers: &[Binding<Self>],
317 scalars: &[(Elem<Self>, usize)],
318 flags: &Flags,
319 ) -> std::fmt::Result {
320 write!(
321 f,
322 "
323
324extern \"C\" __global__ void __launch_bounds__({}) {kernel_name}(
325",
326 flags.cube_dim.num_elems()
327 )?;
328 shared::compile_bindings::<Self>(f, tensor_maps, buffers, !scalars.is_empty(), flags)?;
329 shared::compile_scalars_dynamic::<Self>(f, scalars)?;
330 f.write_str("\n)")?;
331
332 Ok(())
333 }
334
335 fn compile_bindings_body(
336 f: &mut std::fmt::Formatter<'_>,
337 body: &shared::Body<Self>,
338 ) -> std::fmt::Result {
339 if !body.shared_memories.is_empty() {
340 let max_align = body
341 .shared_memories
342 .iter()
343 .map(|smem| smem.align)
344 .max()
345 .unwrap();
346 writeln!(
349 f,
350 "extern __shared__ __align__({max_align}) uchar dynamic_shared_mem[];"
351 )?;
352 }
353 Ok(())
354 }
355}
356
357impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for HipDialect<M> {}
360
361impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for HipDialect<M> {
364 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 writeln!(f, "__syncthreads();\n")
366 }
367
368 fn compile_instruction_sync_warp(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369 panic!("Sync warp is unimplemented on hip")
370 }
371
372 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 writeln!(f, "__threadfence();")
374 }
375
376 fn compile_instruction_find_first_set<T: Component<Self>>(
378 f: &mut std::fmt::Formatter<'_>,
379 input: T,
380 out_elem: Elem<Self>,
381 ) -> std::fmt::Result {
382 write!(f, "{out_elem}(")?;
383 match input.elem() {
384 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
385 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
386 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::U32),
387 }?;
388 write!(f, ")")
389 }
390
391 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
392 f: &mut std::fmt::Formatter<'_>,
393 input: T,
394 out_elem: Elem<Self>,
395 ) -> std::fmt::Result {
396 write!(f, "{out_elem}(")?;
397 match input.elem() {
398 Elem::I32 | Elem::U32 => write!(f, "__clz({input})"),
399 Elem::I64 | Elem::U64 => write!(f, "__clzll({input})"),
400 in_elem => write!(
401 f,
402 "__clz({}) - {}",
403 unary::zero_extend(input),
404 (size_of::<u32>() - in_elem.size()) * 8
405 ),
406 }?;
407 write!(f, ")")
408 }
409
410 fn compile_saturating_add(
411 _f: &mut std::fmt::Formatter<'_>,
412 _lhs: impl Display,
413 _rhs: impl Display,
414 _item: Item<Self>,
415 ) -> std::fmt::Result {
416 unimplemented!("No native instruction exists, Should be replaced in a preprocessor");
417 }
418
419 fn compile_saturating_sub(
420 _f: &mut std::fmt::Formatter<'_>,
421 _lhs: impl Display,
422 _rhs: impl Display,
423 _item: Item<Self>,
424 ) -> std::fmt::Result {
425 unimplemented!("No native instruction exists, Should be replaced in a preprocessor");
426 }
427
428 fn compile_instruction_max_function_name(
430 f: &mut std::fmt::Formatter<'_>,
431 item: Item<Self>,
432 ) -> std::fmt::Result {
433 let max = match item.elem() {
434 Elem::F16 => "__hmax",
435 Elem::BF16 => "max_bfloat16",
436 _ => "max",
437 };
438 write!(f, "{max}")
439 }
440
441 fn compile_instruction_min_function_name(
442 f: &mut std::fmt::Formatter<'_>,
443 item: Item<Self>,
444 ) -> std::fmt::Result {
445 let min = match item.elem() {
446 Elem::F16 => "__hmin",
447 Elem::BF16 => "min_bfloat16",
448 _ => "min",
449 };
450 write!(f, "{min}")
451 }
452
453 fn compile_warp_shuffle(
455 f: &mut std::fmt::Formatter<'_>,
456 var: &str,
457 source: &str,
458 ) -> std::fmt::Result {
459 write!(f, "__shfl({var}, {source})")
460 }
461 fn compile_warp_shuffle_xor(
462 f: &mut std::fmt::Formatter<'_>,
463 var: &str,
464 elem: &Elem<Self>,
465 offset: &str,
466 ) -> std::fmt::Result {
467 match elem {
468 Elem::BF16 => write!(
469 f,
470 "half_to_bfloat16(__shfl_xor(reinterpret_cast<__half&>({var}), {offset}))"
471 ),
472 _ => write!(f, "__shfl_xor({var}, {offset})"),
473 }
474 }
475 fn compile_warp_shuffle_up(
476 f: &mut std::fmt::Formatter<'_>,
477 var: &str,
478 offset: &str,
479 ) -> std::fmt::Result {
480 write!(f, "__shfl_up({var}, {offset})")
481 }
482 fn compile_warp_shuffle_down(
483 f: &mut std::fmt::Formatter<'_>,
484 var: &str,
485 offset: &str,
486 ) -> std::fmt::Result {
487 write!(f, "__shfl_down({var}, {offset})")
488 }
489 fn compile_warp_all<T: Component<Self>>(
490 f: &mut std::fmt::Formatter<'_>,
491 input: &T,
492 ) -> std::fmt::Result {
493 let item = input.item();
494 let elem = item.elem;
495 write!(f, "static_cast<{elem}>(__all({input}))")
496 }
497 fn compile_warp_any<T: Component<Self>>(
498 f: &mut std::fmt::Formatter<'_>,
499 input: &T,
500 ) -> std::fmt::Result {
501 let item = input.item();
502 let elem = item.elem;
503 write!(f, "static_cast<{elem}>(__any({input}))")
504 }
505 fn compile_warp_ballot(
506 f: &mut std::fmt::Formatter<'_>,
507 input: &Variable<Self>,
508 out_elem: &Elem<Self>,
509 ) -> std::fmt::Result {
510 write!(f, "{out_elem}(__ballot({input}))")
511 }
512}
513
514impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for HipDialect<M> {
517 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
518 M::compile_wmma_includes(f, flags)
519 }
520
521 fn compile_wmma_type_definitions(
522 f: &mut std::fmt::Formatter<'_>,
523 flags: &Flags,
524 ) -> std::fmt::Result {
525 M::compile_wmma_type_definitions(f, flags)
526 }
527
528 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529 M::compile_wmma_local_variables(f)
530 }
531
532 fn compile_wmma_fragment_declaration(
533 f: &mut std::fmt::Formatter<'_>,
534 var: &Variable<Self>,
535 ) -> std::fmt::Result {
536 M::compile_wmma_fragment_declaration(f, var)
537 }
538
539 fn compile_wwma_fragment_ident(
540 f: &mut std::fmt::Formatter<'_>,
541 ident: &crate::shared::FragmentIdent<Self>,
542 ) -> std::fmt::Result {
543 M::compile_wwma_fragment_ident(f, ident)
544 }
545
546 fn compile_wmma_fragment_layout(
547 f: &mut std::fmt::Formatter<'_>,
548 layout: &crate::shared::FragmentLayout<Self>,
549 ) -> std::fmt::Result {
550 M::compile_wmma_fragment_layout(f, layout)
551 }
552
553 fn compile_wmma_fragment(
554 f: &mut std::fmt::Formatter<'_>,
555 fragment: &crate::shared::Fragment<Self>,
556 ) -> std::fmt::Result {
557 M::compile_wmma_fragment(f, fragment)
558 }
559
560 fn compile_wmma_instruction(
561 f: &mut std::fmt::Formatter<'_>,
562 instruction: &crate::shared::WmmaInstruction<Self>,
563 ) -> std::fmt::Result {
564 M::compile_wmma_instruction(f, instruction)
565 }
566
567 fn compile_manual_mma(
568 f: &mut std::fmt::Formatter<'_>,
569 mma: ManualMma<Self>,
570 ) -> std::fmt::Result {
571 M::compile_manual_mma(f, mma)
572 }
573
574 fn supported_wmma_combinations(
575 arch: &AMDArchitecture,
576 ) -> crate::shared::SupportedMmaCombinations {
577 M::supported_wmma_combinations(arch)
578 }
579
580 fn supported_mma_combinations(arch: &AMDArchitecture) -> shared::SupportedMmaCombinations {
581 M::supported_mma_combinations(arch)
582 }
583
584 fn compile_scaled_mma(
585 _f: &mut std::fmt::Formatter<'_>,
586 _mma: ManualMma<Self>,
587 _scales_a: Variable<Self>,
588 _scales_b: Variable<Self>,
589 _scales_factor: u32,
590 ) -> std::fmt::Result {
591 panic!("Scaled MMA not supporter in HIP")
592 }
593}
594
595impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for HipDialect<M> {
596 fn processors() -> Vec<Box<dyn Processor>> {
597 vec![
598 Box::new(HipMmaProcessor),
599 Box::new(SaturatingArithmeticProcessor::new(true)),
600 ]
601 }
602}