1use cubecl_core::Feature;
2use cubecl_core::ir::{self as gpu};
3use cubecl_runtime::DeviceProperties;
4use std::fmt::Display;
5use std::{fmt::Debug, marker::PhantomData};
6
7use super::{Component, Dialect, Elem, Variable};
8
9pub type SupportedWmmaCombinations = Vec<(gpu::Elem, gpu::Elem, gpu::Elem, Vec<(u8, u8, u8)>)>;
10
11pub trait Architecture {
12 fn warp_size(&self) -> u32;
13 fn is_wmma_capable(&self) -> bool;
14 fn is_mfma_capable(&self) -> bool;
15 fn get_version(&self) -> u32 {
16 0
17 }
18}
19
20pub fn register_wmma_features(
21 supported_combinations: SupportedWmmaCombinations,
22 properties: &mut DeviceProperties<Feature>,
23) {
24 for (i, o, c, tdims) in supported_combinations {
25 for (m, n, k) in tdims {
26 properties.register_feature(Feature::Cmma {
27 a: i,
28 b: o,
29 c,
30 m,
31 n,
32 k,
33 });
34 }
35 }
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Copy)]
39pub enum FragmentIdent<D: Dialect> {
40 A,
41 B,
42 Accumulator,
43 _Dialect(PhantomData<D>),
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Copy)]
47pub enum FragmentLayout<D: Dialect> {
48 ColMajor,
49 RowMajor,
50 _Dialect(PhantomData<D>),
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Copy)]
54pub struct Fragment<D: Dialect> {
55 pub ident: FragmentIdent<D>,
56 pub m: u8,
57 pub n: u8,
58 pub k: u8,
59 pub elem: Elem<D>,
60 pub layout: Option<FragmentLayout<D>>,
61}
62
63#[derive(Debug, Clone, Copy)]
65pub enum WmmaInstruction<D: Dialect> {
66 Fill {
68 frag: Variable<D>,
69 value: Variable<D>,
70 },
71 Load {
73 frag: Variable<D>,
74 value: Variable<D>,
75 offset: Variable<D>,
76 stride: Variable<D>,
77 layout: Option<FragmentLayout<D>>,
78 },
79 Execute {
83 frag_a: Variable<D>,
84 frag_b: Variable<D>,
85 frag_c: Variable<D>,
86 frag_d: Variable<D>,
87 warp_size: u32,
88 },
89 Store {
91 output: Variable<D>,
92 frag: Variable<D>,
93 stride: Variable<D>,
94 offset: Variable<D>,
95 layout: FragmentLayout<D>,
96 },
97 Cast {
99 input: Variable<D>,
100 output: Variable<D>,
101 },
102}
103
104impl<D: Dialect> Display for FragmentLayout<D> {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 D::compile_wmma_fragment_layout(f, self)
107 }
108}
109
110impl<D: Dialect> Display for FragmentIdent<D> {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 D::compile_wwma_fragment_ident(f, self)
113 }
114}
115
116impl<D: Dialect> Display for Fragment<D> {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 D::compile_wmma_fragment(f, self)
119 }
120}
121
122impl<D: Dialect> Display for WmmaInstruction<D> {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 D::compile_wmma_instruction(f, self)
125 }
126}
127
128pub mod wmma_api_base {
129 use super::*;
130
131 pub fn compile_fragment_declaration<D: Dialect>(
132 f: &mut std::fmt::Formatter<'_>,
133 var: &Variable<D>,
134 ) -> std::fmt::Result {
135 match var {
136 Variable::WmmaFragment { frag, .. } => writeln!(f, "{frag} {var};"),
137 _ => panic!("variable must be a fragment"),
138 }
139 }
140
141 pub fn compile_fragment_ident<D: Dialect>(
142 f: &mut std::fmt::Formatter<'_>,
143 namespace: &str,
144 ident: &FragmentIdent<D>,
145 ) -> std::fmt::Result {
146 match ident {
147 FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
148 FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
149 FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
150 FragmentIdent::_Dialect(_) => Ok(()),
151 }
152 }
153
154 pub fn compile_fragment_layout<D: Dialect>(
155 f: &mut std::fmt::Formatter<'_>,
156 namespace: &str,
157 layout: &FragmentLayout<D>,
158 ) -> std::fmt::Result {
159 match layout {
160 FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
161 FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
162 FragmentLayout::_Dialect(_) => Ok(()),
163 }
164 }
165
166 pub fn compile_fragment<D: Dialect>(
167 f: &mut std::fmt::Formatter<'_>,
168 namespace: &str,
169 fragment: &Fragment<D>,
170 ) -> std::fmt::Result {
171 let elem = match fragment.elem {
172 Elem::TF32 => format!("{namespace}::precision::tf32"),
173 Elem::BF16 => {
174 if fragment.ident == FragmentIdent::Accumulator {
175 format!("{}", Elem::<D>::F16) } else {
177 format!("{}", fragment.elem)
178 }
179 }
180 elem => format!("{elem}"),
181 };
182 match fragment.layout {
183 Some(layout) => write!(
184 f,
185 "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
186 fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
187 ),
188 None => write!(
189 f,
190 "{namespace}::fragment<{}, {}, {}, {}, {}>",
191 fragment.ident, fragment.m, fragment.n, fragment.k, elem,
192 ),
193 }
194 }
195
196 pub fn compile_instruction<D: Dialect>(
197 f: &mut std::fmt::Formatter<'_>,
198 namespace: &str,
199 instruction: &WmmaInstruction<D>,
200 ) -> std::fmt::Result {
201 match instruction {
202 WmmaInstruction::Fill { frag, value } => {
203 writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
204 }
205
206 WmmaInstruction::Load {
207 frag,
208 value,
209 stride,
210 offset,
211 layout: None,
212 } => {
213 let item = value.item();
214 if item.vectorization > 1 {
215 let elem = item.elem;
216 let qualifier = value.const_qualifier();
217 writeln!(
218 f,
219 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem}{qualifier}*>({value} + {offset}), {stride});"
220 )
221 } else {
222 writeln!(
223 f,
224 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride});"
225 )
226 }
227 }
228
229 WmmaInstruction::Load {
230 frag,
231 value,
232 offset,
233 stride,
234 layout: Some(layout),
235 } => {
236 let layout = match layout {
237 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
238 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
239 FragmentLayout::_Dialect(_) => "".to_string(),
240 };
241 let item = value.item();
242 if item.vectorization > 1 {
243 let elem = item.elem;
244 writeln!(
245 f,
246 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value} + {offset}), {stride}, {layout});"
247 )
248 } else {
249 writeln!(
250 f,
251 "{namespace}::load_matrix_sync({frag}, {value} + {offset}, {stride}, {layout});"
252 )
253 }
254 }
255
256 WmmaInstruction::Execute {
257 frag_a,
258 frag_b,
259 frag_c,
260 frag_d,
261 ..
262 } => writeln!(
263 f,
264 "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
265 ),
266
267 WmmaInstruction::Store {
268 output,
269 frag,
270 stride,
271 offset,
272 layout,
273 } => {
274 let layout = match layout {
275 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
276 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
277 FragmentLayout::_Dialect(_) => "".to_string(),
278 };
279
280 let item = output.item();
281 let mut reinterpret_cast = item.vectorization > 1;
282 let elem = match item.elem {
283 Elem::BF16 => {
284 reinterpret_cast = true;
285 Elem::F16
286 }
287 _ => item.elem,
288 };
289 if reinterpret_cast {
290 writeln!(
291 f,
292 "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output} + {offset}), {frag}, {stride}, {layout});"
293 )
294 } else {
295 writeln!(
296 f,
297 "{namespace}::store_matrix_sync({output} + {offset}, {frag}, {stride}, {layout});"
298 )
299 }
300 }
301 WmmaInstruction::Cast { input, output } => {
302 let ty = match output {
303 Variable::WmmaFragment { frag, .. } => frag.elem,
304 _ => panic!("Should be a fragment"),
305 };
306 match ty {
307 Elem::BF16 => {
308 let elem = Elem::<D>::F16;
309 write!(
310 f,
311 "// cast
312for(int t=0; t<{input}.num_elements; t++) {{
313 {ty} elem = {ty}({input}.x[t]);
314 {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
315}}
316"
317 )
318 }
319 _ => {
320 write!(
321 f,
322 "// cast
323for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}
324"
325 )
326 }
327 }
328 }
329 }
330 }
331}