1use std::fmt::Formatter;
2
3use crate::{
4 Dialect,
5 hip::{HipDialect, arch::AMDArchitecture},
6 shared::{
7 Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
8 FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
9 Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
10 wmma_api_base,
11 },
12};
13use cubecl_core::ir::{self as gpu};
14use cubecl_runtime::MmaConfig;
15
16#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
17pub struct WmmaIntrinsicCompiler {}
18
19#[derive(new, Debug, Clone, PartialEq)]
20pub struct WmmaFill<D: Dialect> {
21 frag: Fragment<D>,
22}
23
24#[derive(new, Debug, Clone, PartialEq)]
25pub struct WmmaLoad<D: Dialect> {
26 frag: Fragment<D>,
27 layout: Option<FragmentLayout<D>>,
28}
29
30#[derive(new, Debug, Clone, PartialEq)]
31pub struct WmmaStore<D: Dialect> {
32 frag: Fragment<D>,
33 layout: FragmentLayout<D>,
34}
35
36#[derive(new, Debug, Clone, PartialEq)]
37pub struct WmmaExecute<D: Dialect> {
38 frag_a: Fragment<D>,
39 frag_b: Fragment<D>,
40 frag_c: Fragment<D>,
41 frag_d: Fragment<D>,
42}
43
44#[derive(new, Debug, Clone, PartialEq)]
45pub struct WmmaCast<D: Dialect> {
46 frag_input: Fragment<D>,
47 frag_output: Fragment<D>,
48}
49
50impl<D: Dialect> WmmaFill<D> {
51 pub fn fn_name(&self) -> String {
52 let layout = frag_layout_str(&self.frag.layout);
53 let ident = frag_ident_str(&self.frag.ident);
54 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
55 let elem = self.frag.elem;
56
57 format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
58 }
59
60 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61 let elem = self.frag.elem;
62 let frag = self.frag;
63 let name = self.fn_name();
64
65 write!(
66 f,
67 "
68// Fill the fragment.
69__device__ void {name}({frag}& frag, {elem} value) {{
70 #pragma unroll
71 for (uint i = 0; i < 8; ++i) {{
72 frag[i] = value;
73 }}
74}}
75 "
76 )
77 }
78}
79
80impl<D: Dialect> WmmaLoad<D> {
81 pub fn fn_name(&self) -> String {
82 let layout_frag = frag_layout_str(&self.frag.layout);
83 let layout = frag_layout_str(&self.layout);
84 let ident = frag_ident_str(&self.frag.ident);
85 let elem = self.frag.elem;
86 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
87
88 format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
89 }
90
91 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
116 let elem = self.frag.elem;
117 let frag = self.frag;
118 let name = self.fn_name();
119
120 let (index_body, length, step) = match frag.ident {
121 FragmentIdent::A | FragmentIdent::B => {
122 let length = 16;
123 let step = 1;
124 let index = if (frag.ident == FragmentIdent::A
127 && frag.layout.unwrap() == FragmentLayout::ColMajor)
128 || (frag.ident == FragmentIdent::B
129 && frag.layout.unwrap() == FragmentLayout::RowMajor)
130 {
131 "i * stride + wmmaLane".to_string()
132 } else {
133 "i + wmmaLane * stride".to_string()
134 };
135 (index, length, step)
136 }
137 FragmentIdent::Accumulator => {
138 let length = 8;
139 let step = get_output_accumulator_index_step(&elem, &frag);
140 let index = match self.layout {
141 Some(FragmentLayout::ColMajor) => {
142 "(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
143 }
144 Some(FragmentLayout::RowMajor) => {
145 "(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
146 }
147 _ => panic!(
148 "cannot load data to an accumulator without knowing the layout of the data"
149 ),
150 };
151 (index, length, step)
152 }
153 other => panic!("unknown matrix identifier {other}"),
154 };
155
156 write!(
157 f,
158 "
159// Load the fragment.
160__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
161 {WMMA_LANE_DEF}
162
163 #pragma unroll
164 for (uint i = 0; i < {length}; ++i) {{
165 const uint index = {index_body};
166 frag[i * {step}] = value_ptr[index];
167 }}
168}}
169 "
170 )
171 }
172}
173
174impl<D: Dialect> WmmaStore<D> {
175 pub fn fn_name(&self) -> String {
176 let layout_frag = frag_layout_str(&self.frag.layout);
177 let layout_option = Some(self.layout);
178 let layout = frag_layout_str(&layout_option);
179 let ident = frag_ident_str(&self.frag.ident);
180 let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
181 let elem = self.frag.elem;
182
183 format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
184 }
185
186 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187 let elem = self.frag.elem;
188 let frag = self.frag;
189 let name = self.fn_name();
190 let frag_idx = match elem {
194 Elem::F16 | Elem::BF16 => "elemIdx * 2",
195 Elem::F32 => "elemIdx",
196 other => {
197 panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
198 }
199 };
200 let output_idx = match self.layout {
202 FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
203 FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
204 FragmentLayout::_Dialect(_) => String::new(),
205 };
206
207 write!(
208 f,
209 "
210// Store the fragment.
211__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
212 {WMMA_LANE_DEF}
213
214 #pragma unroll
215 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
216 const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
217 output_ptr[{output_idx}] = frag[{frag_idx}];
218 }}
219}}
220 "
221 )
222 }
223}
224
225impl<D: Dialect> WmmaExecute<D> {
226 pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
227 let frag_a = Fragment {
228 ident: FragmentIdent::A,
229 m: shape.m,
230 n: shape.n,
231 k: shape.k,
232 elem: ab_elem,
233 layout: Some(FragmentLayout::ColMajor),
234 };
235 let frag_b = Fragment {
236 ident: FragmentIdent::B,
237 layout: Some(FragmentLayout::RowMajor),
238 ..frag_a
239 };
240 let frag_cd = Fragment {
241 ident: FragmentIdent::Accumulator,
242 elem: cd_elem,
243 ..frag_b
244 };
245 WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
246 }
247
248 pub fn fn_name(&self) -> String {
249 format!(
250 "wmma_execute_16x16x16_{}_{}",
251 self.frag_a.elem, self.frag_c.elem
252 )
253 }
254
255 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 let name = self.fn_name();
257 let ab_format = match self.frag_a.elem {
258 Elem::F32 => "f32",
259 Elem::BF16 => "bf16",
260 Elem::F16 => "f16",
261 _ => panic!(),
262 };
263 let (cd_format, opsel) = match self.frag_c.elem {
264 Elem::F32 => ("f32", ""),
265 Elem::BF16 => ("bf16", ", false"),
266 Elem::F16 => ("f16", ", false"),
267 _ => panic!(),
268 };
269 let warp_size = 32;
270 write!(
271 f,
272 "
273// Execute wmma.
274__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
275 frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
276}}
277 ", self.frag_a, self.frag_b, self.frag_c, self.frag_d
278 )
279 }
280}
281
282impl<D: Dialect> WmmaCast<D> {
283 pub fn fn_name(&self) -> String {
284 let layout = frag_layout_str(&self.frag_input.layout);
285 let ident = frag_ident_str(&self.frag_input.ident);
286 let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
287 let elem = self.frag_input.elem;
288 let elem_out = self.frag_output.elem;
289
290 format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
291 }
292
293 pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
294 let input = self.frag_input;
295 let output = self.frag_output;
296 let name = self.fn_name();
297 let step = match output.ident {
298 FragmentIdent::Accumulator => {
299 get_output_accumulator_index_step(&self.frag_input.elem, &output)
300 }
301 _ => 1,
302 };
303
304 write!(
305 f,
306 "
307// Cast the fragment.
308__device__ void {name}({input}& input, {output}& output) {{
309 #pragma unroll
310 for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
311 output[elemIdx * {step}] = input[elemIdx];
312 }}
313}}
314 "
315 )
316 }
317}
318
319impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
320 fn compile_wmma_type_definitions(
321 f: &mut std::fmt::Formatter<'_>,
322 flags: &Flags,
323 ) -> std::fmt::Result {
324 if flags.elem_bf16 {
325 f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
326 f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
327 }
328 if flags.elem_f16 {
329 f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
330 f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
331 }
332 f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
333 }
334
335 fn compile_wmma_fragment_declaration(
336 f: &mut std::fmt::Formatter<'_>,
337 var: &crate::shared::Variable<HipDialect<Self>>,
338 ) -> std::fmt::Result {
339 wmma_api_base::compile_fragment_declaration(f, var)
340 }
341
342 fn compile_wmma_fragment(
343 f: &mut std::fmt::Formatter<'_>,
344 fragment: &Fragment<HipDialect<Self>>,
345 ) -> std::fmt::Result {
346 match fragment.ident {
347 FragmentIdent::A | FragmentIdent::B => match fragment.elem {
348 Elem::F16 => write!(f, "half16_t"),
349 Elem::BF16 => write!(f, "bhalf16_t"),
350 other => panic!("unsupported type {other} for {fragment}"),
351 },
352 FragmentIdent::Accumulator => match fragment.elem {
353 Elem::F16 => write!(f, "half16_t"),
354 Elem::BF16 => write!(f, "bhalf16_t"),
355 Elem::F32 => write!(f, "float8_t"),
356 other => panic!("unsupported type {other} for {fragment}"),
357 },
358 FragmentIdent::_Dialect(_) => Ok(()),
359 }
360 }
361
362 fn compile_wmma_instruction(
363 f: &mut std::fmt::Formatter<'_>,
364 instruction: &WmmaInstruction<HipDialect<Self>>,
365 ) -> std::fmt::Result {
366 match instruction {
367 WmmaInstruction::Fill { frag, value } => {
368 let extension = WmmaFill::new(match frag {
369 Variable::WmmaFragment { frag, .. } => *frag,
370 _ => panic!(),
371 });
372 let name = extension.fn_name();
373 writeln!(f, "{name}({frag}, {value});")
374 }
375 WmmaInstruction::Load {
376 frag,
377 value,
378 layout,
379 offset,
380 stride,
381 } => {
382 let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
383 let name = extension.fn_name();
384 let value_ptr = frag_as_ptr(f, value, offset);
385 writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
386 }
387 WmmaInstruction::LdMatrix { .. } => {
388 unimplemented!("Not supported in HIP")
389 }
390 WmmaInstruction::Execute {
391 frag_a,
392 frag_b,
393 frag_c,
394 frag_d,
395 warp_size,
396 } => {
397 assert_eq!(*warp_size, 32, "Only warp size of 32 supported");
398
399 let extension = WmmaExecute::new(
400 variable_to_frag(frag_a),
401 variable_to_frag(frag_b),
402 variable_to_frag(frag_c),
403 variable_to_frag(frag_d),
404 );
405 let name = extension.fn_name();
406 writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
407 }
408 WmmaInstruction::ExecuteManual {
409 shape,
410 frag_a,
411 frag_b,
412 frag_c,
413 frag_d,
414 } => {
415 Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
416 }
417 WmmaInstruction::ExecuteScaled {
418 shape,
419 frag_a,
420 frag_b,
421 frag_c,
422 frag_d,
423 scales_a,
424 scales_b,
425 scales_factor,
426 } => Self::compile_scaled_mma(
427 f,
428 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
429 *scales_a,
430 *scales_b,
431 *scales_factor,
432 ),
433 WmmaInstruction::Store {
434 output,
435 frag,
436 layout,
437 offset,
438 stride,
439 } => {
440 let extension = WmmaStore::new(variable_to_frag(frag), *layout);
441 let name = extension.fn_name();
442 let output_ptr = frag_as_ptr(f, output, offset);
443 writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
444 }
445 WmmaInstruction::Cast { input, output } => {
446 let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
447 let name = extension.fn_name();
448 writeln!(f, "{name}({input}, {output});")
449 }
450 }
451 }
452
453 fn compile_manual_mma(
454 f: &mut std::fmt::Formatter<'_>,
455 mma: ManualMma<HipDialect<Self>>,
456 ) -> std::fmt::Result {
457 compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
458 }
459
460 fn compile_scaled_mma(
461 _f: &mut std::fmt::Formatter<'_>,
462 _mma: ManualMma<HipDialect<Self>>,
463 _scales_a: Variable<HipDialect<Self>>,
464 _scales_b: Variable<HipDialect<Self>>,
465 _scales_factor: u32,
466 ) -> std::fmt::Result {
467 unimplemented!("Not supported in HIP")
468 }
469
470 fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
471 let mut result: SupportedMmaCombinations = vec![];
473 if arch.is_wmma_capable() {
474 let types = vec![
476 (
477 gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
481 (
482 gpu::ElemType::Float(gpu::FloatKind::F16),
483 gpu::ElemType::Float(gpu::FloatKind::F16),
484 gpu::ElemType::Float(gpu::FloatKind::F32),
485 ),
486 (
487 gpu::ElemType::Float(gpu::FloatKind::BF16),
488 gpu::ElemType::Float(gpu::FloatKind::BF16),
489 gpu::ElemType::Float(gpu::FloatKind::F32),
490 ),
491 ];
492 let combinations: SupportedMmaCombinations = types
493 .into_iter()
494 .map(|(a, b, c)| MmaConfig {
495 a_type: a.into(),
496 b_type: b.into(),
497 cd_type: c.into(),
498 m: 16,
499 n: 16,
500 k: 16,
501 })
502 .collect();
503 result.extend(combinations);
504 }
505 result
506 }
507
508 fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
509 supported_mma_combinations(arch)
510 }
511}
512
513fn get_output_accumulator_index_step<D: Dialect>(
514 input_elem: &Elem<D>,
515 output: &Fragment<D>,
516) -> u32 {
517 assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
525
526 match input_elem {
527 Elem::F16 | Elem::BF16 | Elem::F32 => {
528 match output.elem {
529 Elem::F16 | Elem::BF16 => 2,
531 Elem::F32 => 1,
533 other => panic!("unsupported format {other} for {output}"),
534 }
535 }
536 other => panic!("unsupported format {other} for {input_elem}"),
537 }
538}
539
540pub(super) fn compile_manual_mma<D: Dialect>(
541 f: &mut std::fmt::Formatter<'_>,
542 shape: MmaShape<D>,
543 frag_a: &Variable<D>,
544 frag_b: &Variable<D>,
545 frag_c: &Variable<D>,
546 frag_d: &Variable<D>,
547) -> std::fmt::Result {
548 let extension = WmmaExecute::from_manual(shape, frag_a.elem(), frag_c.elem());
549
550 let d_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
551
552 let frag_d_len = d_elems as usize / (32 / frag_d.elem().unpacked().size_bits());
553
554 let frag_a_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
556 let frag_b_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
557 let frag_c_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
558 let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
559
560 let name = extension.fn_name();
562 writeln!(f, "{ty} {frag_a_tmp};", ty = extension.frag_a)?;
563 writeln!(f, "memcpy(&{frag_a_tmp}, {frag_a}, sizeof({frag_a_tmp}));")?;
564 writeln!(f, "{ty} {frag_b_tmp};", ty = extension.frag_b)?;
565 writeln!(f, "memcpy(&{frag_b_tmp}, {frag_b}, sizeof({frag_b_tmp}));")?;
566 writeln!(f, "{ty} {frag_c_tmp};", ty = extension.frag_c)?;
567 writeln!(f, "memcpy(&{frag_c_tmp}, {frag_c}, sizeof({frag_c_tmp}));")?;
568 writeln!(f, "{ty} {frag_d_tmp} = {ty}{{}};", ty = extension.frag_d)?;
569 writeln!(
570 f,
571 "{name}({frag_a_tmp}, {frag_b_tmp}, {frag_c_tmp}, {frag_d_tmp});"
572 )?;
573
574 for _ in 0..frag_d_len {
575 writeln!(f, "memcpy({frag_d}, &{frag_d_tmp}, sizeof({frag_d_tmp}));")?;
576 }
577
578 Ok(())
579}
580
581pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
582 const ENABLED: bool = true;
584
585 if !ENABLED {
586 return Vec::new();
587 }
588
589 let mut result: SupportedMmaCombinations = vec![];
592 if arch.is_wmma_capable() {
593 let types = vec![
595 (
596 gpu::ElemType::Float(gpu::FloatKind::F16),
597 gpu::ElemType::Float(gpu::FloatKind::F32),
598 ),
599 (
600 gpu::ElemType::Float(gpu::FloatKind::BF16),
601 gpu::ElemType::Float(gpu::FloatKind::F32),
602 ),
603 ];
604 let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
605 a_type: ab_elem.into(),
606 b_type: ab_elem.into(),
607 cd_type: cd_elem.into(),
608 m: 16,
609 n: 16,
610 k: 16,
611 });
612 result.extend(combinations);
613 }
614 result
615}
616
617static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";