1use cubecl_runtime::{DeviceProperties, MmaConfig, ScaledMmaConfig};
2use std::fmt::{Display, Formatter};
3use std::{fmt::Debug, marker::PhantomData};
4
5use super::{Component, Dialect, Elem, FmtLeft, Variable};
6
7pub type SupportedMmaCombinations = Vec<MmaConfig>;
8pub type SupportedScaledMmaCombinations = Vec<ScaledMmaConfig>;
9
10pub trait Architecture {
11 fn warp_size(&self) -> u32;
12 fn is_wmma_capable(&self) -> bool;
13 fn is_mfma_capable(&self) -> bool;
14 fn get_version(&self) -> u32 {
15 0
16 }
17}
18
19pub fn register_wmma_features(
20 supported_combinations: SupportedMmaCombinations,
21 properties: &mut DeviceProperties,
22) {
23 for config in supported_combinations {
24 properties.features.cmma.insert(config);
25 }
26}
27
28pub fn register_mma_features(
29 supported_combinations: SupportedMmaCombinations,
30 properties: &mut DeviceProperties,
31) {
32 for config in supported_combinations {
33 properties.features.mma.insert(config);
34 }
35}
36
37pub fn register_scaled_mma_features(
38 supported_combinations: SupportedScaledMmaCombinations,
39 properties: &mut DeviceProperties,
40) {
41 for config in supported_combinations {
42 properties.features.scaled_mma.insert(config);
43 }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Copy)]
47pub enum FragmentIdent<D: Dialect> {
48 A,
49 B,
50 Accumulator,
51 _Dialect(PhantomData<D>),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Copy)]
55pub enum FragmentLayout<D: Dialect> {
56 ColMajor,
57 RowMajor,
58 _Dialect(PhantomData<D>),
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Copy)]
62pub struct Fragment<D: Dialect> {
63 pub ident: FragmentIdent<D>,
64 pub m: u32,
65 pub n: u32,
66 pub k: u32,
67 pub elem: Elem<D>,
68 pub layout: Option<FragmentLayout<D>>,
69}
70
71#[derive(new, Debug, Clone, PartialEq, Eq, Copy)]
72pub struct MmaShape<D: Dialect> {
73 pub m: u32,
74 pub n: u32,
75 pub k: u32,
76 _d: PhantomData<D>,
77}
78
79impl<D: Dialect> MmaShape<D> {
80 pub fn num_elems(&self, ident: FragmentIdent<D>) -> u32 {
81 match ident {
82 FragmentIdent::A => self.m * self.k,
83 FragmentIdent::B => self.k * self.n,
84 FragmentIdent::Accumulator => self.m * self.n,
85 _ => unimplemented!(),
86 }
87 }
88}
89
90#[derive(Debug, Clone, PartialEq)]
92pub enum WmmaInstruction<D: Dialect> {
93 Fill {
95 frag: Variable<D>,
96 value: Variable<D>,
97 },
98 Load {
100 frag: Variable<D>,
101 value: Variable<D>,
102 offset: Variable<D>,
103 stride: Variable<D>,
104 layout: Option<FragmentLayout<D>>,
105 },
106 Execute {
110 frag_a: Variable<D>,
111 frag_b: Variable<D>,
112 frag_c: Variable<D>,
113 frag_d: Variable<D>,
114 warp_size: u32,
115 },
116 ExecuteManual {
123 shape: MmaShape<D>,
124 frag_a: Variable<D>,
125 frag_b: Variable<D>,
126 frag_c: Variable<D>,
127 frag_d: Variable<D>,
128 },
129 ExecuteScaled {
136 shape: MmaShape<D>,
137 frag_a: Variable<D>,
138 frag_b: Variable<D>,
139 frag_c: Variable<D>,
140 frag_d: Variable<D>,
141
142 scales_a: Variable<D>,
143 scales_b: Variable<D>,
144 scales_factor: u32,
145 },
146 Store {
148 output: Variable<D>,
149 frag: Variable<D>,
150 stride: Variable<D>,
151 offset: Variable<D>,
152 layout: FragmentLayout<D>,
153 },
154 LdMatrix {
156 output: Variable<D>,
157 buffer: Variable<D>,
158 offset: Variable<D>,
159 line_size: Option<u32>,
160 factor: u32,
161 transpose: bool,
162 },
163 Cast {
165 input: Variable<D>,
166 output: Variable<D>,
167 },
168}
169
170impl<D: Dialect> Display for FragmentLayout<D> {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 D::compile_wmma_fragment_layout(f, self)
173 }
174}
175
176impl<D: Dialect> Display for FragmentIdent<D> {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 D::compile_wwma_fragment_ident(f, self)
179 }
180}
181
182impl<D: Dialect> Display for Fragment<D> {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 D::compile_wmma_fragment(f, self)
185 }
186}
187
188impl<D: Dialect> Display for WmmaInstruction<D> {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 D::compile_wmma_instruction(f, self)
191 }
192}
193
194pub mod wmma_api_base {
195 use crate::{cuda::ptx::ldmatrix_call, shared::ManualMma};
196
197 use super::*;
198
199 pub fn compile_fragment_declaration<D: Dialect>(
200 f: &mut std::fmt::Formatter<'_>,
201 var: &Variable<D>,
202 ) -> std::fmt::Result {
203 match var {
204 Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
205 _ => panic!("variable must be a fragment"),
206 }
207 }
208
209 pub fn compile_fragment_ident<D: Dialect>(
210 f: &mut std::fmt::Formatter<'_>,
211 namespace: &str,
212 ident: &FragmentIdent<D>,
213 ) -> std::fmt::Result {
214 match ident {
215 FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
216 FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
217 FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
218 FragmentIdent::_Dialect(_) => Ok(()),
219 }
220 }
221
222 pub fn compile_fragment_layout<D: Dialect>(
223 f: &mut std::fmt::Formatter<'_>,
224 namespace: &str,
225 layout: &FragmentLayout<D>,
226 ) -> std::fmt::Result {
227 match layout {
228 FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
229 FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
230 FragmentLayout::_Dialect(_) => Ok(()),
231 }
232 }
233
234 pub fn compile_fragment<D: Dialect>(
235 f: &mut std::fmt::Formatter<'_>,
236 namespace: &str,
237 fragment: &Fragment<D>,
238 ) -> std::fmt::Result {
239 let elem = match fragment.elem {
240 Elem::TF32 => format!("{namespace}::precision::tf32"),
241 Elem::BF16 => {
242 if fragment.ident == FragmentIdent::Accumulator {
243 format!("{}", Elem::<D>::F16) } else {
245 format!("{}", fragment.elem)
246 }
247 }
248 elem => format!("{elem}"),
249 };
250 match fragment.layout {
251 Some(layout) => write!(
252 f,
253 "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
254 fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
255 ),
256 None => write!(
257 f,
258 "{namespace}::fragment<{}, {}, {}, {}, {}>",
259 fragment.ident, fragment.m, fragment.n, fragment.k, elem,
260 ),
261 }
262 }
263
264 pub fn compile_instruction<D: Dialect>(
265 f: &mut std::fmt::Formatter<'_>,
266 namespace: &str,
267 instruction: &WmmaInstruction<D>,
268 ) -> std::fmt::Result {
269 match instruction {
270 WmmaInstruction::Fill { frag, value } => {
271 writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
272 }
273 WmmaInstruction::Load {
274 frag,
275 value,
276 stride,
277 offset,
278 layout: None,
279 } => {
280 let item = value.item();
281 if item.vectorization > 1 {
282 let elem = item.elem;
283 let qualifier = value.const_qualifier();
284 writeln!(
285 f,
286 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
287 )
288 } else {
289 writeln!(
290 f,
291 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
292 )
293 }
294 }
295 WmmaInstruction::Load {
296 frag,
297 value,
298 offset,
299 stride,
300 layout: Some(layout),
301 } => {
302 let layout = match layout {
303 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
304 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
305 FragmentLayout::_Dialect(_) => "".to_string(),
306 };
307 let item = value.item();
308 if item.vectorization > 1 {
309 let elem = item.elem;
310 writeln!(
311 f,
312 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
313 )
314 } else {
315 writeln!(
316 f,
317 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
318 )
319 }
320 }
321 WmmaInstruction::LdMatrix {
322 output,
323 buffer,
324 offset,
325 line_size,
326 factor,
327 transpose,
328 } => f.write_str(&ldmatrix_call(
329 output, buffer, offset, line_size, factor, transpose,
330 )),
331 WmmaInstruction::Execute {
332 frag_a,
333 frag_b,
334 frag_c,
335 frag_d,
336 ..
337 } => writeln!(
338 f,
339 "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
340 ),
341 WmmaInstruction::Store {
342 output,
343 frag,
344 stride,
345 offset,
346 layout,
347 } => {
348 let layout = match layout {
349 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
350 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
351 FragmentLayout::_Dialect(_) => "".to_string(),
352 };
353
354 let item = output.item();
355 let mut reinterpret_cast = item.vectorization > 1;
356 let elem = match item.elem {
357 Elem::BF16 => {
358 reinterpret_cast = true;
359 Elem::F16
360 }
361 _ => item.elem,
362 };
363 if reinterpret_cast {
364 writeln!(
365 f,
366 "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
367 )
368 } else {
369 writeln!(
370 f,
371 "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
372 )
373 }
374 }
375 WmmaInstruction::Cast { input, output } => {
376 let ty = match output {
377 Variable::WmmaFragment { frag, .. } => frag.elem,
378 _ => panic!("Should be a fragment"),
379 };
380 match ty {
381 Elem::BF16 => {
382 let elem = Elem::<D>::F16;
383 write!(
384 f,
385 "// cast
386for(int t=0; t<{input}.num_elements; t++) {{
387 {ty} elem = {ty}({input}.x[t]);
388 {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
389}}
390"
391 )
392 }
393 _ => {
394 write!(
395 f,
396 "// cast
397for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
398"
399 )
400 }
401 }
402 }
403 WmmaInstruction::ExecuteManual {
404 shape,
405 frag_a,
406 frag_b,
407 frag_c,
408 frag_d,
409 } => D::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d)),
410 WmmaInstruction::ExecuteScaled {
411 shape,
412 frag_a,
413 frag_b,
414 frag_c,
415 frag_d,
416 scales_a,
417 scales_b,
418 scales_factor,
419 } => D::compile_scaled_mma(
420 f,
421 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
422 *scales_a,
423 *scales_b,
424 *scales_factor,
425 ),
426 }
427 }
428}
429
430pub fn frag_as_ptr<D: Dialect>(
431 f: &mut Formatter<'_>,
432 frag: &Variable<D>,
433 offset: &Variable<D>,
434) -> Variable<D> {
435 let item = frag.item();
436 let mut frag_ptr = Variable::tmp_ptr(item);
437 if frag.is_const() {
438 frag_ptr.to_const();
439 }
440 let frag_ptr_out = frag_ptr.fmt_left();
441 writeln!(f, "{frag_ptr_out} = {frag} + {offset};").unwrap();
442
443 if item.vectorization > 1 {
444 let mut item_value = item;
445 item_value.vectorization = 1;
446 frag_ptr.reinterpret_ptr(f, item_value)
447 } else {
448 frag_ptr
449 }
450}
451
452pub fn frag_ident_str<D: Dialect>(frag: &FragmentIdent<D>) -> &str {
453 match frag {
454 FragmentIdent::A => "a",
455 FragmentIdent::B => "b",
456 FragmentIdent::Accumulator => "c",
457 FragmentIdent::_Dialect(_) => "d",
458 }
459}
460
461pub fn frag_layout_str<D: Dialect>(frag: &Option<FragmentLayout<D>>) -> &str {
462 match frag {
463 Some(layout) => match layout {
464 FragmentLayout::ColMajor => "col",
465 FragmentLayout::RowMajor => "row",
466 FragmentLayout::_Dialect(_) => "",
467 },
468 None => "",
469 }
470}
471
472pub fn variable_to_frag<D: Dialect>(frag: &Variable<D>) -> Fragment<D> {
473 match frag {
474 Variable::WmmaFragment { frag, .. } => *frag,
475 _ => panic!(),
476 }
477}