1use cubecl_runtime::{DeviceProperties, MmaConfig, ScaledMmaConfig};
2use std::fmt::{Display, Formatter};
3use std::{fmt::Debug, marker::PhantomData};
4
5use super::{Component, Dialect, Elem, FmtLeft, Variable};
6
7pub type SupportedMmaCombinations = Vec<MmaConfig>;
8pub type SupportedScaledMmaCombinations = Vec<ScaledMmaConfig>;
9
10pub trait Architecture {
11 fn warp_size(&self) -> u32;
12 fn is_wmma_capable(&self) -> bool;
13 fn is_mfma_capable(&self) -> bool;
14 fn get_version(&self) -> u32 {
15 0
16 }
17}
18
19pub fn register_wmma_features(
20 supported_combinations: SupportedMmaCombinations,
21 properties: &mut DeviceProperties,
22) {
23 for config in supported_combinations {
24 properties.features.cmma.insert(config);
25 }
26}
27
28pub fn register_mma_features(
29 supported_combinations: SupportedMmaCombinations,
30 properties: &mut DeviceProperties,
31) {
32 for config in supported_combinations {
33 properties.features.mma.insert(config);
34 }
35}
36
37pub fn register_scaled_mma_features(
38 supported_combinations: SupportedScaledMmaCombinations,
39 properties: &mut DeviceProperties,
40) {
41 for config in supported_combinations {
42 properties.features.scaled_mma.insert(config);
43 }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Copy)]
47pub enum FragmentIdent<D: Dialect> {
48 A,
49 B,
50 Accumulator,
51 _Dialect(PhantomData<D>),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Copy)]
55pub enum FragmentLayout<D: Dialect> {
56 ColMajor,
57 RowMajor,
58 _Dialect(PhantomData<D>),
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Copy)]
62pub struct Fragment<D: Dialect> {
63 pub ident: FragmentIdent<D>,
64 pub m: u32,
65 pub n: u32,
66 pub k: u32,
67 pub elem: Elem<D>,
68 pub layout: Option<FragmentLayout<D>>,
69}
70
71#[derive(new, Debug, Clone, PartialEq, Eq, Copy)]
72pub struct MmaShape<D: Dialect> {
73 pub m: u32,
74 pub n: u32,
75 pub k: u32,
76 _d: PhantomData<D>,
77}
78
79impl<D: Dialect> MmaShape<D> {
80 pub fn num_elems(&self, ident: FragmentIdent<D>) -> u32 {
81 match ident {
82 FragmentIdent::A => self.m * self.k,
83 FragmentIdent::B => self.k * self.n,
84 FragmentIdent::Accumulator => self.m * self.n,
85 _ => unimplemented!(),
86 }
87 }
88}
89
90#[derive(Debug, Clone, PartialEq)]
92pub enum WmmaInstruction<D: Dialect> {
93 Fill {
95 frag: Variable<D>,
96 value: Variable<D>,
97 },
98 Load {
100 frag: Variable<D>,
101 value: Variable<D>,
102 offset: Variable<D>,
103 stride: Variable<D>,
104 layout: Option<FragmentLayout<D>>,
105 },
106 Execute {
110 frag_a: Variable<D>,
111 frag_b: Variable<D>,
112 frag_c: Variable<D>,
113 frag_d: Variable<D>,
114 warp_size: u32,
115 },
116 ExecuteManual {
123 shape: MmaShape<D>,
124 frag_a: Variable<D>,
125 frag_b: Variable<D>,
126 frag_c: Variable<D>,
127 frag_d: Variable<D>,
128 },
129 ExecuteScaled {
136 shape: MmaShape<D>,
137 frag_a: Variable<D>,
138 frag_b: Variable<D>,
139 frag_c: Variable<D>,
140 frag_d: Variable<D>,
141
142 scales_a: Variable<D>,
143 scales_b: Variable<D>,
144 scales_factor: u32,
145 },
146 Store {
148 output: Variable<D>,
149 frag: Variable<D>,
150 stride: Variable<D>,
151 offset: Variable<D>,
152 layout: FragmentLayout<D>,
153 },
154 LdMatrix {
156 output: Variable<D>,
157 buffer: Variable<D>,
158 offset: Variable<D>,
159 line_size: Option<u32>,
160 factor: u32,
161 transpose: bool,
162 },
163 StMatrix {
165 registers: Variable<D>,
166 buffer: Variable<D>,
167 offset: Variable<D>,
168 line_size: Option<u32>,
169 factor: u32,
170 transpose: bool,
171 },
172 Cast {
174 input: Variable<D>,
175 output: Variable<D>,
176 },
177}
178
179impl<D: Dialect> Display for FragmentLayout<D> {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 D::compile_wmma_fragment_layout(f, self)
182 }
183}
184
185impl<D: Dialect> Display for FragmentIdent<D> {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 D::compile_wwma_fragment_ident(f, self)
188 }
189}
190
191impl<D: Dialect> Display for Fragment<D> {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 D::compile_wmma_fragment(f, self)
194 }
195}
196
197impl<D: Dialect> Display for WmmaInstruction<D> {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 D::compile_wmma_instruction(f, self)
200 }
201}
202
203pub mod wmma_api_base {
204 use crate::{
205 cuda::ptx::{ldmatrix_call, stmatrix_call},
206 shared::ManualMma,
207 };
208
209 use super::*;
210
211 pub fn compile_fragment_declaration<D: Dialect>(
212 f: &mut std::fmt::Formatter<'_>,
213 var: &Variable<D>,
214 ) -> std::fmt::Result {
215 match var {
216 Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
217 _ => panic!("variable must be a fragment"),
218 }
219 }
220
221 pub fn compile_fragment_ident<D: Dialect>(
222 f: &mut std::fmt::Formatter<'_>,
223 namespace: &str,
224 ident: &FragmentIdent<D>,
225 ) -> std::fmt::Result {
226 match ident {
227 FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
228 FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
229 FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
230 FragmentIdent::_Dialect(_) => Ok(()),
231 }
232 }
233
234 pub fn compile_fragment_layout<D: Dialect>(
235 f: &mut std::fmt::Formatter<'_>,
236 namespace: &str,
237 layout: &FragmentLayout<D>,
238 ) -> std::fmt::Result {
239 match layout {
240 FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
241 FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
242 FragmentLayout::_Dialect(_) => Ok(()),
243 }
244 }
245
246 pub fn compile_fragment<D: Dialect>(
247 f: &mut std::fmt::Formatter<'_>,
248 namespace: &str,
249 fragment: &Fragment<D>,
250 ) -> std::fmt::Result {
251 let elem = match fragment.elem {
252 Elem::TF32 => format!("{namespace}::precision::tf32"),
253 Elem::BF16 => {
254 if fragment.ident == FragmentIdent::Accumulator {
255 format!("{}", Elem::<D>::F16) } else {
257 format!("{}", fragment.elem)
258 }
259 }
260 elem => format!("{elem}"),
261 };
262 match fragment.layout {
263 Some(layout) => write!(
264 f,
265 "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
266 fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
267 ),
268 None => write!(
269 f,
270 "{namespace}::fragment<{}, {}, {}, {}, {}>",
271 fragment.ident, fragment.m, fragment.n, fragment.k, elem,
272 ),
273 }
274 }
275
276 pub fn compile_instruction<D: Dialect>(
277 f: &mut std::fmt::Formatter<'_>,
278 namespace: &str,
279 instruction: &WmmaInstruction<D>,
280 ) -> std::fmt::Result {
281 match instruction {
282 WmmaInstruction::Fill { frag, value } => {
283 writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
284 }
285 WmmaInstruction::Load {
286 frag,
287 value,
288 stride,
289 offset,
290 layout: None,
291 } => {
292 let item = value.item();
293 if item.vectorization > 1 {
294 let elem = item.elem;
295 let qualifier = value.const_qualifier();
296 writeln!(
297 f,
298 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
299 )
300 } else {
301 writeln!(
302 f,
303 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
304 )
305 }
306 }
307 WmmaInstruction::Load {
308 frag,
309 value,
310 offset,
311 stride,
312 layout: Some(layout),
313 } => {
314 let layout = match layout {
315 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
316 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
317 FragmentLayout::_Dialect(_) => "".to_string(),
318 };
319 let item = value.item();
320 if item.vectorization > 1 {
321 let elem = item.elem;
322 writeln!(
323 f,
324 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
325 )
326 } else {
327 writeln!(
328 f,
329 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
330 )
331 }
332 }
333 WmmaInstruction::LdMatrix {
334 output,
335 buffer,
336 offset,
337 line_size,
338 factor,
339 transpose,
340 } => f.write_str(&ldmatrix_call(
341 output, buffer, offset, line_size, factor, transpose,
342 )),
343 WmmaInstruction::StMatrix {
344 registers,
345 buffer,
346 offset,
347 line_size,
348 factor,
349 transpose,
350 } => f.write_str(&stmatrix_call(
351 registers, buffer, offset, line_size, factor, transpose,
352 )),
353 WmmaInstruction::Execute {
354 frag_a,
355 frag_b,
356 frag_c,
357 frag_d,
358 ..
359 } => writeln!(
360 f,
361 "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
362 ),
363 WmmaInstruction::Store {
364 output,
365 frag,
366 stride,
367 offset,
368 layout,
369 } => {
370 let layout = match layout {
371 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
372 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
373 FragmentLayout::_Dialect(_) => "".to_string(),
374 };
375
376 let item = output.item();
377 let mut reinterpret_cast = item.vectorization > 1;
378 let elem = match item.elem {
379 Elem::BF16 => {
380 reinterpret_cast = true;
381 Elem::F16
382 }
383 _ => item.elem,
384 };
385 if reinterpret_cast {
386 writeln!(
387 f,
388 "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
389 )
390 } else {
391 writeln!(
392 f,
393 "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
394 )
395 }
396 }
397 WmmaInstruction::Cast { input, output } => {
398 let ty = match output {
399 Variable::WmmaFragment { frag, .. } => frag.elem,
400 _ => panic!("Should be a fragment"),
401 };
402 match ty {
403 Elem::BF16 => {
404 let elem = Elem::<D>::F16;
405 write!(
406 f,
407 "// cast
408for(int t=0; t<{input}.num_elements; t++) {{
409 {ty} elem = {ty}({input}.x[t]);
410 {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
411}}
412"
413 )
414 }
415 _ => {
416 write!(
417 f,
418 "// cast
419for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
420"
421 )
422 }
423 }
424 }
425 WmmaInstruction::ExecuteManual {
426 shape,
427 frag_a,
428 frag_b,
429 frag_c,
430 frag_d,
431 } => D::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d)),
432 WmmaInstruction::ExecuteScaled {
433 shape,
434 frag_a,
435 frag_b,
436 frag_c,
437 frag_d,
438 scales_a,
439 scales_b,
440 scales_factor,
441 } => D::compile_scaled_mma(
442 f,
443 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
444 *scales_a,
445 *scales_b,
446 *scales_factor,
447 ),
448 }
449 }
450}
451
452pub fn frag_as_ptr<D: Dialect>(
453 f: &mut Formatter<'_>,
454 frag: &Variable<D>,
455 offset: &Variable<D>,
456) -> Variable<D> {
457 let item = frag.item();
458 let mut frag_ptr = Variable::tmp_ptr(item);
459 if frag.is_const() {
460 frag_ptr.to_const();
461 }
462 let frag_ptr_out = frag_ptr.fmt_left();
463 writeln!(f, "{frag_ptr_out} = {frag} + {offset};").unwrap();
464
465 if item.vectorization > 1 {
466 let mut item_value = item;
467 item_value.vectorization = 1;
468 frag_ptr.reinterpret_ptr(f, item_value)
469 } else {
470 frag_ptr
471 }
472}
473
474pub fn frag_ident_str<D: Dialect>(frag: &FragmentIdent<D>) -> &str {
475 match frag {
476 FragmentIdent::A => "a",
477 FragmentIdent::B => "b",
478 FragmentIdent::Accumulator => "c",
479 FragmentIdent::_Dialect(_) => "d",
480 }
481}
482
483pub fn frag_layout_str<D: Dialect>(frag: &Option<FragmentLayout<D>>) -> &str {
484 match frag {
485 Some(layout) => match layout {
486 FragmentLayout::ColMajor => "col",
487 FragmentLayout::RowMajor => "row",
488 FragmentLayout::_Dialect(_) => "",
489 },
490 None => "",
491 }
492}
493
494pub fn variable_to_frag<D: Dialect>(frag: &Variable<D>) -> Fragment<D> {
495 match frag {
496 Variable::WmmaFragment { frag, .. } => *frag,
497 _ => panic!(),
498 }
499}