1use crate::{
2 Dialect,
3 shared::{Component, Elem, FP4Kind, FP6Kind, FP8Kind, Variable},
4};
5
6#[allow(clippy::too_many_arguments)]
7pub fn mma_template<D: Dialect>(
8 a_elem: Elem<D>,
9 b_elem: Elem<D>,
10 cd_elem: Elem<D>,
11 k: u32,
12 n_a_registers: usize,
13 n_b_registers: usize,
14 n_c_registers: usize,
15 n_d_registers: usize,
16) -> String {
17 let a_ty = mma_ty(a_elem);
18 let b_ty = mma_ty(b_elem);
19 let cd_ty = mma_ty(cd_elem);
20
21 let ab_arg_ty = match a_elem {
22 Elem::F32 => &format!("{}", Elem::<D>::F32),
23 _ => &format!("{}", Elem::<D>::U32),
24 };
25 let cd_arg_ty = match cd_elem {
26 Elem::F32 => &format!("{}", Elem::<D>::F32),
27 _ => &format!("{}", Elem::<D>::U32),
28 };
29
30 let args_a = (0..n_a_registers).map(|i| format!("{ab_arg_ty} const ®_a_{i}"));
31 let args_b = (0..n_b_registers).map(|i| format!("{ab_arg_ty} const ®_b_{i}"));
32 let args_c = (0..n_c_registers).map(|i| format!("{cd_arg_ty} const ®_c_{i}"));
33 let args_d = (0..n_d_registers).map(|i| format!("{cd_arg_ty} ®_d_{i}"));
34 let args = args_a
35 .chain(args_b)
36 .chain(args_c)
37 .chain(args_d)
38 .collect::<Vec<_>>()
39 .join(", ");
40
41 let kind = if is_fp6_fp4(a_elem) || is_fp6_fp4(b_elem) {
42 ".kind::f8f6f4"
43 } else {
44 ""
45 };
46
47 let mut idx = 0usize;
48
49 let placeholders_d = comma_separated((0..n_d_registers).map(|_| placeholder(&mut idx)));
50 let placeholders_d = format!("{{{placeholders_d}}}");
51
52 let placeholders_a = comma_separated((0..n_a_registers).map(|_| placeholder(&mut idx)));
53 let placeholders_a = format!("{{{placeholders_a}}}");
54
55 let placeholders_b = comma_separated((0..n_b_registers).map(|_| placeholder(&mut idx)));
56 let placeholders_b = format!("{{{placeholders_b}}}");
57
58 let placeholders_c = comma_separated((0..n_c_registers).map(|_| placeholder(&mut idx)));
59 let placeholders_c = format!("{{{placeholders_c}}}");
60
61 let params_out =
62 comma_separated((0..n_d_registers).map(|i| as_reg(&format!("reg_d_{i}"), cd_elem, true)));
63 let params_a = (0..n_a_registers).map(|i| as_reg(&format!("reg_a_{i}"), a_elem, false));
64 let params_b = (0..n_b_registers).map(|i| as_reg(&format!("reg_b_{i}"), b_elem, false));
65 let params_c = (0..n_c_registers).map(|i| as_reg(&format!("reg_c_{i}"), cd_elem, false));
66 let params_in = comma_separated(params_a.chain(params_b).chain(params_c));
67
68 format!(
69 r#"
70inline __device__ void
71__mma_m16n8k{k}_{a_elem}_{b_elem}_{cd_elem}({args}) {{
72 asm volatile("mma.sync.aligned.m16n8k{k}.row.col{kind}.{cd_ty}.{a_ty}.{b_ty}.{cd_ty}"
73 " {placeholders_d}, {placeholders_a}, {placeholders_b}, {placeholders_c};"
74 : {params_out}
75 : {params_in});
76 }}
77 "#
78 )
79}
80
81fn is_fp6_fp4<D: Dialect>(elem: Elem<D>) -> bool {
82 matches!(elem, Elem::<D>::FP4(_) | Elem::<D>::FP6(_))
83}
84
85#[allow(clippy::too_many_arguments)]
86pub fn mma_scaled_template<D: Dialect>(
87 a_elem: Elem<D>,
88 b_elem: Elem<D>,
89 cd_elem: Elem<D>,
90 k: u32,
91 n_a_registers: usize,
92 n_b_registers: usize,
93 n_c_registers: usize,
94 n_d_registers: usize,
95 scales_elem: Elem<D>,
96 scales_factor: u32,
97) -> String {
98 let a_ty = mma_ty(a_elem);
99 let b_ty = mma_ty(b_elem);
100 let cd_ty = mma_ty(cd_elem);
101 let s_ty = match scales_elem {
103 Elem::FP8(FP8Kind::UE8M0) => "ue8m0",
104 Elem::FP8(FP8Kind::E4M3) => "ue4m3",
105 _ => panic!("Unsupported scales type"),
106 };
107
108 let kind = match scales_factor {
109 1 => "mxf8f6f4",
110 2 | 4 => "mxf4nvf4",
111 _ => panic!("Unsupported scales factor"),
112 };
113
114 let ab_arg_ty = match a_elem {
115 Elem::F32 => &format!("{}", Elem::<D>::F32),
116 _ => &format!("{}", Elem::<D>::U32),
117 };
118 let cd_arg_ty = match cd_elem {
119 Elem::F32 => &format!("{}", Elem::<D>::F32),
120 _ => &format!("{}", Elem::<D>::U32),
121 };
122
123 let args_a = (0..n_a_registers).map(|i| format!("{ab_arg_ty} const ®_a_{i}"));
125 let args_b = (0..n_b_registers).map(|i| format!("{ab_arg_ty} const ®_b_{i}"));
126 let args_c = (0..n_c_registers).map(|i| format!("{cd_arg_ty} const ®_c_{i}"));
127 let args_d = (0..n_d_registers).map(|i| format!("{cd_arg_ty} ®_d_{i}"));
128 let args = args_a
129 .chain(args_b)
130 .chain(args_c)
131 .chain(args_d)
132 .collect::<Vec<_>>()
133 .join(", ");
134
135 let mut idx = 0usize;
136
137 let placeholders_d = comma_separated((0..n_d_registers).map(|_| placeholder(&mut idx)));
138 let placeholders_d = format!("{{{placeholders_d}}}");
139
140 let placeholders_a = comma_separated((0..n_a_registers).map(|_| placeholder(&mut idx)));
141 let placeholders_a = format!("{{{placeholders_a}}}");
142
143 let placeholders_b = comma_separated((0..n_b_registers).map(|_| placeholder(&mut idx)));
144 let placeholders_b = format!("{{{placeholders_b}}}");
145
146 let placeholders_c = comma_separated((0..n_c_registers).map(|_| placeholder(&mut idx)));
147 let placeholders_c = format!("{{{placeholders_c}}}");
148
149 let placeholder_scales_a = format!(
150 "{{{}}}, {{{}, {}}}",
151 placeholder(&mut idx),
152 placeholder(&mut idx),
153 placeholder(&mut idx)
154 );
155 let placeholder_scales_b = format!(
156 "{{{}}}, {{{}, {}}}",
157 placeholder(&mut idx),
158 placeholder(&mut idx),
159 placeholder(&mut idx)
160 );
161
162 let params_out =
163 comma_separated((0..n_d_registers).map(|i| as_reg(&format!("reg_d_{i}"), cd_elem, true)));
164 let params_a = (0..n_a_registers).map(|i| as_reg(&format!("reg_a_{i}"), a_elem, false));
165 let params_b = (0..n_b_registers).map(|i| as_reg(&format!("reg_b_{i}"), b_elem, false));
166 let params_c = (0..n_c_registers).map(|i| as_reg(&format!("reg_c_{i}"), cd_elem, false));
167 let params_in = comma_separated(params_a.chain(params_b).chain(params_c));
168
169 format!(
170 r#"
171inline __device__ void
172__mma_scaled_{scales_factor}x_m16n8k{k}_{a_elem}_{b_elem}_{cd_elem}({args}, uint32 const &scales_a, uint32 const &scales_b) {{
173 static constexpr uint16 tidA = 0;
174 static constexpr uint16 bidA = 0;
175 static constexpr uint16 tidB = 0;
176 static constexpr uint16 bidB = 0;
177
178 asm volatile("mma.sync.aligned.kind::{kind}.block_scale.scale_vec::{scales_factor}X.m16n8k{k}.row.col.{cd_ty}.{a_ty}.{b_ty}.{cd_ty}.{s_ty} "
179 "{placeholders_d}, {placeholders_a}, {placeholders_b}, {placeholders_c}, {placeholder_scales_a}, {placeholder_scales_b};"
180 : {params_out}
181 : {params_in}, "r"(scales_a), "h"(bidA), "h"(tidA), "r"(scales_b), "h"(bidB), "h"(tidB));
182 }}
183 "#
184 )
185}
186
187pub(crate) fn comma_separated(it: impl IntoIterator<Item = String>) -> String {
188 it.into_iter().collect::<Vec<_>>().join(", ")
189}
190
191fn placeholder(idx: &mut usize) -> String {
192 let placeholder = format!("%{idx}");
193 *idx += 1;
194 placeholder
195}
196
197fn as_reg<D: Dialect>(ident: &str, ty: Elem<D>, output: bool) -> String {
198 let ty = match ty {
199 Elem::F32 => "f",
200 Elem::F64 => "d",
201 Elem::U64 => "l",
202 _ => "r",
203 };
204 if output {
205 format!(r#""={ty}"({ident})"#)
206 } else {
207 format!(r#""{ty}"({ident})"#)
208 }
209}
210
211fn mma_ty<D: Dialect>(elem: Elem<D>) -> &'static str {
212 match elem {
213 Elem::TF32 => "tf32",
214 Elem::F32 => "f32",
215 Elem::F64 => "f64",
216 Elem::F16 => "f16",
217 Elem::BF16 => "bf16",
218 Elem::FP4(FP4Kind::E2M1) => "e2m1",
219 Elem::FP4x2(FP4Kind::E2M1) => "e2m1",
221 Elem::FP6(FP6Kind::E2M3) => "e2m3",
222 Elem::FP6(FP6Kind::E3M2) => "e3m2",
223 Elem::FP8(FP8Kind::E4M3) => "e4m3",
224 Elem::FP8(FP8Kind::E5M2) => "e5m2",
225 Elem::FP8(FP8Kind::UE8M0) => "ue8m0",
226 Elem::I8 => "s8",
227 Elem::I16 => "s16",
228 Elem::I32 => "s32",
229 Elem::I64 => "s64",
230 Elem::U8 => "u8",
231 Elem::U16 => "u16",
232 Elem::U32 => "u32",
233 Elem::U64 => "u64",
234 Elem::Bool => "b1",
235 other => panic!("{other} not supported for MMA"),
236 }
237}
238
239pub fn ldmatrix_call<D: Dialect>(
240 output: &Variable<D>,
241 buffer: &Variable<D>,
242 offset: &Variable<D>,
243 line_size: &Option<u32>,
244 factor: &u32,
245 transpose: &bool,
246) -> String {
247 let elem = output.elem();
248 let width = 16 / output.elem().size();
249 let is_transposed = if *transpose { "_trans" } else { "" };
250 let regs =
251 comma_separated((0..*factor).map(|i| format!("reinterpret_cast<uint32&>({output}[{i}])")));
252 let buffer = if let Some(line_size) = *line_size {
253 let mut item = buffer.item();
254 item.vectorization = line_size as usize;
255 format!("reinterpret_cast<{item}*>({})", buffer.fmt_ptr())
256 } else {
257 buffer.fmt_ptr()
258 };
259
260 format!("__ldmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({regs}, {buffer} + {offset});\n")
261}
262
263pub fn ldmatrix_template<D: Dialect>(elem: Elem<D>, factor: u32, transpose: bool) -> String {
264 let width = 16 / elem.size();
265 let arg_ty = Elem::<D>::U32;
266
267 let args_regs = (0..factor).map(|i| format!("{arg_ty} ®_{i}"));
268 let arg_addr = ["void const *row_addr".to_string()];
269 let args = args_regs.chain(arg_addr).collect::<Vec<_>>().join(", ");
270
271 let mut idx = 0usize;
272
273 let placeholders_regs = comma_separated((0..factor).map(|_| placeholder(&mut idx)));
274 let placeholders_regs = format!("{{{placeholders_regs}}}");
275
276 let placeholder_addr = format!("[{}]", placeholder(&mut idx));
277
278 let params_regs = comma_separated((0..factor).map(|i| format!(r#""=r"(reg_{i})"#)));
279 let param_addr = r#""r"(addr)"#;
280
281 let is_transposed = if transpose { "_trans" } else { "" };
282 let transposed_arg = if transpose { ".trans" } else { "" };
283 let num = format!("x{factor}");
284
285 let ty = match elem.size() {
286 2 => "b16",
287 1 => "b8",
288 _ => unreachable!(),
289 };
290
291 format!(
292 r#"
293inline __device__ void
294__ldmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({args}) {{
295 uint32 addr = static_cast<uint32>(__cvta_generic_to_shared(row_addr));
296 asm volatile("ldmatrix.sync.aligned.m8n{width}.{num}{transposed_arg}.shared::cta.{ty}"
297 " {placeholders_regs}, {placeholder_addr};"
298 : {params_regs}
299 : {param_addr});
300 }}
301 "#
302 )
303}
304
305pub fn stmatrix_call<D: Dialect>(
306 registers: &Variable<D>,
307 buffer: &Variable<D>,
308 offset: &Variable<D>,
309 line_size: &Option<u32>,
310 factor: &u32,
311 transpose: &bool,
312) -> String {
313 let elem = registers.elem();
314 let width = 16 / registers.elem().size();
315 let is_transposed = if *transpose { "_trans" } else { "" };
316 let regs = comma_separated(
317 (0..*factor).map(|i| format!("reinterpret_cast<uint32&>({registers}[{i}])")),
318 );
319 let buffer = if let Some(line_size) = *line_size {
320 let mut item = buffer.item();
321 item.vectorization = line_size as usize;
322 format!("reinterpret_cast<{item}*>({})", buffer.fmt_ptr())
323 } else {
324 buffer.fmt_ptr()
325 };
326
327 format!("__stmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({regs}, {buffer} + {offset});\n")
328}
329
330pub fn stmatrix_template<D: Dialect>(elem: Elem<D>, factor: u32, transpose: bool) -> String {
331 let width = 16 / elem.size();
332 let arg_ty = Elem::<D>::U32;
333
334 let args_regs = (0..factor).map(|i| format!("{arg_ty} const ®_{i}"));
335 let arg_addr = ["void *row_addr".to_string()];
336 let args = args_regs.chain(arg_addr).collect::<Vec<_>>().join(", ");
337
338 let mut idx = 0usize;
339
340 let placeholder_addr = format!("[{}]", placeholder(&mut idx));
341
342 let placeholders_regs = comma_separated((0..factor).map(|_| placeholder(&mut idx)));
343 let placeholders_regs = format!("{{{placeholders_regs}}}");
344
345 let params_regs = comma_separated((0..factor).map(|i| format!(r#""r"(reg_{i})"#)));
346 let param_addr = r#""r"(addr)"#;
347
348 let is_transposed = if transpose { "_trans" } else { "" };
349 let transposed_arg = if transpose { ".trans" } else { "" };
350 let num = format!("x{factor}");
351
352 let ty = match elem.size() {
353 2 => "b16",
354 1 => "b8",
355 _ => unreachable!(),
356 };
357
358 format!(
360 r#"
361inline __device__ void
362__stmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({args}) {{
363 uint32 addr = static_cast<uint32>(__cvta_generic_to_shared(row_addr));
364 asm volatile("stmatrix.sync.aligned.m8n{width}.{num}{transposed_arg}.shared::cta.{ty}"
365 " {placeholder_addr}, {placeholders_regs};"
366 :: {param_addr}, {params_regs});
367 }}
368 "#
369 )
370}