1use std::fmt::Formatter;
2
3use crate::{
4 Dialect,
5 hip::{HipDialect, arch::AMDArchitecture},
6 shared::{
7 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8 FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
9 Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
10 wmma_api_base,
11 },
12};
13use cubecl_core::ir::{self as gpu, Matrix, MatrixIdent};
14use cubecl_runtime::MmaConfig;
15
16#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
17pub struct WmmaIntrinsicCompiler {}
18
19#[derive(new, Debug, Clone, PartialEq)]
20pub struct WmmaFill<D: Dialect> {
21 frag: Fragment<D>,
22}
23
24#[derive(new, Debug, Clone, PartialEq)]
25pub struct WmmaLoad<D: Dialect> {
26 frag: Fragment<D>,
27 layout: Option<FragmentLayout<D>>,
28}
29
30#[derive(new, Debug, Clone, PartialEq)]
31pub struct WmmaStore<D: Dialect> {
32 frag: Fragment<D>,
33 layout: FragmentLayout<D>,
34}
35
36#[derive(new, Debug, Clone, PartialEq)]
37pub struct WmmaExecute<D: Dialect> {
38 frag_a: Fragment<D>,
39 frag_b: Fragment<D>,
40 frag_c: Fragment<D>,
41 frag_d: Fragment<D>,
42}
43
44#[derive(new, Debug, Clone, PartialEq)]
45pub struct WmmaCast<D: Dialect> {
46 frag_input: Fragment<D>,
47 frag_output: Fragment<D>,
48}
49
50impl<D: Dialect> WmmaFill<D> {
51 pub fn fn_name(&self) -> String {
52 let layout = frag_layout_str(&self.frag.layout);
53 let ident = frag_ident_str(&self.frag.ident);
54 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
55 let elem = self.frag.elem;
56
57 format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
58 }
59
60 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61 let elem = self.frag.elem;
62 let frag = self.frag;
63 let name = self.fn_name();
64
65 write!(
66 f,
67 "
68// Fill the fragment.
69__device__ void {name}({frag}& frag, {elem} value) {{
70 #pragma unroll
71 for (uint i = 0; i < 8; ++i) {{
72 frag[i] = value;
73 }}
74}}
75 "
76 )
77 }
78}
79
80impl<D: Dialect> WmmaLoad<D> {
81 pub fn fn_name(&self) -> String {
82 let layout_frag = frag_layout_str(&self.frag.layout);
83 let layout = frag_layout_str(&self.layout);
84 let ident = frag_ident_str(&self.frag.ident);
85 let elem = self.frag.elem;
86 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
87
88 format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
89 }
90
91 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116 let elem = self.frag.elem;
117 let frag = self.frag;
118 let name = self.fn_name();
119
120 let (index_body, length, step) = match frag.ident {
121 FragmentIdent::A | FragmentIdent::B => {
122 let length = 16;
123 let step = 1;
124 let index = if (frag.ident == FragmentIdent::A
127 && frag.layout.unwrap() == FragmentLayout::ColMajor)
128 || (frag.ident == FragmentIdent::B
129 && frag.layout.unwrap() == FragmentLayout::RowMajor)
130 {
131 "i * stride + wmmaLane".to_string()
132 } else {
133 "i + wmmaLane * stride".to_string()
134 };
135 (index, length, step)
136 }
137 FragmentIdent::Accumulator => {
138 let length = 8;
139 let step = get_output_accumulator_index_step(&elem, &frag);
140 let index = match self.layout {
141 Some(FragmentLayout::ColMajor) => {
142 "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
143 }
144 Some(FragmentLayout::RowMajor) => {
145 "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
146 }
147 _ => panic!(
148 "cannot load data to an accumulator without knowing the layout of the data"
149 ),
150 };
151 (index, length, step)
152 }
153 other => panic!("unknown matrix identifier {other}"),
154 };
155
156 write!(
157 f,
158 "
159// Load the fragment.
160__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
161 {WMMA_LANE_DEF}
162
163 #pragma unroll
164 for (uint i = 0; i < {length}; ++i) {{
165 const uint index = {index_body};
166 frag[i * {step}] = value_ptr[index];
167 }}
168}}
169 "
170 )
171 }
172}
173
174impl<D: Dialect> WmmaStore<D> {
175 pub fn fn_name(&self) -> String {
176 let layout_frag = frag_layout_str(&self.frag.layout);
177 let layout_option = Some(self.layout);
178 let layout = frag_layout_str(&layout_option);
179 let ident = frag_ident_str(&self.frag.ident);
180 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
181 let elem = self.frag.elem;
182
183 format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
184 }
185
186 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187 let elem = self.frag.elem;
188 let frag = self.frag;
189 let name = self.fn_name();
190 let frag_idx = match elem {
194 Elem::F16 | Elem::BF16 => "elemIdx * 2",
195 Elem::F32 => "elemIdx",
196 other => {
197 panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
198 }
199 };
200 let output_idx = match self.layout {
202 FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
203 FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
204 FragmentLayout::_Dialect(_) => String::new(),
205 };
206
207 write!(
208 f,
209 "
210// Store the fragment.
211__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
212 {WMMA_LANE_DEF}
213
214 #pragma unroll
215 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
216 const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
217 output_ptr[{output_idx}] = frag[{frag_idx}];
218 }}
219}}
220 "
221 )
222 }
223}
224
225impl<D: Dialect> WmmaExecute<D> {
226 pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
227 let frag_a = Fragment {
228 ident: FragmentIdent::A,
229 m: shape.m,
230 n: shape.n,
231 k: shape.k,
232 elem: ab_elem,
233 layout: Some(FragmentLayout::ColMajor),
234 };
235 let frag_b = Fragment {
236 ident: FragmentIdent::B,
237 layout: Some(FragmentLayout::RowMajor),
238 ..frag_a
239 };
240 let frag_cd = Fragment {
241 ident: FragmentIdent::Accumulator,
242 elem: cd_elem,
243 ..frag_b
244 };
245 WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
246 }
247
248 pub fn fn_name(&self) -> String {
249 format!(
250 "wmma_execute_16x16x16_{}_{}",
251 self.frag_a.elem, self.frag_c.elem
252 )
253 }
254
255 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 let name = self.fn_name();
257 let ab_format = match self.frag_a.elem {
258 Elem::F32 => "f32",
259 Elem::BF16 => "bf16",
260 Elem::F16 => "f16",
261 _ => panic!(),
262 };
263 let (cd_format, opsel) = match self.frag_c.elem {
264 Elem::F32 => ("f32", ""),
265 Elem::BF16 => ("bf16", ", false"),
266 Elem::F16 => ("f16", ", false"),
267 _ => panic!(),
268 };
269 let warp_size = 32;
270 write!(
271 f,
272 "
273// Execute wmma.
274__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
275 frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
276}}
277 ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
278 )
279 }
280}
281
282impl<D: Dialect> WmmaCast<D> {
283 pub fn fn_name(&self) -> String {
284 let layout = frag_layout_str(&self.frag_input.layout);
285 let ident = frag_ident_str(&self.frag_input.ident);
286 let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
287 let elem = self.frag_input.elem;
288 let elem_out = self.frag_output.elem;
289
290 format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
291 }
292
293 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294 let input = self.frag_input;
295 let output = self.frag_output;
296 let name = self.fn_name();
297 let step = match output.ident {
298 FragmentIdent::Accumulator => {
299 get_output_accumulator_index_step(&self.frag_input.elem, &output)
300 }
301 _ => 1,
302 };
303
304 write!(
305 f,
306 "
307// Cast the fragment.
308__device__ void {name}({input}& input, {output}& output) {{
309 #pragma unroll
310 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
311 output[elemIdx * {step}] = input[elemIdx];
312 }}
313}}
314 "
315 )
316 }
317}
318
319impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
320 fn compile_wmma_type_definitions(
321 f: &mut std::fmt::Formatter<'_>,
322 flags: &Flags,
323 ) -> std::fmt::Result {
324 if flags.elem_bf16 {
325 f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
326 f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
327 }
328 if flags.elem_f16 {
329 f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
330 f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
331 }
332 f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
333 }
334
335 fn compile_wmma_fragment_declaration(
336 f: &mut std::fmt::Formatter<'_>,
337 var: &crate::shared::Variable<HipDialect<Self>>,
338 ) -> std::fmt::Result {
339 wmma_api_base::compile_fragment_declaration(f, var)
340 }
341
342 fn compile_wmma_fragment(
343 f: &mut std::fmt::Formatter<'_>,
344 fragment: &Fragment<HipDialect<Self>>,
345 ) -> std::fmt::Result {
346 match fragment.ident {
347 FragmentIdent::A | FragmentIdent::B => match fragment.elem {
348 Elem::F16 => write!(f, "half16_t"),
349 Elem::BF16 => write!(f, "bhalf16_t"),
350 other => panic!("unsupported type {other} for {fragment}"),
351 },
352 FragmentIdent::Accumulator => match fragment.elem {
353 Elem::F16 => write!(f, "half16_t"),
354 Elem::BF16 => write!(f, "bhalf16_t"),
355 Elem::F32 => write!(f, "float8_t"),
356 other => panic!("unsupported type {other} for {fragment}"),
357 },
358 FragmentIdent::_Dialect(_) => Ok(()),
359 }
360 }
361
362 fn compile_wmma_instruction(
363 f: &mut std::fmt::Formatter<'_>,
364 instruction: &WmmaInstruction<HipDialect<Self>>,
365 ) -> std::fmt::Result {
366 match instruction {
367 WmmaInstruction::Fill { frag, value } => {
368 let extension = WmmaFill::new(match frag {
369 Variable::WmmaFragment { frag, .. } => *frag,
370 _ => panic!(),
371 });
372 let name = extension.fn_name();
373 writeln!(f, "{name}({frag}, {value});")
374 }
375 WmmaInstruction::Load {
376 frag,
377 value,
378 layout,
379 offset,
380 stride,
381 } => {
382 let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
383 let name = extension.fn_name();
384 let value_ptr = frag_as_ptr(f, value, offset);
385 writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
386 }
387 WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
388 f.write_str("#error LdMatrix & StMatrix are not supported on HIP\n")
389 }
390 WmmaInstruction::Execute {
391 frag_a,
392 frag_b,
393 frag_c,
394 frag_d,
395 warp_size,
396 } => {
397 if *warp_size != 32 {
398 f.write_str(
399 "#error Only warp size of 32 supported for Wmma::Execute on HIP\n",
400 )?;
401 }
402
403 let extension = WmmaExecute::new(
404 variable_to_frag(frag_a),
405 variable_to_frag(frag_b),
406 variable_to_frag(frag_c),
407 variable_to_frag(frag_d),
408 );
409 let name = extension.fn_name();
410 writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
411 }
412 WmmaInstruction::ExecuteManual {
413 shape,
414 frag_a,
415 frag_b,
416 frag_c,
417 frag_d,
418 } => {
419 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
420 }
421 WmmaInstruction::ExecuteScaled {
422 shape,
423 frag_a,
424 frag_b,
425 frag_c,
426 frag_d,
427 scales_a,
428 scales_b,
429 scales_factor,
430 } => Self::compile_scaled_mma(
431 f,
432 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
433 *scales_a,
434 *scales_b,
435 *scales_factor,
436 ),
437 WmmaInstruction::Store {
438 output,
439 frag,
440 layout,
441 offset,
442 stride,
443 } => {
444 let extension = WmmaStore::new(variable_to_frag(frag), *layout);
445 let name = extension.fn_name();
446 let output_ptr = frag_as_ptr(f, output, offset);
447 writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
448 }
449 WmmaInstruction::Cast { input, output } => {
450 let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
451 let name = extension.fn_name();
452 writeln!(f, "{name}({input}, {output});")
453 }
454 }
455 }
456
457 fn compile_manual_mma(
458 f: &mut std::fmt::Formatter<'_>,
459 mma: ManualMma<HipDialect<Self>>,
460 ) -> std::fmt::Result {
461 compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
462 }
463
464 fn compile_scaled_mma(
465 f: &mut std::fmt::Formatter<'_>,
466 _mma: ManualMma<HipDialect<Self>>,
467 _scales_a: Variable<HipDialect<Self>>,
468 _scales_b: Variable<HipDialect<Self>>,
469 _scales_factor: u32,
470 ) -> std::fmt::Result {
471 f.write_str("#error scaled mma not supported in HIP\n")
472 }
473
474 fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
475 let mut result: SupportedMmaCombinations = vec![];
477 if arch.is_wmma_capable() {
478 let types = vec![
480 (
481 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
485 (
486 gpu::ElemType::Float(gpu::FloatKind::F16),
487 gpu::ElemType::Float(gpu::FloatKind::F16),
488 gpu::ElemType::Float(gpu::FloatKind::F32),
489 ),
490 (
491 gpu::ElemType::Float(gpu::FloatKind::BF16),
492 gpu::ElemType::Float(gpu::FloatKind::BF16),
493 gpu::ElemType::Float(gpu::FloatKind::F32),
494 ),
495 ];
496 let combinations: SupportedMmaCombinations = types
497 .into_iter()
498 .map(|(a, b, c)| MmaConfig {
499 a_type: a.into(),
500 b_type: b.into(),
501 cd_type: c.into(),
502 m: 16,
503 n: 16,
504 k: 16,
505 })
506 .collect();
507 result.extend(combinations);
508 }
509 result
510 }
511
512 fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
513 supported_mma_combinations(arch)
514 }
515}
516
517fn get_output_accumulator_index_step<D: Dialect>(
518 input_elem: &Elem<D>,
519 output: &Fragment<D>,
520) -> u32 {
521 assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
529
530 match input_elem {
531 Elem::F16 | Elem::BF16 | Elem::F32 => {
532 match output.elem {
533 Elem::F16 | Elem::BF16 => 2,
535 Elem::F32 => 1,
537 other => panic!("unsupported format {other} for {output}"),
538 }
539 }
540 other => panic!("unsupported format {other} for {input_elem}"),
541 }
542}
543
544pub(super) fn compile_manual_mma<D: Dialect>(
545 f: &mut std::fmt::Formatter<'_>,
546 shape: MmaShape<D>,
547 frag_a: &Variable<D>,
548 frag_b: &Variable<D>,
549 frag_c: &Variable<D>,
550 frag_d: &Variable<D>,
551) -> std::fmt::Result {
552 let extension = WmmaExecute::from_manual(shape, frag_a.elem(), frag_c.elem());
553
554 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
555
556 let frag_cd_step = 4usize.div_ceil(frag_c.elem().size());
557 let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
558
559 let frag = |var: &Variable<D>, len: usize| {
564 let vec = var.item().vectorization;
565 let frag: Vec<_> = if vec > 1 {
566 (0..len)
567 .map(|i| format!("{var}[{}].i_{}", i / vec, i % vec))
568 .collect()
569 } else {
570 (0..len).map(|i| format!("{var}[{}]", i)).collect()
571 };
572 frag.join(", ")
573 };
574
575 let frag_a = frag(frag_a, 16);
576 let frag_b = frag(frag_b, 16);
577 let frag_c = {
580 let vec = frag_c.item().vectorization;
581 let frag: Vec<_> = if vec > 1 {
582 (0..cd_elems as usize)
583 .flat_map(|i| {
584 (0..frag_cd_step).map(move |_| format!("{frag_c}[{}].i_{}", i / vec, i % vec))
585 })
586 .collect()
587 } else {
588 (0..cd_elems as usize)
589 .flat_map(|i| (0..frag_cd_step).map(move |_| format!("{frag_c}[{}]", i)))
590 .collect()
591 };
592 frag.join(", ")
593 };
594
595 let name = extension.fn_name();
597
598 writeln!(f, "{} {frag_d_tmp} = {{}};", extension.frag_d)?;
600
601 writeln!(
602 f,
603 "{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
604 extension.frag_a, extension.frag_b, extension.frag_c
605 )?;
606
607 for i in 0..cd_elems as usize {
608 let vec = frag_d.item().vectorization;
609 if vec > 1 {
610 writeln!(
611 f,
612 "{frag_d}[{}].i_{} = {frag_d_tmp}[{i} * {frag_cd_step}];",
613 i / vec,
614 i % vec
615 )?;
616 } else {
617 writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i} * {frag_cd_step}];")?;
618 }
619 }
620
621 Ok(())
622}
623
624pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
625 const ENABLED: bool = true;
627
628 if !ENABLED {
629 return Vec::new();
630 }
631
632 let mut result: SupportedMmaCombinations = vec![];
635 if arch.is_wmma_capable() {
636 let types = vec![
638 (
639 gpu::ElemType::Float(gpu::FloatKind::F16),
640 gpu::ElemType::Float(gpu::FloatKind::F32),
641 ),
642 (
643 gpu::ElemType::Float(gpu::FloatKind::BF16),
644 gpu::ElemType::Float(gpu::FloatKind::F32),
645 ),
646 ];
647 let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
648 a_type: ab_elem.into(),
649 b_type: ab_elem.into(),
650 cd_type: cd_elem.into(),
651 m: 16,
652 n: 16,
653 k: 16,
654 });
655 result.extend(combinations);
656 }
657 result
658}
659
660pub fn contiguous_elements_rdna3(ident: MatrixIdent, matrix: Matrix) -> u32 {
661 let max_line_size = 16 / matrix.storage.size();
663 match ident {
664 MatrixIdent::A | MatrixIdent::B => 16.min(max_line_size) as u32,
665 MatrixIdent::Accumulator => 1,
666 }
667}
668
669static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";