1use std::fmt::Formatter;
2
3use crate::{
4 Dialect,
5 cuda::ptx::comma_separated,
6 hip::{HipDialect, arch::AMDArchitecture},
7 shared::{
8 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
9 FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
10 Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
11 wmma_api_base,
12 },
13};
14use cubecl_core::ir::{self as gpu};
15use cubecl_runtime::MmaConfig;
16
17#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
18pub struct WmmaIntrinsicCompiler {}
19
20#[derive(new, Debug, Clone, PartialEq)]
21pub struct WmmaFill<D: Dialect> {
22 frag: Fragment<D>,
23}
24
25#[derive(new, Debug, Clone, PartialEq)]
26pub struct WmmaLoad<D: Dialect> {
27 frag: Fragment<D>,
28 layout: Option<FragmentLayout<D>>,
29}
30
31#[derive(new, Debug, Clone, PartialEq)]
32pub struct WmmaStore<D: Dialect> {
33 frag: Fragment<D>,
34 layout: FragmentLayout<D>,
35}
36
37#[derive(new, Debug, Clone, PartialEq)]
38pub struct WmmaExecute<D: Dialect> {
39 frag_a: Fragment<D>,
40 frag_b: Fragment<D>,
41 frag_c: Fragment<D>,
42 frag_d: Fragment<D>,
43}
44
45#[derive(new, Debug, Clone, PartialEq)]
46pub struct WmmaCast<D: Dialect> {
47 frag_input: Fragment<D>,
48 frag_output: Fragment<D>,
49}
50
51impl<D: Dialect> WmmaFill<D> {
52 pub fn fn_name(&self) -> String {
53 let layout = frag_layout_str(&self.frag.layout);
54 let ident = frag_ident_str(&self.frag.ident);
55 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
56 let elem = self.frag.elem;
57
58 format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
59 }
60
61 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62 let elem = self.frag.elem;
63 let frag = self.frag;
64 let name = self.fn_name();
65
66 write!(
67 f,
68 "
69// Fill the fragment.
70__device__ void {name}({frag}& frag, {elem} value) {{
71 #pragma unroll
72 for (uint i = 0; i < 8; ++i) {{
73 frag[i] = value;
74 }}
75}}
76 "
77 )
78 }
79}
80
81impl<D: Dialect> WmmaLoad<D> {
82 pub fn fn_name(&self) -> String {
83 let layout_frag = frag_layout_str(&self.frag.layout);
84 let layout = frag_layout_str(&self.layout);
85 let ident = frag_ident_str(&self.frag.ident);
86 let elem = self.frag.elem;
87 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
88
89 format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
90 }
91
92 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
117 let elem = self.frag.elem;
118 let frag = self.frag;
119 let name = self.fn_name();
120
121 let (index_body, length, step) = match frag.ident {
122 FragmentIdent::A | FragmentIdent::B => {
123 let length = 16;
124 let step = 1;
125 let index = if (frag.ident == FragmentIdent::A
128 && frag.layout.unwrap() == FragmentLayout::ColMajor)
129 || (frag.ident == FragmentIdent::B
130 && frag.layout.unwrap() == FragmentLayout::RowMajor)
131 {
132 "i * stride + wmmaLane".to_string()
133 } else {
134 "i + wmmaLane * stride".to_string()
135 };
136 (index, length, step)
137 }
138 FragmentIdent::Accumulator => {
139 let length = 8;
140 let step = get_output_accumulator_index_step(&elem, &frag);
141 let index = match self.layout {
142 Some(FragmentLayout::ColMajor) => {
143 "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
144 }
145 Some(FragmentLayout::RowMajor) => {
146 "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
147 }
148 _ => panic!(
149 "cannot load data to an accumulator without knowing the layout of the data"
150 ),
151 };
152 (index, length, step)
153 }
154 other => panic!("unknown matrix identifier {other}"),
155 };
156
157 write!(
158 f,
159 "
160// Load the fragment.
161__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
162 {WMMA_LANE_DEF}
163
164 #pragma unroll
165 for (uint i = 0; i < {length}; ++i) {{
166 const uint index = {index_body};
167 frag[i * {step}] = value_ptr[index];
168 }}
169}}
170 "
171 )
172 }
173}
174
175impl<D: Dialect> WmmaStore<D> {
176 pub fn fn_name(&self) -> String {
177 let layout_frag = frag_layout_str(&self.frag.layout);
178 let layout_option = Some(self.layout);
179 let layout = frag_layout_str(&layout_option);
180 let ident = frag_ident_str(&self.frag.ident);
181 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
182 let elem = self.frag.elem;
183
184 format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
185 }
186
187 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
188 let elem = self.frag.elem;
189 let frag = self.frag;
190 let name = self.fn_name();
191 let frag_idx = match elem {
195 Elem::F16 | Elem::BF16 => "elemIdx * 2",
196 Elem::F32 => "elemIdx",
197 other => {
198 panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
199 }
200 };
201 let output_idx = match self.layout {
203 FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
204 FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
205 FragmentLayout::_Dialect(_) => String::new(),
206 };
207
208 write!(
209 f,
210 "
211// Store the fragment.
212__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
213 {WMMA_LANE_DEF}
214
215 #pragma unroll
216 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
217 const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
218 output_ptr[{output_idx}] = frag[{frag_idx}];
219 }}
220}}
221 "
222 )
223 }
224}
225
226impl<D: Dialect> WmmaExecute<D> {
227 pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
228 let frag_a = Fragment {
229 ident: FragmentIdent::A,
230 m: shape.m,
231 n: shape.n,
232 k: shape.k,
233 elem: ab_elem,
234 layout: Some(FragmentLayout::ColMajor),
235 };
236 let frag_b = Fragment {
237 ident: FragmentIdent::B,
238 layout: Some(FragmentLayout::RowMajor),
239 ..frag_a
240 };
241 let frag_cd = Fragment {
242 ident: FragmentIdent::Accumulator,
243 elem: cd_elem,
244 ..frag_b
245 };
246 WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
247 }
248
249 pub fn fn_name(&self) -> String {
250 format!(
251 "wmma_execute_16x16x16_{}_{}",
252 self.frag_a.elem, self.frag_c.elem
253 )
254 }
255
256 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
257 let name = self.fn_name();
258 let ab_format = match self.frag_a.elem {
259 Elem::F32 => "f32",
260 Elem::BF16 => "bf16",
261 Elem::F16 => "f16",
262 _ => panic!(),
263 };
264 let (cd_format, opsel) = match self.frag_c.elem {
265 Elem::F32 => ("f32", ""),
266 Elem::BF16 => ("bf16", ", false"),
267 Elem::F16 => ("f16", ", false"),
268 _ => panic!(),
269 };
270 let warp_size = 32;
271 write!(
272 f,
273 "
274// Execute wmma.
275__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
276 frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
277}}
278 ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
279 )
280 }
281}
282
283impl<D: Dialect> WmmaCast<D> {
284 pub fn fn_name(&self) -> String {
285 let layout = frag_layout_str(&self.frag_input.layout);
286 let ident = frag_ident_str(&self.frag_input.ident);
287 let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
288 let elem = self.frag_input.elem;
289 let elem_out = self.frag_output.elem;
290
291 format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
292 }
293
294 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
295 let input = self.frag_input;
296 let output = self.frag_output;
297 let name = self.fn_name();
298 let step = match output.ident {
299 FragmentIdent::Accumulator => {
300 get_output_accumulator_index_step(&self.frag_input.elem, &output)
301 }
302 _ => 1,
303 };
304
305 write!(
306 f,
307 "
308// Cast the fragment.
309__device__ void {name}({input}& input, {output}& output) {{
310 #pragma unroll
311 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
312 output[elemIdx * {step}] = input[elemIdx];
313 }}
314}}
315 "
316 )
317 }
318}
319
320impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
321 fn compile_wmma_type_definitions(
322 f: &mut std::fmt::Formatter<'_>,
323 flags: &Flags,
324 ) -> std::fmt::Result {
325 if flags.elem_bf16 {
326 f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
327 f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
328 }
329 if flags.elem_f16 {
330 f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
331 f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
332 }
333 f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
334 }
335
336 fn compile_wmma_fragment_declaration(
337 f: &mut std::fmt::Formatter<'_>,
338 var: &crate::shared::Variable<HipDialect<Self>>,
339 ) -> std::fmt::Result {
340 wmma_api_base::compile_fragment_declaration(f, var)
341 }
342
343 fn compile_wmma_fragment(
344 f: &mut std::fmt::Formatter<'_>,
345 fragment: &Fragment<HipDialect<Self>>,
346 ) -> std::fmt::Result {
347 match fragment.ident {
348 FragmentIdent::A | FragmentIdent::B => match fragment.elem {
349 Elem::F16 => write!(f, "half16_t"),
350 Elem::BF16 => write!(f, "bhalf16_t"),
351 other => panic!("unsupported type {other} for {fragment}"),
352 },
353 FragmentIdent::Accumulator => match fragment.elem {
354 Elem::F16 => write!(f, "half16_t"),
355 Elem::BF16 => write!(f, "bhalf16_t"),
356 Elem::F32 => write!(f, "float8_t"),
357 other => panic!("unsupported type {other} for {fragment}"),
358 },
359 FragmentIdent::_Dialect(_) => Ok(()),
360 }
361 }
362
363 fn compile_wmma_instruction(
364 f: &mut std::fmt::Formatter<'_>,
365 instruction: &WmmaInstruction<HipDialect<Self>>,
366 ) -> std::fmt::Result {
367 match instruction {
368 WmmaInstruction::Fill { frag, value } => {
369 let extension = WmmaFill::new(match frag {
370 Variable::WmmaFragment { frag, .. } => *frag,
371 _ => panic!(),
372 });
373 let name = extension.fn_name();
374 writeln!(f, "{name}({frag}, {value});")
375 }
376 WmmaInstruction::Load {
377 frag,
378 value,
379 layout,
380 offset,
381 stride,
382 } => {
383 let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
384 let name = extension.fn_name();
385 let value_ptr = frag_as_ptr(f, value, offset);
386 writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
387 }
388 WmmaInstruction::Execute {
389 frag_a,
390 frag_b,
391 frag_c,
392 frag_d,
393 warp_size,
394 } => {
395 assert_eq!(*warp_size, 32, "Only warp size of 32 supported");
396
397 let extension = WmmaExecute::new(
398 variable_to_frag(frag_a),
399 variable_to_frag(frag_b),
400 variable_to_frag(frag_c),
401 variable_to_frag(frag_d),
402 );
403 let name = extension.fn_name();
404 writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
405 }
406 WmmaInstruction::ExecuteManual {
407 shape,
408 frag_a,
409 frag_b,
410 frag_c,
411 frag_d,
412 } => {
413 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
414 }
415 WmmaInstruction::ExecuteScaled {
416 shape,
417 frag_a,
418 frag_b,
419 frag_c,
420 frag_d,
421 scales_a,
422 scales_b,
423 scales_factor,
424 } => Self::compile_scaled_mma(
425 f,
426 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
427 *scales_a,
428 *scales_b,
429 *scales_factor,
430 ),
431 WmmaInstruction::Store {
432 output,
433 frag,
434 layout,
435 offset,
436 stride,
437 } => {
438 let extension = WmmaStore::new(variable_to_frag(frag), *layout);
439 let name = extension.fn_name();
440 let output_ptr = frag_as_ptr(f, output, offset);
441 writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
442 }
443 WmmaInstruction::Cast { input, output } => {
444 let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
445 let name = extension.fn_name();
446 writeln!(f, "{name}({input}, {output});")
447 }
448 }
449 }
450
451 fn compile_manual_mma(
452 f: &mut std::fmt::Formatter<'_>,
453 mma: ManualMma<HipDialect<Self>>,
454 ) -> std::fmt::Result {
455 compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
456 }
457
458 fn compile_scaled_mma(
459 _f: &mut std::fmt::Formatter<'_>,
460 _mma: ManualMma<HipDialect<Self>>,
461 _scales_a: Variable<HipDialect<Self>>,
462 _scales_b: Variable<HipDialect<Self>>,
463 _scales_factor: u32,
464 ) -> std::fmt::Result {
465 unimplemented!("Not supported in HIP")
466 }
467
468 fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
469 let mut result: SupportedMmaCombinations = vec![];
471 if arch.is_wmma_capable() {
472 let types = vec![
474 (
475 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
479 (
480 gpu::ElemType::Float(gpu::FloatKind::F16),
481 gpu::ElemType::Float(gpu::FloatKind::F16),
482 gpu::ElemType::Float(gpu::FloatKind::F32),
483 ),
484 (
485 gpu::ElemType::Float(gpu::FloatKind::BF16),
486 gpu::ElemType::Float(gpu::FloatKind::BF16),
487 gpu::ElemType::Float(gpu::FloatKind::F32),
488 ),
489 ];
490 let combinations: SupportedMmaCombinations = types
491 .into_iter()
492 .map(|(a, b, c)| MmaConfig {
493 a_type: a.into(),
494 b_type: b.into(),
495 cd_type: c.into(),
496 m: 16,
497 n: 16,
498 k: 16,
499 })
500 .collect();
501 result.extend(combinations);
502 }
503 result
504 }
505
506 fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
507 supported_mma_combinations(arch)
508 }
509}
510
511fn get_output_accumulator_index_step<D: Dialect>(
512 input_elem: &Elem<D>,
513 output: &Fragment<D>,
514) -> u32 {
515 assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
523
524 match input_elem {
525 Elem::F16 | Elem::BF16 | Elem::F32 => {
526 match output.elem {
527 Elem::F16 | Elem::BF16 => 2,
529 Elem::F32 => 1,
531 other => panic!("unsupported format {other} for {output}"),
532 }
533 }
534 other => panic!("unsupported format {other} for {input_elem}"),
535 }
536}
537
538pub(super) fn compile_manual_mma<D: Dialect>(
539 f: &mut std::fmt::Formatter<'_>,
540 shape: MmaShape<D>,
541 frag_a: &[Variable<D>],
542 frag_b: &[Variable<D>],
543 frag_c: &[Variable<D>],
544 frag_d: &Variable<D>,
545) -> std::fmt::Result {
546 let extension = WmmaExecute::from_manual(shape, frag_a[0].elem(), frag_c[0].elem());
547 let frag_d_len = frag_c.len();
548
549 let frag_a = comma_separated(
550 frag_a
551 .iter()
552 .flat_map(|it| (0..it.item().vectorization).map(|i| it.index(i)))
553 .map(|it| format!("{it}")),
554 );
555 let frag_b = comma_separated(
556 frag_b
557 .iter()
558 .flat_map(|it| (0..it.item().vectorization).map(|i| it.index(i)))
559 .map(|it| format!("{it}")),
560 );
561 let frag_c = comma_separated(
562 frag_c
563 .iter()
564 .flat_map(|it| (0..it.item().vectorization).map(|i| it.index(i)))
565 .map(|it| format!("{it}")),
566 );
567
568 let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
570
571 let name = extension.fn_name();
572 writeln!(f, "{ty} {frag_d_tmp} = {ty}{{}};", ty = extension.frag_d)?;
573 writeln!(
574 f,
575 "{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
576 extension.frag_a, extension.frag_b, extension.frag_c
577 )?;
578
579 for i in 0..frag_d_len {
580 writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i}];")?;
581 }
582
583 Ok(())
584}
585
586pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
587 let mut result: SupportedMmaCombinations = vec![];
590 if arch.is_wmma_capable() {
591 let types = vec![
593 (
594 gpu::ElemType::Float(gpu::FloatKind::F16),
595 gpu::ElemType::Float(gpu::FloatKind::F32),
596 ),
597 (
598 gpu::ElemType::Float(gpu::FloatKind::BF16),
599 gpu::ElemType::Float(gpu::FloatKind::F32),
600 ),
601 ];
602 let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
603 a_type: ab_elem.into(),
604 b_type: ab_elem.into(),
605 cd_type: cd_elem.into(),
606 m: 16,
607 n: 16,
608 k: 16,
609 });
610 result.extend(combinations);
611 }
612 result
613}
614
615static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";