1use cubecl_core::ir::{self as gpu};
2use cubecl_core::Feature;
3use cubecl_runtime::DeviceProperties;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7use std::{fmt::Debug, marker::PhantomData};
8
9use super::{Component, Dialect, Elem, Variable};
10
11pub type SupportedWmmaCombinations = Vec<(gpu::Elem, gpu::Elem, gpu::Elem, Vec<(u8, u8, u8)>)>;
12
13pub trait Architecture: FromStr<Err = String> {
14 fn warp_size(&self) -> u32;
15 fn is_wmma_capable(&self) -> bool;
16 fn is_mfma_capable(&self) -> bool;
17}
18
19pub trait WmmaCompiler<D: Dialect>:
20 Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
21{
22 type Architecture: Architecture;
23
24 fn wmma_includes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
25 fn deftypes(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
26 fn local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
27
28 fn compile_fragment_ident(
29 ident: &FragmentIdent<D>,
30 f: &mut std::fmt::Formatter<'_>,
31 ) -> std::fmt::Result;
32
33 fn compile_fragment_layout(
34 layout: &FragmentLayout<D>,
35 f: &mut std::fmt::Formatter<'_>,
36 ) -> std::fmt::Result;
37
38 fn compile_fragment(
39 fragment: &Fragment<D>,
40 f: &mut std::fmt::Formatter<'_>,
41 ) -> std::fmt::Result;
42
43 fn compile_instruction(
44 instruction: &WmmaInstruction<D>,
45 f: &mut std::fmt::Formatter<'_>,
46 ) -> std::fmt::Result;
47
48 fn supported_wmma_combinations(arch: &Self::Architecture) -> SupportedWmmaCombinations;
49}
50
51pub fn register_wmma_features(
52 supported_combinations: SupportedWmmaCombinations,
53 properties: &mut DeviceProperties<Feature>,
54) {
55 for (i, o, c, tdims) in supported_combinations {
56 for (m, n, k) in tdims {
57 properties.register_feature(Feature::Cmma {
58 a: i,
59 b: o,
60 c,
61 m,
62 n,
63 k,
64 });
65 }
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Copy)]
70pub enum FragmentIdent<D: Dialect> {
71 A,
72 B,
73 Accumulator,
74 _Dialect(PhantomData<D>),
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Copy)]
78pub enum FragmentLayout<D: Dialect> {
79 ColMajor,
80 RowMajor,
81 _Dialect(PhantomData<D>),
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, Copy)]
85pub struct Fragment<D: Dialect> {
86 pub ident: FragmentIdent<D>,
87 pub m: u8,
88 pub n: u8,
89 pub k: u8,
90 pub elem: Elem<D>,
91 pub layout: Option<FragmentLayout<D>>,
92}
93
94#[derive(Debug, Clone, Copy)]
96pub enum WmmaInstruction<D: Dialect> {
97 Fill {
99 frag: Variable<D>,
100 value: Variable<D>,
101 },
102 Load {
104 frag: Variable<D>,
105 value: Variable<D>,
106 stride: Variable<D>,
107 layout: Option<FragmentLayout<D>>,
108 },
109 Execute {
113 frag_a: Variable<D>,
114 frag_b: Variable<D>,
115 frag_c: Variable<D>,
116 frag_d: Variable<D>,
117 warp_size: u32,
118 },
119 Store {
121 output: Variable<D>,
122 frag: Variable<D>,
123 stride: Variable<D>,
124 layout: FragmentLayout<D>,
125 },
126 Cast {
128 input: Variable<D>,
129 output: Variable<D>,
130 },
131}
132
133impl<D: Dialect> Display for FragmentLayout<D> {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 D::compile_fragment_layout(self, f)
136 }
137}
138
139impl<D: Dialect> Display for FragmentIdent<D> {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 D::compile_fragment_ident(self, f)
142 }
143}
144
145impl<D: Dialect> Display for Fragment<D> {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 D::compile_fragment(self, f)
148 }
149}
150
151impl<D: Dialect> Display for WmmaInstruction<D> {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 D::compile_instruction(self, f)
154 }
155}
156
157pub mod wmma_api_base {
158 use super::*;
159
160 pub fn compile_fragment_ident<D: Dialect>(
161 namespace: &str,
162 ident: &FragmentIdent<D>,
163 f: &mut std::fmt::Formatter<'_>,
164 ) -> std::fmt::Result {
165 match ident {
166 FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
167 FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
168 FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
169 FragmentIdent::_Dialect(_) => Ok(()),
170 }
171 }
172
173 pub fn compile_fragment_layout<D: Dialect>(
174 namespace: &str,
175 layout: &FragmentLayout<D>,
176 f: &mut std::fmt::Formatter<'_>,
177 ) -> std::fmt::Result {
178 match layout {
179 FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
180 FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
181 FragmentLayout::_Dialect(_) => Ok(()),
182 }
183 }
184
185 pub fn compile_fragment<D: Dialect>(
186 namespace: &str,
187 fragment: &Fragment<D>,
188 f: &mut std::fmt::Formatter<'_>,
189 ) -> std::fmt::Result {
190 let elem = match fragment.elem {
191 Elem::TF32 => format!("{namespace}::precision::tf32"),
192 Elem::BF16 => {
193 if fragment.ident == FragmentIdent::Accumulator {
194 format!("{}", Elem::<D>::F16) } else {
196 format!("{}", fragment.elem)
197 }
198 }
199 elem => format!("{elem}"),
200 };
201 match fragment.layout {
202 Some(layout) => write!(
203 f,
204 "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
205 fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
206 ),
207 None => write!(
208 f,
209 "{namespace}::fragment<{}, {}, {}, {}, {}>",
210 fragment.ident, fragment.m, fragment.n, fragment.k, elem,
211 ),
212 }
213 }
214
215 pub fn compile_instruction<D: Dialect>(
216 namespace: &str,
217 instruction: &WmmaInstruction<D>,
218 f: &mut std::fmt::Formatter<'_>,
219 ) -> std::fmt::Result {
220 match instruction {
221 WmmaInstruction::Fill { frag, value } => {
222 writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
223 }
224
225 WmmaInstruction::Load {
226 frag,
227 value,
228 stride,
229 layout: None,
230 } => {
231 let item = value.item();
232 if item.vectorization > 1 {
233 let elem = item.elem;
234 writeln!(f, "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value}), {stride});")
235 } else {
236 writeln!(
237 f,
238 "{namespace}::load_matrix_sync({frag}, {value}, {stride});"
239 )
240 }
241 }
242
243 WmmaInstruction::Load {
244 frag,
245 value,
246 stride,
247 layout: Some(layout),
248 } => {
249 let layout = match layout {
250 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
251 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
252 FragmentLayout::_Dialect(_) => "".to_string(),
253 };
254 let item = value.item();
255 if item.vectorization > 1 {
256 let elem = item.elem;
257 writeln!(f, "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value}), {stride}, {layout});")
258 } else {
259 writeln!(
260 f,
261 "{namespace}::load_matrix_sync({frag}, {value}, {stride}, {layout});"
262 )
263 }
264 }
265
266 WmmaInstruction::Execute {
267 frag_a,
268 frag_b,
269 frag_c,
270 frag_d,
271 ..
272 } => writeln!(
273 f,
274 "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
275 ),
276
277 WmmaInstruction::Store {
278 output,
279 frag,
280 stride,
281 layout,
282 } => {
283 let layout = match layout {
284 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
285 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
286 FragmentLayout::_Dialect(_) => "".to_string(),
287 };
288
289 let item = output.item();
290 let mut reinterpret_cast = item.vectorization > 1;
291 let elem = match item.elem {
292 Elem::BF16 => {
293 reinterpret_cast = true;
294 Elem::F16
295 }
296 _ => item.elem,
297 };
298 if reinterpret_cast {
299 writeln!(
300 f,
301 "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output}), {frag}, {stride}, {layout});"
302 )
303 } else {
304 writeln!(
305 f,
306 "{namespace}::store_matrix_sync({output}, {frag}, {stride}, {layout});"
307 )
308 }
309 }
310 WmmaInstruction::Cast { input, output } => {
311 let ty = match output {
312 Variable::WmmaFragment { frag, .. } => frag.elem,
313 _ => panic!("Should be a fragment"),
314 };
315 match ty {
316 Elem::BF16 => {
317 let elem = Elem::<D>::F16;
318 writeln!(
319 f,
320 "for(int t=0; t<{input}.num_elements; t++) {{
321 {ty} elem = {ty}({input}.x[t]);
322 {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
323 }}"
324 )
325 }
326 _ => {
327 writeln!(
328 f,
329 "for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}"
330 )
331 }
332 }
333 }
334 }
335 }
336}