1use cubecl_core::Feature;
2use cubecl_core::ir::{self as gpu};
3use cubecl_runtime::DeviceProperties;
4use std::fmt::Display;
5use std::str::FromStr;
6use std::{fmt::Debug, marker::PhantomData};
7
8use super::{Component, Dialect, Elem, Variable};
9
10pub type SupportedWmmaCombinations = Vec<(gpu::Elem, gpu::Elem, gpu::Elem, Vec<(u8, u8, u8)>)>;
11
12pub trait Architecture: FromStr<Err = String> {
13 fn warp_size(&self) -> u32;
14 fn is_wmma_capable(&self) -> bool;
15 fn is_mfma_capable(&self) -> bool;
16}
17
18pub fn register_wmma_features(
19 supported_combinations: SupportedWmmaCombinations,
20 properties: &mut DeviceProperties<Feature>,
21) {
22 for (i, o, c, tdims) in supported_combinations {
23 for (m, n, k) in tdims {
24 properties.register_feature(Feature::Cmma {
25 a: i,
26 b: o,
27 c,
28 m,
29 n,
30 k,
31 });
32 }
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Copy)]
37pub enum FragmentIdent<D: Dialect> {
38 A,
39 B,
40 Accumulator,
41 _Dialect(PhantomData<D>),
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Copy)]
45pub enum FragmentLayout<D: Dialect> {
46 ColMajor,
47 RowMajor,
48 _Dialect(PhantomData<D>),
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Copy)]
52pub struct Fragment<D: Dialect> {
53 pub ident: FragmentIdent<D>,
54 pub m: u8,
55 pub n: u8,
56 pub k: u8,
57 pub elem: Elem<D>,
58 pub layout: Option<FragmentLayout<D>>,
59}
60
61#[derive(Debug, Clone, Copy)]
63pub enum WmmaInstruction<D: Dialect> {
64 Fill {
66 frag: Variable<D>,
67 value: Variable<D>,
68 },
69 Load {
71 frag: Variable<D>,
72 value: Variable<D>,
73 stride: Variable<D>,
74 layout: Option<FragmentLayout<D>>,
75 },
76 Execute {
80 frag_a: Variable<D>,
81 frag_b: Variable<D>,
82 frag_c: Variable<D>,
83 frag_d: Variable<D>,
84 warp_size: u32,
85 },
86 Store {
88 output: Variable<D>,
89 frag: Variable<D>,
90 stride: Variable<D>,
91 layout: FragmentLayout<D>,
92 },
93 Cast {
95 input: Variable<D>,
96 output: Variable<D>,
97 },
98}
99
100impl<D: Dialect> Display for FragmentLayout<D> {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 D::compile_fragment_layout(self, f)
103 }
104}
105
106impl<D: Dialect> Display for FragmentIdent<D> {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 D::compile_fragment_ident(self, f)
109 }
110}
111
112impl<D: Dialect> Display for Fragment<D> {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 D::compile_fragment(self, f)
115 }
116}
117
118impl<D: Dialect> Display for WmmaInstruction<D> {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 D::compile_instruction(self, f)
121 }
122}
123
124pub mod wmma_api_base {
125 use super::*;
126
127 pub fn compile_fragment_ident<D: Dialect>(
128 namespace: &str,
129 ident: &FragmentIdent<D>,
130 f: &mut std::fmt::Formatter<'_>,
131 ) -> std::fmt::Result {
132 match ident {
133 FragmentIdent::A => write!(f, "{namespace}::matrix_a"),
134 FragmentIdent::B => write!(f, "{namespace}::matrix_b"),
135 FragmentIdent::Accumulator => write!(f, "{namespace}::accumulator"),
136 FragmentIdent::_Dialect(_) => Ok(()),
137 }
138 }
139
140 pub fn compile_fragment_layout<D: Dialect>(
141 namespace: &str,
142 layout: &FragmentLayout<D>,
143 f: &mut std::fmt::Formatter<'_>,
144 ) -> std::fmt::Result {
145 match layout {
146 FragmentLayout::ColMajor => f.write_str(format!("{namespace}::col_major").as_str()),
147 FragmentLayout::RowMajor => f.write_str(format!("{namespace}::row_major").as_str()),
148 FragmentLayout::_Dialect(_) => Ok(()),
149 }
150 }
151
152 pub fn compile_fragment<D: Dialect>(
153 namespace: &str,
154 fragment: &Fragment<D>,
155 f: &mut std::fmt::Formatter<'_>,
156 ) -> std::fmt::Result {
157 let elem = match fragment.elem {
158 Elem::TF32 => format!("{namespace}::precision::tf32"),
159 Elem::BF16 => {
160 if fragment.ident == FragmentIdent::Accumulator {
161 format!("{}", Elem::<D>::F16) } else {
163 format!("{}", fragment.elem)
164 }
165 }
166 elem => format!("{elem}"),
167 };
168 match fragment.layout {
169 Some(layout) => write!(
170 f,
171 "{namespace}::fragment<{}, {}, {}, {}, {}, {}>",
172 fragment.ident, fragment.m, fragment.n, fragment.k, elem, layout
173 ),
174 None => write!(
175 f,
176 "{namespace}::fragment<{}, {}, {}, {}, {}>",
177 fragment.ident, fragment.m, fragment.n, fragment.k, elem,
178 ),
179 }
180 }
181
182 pub fn compile_instruction<D: Dialect>(
183 namespace: &str,
184 instruction: &WmmaInstruction<D>,
185 f: &mut std::fmt::Formatter<'_>,
186 ) -> std::fmt::Result {
187 match instruction {
188 WmmaInstruction::Fill { frag, value } => {
189 writeln!(f, "{namespace}::fill_fragment({frag}, {value});")
190 }
191
192 WmmaInstruction::Load {
193 frag,
194 value,
195 stride,
196 layout: None,
197 } => {
198 let item = value.item();
199 if item.vectorization > 1 {
200 let elem = item.elem;
201 writeln!(
202 f,
203 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value}), {stride});"
204 )
205 } else {
206 writeln!(
207 f,
208 "{namespace}::load_matrix_sync({frag}, {value}, {stride});"
209 )
210 }
211 }
212
213 WmmaInstruction::Load {
214 frag,
215 value,
216 stride,
217 layout: Some(layout),
218 } => {
219 let layout = match layout {
220 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
221 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
222 FragmentLayout::_Dialect(_) => "".to_string(),
223 };
224 let item = value.item();
225 if item.vectorization > 1 {
226 let elem = item.elem;
227 writeln!(
228 f,
229 "{namespace}::load_matrix_sync({frag}, reinterpret_cast<{elem} *>({value}), {stride}, {layout});"
230 )
231 } else {
232 writeln!(
233 f,
234 "{namespace}::load_matrix_sync({frag}, {value}, {stride}, {layout});"
235 )
236 }
237 }
238
239 WmmaInstruction::Execute {
240 frag_a,
241 frag_b,
242 frag_c,
243 frag_d,
244 ..
245 } => writeln!(
246 f,
247 "{namespace}::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});"
248 ),
249
250 WmmaInstruction::Store {
251 output,
252 frag,
253 stride,
254 layout,
255 } => {
256 let layout = match layout {
257 FragmentLayout::ColMajor => format!("{namespace}::mem_col_major"),
258 FragmentLayout::RowMajor => format!("{namespace}::mem_row_major"),
259 FragmentLayout::_Dialect(_) => "".to_string(),
260 };
261
262 let item = output.item();
263 let mut reinterpret_cast = item.vectorization > 1;
264 let elem = match item.elem {
265 Elem::BF16 => {
266 reinterpret_cast = true;
267 Elem::F16
268 }
269 _ => item.elem,
270 };
271 if reinterpret_cast {
272 writeln!(
273 f,
274 "{namespace}::store_matrix_sync(reinterpret_cast<{elem} *>({output}), {frag}, {stride}, {layout});"
275 )
276 } else {
277 writeln!(
278 f,
279 "{namespace}::store_matrix_sync({output}, {frag}, {stride}, {layout});"
280 )
281 }
282 }
283 WmmaInstruction::Cast { input, output } => {
284 let ty = match output {
285 Variable::WmmaFragment { frag, .. } => frag.elem,
286 _ => panic!("Should be a fragment"),
287 };
288 match ty {
289 Elem::BF16 => {
290 let elem = Elem::<D>::F16;
291 writeln!(
292 f,
293 "for(int t=0; t<{input}.num_elements; t++) {{
294 {ty} elem = {ty}({input}.x[t]);
295 {output}.x[t] = *reinterpret_cast<{elem} *>(&elem);
296 }}"
297 )
298 }
299 _ => {
300 writeln!(
301 f,
302 "for(int t=0; t<{input}.num_elements; t++) {{ {output}.x[t] = {ty}({input}.x[t]); }}"
303 )
304 }
305 }
306 }
307 }
308 }
309}