1use std::fmt::Formatter;
2
3use crate::{
4 Dialect,
5 hip::{HipDialect, arch::AMDArchitecture},
6 shared::{
7 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8 FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
9 Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
10 wmma_api_base,
11 },
12};
13use cubecl_core::ir::{self as gpu, Matrix, MatrixIdent};
14use cubecl_runtime::MmaConfig;
15
16#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
17pub struct WmmaIntrinsicCompiler {}
18
19#[derive(new, Debug, Clone, PartialEq)]
20pub struct WmmaFill<D: Dialect> {
21 frag: Fragment<D>,
22}
23
24#[derive(new, Debug, Clone, PartialEq)]
25pub struct WmmaLoad<D: Dialect> {
26 frag: Fragment<D>,
27 layout: Option<FragmentLayout<D>>,
28}
29
30#[derive(new, Debug, Clone, PartialEq)]
31pub struct WmmaStore<D: Dialect> {
32 frag: Fragment<D>,
33 layout: FragmentLayout<D>,
34}
35
36#[derive(new, Debug, Clone, PartialEq)]
37pub struct WmmaExecute<D: Dialect> {
38 frag_a: Fragment<D>,
39 frag_b: Fragment<D>,
40 frag_c: Fragment<D>,
41 frag_d: Fragment<D>,
42}
43
44#[derive(new, Debug, Clone, PartialEq)]
45pub struct WmmaCast<D: Dialect> {
46 frag_input: Fragment<D>,
47 frag_output: Fragment<D>,
48}
49
50impl<D: Dialect> WmmaFill<D> {
51 pub fn fn_name(&self) -> String {
52 let layout = frag_layout_str(&self.frag.layout);
53 let ident = frag_ident_str(&self.frag.ident);
54 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
55 let elem = self.frag.elem;
56
57 format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
58 }
59
60 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61 let elem = self.frag.elem;
62 let frag = self.frag;
63 let name = self.fn_name();
64
65 write!(
66 f,
67 "
68// Fill the fragment.
69__device__ void {name}({frag}& frag, {elem} value) {{
70 #pragma unroll
71 for (uint i = 0; i < 8; ++i) {{
72 frag[i] = value;
73 }}
74}}
75 "
76 )
77 }
78}
79
80impl<D: Dialect> WmmaLoad<D> {
81 pub fn fn_name(&self) -> String {
82 let layout_frag = frag_layout_str(&self.frag.layout);
83 let layout = frag_layout_str(&self.layout);
84 let ident = frag_ident_str(&self.frag.ident);
85 let elem = self.frag.elem;
86 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
87
88 format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
89 }
90
91 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116 let elem = self.frag.elem;
117 let frag = self.frag;
118 let name = self.fn_name();
119
120 let (index_body, length, step) = match frag.ident {
121 FragmentIdent::A | FragmentIdent::B => {
122 let length = 16;
123 let step = 1;
124 let index = if (frag.ident == FragmentIdent::A
127 && frag.layout.unwrap() == FragmentLayout::ColMajor)
128 || (frag.ident == FragmentIdent::B
129 && frag.layout.unwrap() == FragmentLayout::RowMajor)
130 {
131 "i * stride + wmmaLane".to_string()
132 } else {
133 "i + wmmaLane * stride".to_string()
134 };
135 (index, length, step)
136 }
137 FragmentIdent::Accumulator => {
138 let length = 8;
139 let step = get_output_accumulator_index_step(&elem, &frag);
140 let index = match self.layout {
141 Some(FragmentLayout::ColMajor) => {
142 "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
143 }
144 Some(FragmentLayout::RowMajor) => {
145 "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
146 }
147 _ => panic!(
148 "cannot load data to an accumulator without knowing the layout of the data"
149 ),
150 };
151 (index, length, step)
152 }
153 other => panic!("unknown matrix identifier {other}"),
154 };
155
156 write!(
157 f,
158 "
159// Load the fragment.
160__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
161 {WMMA_LANE_DEF}
162
163 #pragma unroll
164 for (uint i = 0; i < {length}; ++i) {{
165 const uint index = {index_body};
166 frag[i * {step}] = value_ptr[index];
167 }}
168}}
169 "
170 )
171 }
172}
173
174impl<D: Dialect> WmmaStore<D> {
175 pub fn fn_name(&self) -> String {
176 let layout_frag = frag_layout_str(&self.frag.layout);
177 let layout_option = Some(self.layout);
178 let layout = frag_layout_str(&layout_option);
179 let ident = frag_ident_str(&self.frag.ident);
180 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
181 let elem = self.frag.elem;
182
183 format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
184 }
185
186 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187 let elem = self.frag.elem;
188 let frag = self.frag;
189 let name = self.fn_name();
190 let frag_idx = match elem {
194 Elem::F16 | Elem::BF16 => "elemIdx * 2",
195 Elem::F32 => "elemIdx",
196 other => {
197 panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
198 }
199 };
200 let output_idx = match self.layout {
202 FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
203 FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
204 FragmentLayout::_Dialect(_) => String::new(),
205 };
206
207 write!(
208 f,
209 "
210// Store the fragment.
211__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
212 {WMMA_LANE_DEF}
213
214 #pragma unroll
215 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
216 const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
217 output_ptr[{output_idx}] = frag[{frag_idx}];
218 }}
219}}
220 "
221 )
222 }
223}
224
225impl<D: Dialect> WmmaExecute<D> {
226 pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
227 let frag_a = Fragment {
228 ident: FragmentIdent::A,
229 m: shape.m,
230 n: shape.n,
231 k: shape.k,
232 elem: ab_elem,
233 layout: Some(FragmentLayout::ColMajor),
234 };
235 let frag_b = Fragment {
236 ident: FragmentIdent::B,
237 layout: Some(FragmentLayout::RowMajor),
238 ..frag_a
239 };
240 let frag_cd = Fragment {
241 ident: FragmentIdent::Accumulator,
242 elem: cd_elem,
243 ..frag_b
244 };
245 WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
246 }
247
248 pub fn fn_name(&self) -> String {
249 format!(
250 "wmma_execute_16x16x16_{}_{}",
251 self.frag_a.elem, self.frag_c.elem
252 )
253 }
254
255 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 let name = self.fn_name();
257 let ab_format = match self.frag_a.elem {
258 Elem::F32 => "f32",
259 Elem::BF16 => "bf16",
260 Elem::F16 => "f16",
261 _ => panic!(),
262 };
263 let (cd_format, opsel) = match self.frag_c.elem {
264 Elem::F32 => ("f32", ""),
265 Elem::BF16 => ("bf16", ", false"),
266 Elem::F16 => ("f16", ", false"),
267 _ => panic!(),
268 };
269 let warp_size = 32;
270 write!(
271 f,
272 "
273// Execute wmma.
274__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
275 frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
276}}
277 ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
278 )
279 }
280}
281
282impl<D: Dialect> WmmaCast<D> {
283 pub fn fn_name(&self) -> String {
284 let layout = frag_layout_str(&self.frag_input.layout);
285 let ident = frag_ident_str(&self.frag_input.ident);
286 let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
287 let elem = self.frag_input.elem;
288 let elem_out = self.frag_output.elem;
289
290 format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
291 }
292
293 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294 let input = self.frag_input;
295 let output = self.frag_output;
296 let name = self.fn_name();
297 let step = match output.ident {
298 FragmentIdent::Accumulator => {
299 get_output_accumulator_index_step(&self.frag_input.elem, &output)
300 }
301 _ => 1,
302 };
303
304 write!(
305 f,
306 "
307// Cast the fragment.
308__device__ void {name}({input}& input, {output}& output) {{
309 #pragma unroll
310 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
311 output[elemIdx * {step}] = input[elemIdx];
312 }}
313}}
314 "
315 )
316 }
317}
318
319impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
320 fn compile_wmma_type_definitions(
321 f: &mut std::fmt::Formatter<'_>,
322 flags: &Flags,
323 ) -> std::fmt::Result {
324 if flags.elem_bf16 {
325 f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
326 f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
327 }
328 if flags.elem_f16 {
329 f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
330 f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
331 }
332 f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
333 }
334
335 fn compile_wmma_fragment_declaration(
336 f: &mut std::fmt::Formatter<'_>,
337 var: &crate::shared::Variable<HipDialect<Self>>,
338 ) -> std::fmt::Result {
339 wmma_api_base::compile_fragment_declaration(f, var)
340 }
341
342 fn compile_wmma_fragment(
343 f: &mut std::fmt::Formatter<'_>,
344 fragment: &Fragment<HipDialect<Self>>,
345 ) -> std::fmt::Result {
346 match fragment.ident {
347 FragmentIdent::A | FragmentIdent::B => match fragment.elem {
348 Elem::F16 => write!(f, "half16_t"),
349 Elem::BF16 => write!(f, "bhalf16_t"),
350 other => panic!("unsupported type {other} for {fragment}"),
351 },
352 FragmentIdent::Accumulator => match fragment.elem {
353 Elem::F16 => write!(f, "half16_t"),
354 Elem::BF16 => write!(f, "bhalf16_t"),
355 Elem::F32 => write!(f, "float8_t"),
356 other => panic!("unsupported type {other} for {fragment}"),
357 },
358 FragmentIdent::_Dialect(_) => Ok(()),
359 }
360 }
361
362 fn compile_wmma_instruction(
363 f: &mut std::fmt::Formatter<'_>,
364 instruction: &WmmaInstruction<HipDialect<Self>>,
365 ) -> std::fmt::Result {
366 match instruction {
367 WmmaInstruction::Fill { frag, value } => {
368 let extension = WmmaFill::new(match frag {
369 Variable::WmmaFragment { frag, .. } => *frag,
370 _ => panic!(),
371 });
372 let name = extension.fn_name();
373 writeln!(f, "{name}({frag}, {value});")
374 }
375 WmmaInstruction::Load {
376 frag,
377 value,
378 layout,
379 offset,
380 stride,
381 } => {
382 let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
383 let name = extension.fn_name();
384 let value_ptr = frag_as_ptr(f, value, offset);
385 writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
386 }
387 WmmaInstruction::LdMatrix { .. } => {
388 unimplemented!("Not supported in HIP")
389 }
390 WmmaInstruction::Execute {
391 frag_a,
392 frag_b,
393 frag_c,
394 frag_d,
395 warp_size,
396 } => {
397 assert_eq!(*warp_size, 32, "Only warp size of 32 supported");
398
399 let extension = WmmaExecute::new(
400 variable_to_frag(frag_a),
401 variable_to_frag(frag_b),
402 variable_to_frag(frag_c),
403 variable_to_frag(frag_d),
404 );
405 let name = extension.fn_name();
406 writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
407 }
408 WmmaInstruction::ExecuteManual {
409 shape,
410 frag_a,
411 frag_b,
412 frag_c,
413 frag_d,
414 } => {
415 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
416 }
417 WmmaInstruction::ExecuteScaled {
418 shape,
419 frag_a,
420 frag_b,
421 frag_c,
422 frag_d,
423 scales_a,
424 scales_b,
425 scales_factor,
426 } => Self::compile_scaled_mma(
427 f,
428 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
429 *scales_a,
430 *scales_b,
431 *scales_factor,
432 ),
433 WmmaInstruction::Store {
434 output,
435 frag,
436 layout,
437 offset,
438 stride,
439 } => {
440 let extension = WmmaStore::new(variable_to_frag(frag), *layout);
441 let name = extension.fn_name();
442 let output_ptr = frag_as_ptr(f, output, offset);
443 writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
444 }
445 WmmaInstruction::Cast { input, output } => {
446 let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
447 let name = extension.fn_name();
448 writeln!(f, "{name}({input}, {output});")
449 }
450 }
451 }
452
453 fn compile_manual_mma(
454 f: &mut std::fmt::Formatter<'_>,
455 mma: ManualMma<HipDialect<Self>>,
456 ) -> std::fmt::Result {
457 compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
458 }
459
460 fn compile_scaled_mma(
461 _f: &mut std::fmt::Formatter<'_>,
462 _mma: ManualMma<HipDialect<Self>>,
463 _scales_a: Variable<HipDialect<Self>>,
464 _scales_b: Variable<HipDialect<Self>>,
465 _scales_factor: u32,
466 ) -> std::fmt::Result {
467 unimplemented!("Not supported in HIP")
468 }
469
470 fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
471 let mut result: SupportedMmaCombinations = vec![];
473 if arch.is_wmma_capable() {
474 let types = vec![
476 (
477 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
481 (
482 gpu::ElemType::Float(gpu::FloatKind::F16),
483 gpu::ElemType::Float(gpu::FloatKind::F16),
484 gpu::ElemType::Float(gpu::FloatKind::F32),
485 ),
486 (
487 gpu::ElemType::Float(gpu::FloatKind::BF16),
488 gpu::ElemType::Float(gpu::FloatKind::BF16),
489 gpu::ElemType::Float(gpu::FloatKind::F32),
490 ),
491 ];
492 let combinations: SupportedMmaCombinations = types
493 .into_iter()
494 .map(|(a, b, c)| MmaConfig {
495 a_type: a.into(),
496 b_type: b.into(),
497 cd_type: c.into(),
498 m: 16,
499 n: 16,
500 k: 16,
501 })
502 .collect();
503 result.extend(combinations);
504 }
505 result
506 }
507
508 fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
509 supported_mma_combinations(arch)
510 }
511}
512
513fn get_output_accumulator_index_step<D: Dialect>(
514 input_elem: &Elem<D>,
515 output: &Fragment<D>,
516) -> u32 {
517 assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
525
526 match input_elem {
527 Elem::F16 | Elem::BF16 | Elem::F32 => {
528 match output.elem {
529 Elem::F16 | Elem::BF16 => 2,
531 Elem::F32 => 1,
533 other => panic!("unsupported format {other} for {output}"),
534 }
535 }
536 other => panic!("unsupported format {other} for {input_elem}"),
537 }
538}
539
540pub(super) fn compile_manual_mma<D: Dialect>(
541 f: &mut std::fmt::Formatter<'_>,
542 shape: MmaShape<D>,
543 frag_a: &Variable<D>,
544 frag_b: &Variable<D>,
545 frag_c: &Variable<D>,
546 frag_d: &Variable<D>,
547) -> std::fmt::Result {
548 let extension = WmmaExecute::from_manual(shape, frag_a.elem(), frag_c.elem());
549
550 let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
551
552 let frag_cd_step = 4usize.div_ceil(frag_c.elem().size());
553 let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
554
555 let frag = |var: &Variable<D>, len: usize| {
560 let vec = var.item().vectorization;
561 let frag: Vec<_> = if vec > 1 {
562 (0..len)
563 .map(|i| format!("{var}[{}].i_{}", i / vec, i % vec))
564 .collect()
565 } else {
566 (0..len).map(|i| format!("{var}[{}]", i)).collect()
567 };
568 frag.join(", ")
569 };
570
571 let frag_a = frag(frag_a, 16);
572 let frag_b = frag(frag_b, 16);
573 let frag_c = {
576 let vec = frag_c.item().vectorization;
577 let frag: Vec<_> = if vec > 1 {
578 (0..cd_elems as usize)
579 .flat_map(|i| {
580 (0..frag_cd_step).map(move |_| format!("{frag_c}[{}].i_{}", i / vec, i % vec))
581 })
582 .collect()
583 } else {
584 (0..cd_elems as usize)
585 .flat_map(|i| (0..frag_cd_step).map(move |_| format!("{frag_c}[{}]", i)))
586 .collect()
587 };
588 frag.join(", ")
589 };
590
591 let name = extension.fn_name();
593
594 writeln!(f, "{} {frag_d_tmp} = {{}};", extension.frag_d)?;
596
597 writeln!(
598 f,
599 "{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
600 extension.frag_a, extension.frag_b, extension.frag_c
601 )?;
602
603 for i in 0..cd_elems as usize {
604 let vec = frag_d.item().vectorization;
605 if vec > 1 {
606 writeln!(
607 f,
608 "{frag_d}[{}].i_{} = {frag_d_tmp}[{i} * {frag_cd_step}];",
609 i / vec,
610 i % vec
611 )?;
612 } else {
613 writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i} * {frag_cd_step}];")?;
614 }
615 }
616
617 Ok(())
618}
619
620pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
621 const ENABLED: bool = true;
623
624 if !ENABLED {
625 return Vec::new();
626 }
627
628 let mut result: SupportedMmaCombinations = vec![];
631 if arch.is_wmma_capable() {
632 let types = vec![
634 (
635 gpu::ElemType::Float(gpu::FloatKind::F16),
636 gpu::ElemType::Float(gpu::FloatKind::F32),
637 ),
638 (
639 gpu::ElemType::Float(gpu::FloatKind::BF16),
640 gpu::ElemType::Float(gpu::FloatKind::F32),
641 ),
642 ];
643 let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
644 a_type: ab_elem.into(),
645 b_type: ab_elem.into(),
646 cd_type: cd_elem.into(),
647 m: 16,
648 n: 16,
649 k: 16,
650 });
651 result.extend(combinations);
652 }
653 result
654}
655
656pub fn contiguous_elements_rdna3(ident: MatrixIdent, matrix: Matrix) -> u32 {
657 let max_line_size = 16 / matrix.storage.size();
659 match ident {
660 MatrixIdent::A | MatrixIdent::B => 16.min(max_line_size) as u32,
661 MatrixIdent::Accumulator => 1,
662 }
663}
664
665static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";