1use super::{Component, Dialect, Elem, FmtLeft, Variable};
2use cubecl_core::ir::{
3 DeviceProperties,
4 features::{MmaConfig, ScaledMmaConfig},
5};
6use std::{
7 fmt::{Debug, Display, Formatter},
8 marker::PhantomData,
9};
10
11pub type SupportedMmaCombinations = Vec<MmaConfig>;
12pub type SupportedScaledMmaCombinations = Vec<ScaledMmaConfig>;
13
14pub trait Architecture {
15 fn warp_size(&self) -> u32;
16 fn is_wmma_capable(&self) -> bool;
17 fn is_mfma_capable(&self) -> bool;
18 fn get_version(&self) -> u32 {
19 0
20 }
21}
22
23pub fn register_wmma_features(
24 supported_combinations: SupportedMmaCombinations,
25 properties: &mut DeviceProperties,
26) {
27 for config in supported_combinations {
28 properties.features.cmma.insert(config);
29 }
30}
31
32pub fn register_mma_features(
33 supported_combinations: SupportedMmaCombinations,
34 properties: &mut DeviceProperties,
35) {
36 for config in supported_combinations {
37 properties.features.mma.insert(config);
38 }
39}
40
41pub fn register_scaled_mma_features(
42 supported_combinations: SupportedScaledMmaCombinations,
43 properties: &mut DeviceProperties,
44) {
45 for config in supported_combinations {
46 properties.features.scaled_mma.insert(config);
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Copy)]
51pub enum FragmentIdent<D: Dialect> {
52 A,
53 B,
54 Accumulator,
55 _Dialect(PhantomData<D>),
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Copy)]
59pub enum FragmentLayout<D: Dialect> {
60 ColMajor,
61 RowMajor,
62 _Dialect(PhantomData<D>),
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Copy)]
66pub struct Fragment<D: Dialect> {
67 pub ident: FragmentIdent<D>,
68 pub m: u32,
69 pub n: u32,
70 pub k: u32,
71 pub elem: Elem<D>,
72 pub layout: Option<FragmentLayout<D>>,
73}
74
75#[derive(new, Debug, Clone, PartialEq, Eq, Copy)]
76pub struct MmaShape<D: Dialect> {
77 pub m: u32,
78 pub n: u32,
79 pub k: u32,
80 _d: PhantomData<D>,
81}
82
83impl<D: Dialect> MmaShape<D> {
84 pub fn num_elems(&self, ident: FragmentIdent<D>) -> u32 {
85 match ident {
86 FragmentIdent::A => self.m * self.k,
87 FragmentIdent::B => self.k * self.n,
88 FragmentIdent::Accumulator => self.m * self.n,
89 _ => unimplemented!(),
90 }
91 }
92}
93
94#[derive(Debug, Clone, PartialEq)]
96pub enum WmmaInstruction<D: Dialect> {
97 Fill {
99 frag: Variable<D>,
100 value: Variable<D>,
101 },
102 Load {
104 frag: Variable<D>,
105 value: Variable<D>,
106 offset: Variable<D>,
107 stride: Variable<D>,
108 layout: Option<FragmentLayout<D>>,
109 },
110 Execute {
114 frag_a: Variable<D>,
115 frag_b: Variable<D>,
116 frag_c: Variable<D>,
117 frag_d: Variable<D>,
118 warp_size: u32,
119 },
120 ExecuteManual {
127 shape: MmaShape<D>,
128 frag_a: Variable<D>,
129 frag_b: Variable<D>,
130 frag_c: Variable<D>,
131 frag_d: Variable<D>,
132 },
133 ExecuteScaled {
140 shape: MmaShape<D>,
141 frag_a: Variable<D>,
142 frag_b: Variable<D>,
143 frag_c: Variable<D>,
144 frag_d: Variable<D>,
145
146 scales_a: Variable<D>,
147 scales_b: Variable<D>,
148 scales_factor: u32,
149 },
150 Store {
152 output: Variable<D>,
153 frag: Variable<D>,
154 stride: Variable<D>,
155 offset: Variable<D>,
156 layout: FragmentLayout<D>,
157 },
158 LdMatrix {
160 output: Variable<D>,
161 buffer: Variable<D>,
162 offset: Variable<D>,
163 line_size: Option<usize>,
164 factor: u32,
165 transpose: bool,
166 },
167 StMatrix {
169 registers: Variable<D>,
170 buffer: Variable<D>,
171 offset: Variable<D>,
172 line_size: Option<usize>,
173 factor: u32,
174 transpose: bool,
175 },
176 Cast {
178 input: Variable<D>,
179 output: Variable<D>,
180 },
181}
182
183impl<D: Dialect> Display for FragmentLayout<D> {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 D::compile_wmma_fragment_layout(f, self)
186 }
187}
188
189impl<D: Dialect> Display for FragmentIdent<D> {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 D::compile_wwma_fragment_ident(f, self)
192 }
193}
194
195impl<D: Dialect> Display for Fragment<D> {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 D::compile_wmma_fragment(f, self)
198 }
199}
200
201impl<D: Dialect> Display for WmmaInstruction<D> {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 D::compile_wmma_instruction(f, self)
204 }
205}
206
207pub mod wmma_api_base {
208 use crate::{
209 cuda::ptx::{ldmatrix_call, stmatrix_call},
210 shared::ManualMma,
211 };
212
213 use super::*;
214
215 pub fn compile_fragment_declaration<D: Dialect>(
216 f: &mut std::fmt::Formatter<'_>,
217 var: &Variable<D>,
218 ) -> std::fmt::Result {
219 match var {
220 Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
221 _ => panic!("variable must be a fragment"),
222 }
223 }
224
225 pub fn compile_fragment_ident<D: Dialect>(
226 f: &mut std::fmt::Formatter<'_>,
227 namespace: &str,
228 ident: &FragmentIdent<D>,
229 ) -> std::fmt::Result {
230 match ident {
231 FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
232 FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
233 FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
234 FragmentIdent::_Dialect(_) => Ok(()),
235 }
236 }
237
238 pub fn compile_fragment_layout<D: Dialect>(
239 f: &mut std::fmt::Formatter<'_>,
240 namespace: &str,
241 layout: &FragmentLayout<D>,
242 ) -> std::fmt::Result {
243 match layout {
244 FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
245 FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
246 FragmentLayout::_Dialect(_) => Ok(()),
247 }
248 }
249
250 pub fn compile_fragment<D: Dialect>(
251 f: &mut std::fmt::Formatter<'_>,
252 namespace: &str,
253 fragment: &Fragment<D>,
254 ) -> std::fmt::Result {
255 let elem = match fragment.elem {
256 Elem::TF32 => format!("{namespace}::precision::tf32"),
257 Elem::BF16 => {
258 if fragment.ident == FragmentIdent::Accumulator {
259 format!("{}", Elem::<D>::F16) } else {
261 format!("{}", fragment.elem)
262 }
263 }
264 elem => format!("{elem}"),
265 };
266 match fragment.layout {
267 Some(layout) => write!(
268 f,
269 "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
270 fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
271 ),
272 None => write!(
273 f,
274 "{namespace}::fragment<{}, {}, {}, {}, {}>",
275 fragment.ident, fragment.m, fragment.n, fragment.k, elem,
276 ),
277 }
278 }
279
280 pub fn compile_instruction<D: Dialect>(
281 f: &mut std::fmt::Formatter<'_>,
282 namespace: &str,
283 instruction: &WmmaInstruction<D>,
284 ) -> std::fmt::Result {
285 match instruction {
286 WmmaInstruction::Fill { frag, value } => {
287 writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
288 }
289 WmmaInstruction::Load {
290 frag,
291 value,
292 stride,
293 offset,
294 layout: None,
295 } => {
296 let item = value.item();
297 if item.vectorization > 1 {
298 let elem = item.elem;
299 let qualifier = value.const_qualifier();
300 writeln!(
301 f,
302 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
303 )
304 } else {
305 writeln!(
306 f,
307 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
308 )
309 }
310 }
311 WmmaInstruction::Load {
312 frag,
313 value,
314 offset,
315 stride,
316 layout: Some(layout),
317 } => {
318 let layout = match layout {
319 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
320 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
321 FragmentLayout::_Dialect(_) => "".to_string(),
322 };
323 let item = value.item();
324 if item.vectorization > 1 {
325 let elem = item.elem;
326 writeln!(
327 f,
328 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
329 )
330 } else {
331 writeln!(
332 f,
333 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
334 )
335 }
336 }
337 WmmaInstruction::LdMatrix {
338 output,
339 buffer,
340 offset,
341 line_size,
342 factor,
343 transpose,
344 } => f.write_str(&ldmatrix_call(
345 output, buffer, offset, line_size, factor, transpose,
346 )),
347 WmmaInstruction::StMatrix {
348 registers,
349 buffer,
350 offset,
351 line_size,
352 factor,
353 transpose,
354 } => f.write_str(&stmatrix_call(
355 registers, buffer, offset, line_size, factor, transpose,
356 )),
357 WmmaInstruction::Execute {
358 frag_a,
359 frag_b,
360 frag_c,
361 frag_d,
362 ..
363 } => writeln!(
364 f,
365 "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
366 ),
367 WmmaInstruction::Store {
368 output,
369 frag,
370 stride,
371 offset,
372 layout,
373 } => {
374 let layout = match layout {
375 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
376 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
377 FragmentLayout::_Dialect(_) => "".to_string(),
378 };
379
380 let item = output.item();
381 let mut reinterpret_cast = item.vectorization > 1;
382 let elem = match item.elem {
383 Elem::BF16 => {
384 reinterpret_cast = true;
385 Elem::F16
386 }
387 _ => item.elem,
388 };
389 if reinterpret_cast {
390 writeln!(
391 f,
392 "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
393 )
394 } else {
395 writeln!(
396 f,
397 "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
398 )
399 }
400 }
401 WmmaInstruction::Cast { input, output } => {
402 let ty = match output {
403 Variable::WmmaFragment { frag, .. } => frag.elem,
404 _ => panic!("Should be a fragment"),
405 };
406 match ty {
407 Elem::BF16 => {
408 let elem = Elem::<D>::F16;
409 write!(
410 f,
411 "// cast
412for(int t=0; t<{input}.num_elements; t++) {{
413 {ty} elem = {ty}({input}.x[t]);
414 {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
415}}
416"
417 )
418 }
419 _ => {
420 write!(
421 f,
422 "// cast
423for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
424"
425 )
426 }
427 }
428 }
429 WmmaInstruction::ExecuteManual {
430 shape,
431 frag_a,
432 frag_b,
433 frag_c,
434 frag_d,
435 } => D::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d)),
436 WmmaInstruction::ExecuteScaled {
437 shape,
438 frag_a,
439 frag_b,
440 frag_c,
441 frag_d,
442 scales_a,
443 scales_b,
444 scales_factor,
445 } => D::compile_scaled_mma(
446 f,
447 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
448 *scales_a,
449 *scales_b,
450 *scales_factor,
451 ),
452 }
453 }
454}
455
456pub fn frag_as_ptr<D: Dialect>(
457 f: &mut Formatter<'_>,
458 frag: &Variable<D>,
459 offset: &Variable<D>,
460) -> Variable<D> {
461 let item = frag.item();
462 let mut frag_ptr = Variable::tmp_ptr(item);
463 if frag.is_const() {
464 frag_ptr.to_const();
465 }
466 let frag_ptr_out = frag_ptr.fmt_left();
467 writeln!(f, "{frag_ptr_out} = {frag} + {offset};").unwrap();
468
469 if item.vectorization > 1 {
470 let mut item_value = item;
471 item_value.vectorization = 1;
472 frag_ptr.reinterpret_ptr(f, item_value)
473 } else {
474 frag_ptr
475 }
476}
477
478pub fn frag_ident_str<D: Dialect>(frag: &FragmentIdent<D>) -> &str {
479 match frag {
480 FragmentIdent::A => "a",
481 FragmentIdent::B => "b",
482 FragmentIdent::Accumulator => "c",
483 FragmentIdent::_Dialect(_) => "d",
484 }
485}
486
487pub fn frag_layout_str<D: Dialect>(frag: &Option<FragmentLayout<D>>) -> &str {
488 match frag {
489 Some(layout) => match layout {
490 FragmentLayout::ColMajor => "col",
491 FragmentLayout::RowMajor => "row",
492 FragmentLayout::_Dialect(_) => "",
493 },
494 None => "",
495 }
496}
497
498pub fn variable_to_frag<D: Dialect>(frag: &Variable<D>) -> Fragment<D> {
499 match frag {
500 Variable::WmmaFragment { frag, .. } => *frag,
501 _ => panic!(),
502 }
503}