1use core::any::TypeId;
2use std::fmt::Display;
3use std::{collections::HashSet, marker::PhantomData};
4
5use cubecl_core::{ir::Processor, post_processing::saturating::SaturatingArithmeticProcessor};
6
7use crate::shared::DialectWarpReduceCompiler;
8use crate::{
9 Dialect,
10 shared::{
11 self, Binding, DialectBindings, DialectCubeBuiltins, DialectIncludes, DialectTypes,
12 DialectWmmaCompiler, Flags, Item, ManualMma,
13 },
14};
15use crate::{
16 hip::processors::HipMmaProcessor,
17 shared::{
18 Component, DialectInstructions, DialectProcessors, Elem, Instruction, Variable, unary,
19 variable_to_frag,
20 },
21};
22
23use super::Extension;
24use super::arch::AMDArchitecture;
25use super::extension::{WmmaExtension, format_f162bf16, format_max, format_min};
26use super::mma::{WmmaCast, WmmaExecute, WmmaFill, WmmaIntrinsicCompiler, WmmaLoad, WmmaStore};
27
28#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
29pub struct HipDialect<M> {
30 _wmma_compiler: PhantomData<M>,
31}
32
33impl<M: DialectWmmaCompiler<Self>> Dialect for HipDialect<M> {
36 type Architecture = AMDArchitecture;
37}
38
39impl<M: DialectWmmaCompiler<Self>> DialectWarpReduceCompiler<Self> for HipDialect<M> {}
40
41impl<M: DialectWmmaCompiler<Self>> DialectIncludes<Self> for HipDialect<M> {
44 type Extension = Extension<Self>;
45
46 fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
47 f.write_str("#include <hip/hip_runtime.h>\n")?;
48 if flags.elem_bf16 {
49 f.write_str("#include <hip/hip_bf16.h>\n")?;
50 }
51 if flags.elem_f16 {
52 f.write_str("#include <hip/hip_fp16.h>\n")?;
53 }
54 if flags.inst_wmma {
55 Self::compile_wmma_includes(f, flags)?;
56 }
57 Ok(())
58 }
59
60 fn compile_extensions(
61 f: &mut std::fmt::Formatter<'_>,
62 extensions: &[Self::Extension],
63 ) -> std::fmt::Result {
64 for extension in extensions {
65 match extension {
66 Extension::F162BF16 => format_f162bf16(f)?,
67 Extension::Max(var) => format_max::<Self>(f, var)?,
68 Extension::Min(var) => format_min::<Self>(f, var)?,
69 Extension::NoExtension => {}
70 Extension::Wmma(inst) => inst.format_wmma(f)?,
71 }
72 }
73 Ok(())
74 }
75
76 fn register_instruction_extension(
77 extensions: &mut Vec<Self::Extension>,
78 instruction: &Instruction<Self>,
79 ) {
80 let mut register_extension = |extension: Self::Extension| {
81 if !extensions.contains(&extension) {
82 extensions.push(extension);
83 }
84 };
85 #[allow(clippy::single_match)]
86 match instruction {
87 shared::Instruction::<Self>::Max(op) => {
88 register_extension(Extension::Max(*op.lhs.item().elem()));
89 }
90 shared::Instruction::<Self>::Min(op) => {
91 register_extension(Extension::Min(*op.lhs.item().elem()));
92 }
93 _ => {}
94 }
95 }
96
97 fn register_warp_instruction_extension(
98 extensions: &mut Vec<Self::Extension>,
99 instruction: &shared::WarpInstruction<Self>,
100 ) {
101 let mut register_extension = |extension: Self::Extension| {
102 if !extensions.contains(&extension) {
103 extensions.push(extension);
104 }
105 };
106
107 #[allow(clippy::single_match)]
108 match instruction {
109 shared::WarpInstruction::<Self>::ReduceMax { input, .. } => {
110 let input_item = input.item();
111 let input_elem = input_item.elem();
112 if *input_elem == Elem::<Self>::BF16 {
113 register_extension(Extension::F162BF16);
114 }
115 register_extension(Extension::Max(*input_elem));
116 }
117 shared::WarpInstruction::<Self>::ReduceMin { input, .. } => {
118 let input_item = input.item();
119 let input_elem = input_item.elem();
120 if *input_elem == Elem::<Self>::BF16 {
121 register_extension(Extension::F162BF16);
122 }
123 register_extension(Extension::Min(*input_elem));
124 }
125 shared::WarpInstruction::<Self>::ReduceProd { input, .. } => {
126 let input_item = input.item();
127 let input_elem = input_item.elem();
128 if *input_elem == Elem::<Self>::BF16 {
129 register_extension(Extension::F162BF16);
130 }
131 }
132 shared::WarpInstruction::<Self>::ReduceSum { input, .. } => {
133 let input_item = input.item();
134 let input_elem = input_item.elem();
135 if *input_elem == Elem::<Self>::BF16 {
136 register_extension(Extension::F162BF16);
137 }
138 }
139 _ => {}
140 }
141 }
142
143 fn register_wmma_instruction_extension(
144 extensions: &mut Vec<Self::Extension>,
145 instruction: &shared::WmmaInstruction<Self>,
146 ) {
147 if TypeId::of::<M>() == TypeId::of::<WmmaIntrinsicCompiler>() {
148 let extension = match instruction {
149 shared::WmmaInstruction::Fill { frag, .. } => {
150 Extension::Wmma(WmmaExtension::Fill(WmmaFill::new(variable_to_frag(frag))))
151 }
152 shared::WmmaInstruction::Load { frag, layout, .. } => Extension::Wmma(
153 WmmaExtension::Load(WmmaLoad::new(variable_to_frag(frag), *layout)),
154 ),
155 shared::WmmaInstruction::Execute {
156 frag_a,
157 frag_b,
158 frag_c,
159 frag_d,
160 warp_size: _,
161 } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::new(
162 variable_to_frag(frag_a),
163 variable_to_frag(frag_b),
164 variable_to_frag(frag_c),
165 variable_to_frag(frag_d),
166 ))),
167 shared::WmmaInstruction::ExecuteManual {
168 shape,
169 frag_a,
170 frag_c,
171 ..
172 } => Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
173 *shape,
174 frag_a[0].elem(),
175 frag_c[0].elem(),
176 ))),
177 shared::WmmaInstruction::ExecuteScaled { .. } => {
178 unimplemented!("Not supported in HIP")
179 }
180 shared::WmmaInstruction::Store { frag, layout, .. } => Extension::Wmma(
181 WmmaExtension::Store(WmmaStore::new(variable_to_frag(frag), *layout)),
182 ),
183 shared::WmmaInstruction::Cast { input, output } => {
184 Extension::Wmma(WmmaExtension::Cast(WmmaCast::new(
185 variable_to_frag(input),
186 variable_to_frag(output),
187 )))
188 }
189 };
190
191 if !extensions.contains(&extension) {
192 extensions.push(extension);
193 }
194 } else if let shared::WmmaInstruction::ExecuteManual {
195 shape,
196 frag_a,
197 frag_c,
198 ..
199 } = instruction
200 {
201 let extension = Extension::Wmma(WmmaExtension::Execute(WmmaExecute::from_manual(
202 *shape,
203 frag_a[0].elem(),
204 frag_c[0].elem(),
205 )));
206
207 if !extensions.contains(&extension) {
208 extensions.push(extension);
209 }
210 }
211 }
212}
213
214impl<M: DialectWmmaCompiler<Self>> DialectTypes<Self> for HipDialect<M> {
217 fn item_can_be_optimized() -> bool {
218 false
220 }
221
222 fn compile_type_definitions(
223 f: &mut std::fmt::Formatter<'_>,
224 items: &HashSet<Item<Self>>,
225 _scalars: &[(Elem<Self>, usize)],
226 flags: &Flags,
227 ) -> std::fmt::Result {
228 shared::type_definitions::<Self>(f)?;
229 shared::type_vectorized_definitions::<Self>(f, items)?;
230
231 if flags.inst_wmma {
232 Self::compile_wmma_type_definitions(f, flags)?;
233 }
234
235 Ok(())
236 }
237
238 fn compile_elem(
239 f: &mut std::fmt::Formatter<'_>,
240 elem: &shared::Elem<Self>,
241 words: bool,
242 ) -> std::fmt::Result {
243 if words {
244 match elem {
245 shared::Elem::F32 => f.write_str("float"),
246 shared::Elem::F64 => f.write_str("double"),
247 shared::Elem::TF32 => f.write_str("float"),
248 shared::Elem::I8 => f.write_str("char"),
249 shared::Elem::I16 => f.write_str("short"),
250 shared::Elem::I32 => f.write_str("int"),
251 shared::Elem::I64 => f.write_str("long"),
252 shared::Elem::U8 => f.write_str("uchar"),
253 shared::Elem::U16 => f.write_str("ushort"),
254 shared::Elem::U32 => f.write_str("uint"),
255 shared::Elem::U64 => f.write_str("ulong"),
256 _ => Self::compile_elem(f, elem, false),
257 }
258 } else {
259 match elem {
260 shared::Elem::FP4(_)
261 | shared::Elem::FP4x2(_)
262 | shared::Elem::FP6(_)
263 | shared::Elem::FP6x2(_)
264 | shared::Elem::FP8(_)
265 | shared::Elem::FP8x2(_) => unimplemented!("FP4/FP6/FP8 not supported in HIP"),
266 shared::Elem::F16 => f.write_str("__half"),
267 shared::Elem::F16x2 => f.write_str("__half2"),
268 shared::Elem::F32 => f.write_str("float"),
269 shared::Elem::F64 => f.write_str("double"),
270 shared::Elem::BF16 => f.write_str("__bf16"),
271 shared::Elem::BF16x2 => f.write_str("__bf162"),
272 shared::Elem::TF32 => f.write_str("float"),
273 shared::Elem::I8 => f.write_str("int8"),
274 shared::Elem::I16 => f.write_str("int16"),
275 shared::Elem::I32 => f.write_str("int32"),
276 shared::Elem::I64 => f.write_str("int64"),
277 shared::Elem::U8 => f.write_str("uint8"),
278 shared::Elem::U16 => f.write_str("uint16"),
279 shared::Elem::U32 => f.write_str("uint32"),
280 shared::Elem::U64 => f.write_str("uint64"),
281 shared::Elem::Bool => f.write_str("bool"),
282 shared::Elem::Atomic(inner) => inner.fmt(f),
283 shared::Elem::_Dialect(_) => Ok(()),
284 }
285 }
286 }
287
288 fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<Self>) -> std::fmt::Result {
289 if 1 == item.vectorization {
290 return write!(f, "{}", item.elem);
291 }
292 if item.native {
293 Self::compile_elem(f, &item.elem, true)?;
295 write!(f, "{}", item.vectorization)
296 } else {
297 write!(f, "{}_{}", item.elem, item.vectorization)
298 }
299 }
300
301 fn compile_local_memory_qualifier(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 Ok(())
303 }
304}
305
306impl<M: DialectWmmaCompiler<Self>> DialectBindings<Self> for HipDialect<M> {
309 fn compile_kernel_signature(
310 f: &mut std::fmt::Formatter<'_>,
311 kernel_name: &str,
312 tensor_maps: &[Binding<Self>],
313 buffers: &[Binding<Self>],
314 scalars: &[(Elem<Self>, usize)],
315 flags: &Flags,
316 ) -> std::fmt::Result {
317 write!(
318 f,
319 "
320
321extern \"C\" __global__ void __launch_bounds__({}) {kernel_name}(
322",
323 flags.cube_dim.num_elems()
324 )?;
325 shared::compile_bindings::<Self>(f, tensor_maps, buffers, !scalars.is_empty(), flags)?;
326 shared::compile_scalars_dynamic::<Self>(f, scalars)?;
327 f.write_str("\n)")?;
328
329 Ok(())
330 }
331
332 fn compile_bindings_body(
333 f: &mut std::fmt::Formatter<'_>,
334 body: &shared::Body<Self>,
335 ) -> std::fmt::Result {
336 if !body.shared_memories.is_empty() {
337 let max_align = body
338 .shared_memories
339 .iter()
340 .map(|smem| smem.align)
341 .max()
342 .unwrap();
343 writeln!(
346 f,
347 "extern __shared__ __align__({max_align}) uchar dynamic_shared_mem[];"
348 )?;
349 }
350 Ok(())
351 }
352}
353
354impl<M: DialectWmmaCompiler<Self>> DialectCubeBuiltins<Self> for HipDialect<M> {}
357
358impl<M: DialectWmmaCompiler<Self>> DialectInstructions<Self> for HipDialect<M> {
361 fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 writeln!(f, "__syncthreads();\n")
363 }
364
365 fn compile_instruction_sync_warp(_f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 panic!("Sync warp is unimplemented on hip")
367 }
368
369 fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 writeln!(f, "__threadfence();")
371 }
372
373 fn compile_instruction_find_first_set<T: Component<Self>>(
375 f: &mut std::fmt::Formatter<'_>,
376 input: T,
377 out_elem: Elem<Self>,
378 ) -> std::fmt::Result {
379 write!(f, "{out_elem}(")?;
380 match input.elem() {
381 Elem::I32 | Elem::U32 => write!(f, "__ffs({input})"),
382 Elem::I64 | Elem::U64 => write!(f, "__ffsll({input})"),
383 _ => write!(f, "__ffs({}({input}))", Elem::<Self>::U32),
384 }?;
385 write!(f, ")")
386 }
387
388 fn compile_instruction_leading_zeros_scalar<T: Component<Self>>(
389 f: &mut std::fmt::Formatter<'_>,
390 input: T,
391 out_elem: Elem<Self>,
392 ) -> std::fmt::Result {
393 write!(f, "{out_elem}(")?;
394 match input.elem() {
395 Elem::I32 | Elem::U32 => write!(f, "__clz({input})"),
396 Elem::I64 | Elem::U64 => write!(f, "__clzll({input})"),
397 in_elem => write!(
398 f,
399 "__clz({}) - {}",
400 unary::zero_extend(input),
401 (size_of::<u32>() - in_elem.size()) * 8
402 ),
403 }?;
404 write!(f, ")")
405 }
406
407 fn compile_saturating_add(
408 _f: &mut std::fmt::Formatter<'_>,
409 _lhs: impl Display,
410 _rhs: impl Display,
411 _item: Item<Self>,
412 ) -> std::fmt::Result {
413 unimplemented!("No native instruction exists, Should be replaced in a preprocessor");
414 }
415
416 fn compile_saturating_sub(
417 _f: &mut std::fmt::Formatter<'_>,
418 _lhs: impl Display,
419 _rhs: impl Display,
420 _item: Item<Self>,
421 ) -> std::fmt::Result {
422 unimplemented!("No native instruction exists, Should be replaced in a preprocessor");
423 }
424
425 fn compile_instruction_max_function_name(
427 f: &mut std::fmt::Formatter<'_>,
428 item: Item<Self>,
429 ) -> std::fmt::Result {
430 let max = match item.elem() {
431 Elem::F16 => "__hmax",
432 Elem::BF16 => "max_bfloat16",
433 _ => "max",
434 };
435 write!(f, "{max}")
436 }
437
438 fn compile_instruction_min_function_name(
439 f: &mut std::fmt::Formatter<'_>,
440 item: Item<Self>,
441 ) -> std::fmt::Result {
442 let min = match item.elem() {
443 Elem::F16 => "__hmin",
444 Elem::BF16 => "min_bfloat16",
445 _ => "min",
446 };
447 write!(f, "{min}")
448 }
449
450 fn compile_warp_shuffle(
452 f: &mut std::fmt::Formatter<'_>,
453 var: &str,
454 source: &str,
455 ) -> std::fmt::Result {
456 write!(f, "__shfl({var}, {source})")
457 }
458 fn compile_warp_shuffle_xor(
459 f: &mut std::fmt::Formatter<'_>,
460 var: &str,
461 elem: &Elem<Self>,
462 offset: &str,
463 ) -> std::fmt::Result {
464 match elem {
465 Elem::BF16 => write!(
466 f,
467 "half_to_bfloat16(__shfl_xor(reinterpret_cast<__half&>({var}), {offset}))"
468 ),
469 _ => write!(f, "__shfl_xor({var}, {offset})"),
470 }
471 }
472 fn compile_warp_shuffle_up(
473 f: &mut std::fmt::Formatter<'_>,
474 var: &str,
475 offset: &str,
476 ) -> std::fmt::Result {
477 write!(f, "__shfl_up({var}, {offset})")
478 }
479 fn compile_warp_shuffle_down(
480 f: &mut std::fmt::Formatter<'_>,
481 var: &str,
482 offset: &str,
483 ) -> std::fmt::Result {
484 write!(f, "__shfl_down({var}, {offset})")
485 }
486 fn compile_warp_all<T: Component<Self>>(
487 f: &mut std::fmt::Formatter<'_>,
488 input: &T,
489 ) -> std::fmt::Result {
490 let item = input.item();
491 let elem = item.elem;
492 write!(f, "static_cast<{elem}>(__all({input}))")
493 }
494 fn compile_warp_any<T: Component<Self>>(
495 f: &mut std::fmt::Formatter<'_>,
496 input: &T,
497 ) -> std::fmt::Result {
498 let item = input.item();
499 let elem = item.elem;
500 write!(f, "static_cast<{elem}>(__any({input}))")
501 }
502 fn compile_warp_ballot(
503 f: &mut std::fmt::Formatter<'_>,
504 input: &Variable<Self>,
505 out_elem: &Elem<Self>,
506 ) -> std::fmt::Result {
507 write!(f, "{out_elem}(__ballot({input}))")
508 }
509}
510
511impl<M: DialectWmmaCompiler<Self>> DialectWmmaCompiler<Self> for HipDialect<M> {
514 fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags) -> std::fmt::Result {
515 M::compile_wmma_includes(f, flags)
516 }
517
518 fn compile_wmma_type_definitions(
519 f: &mut std::fmt::Formatter<'_>,
520 flags: &Flags,
521 ) -> std::fmt::Result {
522 M::compile_wmma_type_definitions(f, flags)
523 }
524
525 fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 M::compile_wmma_local_variables(f)
527 }
528
529 fn compile_wmma_fragment_declaration(
530 f: &mut std::fmt::Formatter<'_>,
531 var: &Variable<Self>,
532 ) -> std::fmt::Result {
533 M::compile_wmma_fragment_declaration(f, var)
534 }
535
536 fn compile_wwma_fragment_ident(
537 f: &mut std::fmt::Formatter<'_>,
538 ident: &crate::shared::FragmentIdent<Self>,
539 ) -> std::fmt::Result {
540 M::compile_wwma_fragment_ident(f, ident)
541 }
542
543 fn compile_wmma_fragment_layout(
544 f: &mut std::fmt::Formatter<'_>,
545 layout: &crate::shared::FragmentLayout<Self>,
546 ) -> std::fmt::Result {
547 M::compile_wmma_fragment_layout(f, layout)
548 }
549
550 fn compile_wmma_fragment(
551 f: &mut std::fmt::Formatter<'_>,
552 fragment: &crate::shared::Fragment<Self>,
553 ) -> std::fmt::Result {
554 M::compile_wmma_fragment(f, fragment)
555 }
556
557 fn compile_wmma_instruction(
558 f: &mut std::fmt::Formatter<'_>,
559 instruction: &crate::shared::WmmaInstruction<Self>,
560 ) -> std::fmt::Result {
561 M::compile_wmma_instruction(f, instruction)
562 }
563
564 fn compile_manual_mma(
565 f: &mut std::fmt::Formatter<'_>,
566 mma: ManualMma<Self>,
567 ) -> std::fmt::Result {
568 M::compile_manual_mma(f, mma)
569 }
570
571 fn supported_wmma_combinations(
572 arch: &AMDArchitecture,
573 ) -> crate::shared::SupportedMmaCombinations {
574 M::supported_wmma_combinations(arch)
575 }
576
577 fn supported_mma_combinations(arch: &AMDArchitecture) -> shared::SupportedMmaCombinations {
578 M::supported_mma_combinations(arch)
579 }
580
581 fn compile_scaled_mma(
582 _f: &mut std::fmt::Formatter<'_>,
583 _mma: ManualMma<Self>,
584 _scales_a: Variable<Self>,
585 _scales_b: Variable<Self>,
586 _scales_factor: u32,
587 ) -> std::fmt::Result {
588 panic!("Scaled MMA not supporter in HIP")
589 }
590}
591
592impl<M: DialectWmmaCompiler<Self>> DialectProcessors<Self> for HipDialect<M> {
593 fn processors() -> Vec<Box<dyn Processor>> {
594 vec![
595 Box::new(HipMmaProcessor),
596 Box::new(SaturatingArithmeticProcessor::new(true)),
597 ]
598 }
599}