1use cubecl_runtime::{DeviceProperties, MmaConfig, ScaledMmaConfig};
2use std::fmt::{Display, Formatter};
3use std::{fmt::Debug, marker::PhantomData};
4
5use super::{Component, Dialect, Elem, FmtLeft, Variable};
6
7pub type SupportedMmaCombinations = Vec<MmaConfig>;
8pub type SupportedScaledMmaCombinations = Vec<ScaledMmaConfig>;
9
10pub trait Architecture {
11 fn warp_size(&self) -> u32;
12 fn is_wmma_capable(&self) -> bool;
13 fn is_mfma_capable(&self) -> bool;
14 fn get_version(&self) -> u32 {
15 0
16 }
17}
18
19pub fn register_wmma_features(
20 supported_combinations: SupportedMmaCombinations,
21 properties: &mut DeviceProperties,
22) {
23 for config in supported_combinations {
24 properties.features.cmma.insert(config);
25 }
26}
27
28pub fn register_mma_features(
29 supported_combinations: SupportedMmaCombinations,
30 properties: &mut DeviceProperties,
31) {
32 for config in supported_combinations {
33 properties.features.mma.insert(config);
34 }
35}
36
37pub fn register_scaled_mma_features(
38 supported_combinations: SupportedScaledMmaCombinations,
39 properties: &mut DeviceProperties,
40) {
41 for config in supported_combinations {
42 properties.features.scaled_mma.insert(config);
43 }
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Copy)]
47pub enum FragmentIdent<D: Dialect> {
48 A,
49 B,
50 Accumulator,
51 _Dialect(PhantomData<D>),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Copy)]
55pub enum FragmentLayout<D: Dialect> {
56 ColMajor,
57 RowMajor,
58 _Dialect(PhantomData<D>),
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Copy)]
62pub struct Fragment<D: Dialect> {
63 pub ident: FragmentIdent<D>,
64 pub m: u32,
65 pub n: u32,
66 pub k: u32,
67 pub elem: Elem<D>,
68 pub layout: Option<FragmentLayout<D>>,
69}
70
71#[derive(new, Debug, Clone, PartialEq, Eq, Copy)]
72pub struct MmaShape<D: Dialect> {
73 pub m: u32,
74 pub n: u32,
75 pub k: u32,
76 _d: PhantomData<D>,
77}
78
79impl<D: Dialect> MmaShape<D> {
80 pub fn num_elems(&self, ident: FragmentIdent<D>) -> u32 {
81 match ident {
82 FragmentIdent::A => self.m * self.k,
83 FragmentIdent::B => self.k * self.n,
84 FragmentIdent::Accumulator => self.m * self.n,
85 _ => unimplemented!(),
86 }
87 }
88}
89
90#[derive(Debug, Clone, PartialEq)]
92pub enum WmmaInstruction<D: Dialect> {
93 Fill {
95 frag: Variable<D>,
96 value: Variable<D>,
97 },
98 Load {
100 frag: Variable<D>,
101 value: Variable<D>,
102 offset: Variable<D>,
103 stride: Variable<D>,
104 layout: Option<FragmentLayout<D>>,
105 },
106 Execute {
110 frag_a: Variable<D>,
111 frag_b: Variable<D>,
112 frag_c: Variable<D>,
113 frag_d: Variable<D>,
114 warp_size: u32,
115 },
116 ExecuteManual {
123 shape: MmaShape<D>,
124 frag_a: Vec<Variable<D>>,
125 frag_b: Vec<Variable<D>>,
126 frag_c: Vec<Variable<D>>,
127 frag_d: Variable<D>,
128 },
129 ExecuteScaled {
136 shape: MmaShape<D>,
137 frag_a: Vec<Variable<D>>,
138 frag_b: Vec<Variable<D>>,
139 frag_c: Vec<Variable<D>>,
140 frag_d: Variable<D>,
141
142 scales_a: Variable<D>,
143 scales_b: Variable<D>,
144 scales_factor: u32,
145 },
146 Store {
148 output: Variable<D>,
149 frag: Variable<D>,
150 stride: Variable<D>,
151 offset: Variable<D>,
152 layout: FragmentLayout<D>,
153 },
154 Cast {
156 input: Variable<D>,
157 output: Variable<D>,
158 },
159}
160
161impl<D: Dialect> Display for FragmentLayout<D> {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 D::compile_wmma_fragment_layout(f, self)
164 }
165}
166
167impl<D: Dialect> Display for FragmentIdent<D> {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 D::compile_wwma_fragment_ident(f, self)
170 }
171}
172
173impl<D: Dialect> Display for Fragment<D> {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 D::compile_wmma_fragment(f, self)
176 }
177}
178
179impl<D: Dialect> Display for WmmaInstruction<D> {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 D::compile_wmma_instruction(f, self)
182 }
183}
184
185pub mod wmma_api_base {
186 use crate::shared::ManualMma;
187
188 use super::*;
189
190 pub fn compile_fragment_declaration<D: Dialect>(
191 f: &mut std::fmt::Formatter<'_>,
192 var: &Variable<D>,
193 ) -> std::fmt::Result {
194 match var {
195 Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
196 _ => panic!("variable must be a fragment"),
197 }
198 }
199
200 pub fn compile_fragment_ident<D: Dialect>(
201 f: &mut std::fmt::Formatter<'_>,
202 namespace: &str,
203 ident: &FragmentIdent<D>,
204 ) -> std::fmt::Result {
205 match ident {
206 FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
207 FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
208 FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
209 FragmentIdent::_Dialect(_) => Ok(()),
210 }
211 }
212
213 pub fn compile_fragment_layout<D: Dialect>(
214 f: &mut std::fmt::Formatter<'_>,
215 namespace: &str,
216 layout: &FragmentLayout<D>,
217 ) -> std::fmt::Result {
218 match layout {
219 FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
220 FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
221 FragmentLayout::_Dialect(_) => Ok(()),
222 }
223 }
224
225 pub fn compile_fragment<D: Dialect>(
226 f: &mut std::fmt::Formatter<'_>,
227 namespace: &str,
228 fragment: &Fragment<D>,
229 ) -> std::fmt::Result {
230 let elem = match fragment.elem {
231 Elem::TF32 => format!("{namespace}::precision::tf32"),
232 Elem::BF16 => {
233 if fragment.ident == FragmentIdent::Accumulator {
234 format!("{}", Elem::<D>::F16) } else {
236 format!("{}", fragment.elem)
237 }
238 }
239 elem => format!("{elem}"),
240 };
241 match fragment.layout {
242 Some(layout) => write!(
243 f,
244 "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
245 fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
246 ),
247 None => write!(
248 f,
249 "{namespace}::fragment<{}, {}, {}, {}, {}>",
250 fragment.ident, fragment.m, fragment.n, fragment.k, elem,
251 ),
252 }
253 }
254
255 pub fn compile_instruction<D: Dialect>(
256 f: &mut std::fmt::Formatter<'_>,
257 namespace: &str,
258 instruction: &WmmaInstruction<D>,
259 ) -> std::fmt::Result {
260 match instruction {
261 WmmaInstruction::Fill { frag, value } => {
262 writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
263 }
264 WmmaInstruction::Load {
265 frag,
266 value,
267 stride,
268 offset,
269 layout: None,
270 } => {
271 let item = value.item();
272 if item.vectorization > 1 {
273 let elem = item.elem;
274 let qualifier = value.const_qualifier();
275 writeln!(
276 f,
277 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
278 )
279 } else {
280 writeln!(
281 f,
282 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
283 )
284 }
285 }
286 WmmaInstruction::Load {
287 frag,
288 value,
289 offset,
290 stride,
291 layout: Some(layout),
292 } => {
293 let layout = match layout {
294 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
295 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
296 FragmentLayout::_Dialect(_) => "".to_string(),
297 };
298 let item = value.item();
299 if item.vectorization > 1 {
300 let elem = item.elem;
301 writeln!(
302 f,
303 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
304 )
305 } else {
306 writeln!(
307 f,
308 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
309 )
310 }
311 }
312 WmmaInstruction::Execute {
313 frag_a,
314 frag_b,
315 frag_c,
316 frag_d,
317 ..
318 } => writeln!(
319 f,
320 "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
321 ),
322 WmmaInstruction::Store {
323 output,
324 frag,
325 stride,
326 offset,
327 layout,
328 } => {
329 let layout = match layout {
330 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
331 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
332 FragmentLayout::_Dialect(_) => "".to_string(),
333 };
334
335 let item = output.item();
336 let mut reinterpret_cast = item.vectorization > 1;
337 let elem = match item.elem {
338 Elem::BF16 => {
339 reinterpret_cast = true;
340 Elem::F16
341 }
342 _ => item.elem,
343 };
344 if reinterpret_cast {
345 writeln!(
346 f,
347 "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
348 )
349 } else {
350 writeln!(
351 f,
352 "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
353 )
354 }
355 }
356 WmmaInstruction::Cast { input, output } => {
357 let ty = match output {
358 Variable::WmmaFragment { frag, .. } => frag.elem,
359 _ => panic!("Should be a fragment"),
360 };
361 match ty {
362 Elem::BF16 => {
363 let elem = Elem::<D>::F16;
364 write!(
365 f,
366 "// cast
367for(int t=0; t<{input}.num_elements; t++) {{
368 {ty} elem = {ty}({input}.x[t]);
369 {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
370}}
371"
372 )
373 }
374 _ => {
375 write!(
376 f,
377 "// cast
378for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
379"
380 )
381 }
382 }
383 }
384 WmmaInstruction::ExecuteManual {
385 shape,
386 frag_a,
387 frag_b,
388 frag_c,
389 frag_d,
390 } => D::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d)),
391 WmmaInstruction::ExecuteScaled {
392 shape,
393 frag_a,
394 frag_b,
395 frag_c,
396 frag_d,
397 scales_a,
398 scales_b,
399 scales_factor,
400 } => D::compile_scaled_mma(
401 f,
402 ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
403 *scales_a,
404 *scales_b,
405 *scales_factor,
406 ),
407 }
408 }
409}
410
411pub fn frag_as_ptr<D: Dialect>(
412 f: &mut Formatter<'_>,
413 frag: &Variable<D>,
414 offset: &Variable<D>,
415) -> Variable<D> {
416 let item = frag.item();
417 let mut frag_ptr = Variable::tmp_ptr(item);
418 if frag.is_const() {
419 frag_ptr.to_const();
420 }
421 let frag_ptr_out = frag_ptr.fmt_left();
422 writeln!(f, "{frag_ptr_out} = {frag} + {offset};").unwrap();
423
424 if item.vectorization > 1 {
425 let mut item_value = item;
426 item_value.vectorization = 1;
427 frag_ptr.reinterpret_ptr(f, item_value)
428 } else {
429 frag_ptr
430 }
431}
432
433pub fn frag_ident_str<D: Dialect>(frag: &FragmentIdent<D>) -> &str {
434 match frag {
435 FragmentIdent::A => "a",
436 FragmentIdent::B => "b",
437 FragmentIdent::Accumulator => "c",
438 FragmentIdent::_Dialect(_) => "d",
439 }
440}
441
442pub fn frag_layout_str<D: Dialect>(frag: &Option<FragmentLayout<D>>) -> &str {
443 match frag {
444 Some(layout) => match layout {
445 FragmentLayout::ColMajor => "col",
446 FragmentLayout::RowMajor => "row",
447 FragmentLayout::_Dialect(_) => "",
448 },
449 None => "",
450 }
451}
452
453pub fn variable_to_frag<D: Dialect>(frag: &Variable<D>) -> Fragment<D> {
454 match frag {
455 Variable::WmmaFragment { frag, .. } => *frag,
456 _ => panic!(),
457 }
458}