1use std::fmt::Formatter;
2
3use crate::{
4 Dialect,
5 hip::{HipDialect, arch::AMDArchitecture},
6 shared::{
7 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8 FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
9 Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
10 wmma_api_base,
11 },
12};
13use cubecl_core::ir::{self as gpu, Matrix, MatrixIdent, features::MmaConfig};
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
16pub struct WmmaIntrinsicCompiler {}
17
18#[derive(new, Debug, Clone, PartialEq)]
19pub struct WmmaFill<D: Dialect> {
20 frag: Fragment<D>,
21}
22
23#[derive(new, Debug, Clone, PartialEq)]
24pub struct WmmaLoad<D: Dialect> {
25 frag: Fragment<D>,
26 layout: Option<FragmentLayout<D>>,
27}
28
29#[derive(new, Debug, Clone, PartialEq)]
30pub struct WmmaStore<D: Dialect> {
31 frag: Fragment<D>,
32 layout: FragmentLayout<D>,
33}
34
35#[derive(new, Debug, Clone, PartialEq)]
36pub struct WmmaExecute<D: Dialect> {
37 frag_a: Fragment<D>,
38 frag_b: Fragment<D>,
39 frag_c: Fragment<D>,
40 frag_d: Fragment<D>,
41}
42
43#[derive(new, Debug, Clone, PartialEq)]
44pub struct WmmaCast<D: Dialect> {
45 frag_input: Fragment<D>,
46 frag_output: Fragment<D>,
47}
48
49impl<D: Dialect> WmmaFill<D> {
50 pub fn fn_name(&self) -> String {
51 let layout = frag_layout_str(&self.frag.layout);
52 let ident = frag_ident_str(&self.frag.ident);
53 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
54 let elem = self.frag.elem;
55
56 format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
57 }
58
59 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 let elem = self.frag.elem;
61 let frag = self.frag;
62 let name = self.fn_name();
63
64 write!(
65 f,
66 "
67// Fill the fragment.
68__device__ void {name}({frag}& frag, {elem} value) {{
69 #pragma unroll
70 for (uint i = 0; i < 8; ++i) {{
71 frag[i] = value;
72 }}
73}}
74 "
75 )
76 }
77}
78
79impl<D: Dialect> WmmaLoad<D> {
80 pub fn fn_name(&self) -> String {
81 let layout_frag = frag_layout_str(&self.frag.layout);
82 let layout = frag_layout_str(&self.layout);
83 let ident = frag_ident_str(&self.frag.ident);
84 let elem = self.frag.elem;
85 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
86
87 format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
88 }
89
90 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
115 let elem = self.frag.elem;
116 let frag = self.frag;
117 let name = self.fn_name();
118
119 let (index_body, length, step) = match frag.ident {
120 FragmentIdent::A | FragmentIdent::B => {
121 let length = 16;
122 let step = 1;
123 let index = if (frag.ident == FragmentIdent::A
126 && frag.layout.unwrap() == FragmentLayout::ColMajor)
127 || (frag.ident == FragmentIdent::B
128 && frag.layout.unwrap() == FragmentLayout::RowMajor)
129 {
130 "i * stride + wmmaLane".to_string()
131 } else {
132 "i + wmmaLane * stride".to_string()
133 };
134 (index, length, step)
135 }
136 FragmentIdent::Accumulator => {
137 let length = 8;
138 let step = get_output_accumulator_index_step(&elem, &frag);
139 let index = match self.layout {
140 Some(FragmentLayout::ColMajor) => {
141 "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
142 }
143 Some(FragmentLayout::RowMajor) => {
144 "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
145 }
146 _ => panic!(
147 "cannot load data to an accumulator without knowing the layout of the data"
148 ),
149 };
150 (index, length, step)
151 }
152 other => panic!("unknown matrix identifier {other}"),
153 };
154
155 write!(
156 f,
157 "
158// Load the fragment.
159__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
160 {WMMA_LANE_DEF}
161
162 #pragma unroll
163 for (uint i = 0; i < {length}; ++i) {{
164 const uint index = {index_body};
165 frag[i * {step}] = value_ptr[index];
166 }}
167}}
168 "
169 )
170 }
171}
172
173impl<D: Dialect> WmmaStore<D> {
174 pub fn fn_name(&self) -> String {
175 let layout_frag = frag_layout_str(&self.frag.layout);
176 let layout_option = Some(self.layout);
177 let layout = frag_layout_str(&layout_option);
178 let ident = frag_ident_str(&self.frag.ident);
179 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
180 let elem = self.frag.elem;
181
182 format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
183 }
184
185 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
186 let elem = self.frag.elem;
187 let frag = self.frag;
188 let name = self.fn_name();
189 let frag_idx = match elem {
193 Elem::F16 | Elem::BF16 => "elemIdx * 2",
194 Elem::F32 => "elemIdx",
195 other => {
196 panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
197 }
198 };
199 let output_idx = match self.layout {
201 FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
202 FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
203 FragmentLayout::_Dialect(_) => String::new(),
204 };
205
206 write!(
207 f,
208 "
209// Store the fragment.
210__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
211 {WMMA_LANE_DEF}
212
213 #pragma unroll
214 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
215 const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
216 output_ptr[{output_idx}] = frag[{frag_idx}];
217 }}
218}}
219 "
220 )
221 }
222}
223
224impl<D: Dialect> WmmaExecute<D> {
225 pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
226 let frag_a = Fragment {
227 ident: FragmentIdent::A,
228 m: shape.m,
229 n: shape.n,
230 k: shape.k,
231 elem: ab_elem,
232 layout: Some(FragmentLayout::ColMajor),
233 };
234 let frag_b = Fragment {
235 ident: FragmentIdent::B,
236 layout: Some(FragmentLayout::RowMajor),
237 ..frag_a
238 };
239 let frag_cd = Fragment {
240 ident: FragmentIdent::Accumulator,
241 elem: cd_elem,
242 ..frag_b
243 };
244 WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
245 }
246
247 pub fn fn_name(&self) -> String {
248 format!(
249 "wmma_execute_16x16x16_{}_{}",
250 self.frag_a.elem, self.frag_c.elem
251 )
252 }
253
254 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
255 let name = self.fn_name();
256 let ab_format = match self.frag_a.elem {
257 Elem::F32 => "f32",
258 Elem::BF16 => "bf16",
259 Elem::F16 => "f16",
260 _ => panic!(),
261 };
262 let (cd_format, opsel) = match self.frag_c.elem {
263 Elem::F32 => ("f32", ""),
264 Elem::BF16 => ("bf16", ", false"),
265 Elem::F16 => ("f16", ", false"),
266 _ => panic!(),
267 };
268 let warp_size = 32;
269 write!(
270 f,
271 "
272// Execute wmma.
273__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
274 frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
275}}
276 ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
277 )
278 }
279}
280
281impl<D: Dialect> WmmaCast<D> {
282 pub fn fn_name(&self) -> String {
283 let layout = frag_layout_str(&self.frag_input.layout);
284 let ident = frag_ident_str(&self.frag_input.ident);
285 let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
286 let elem = self.frag_input.elem;
287 let elem_out = self.frag_output.elem;
288
289 format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
290 }
291
292 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293 let input = self.frag_input;
294 let output = self.frag_output;
295 let name = self.fn_name();
296 let step = match output.ident {
297 FragmentIdent::Accumulator => {
298 get_output_accumulator_index_step(&self.frag_input.elem, &output)
299 }
300 _ => 1,
301 };
302
303 write!(
304 f,
305 "
306// Cast the fragment.
307__device__ void {name}({input}& input, {output}& output) {{
308 #pragma unroll
309 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
310 output[elemIdx * {step}] = input[elemIdx];
311 }}
312}}
313 "
314 )
315 }
316}
317
318impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
319 fn compile_wmma_type_definitions(
320 f: &mut std::fmt::Formatter<'_>,
321 flags: &Flags<HipDialect<Self>>,
322 ) -> std::fmt::Result {
323 if flags.elem_bf16 {
324 f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
325 f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
326 }
327 if flags.elem_f16 {
328 f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
329 f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
330 }
331 f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
332 }
333
334 fn compile_wmma_fragment_declaration(
335 f: &mut std::fmt::Formatter<'_>,
336 var: &crate::shared::Variable<HipDialect<Self>>,
337 ) -> std::fmt::Result {
338 wmma_api_base::compile_fragment_declaration(f, var)
339 }
340
341 fn compile_wmma_fragment(
342 f: &mut std::fmt::Formatter<'_>,
343 fragment: &Fragment<HipDialect<Self>>,
344 ) -> std::fmt::Result {
345 match fragment.ident {
346 FragmentIdent::A | FragmentIdent::B => match fragment.elem {
347 Elem::F16 => write!(f, "half16_t"),
348 Elem::BF16 => write!(f, "bhalf16_t"),
349 other => panic!("unsupported type {other} for {fragment}"),
350 },
351 FragmentIdent::Accumulator => match fragment.elem {
352 Elem::F16 => write!(f, "half16_t"),
353 Elem::BF16 => write!(f, "bhalf16_t"),
354 Elem::F32 => write!(f, "float8_t"),
355 other => panic!("unsupported type {other} for {fragment}"),
356 },
357 FragmentIdent::_Dialect(_) => Ok(()),
358 }
359 }
360
361 fn compile_wmma_instruction(
362 f: &mut std::fmt::Formatter<'_>,
363 instruction: &WmmaInstruction<HipDialect<Self>>,
364 ) -> std::fmt::Result {
365 match instruction {
366 WmmaInstruction::Fill { frag, value } => {
367 let extension = WmmaFill::new(match frag {
368 Variable::WmmaFragment { frag, .. } => *frag,
369 _ => panic!(),
370 });
371 let name = extension.fn_name();
372 writeln!(f, "{name}({frag}, {value});")
373 }
374 WmmaInstruction::Load {
375 frag,
376 value,
377 layout,
378 offset,
379 stride,
380 } => {
381 let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
382 let name = extension.fn_name();
383 let value_ptr = frag_as_ptr(f, value, offset);
384 writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
385 }
386 WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
387 f.write_str("#error LdMatrix & StMatrix are not supported on HIP\n")
388 }
389 WmmaInstruction::Execute {
390 frag_a,
391 frag_b,
392 frag_c,
393 frag_d,
394 warp_size,
395 } => {
396 if *warp_size != 32 {
397 f.write_str(
398 "#error Only warp size of 32 supported for Wmma::Execute on HIP\n",
399 )?;
400 }
401
402 let extension = WmmaExecute::new(
403 variable_to_frag(frag_a),
404 variable_to_frag(frag_b),
405 variable_to_frag(frag_c),
406 variable_to_frag(frag_d),
407 );
408 let name = extension.fn_name();
409 writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
410 }
411 WmmaInstruction::ExecuteManual {
412 shape,
413 frag_a,
414 frag_b,
415 frag_c,
416 frag_d,
417 } => {
418 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
419 }
420 WmmaInstruction::ExecuteScaled {
421 shape,
422 frag_a,
423 frag_b,
424 frag_c,
425 frag_d,
426 scales_a,
427 scales_b,
428 scales_factor,
429 } => Self::compile_scaled_mma(
430 f,
431 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
432 *scales_a,
433 *scales_b,
434 *scales_factor,
435 ),
436 WmmaInstruction::Store {
437 output,
438 frag,
439 layout,
440 offset,
441 stride,
442 } => {
443 let extension = WmmaStore::new(variable_to_frag(frag), *layout);
444 let name = extension.fn_name();
445 let output_ptr = frag_as_ptr(f, output, offset);
446 writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
447 }
448 WmmaInstruction::Cast { input, output } => {
449 let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
450 let name = extension.fn_name();
451 writeln!(f, "{name}({input}, {output});")
452 }
453 }
454 }
455
456 fn compile_manual_mma(
457 f: &mut std::fmt::Formatter<'_>,
458 mma: ManualMma<HipDialect<Self>>,
459 ) -> std::fmt::Result {
460 compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
461 }
462
463 fn compile_scaled_mma(
464 f: &mut std::fmt::Formatter<'_>,
465 _mma: ManualMma<HipDialect<Self>>,
466 _scales_a: Variable<HipDialect<Self>>,
467 _scales_b: Variable<HipDialect<Self>>,
468 _scales_factor: u32,
469 ) -> std::fmt::Result {
470 f.write_str("#error scaled mma not supported in HIP\n")
471 }
472
473 fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
474 let mut result: SupportedMmaCombinations = vec![];
476 if arch.is_wmma_capable() {
477 let types = vec![
479 (
480 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
484 (
485 gpu::ElemType::Float(gpu::FloatKind::F16),
486 gpu::ElemType::Float(gpu::FloatKind::F16),
487 gpu::ElemType::Float(gpu::FloatKind::F32),
488 ),
489 (
490 gpu::ElemType::Float(gpu::FloatKind::BF16),
491 gpu::ElemType::Float(gpu::FloatKind::BF16),
492 gpu::ElemType::Float(gpu::FloatKind::F32),
493 ),
494 ];
495 let combinations: SupportedMmaCombinations = types
496 .into_iter()
497 .map(|(a, b, c)| MmaConfig {
498 a_type: a.into(),
499 b_type: b.into(),
500 cd_type: c.into(),
501 m: 16,
502 n: 16,
503 k: 16,
504 })
505 .collect();
506 result.extend(combinations);
507 }
508 result
509 }
510
511 fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
512 supported_mma_combinations(arch)
513 }
514}
515
516fn get_output_accumulator_index_step<D: Dialect>(
517 input_elem: &Elem<D>,
518 output: &Fragment<D>,
519) -> u32 {
520 assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
528
529 match input_elem {
530 Elem::F16 | Elem::BF16 | Elem::F32 => {
531 match output.elem {
532 Elem::F16 | Elem::BF16 => 2,
534 Elem::F32 => 1,
536 other => panic!("unsupported format {other} for {output}"),
537 }
538 }
539 other => panic!("unsupported format {other} for {input_elem}"),
540 }
541}
542
543pub(super) fn compile_manual_mma<D: Dialect>(
544 f: &mut std::fmt::Formatter<'_>,
545 shape: MmaShape<D>,
546 frag_a: &Variable<D>,
547 frag_b: &Variable<D>,
548 frag_c: &Variable<D>,
549 frag_d: &Variable<D>,
550) -> std::fmt::Result {
551 let extension = WmmaExecute::from_manual(shape, frag_a.elem(), frag_c.elem());
552
553 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
554
555 let frag_cd_step = 4usize.div_ceil(frag_c.elem().size());
556 let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
557
558 let frag = |var: &Variable<D>, len: usize| {
563 let vec = var.item().vectorization;
564 let frag: Vec<_> = if vec > 1 {
565 (0..len)
566 .map(|i| format!("{var}[{}].i_{}", i / vec, i % vec))
567 .collect()
568 } else {
569 (0..len).map(|i| format!("{var}[{}]", i)).collect()
570 };
571 frag.join(", ")
572 };
573
574 let frag_a = frag(frag_a, 16);
575 let frag_b = frag(frag_b, 16);
576 let frag_c = {
579 let vec = frag_c.item().vectorization;
580 let frag: Vec<_> = if vec > 1 {
581 (0..cd_elems as usize)
582 .flat_map(|i| {
583 (0..frag_cd_step).map(move |_| format!("{frag_c}[{}].i_{}", i / vec, i % vec))
584 })
585 .collect()
586 } else {
587 (0..cd_elems as usize)
588 .flat_map(|i| (0..frag_cd_step).map(move |_| format!("{frag_c}[{}]", i)))
589 .collect()
590 };
591 frag.join(", ")
592 };
593
594 let name = extension.fn_name();
596
597 writeln!(f, "{} {frag_d_tmp} = {{}};", extension.frag_d)?;
599
600 writeln!(
601 f,
602 "{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
603 extension.frag_a, extension.frag_b, extension.frag_c
604 )?;
605
606 for i in 0..cd_elems as usize {
607 let vec = frag_d.item().vectorization;
608 if vec > 1 {
609 writeln!(
610 f,
611 "{frag_d}[{}].i_{} = {frag_d_tmp}[{i} * {frag_cd_step}];",
612 i / vec,
613 i % vec
614 )?;
615 } else {
616 writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i} * {frag_cd_step}];")?;
617 }
618 }
619
620 Ok(())
621}
622
623pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
624 const ENABLED: bool = true;
626
627 if !ENABLED {
628 return Vec::new();
629 }
630
631 let mut result: SupportedMmaCombinations = vec![];
634 if arch.is_wmma_capable() {
635 let types = vec![
637 (
638 gpu::ElemType::Float(gpu::FloatKind::F16),
639 gpu::ElemType::Float(gpu::FloatKind::F32),
640 ),
641 (
642 gpu::ElemType::Float(gpu::FloatKind::BF16),
643 gpu::ElemType::Float(gpu::FloatKind::F32),
644 ),
645 ];
646 let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
647 a_type: ab_elem.into(),
648 b_type: ab_elem.into(),
649 cd_type: cd_elem.into(),
650 m: 16,
651 n: 16,
652 k: 16,
653 });
654 result.extend(combinations);
655 }
656 result
657}
658
659pub fn contiguous_elements_rdna3(ident: MatrixIdent, matrix: Matrix) -> usize {
660 let max_line_size = 16 / matrix.storage.size();
662 match ident {
663 MatrixIdent::A | MatrixIdent::B => 16.min(max_line_size),
664 MatrixIdent::Accumulator => 1,
665 }
666}
667
668static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";