1use crate::{
2 Dialect,
3 shared::{Component, Elem, FP4Kind, FP6Kind, FP8Kind, Variable},
4};
5
6pub const TMA_LOAD_IM2COL: &str = include_str!("tma_load_im2col.cuh");
7
8#[allow(clippy::too_many_arguments)]
9pub fn mma_template<D: Dialect>(
10 a_elem: Elem<D>,
11 b_elem: Elem<D>,
12 cd_elem: Elem<D>,
13 k: u32,
14 n_a_registers: usize,
15 n_b_registers: usize,
16 n_c_registers: usize,
17 n_d_registers: usize,
18) -> String {
19 let a_ty = mma_ty(a_elem);
20 let b_ty = mma_ty(b_elem);
21 let cd_ty = mma_ty(cd_elem);
22
23 let ab_arg_ty = match a_elem {
24 Elem::F32 => &format!("{}", Elem::<D>::F32),
25 _ => &format!("{}", Elem::<D>::U32),
26 };
27 let cd_arg_ty = match cd_elem {
28 Elem::F32 => &format!("{}", Elem::<D>::F32),
29 _ => &format!("{}", Elem::<D>::U32),
30 };
31
32 let args_a = (0..n_a_registers).map(|i| format!("{ab_arg_ty} const ®_a_{i}"));
33 let args_b = (0..n_b_registers).map(|i| format!("{ab_arg_ty} const ®_b_{i}"));
34 let args_c = (0..n_c_registers).map(|i| format!("{cd_arg_ty} const ®_c_{i}"));
35 let args_d = (0..n_d_registers).map(|i| format!("{cd_arg_ty} ®_d_{i}"));
36 let args = args_a
37 .chain(args_b)
38 .chain(args_c)
39 .chain(args_d)
40 .collect::<Vec<_>>()
41 .join(", ");
42
43 let kind = if is_fp6_fp4(a_elem) || is_fp6_fp4(b_elem) {
44 ".kind::f8f6f4"
45 } else {
46 ""
47 };
48
49 let mut idx = 0usize;
50
51 let placeholders_d = comma_separated((0..n_d_registers).map(|_| placeholder(&mut idx)));
52 let placeholders_d = format!("{{{placeholders_d}}}");
53
54 let placeholders_a = comma_separated((0..n_a_registers).map(|_| placeholder(&mut idx)));
55 let placeholders_a = format!("{{{placeholders_a}}}");
56
57 let placeholders_b = comma_separated((0..n_b_registers).map(|_| placeholder(&mut idx)));
58 let placeholders_b = format!("{{{placeholders_b}}}");
59
60 let placeholders_c = comma_separated((0..n_c_registers).map(|_| placeholder(&mut idx)));
61 let placeholders_c = format!("{{{placeholders_c}}}");
62
63 let params_out =
64 comma_separated((0..n_d_registers).map(|i| as_reg(&format!("reg_d_{i}"), cd_elem, true)));
65 let params_a = (0..n_a_registers).map(|i| as_reg(&format!("reg_a_{i}"), a_elem, false));
66 let params_b = (0..n_b_registers).map(|i| as_reg(&format!("reg_b_{i}"), b_elem, false));
67 let params_c = (0..n_c_registers).map(|i| as_reg(&format!("reg_c_{i}"), cd_elem, false));
68 let params_in = comma_separated(params_a.chain(params_b).chain(params_c));
69
70 format!(
71 r#"
72inline __device__ void
73__mma_m16n8k{k}_{a_elem}_{b_elem}_{cd_elem}({args}) {{
74 asm volatile("mma.sync.aligned.m16n8k{k}.row.col{kind}.{cd_ty}.{a_ty}.{b_ty}.{cd_ty}"
75 " {placeholders_d}, {placeholders_a}, {placeholders_b}, {placeholders_c};"
76 : {params_out}
77 : {params_in});
78 }}
79 "#
80 )
81}
82
83fn is_fp6_fp4<D: Dialect>(elem: Elem<D>) -> bool {
84 matches!(elem, Elem::<D>::FP4(_) | Elem::<D>::FP6(_))
85}
86
87#[allow(clippy::too_many_arguments)]
88pub fn mma_scaled_template<D: Dialect>(
89 a_elem: Elem<D>,
90 b_elem: Elem<D>,
91 cd_elem: Elem<D>,
92 k: u32,
93 n_a_registers: usize,
94 n_b_registers: usize,
95 n_c_registers: usize,
96 n_d_registers: usize,
97 scales_elem: Elem<D>,
98 scales_factor: u32,
99) -> String {
100 let a_ty = mma_ty(a_elem);
101 let b_ty = mma_ty(b_elem);
102 let cd_ty = mma_ty(cd_elem);
103 let s_ty = match scales_elem {
105 Elem::FP8(FP8Kind::UE8M0) => "ue8m0",
106 Elem::FP8(FP8Kind::E4M3) => "ue4m3",
107 _ => panic!("Unsupported scales type"),
108 };
109
110 let kind = match scales_factor {
111 1 => "mxf8f6f4",
112 2 | 4 => "mxf4nvf4",
113 _ => panic!("Unsupported scales factor"),
114 };
115
116 let ab_arg_ty = match a_elem {
117 Elem::F32 => &format!("{}", Elem::<D>::F32),
118 _ => &format!("{}", Elem::<D>::U32),
119 };
120 let cd_arg_ty = match cd_elem {
121 Elem::F32 => &format!("{}", Elem::<D>::F32),
122 _ => &format!("{}", Elem::<D>::U32),
123 };
124
125 let args_a = (0..n_a_registers).map(|i| format!("{ab_arg_ty} const ®_a_{i}"));
127 let args_b = (0..n_b_registers).map(|i| format!("{ab_arg_ty} const ®_b_{i}"));
128 let args_c = (0..n_c_registers).map(|i| format!("{cd_arg_ty} const ®_c_{i}"));
129 let args_d = (0..n_d_registers).map(|i| format!("{cd_arg_ty} ®_d_{i}"));
130 let args = args_a
131 .chain(args_b)
132 .chain(args_c)
133 .chain(args_d)
134 .collect::<Vec<_>>()
135 .join(", ");
136
137 let mut idx = 0usize;
138
139 let placeholders_d = comma_separated((0..n_d_registers).map(|_| placeholder(&mut idx)));
140 let placeholders_d = format!("{{{placeholders_d}}}");
141
142 let placeholders_a = comma_separated((0..n_a_registers).map(|_| placeholder(&mut idx)));
143 let placeholders_a = format!("{{{placeholders_a}}}");
144
145 let placeholders_b = comma_separated((0..n_b_registers).map(|_| placeholder(&mut idx)));
146 let placeholders_b = format!("{{{placeholders_b}}}");
147
148 let placeholders_c = comma_separated((0..n_c_registers).map(|_| placeholder(&mut idx)));
149 let placeholders_c = format!("{{{placeholders_c}}}");
150
151 let placeholder_scales_a = format!(
152 "{{{}}}, {{{}, {}}}",
153 placeholder(&mut idx),
154 placeholder(&mut idx),
155 placeholder(&mut idx)
156 );
157 let placeholder_scales_b = format!(
158 "{{{}}}, {{{}, {}}}",
159 placeholder(&mut idx),
160 placeholder(&mut idx),
161 placeholder(&mut idx)
162 );
163
164 let params_out =
165 comma_separated((0..n_d_registers).map(|i| as_reg(&format!("reg_d_{i}"), cd_elem, true)));
166 let params_a = (0..n_a_registers).map(|i| as_reg(&format!("reg_a_{i}"), a_elem, false));
167 let params_b = (0..n_b_registers).map(|i| as_reg(&format!("reg_b_{i}"), b_elem, false));
168 let params_c = (0..n_c_registers).map(|i| as_reg(&format!("reg_c_{i}"), cd_elem, false));
169 let params_in = comma_separated(params_a.chain(params_b).chain(params_c));
170
171 format!(
172 r#"
173inline __device__ void
174__mma_scaled_{scales_factor}x_m16n8k{k}_{a_elem}_{b_elem}_{cd_elem}({args}, uint32 const &scales_a, uint32 const &scales_b) {{
175 static constexpr uint16 tidA = 0;
176 static constexpr uint16 bidA = 0;
177 static constexpr uint16 tidB = 0;
178 static constexpr uint16 bidB = 0;
179
180 asm volatile("mma.sync.aligned.kind::{kind}.block_scale.scale_vec::{scales_factor}X.m16n8k{k}.row.col.{cd_ty}.{a_ty}.{b_ty}.{cd_ty}.{s_ty} "
181 "{placeholders_d}, {placeholders_a}, {placeholders_b}, {placeholders_c}, {placeholder_scales_a}, {placeholder_scales_b};"
182 : {params_out}
183 : {params_in}, "r"(scales_a), "h"(bidA), "h"(tidA), "r"(scales_b), "h"(bidB), "h"(tidB));
184 }}
185 "#
186 )
187}
188
189pub(crate) fn comma_separated(it: impl IntoIterator<Item = String>) -> String {
190 it.into_iter().collect::<Vec<_>>().join(", ")
191}
192
193fn placeholder(idx: &mut usize) -> String {
194 let placeholder = format!("%{idx}");
195 *idx += 1;
196 placeholder
197}
198
199fn as_reg<D: Dialect>(ident: &str, ty: Elem<D>, output: bool) -> String {
200 let ty = match ty {
201 Elem::F32 => "f",
202 Elem::F64 => "d",
203 Elem::U64 => "l",
204 _ => "r",
205 };
206 if output {
207 format!(r#""={ty}"({ident})"#)
208 } else {
209 format!(r#""{ty}"({ident})"#)
210 }
211}
212
213fn mma_ty<D: Dialect>(elem: Elem<D>) -> &'static str {
214 match elem {
215 Elem::TF32 => "tf32",
216 Elem::F32 => "f32",
217 Elem::F64 => "f64",
218 Elem::F16 => "f16",
219 Elem::BF16 => "bf16",
220 Elem::FP4(FP4Kind::E2M1) => "e2m1",
221 Elem::FP4x2(FP4Kind::E2M1) => "e2m1",
223 Elem::FP6(FP6Kind::E2M3) => "e2m3",
224 Elem::FP6(FP6Kind::E3M2) => "e3m2",
225 Elem::FP8(FP8Kind::E4M3) => "e4m3",
226 Elem::FP8(FP8Kind::E5M2) => "e5m2",
227 Elem::FP8(FP8Kind::UE8M0) => "ue8m0",
228 Elem::I8 => "s8",
229 Elem::I16 => "s16",
230 Elem::I32 => "s32",
231 Elem::I64 => "s64",
232 Elem::U8 => "u8",
233 Elem::U16 => "u16",
234 Elem::U32 => "u32",
235 Elem::U64 => "u64",
236 Elem::Bool => "b1",
237 other => panic!("{other} not supported for MMA"),
238 }
239}
240
241pub fn ldmatrix_call<D: Dialect>(
242 output: &Variable<D>,
243 buffer: &Variable<D>,
244 offset: &Variable<D>,
245 line_size: &Option<u32>,
246 factor: &u32,
247 transpose: &bool,
248) -> String {
249 let elem = output.elem();
250 let width = 16 / output.elem().size();
251 let is_transposed = if *transpose { "_trans" } else { "" };
252 let regs =
253 comma_separated((0..*factor).map(|i| format!("reinterpret_cast<uint32&>({output}[{i}])")));
254 let buffer = if let Some(line_size) = *line_size {
255 let mut item = buffer.item();
256 item.vectorization = line_size as usize;
257 format!("reinterpret_cast<{item}*>({})", buffer.fmt_ptr())
258 } else {
259 buffer.fmt_ptr()
260 };
261
262 format!("__ldmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({regs}, {buffer} + {offset});\n")
263}
264
265pub fn ldmatrix_template<D: Dialect>(elem: Elem<D>, factor: u32, transpose: bool) -> String {
266 let width = 16 / elem.size();
267 let arg_ty = Elem::<D>::U32;
268
269 let args_regs = (0..factor).map(|i| format!("{arg_ty} ®_{i}"));
270 let arg_addr = ["void const *row_addr".to_string()];
271 let args = args_regs.chain(arg_addr).collect::<Vec<_>>().join(", ");
272
273 let mut idx = 0usize;
274
275 let placeholders_regs = comma_separated((0..factor).map(|_| placeholder(&mut idx)));
276 let placeholders_regs = format!("{{{placeholders_regs}}}");
277
278 let placeholder_addr = format!("[{}]", placeholder(&mut idx));
279
280 let params_regs = comma_separated((0..factor).map(|i| format!(r#""=r"(reg_{i})"#)));
281 let param_addr = r#""r"(addr)"#;
282
283 let is_transposed = if transpose { "_trans" } else { "" };
284 let transposed_arg = if transpose { ".trans" } else { "" };
285 let num = format!("x{factor}");
286
287 let ty = match elem.size() {
288 2 => "b16",
289 1 => "b8",
290 _ => unreachable!(),
291 };
292
293 format!(
294 r#"
295inline __device__ void
296__ldmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({args}) {{
297 uint32 addr = static_cast<uint32>(__cvta_generic_to_shared(row_addr));
298 asm volatile("ldmatrix.sync.aligned.m8n{width}.{num}{transposed_arg}.shared::cta.{ty}"
299 " {placeholders_regs}, {placeholder_addr};"
300 : {params_regs}
301 : {param_addr});
302 }}
303 "#
304 )
305}
306
307pub fn stmatrix_call<D: Dialect>(
308 registers: &Variable<D>,
309 buffer: &Variable<D>,
310 offset: &Variable<D>,
311 line_size: &Option<u32>,
312 factor: &u32,
313 transpose: &bool,
314) -> String {
315 let elem = registers.elem();
316 let width = 16 / registers.elem().size();
317 let is_transposed = if *transpose { "_trans" } else { "" };
318 let regs = comma_separated(
319 (0..*factor).map(|i| format!("reinterpret_cast<uint32&>({registers}[{i}])")),
320 );
321 let buffer = if let Some(line_size) = *line_size {
322 let mut item = buffer.item();
323 item.vectorization = line_size as usize;
324 format!("reinterpret_cast<{item}*>({})", buffer.fmt_ptr())
325 } else {
326 buffer.fmt_ptr()
327 };
328
329 format!("__stmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({regs}, {buffer} + {offset});\n")
330}
331
332pub fn stmatrix_template<D: Dialect>(elem: Elem<D>, factor: u32, transpose: bool) -> String {
333 let width = 16 / elem.size();
334 let arg_ty = Elem::<D>::U32;
335
336 let args_regs = (0..factor).map(|i| format!("{arg_ty} const ®_{i}"));
337 let arg_addr = ["void *row_addr".to_string()];
338 let args = args_regs.chain(arg_addr).collect::<Vec<_>>().join(", ");
339
340 let mut idx = 0usize;
341
342 let placeholder_addr = format!("[{}]", placeholder(&mut idx));
343
344 let placeholders_regs = comma_separated((0..factor).map(|_| placeholder(&mut idx)));
345 let placeholders_regs = format!("{{{placeholders_regs}}}");
346
347 let params_regs = comma_separated((0..factor).map(|i| format!(r#""r"(reg_{i})"#)));
348 let param_addr = r#""r"(addr)"#;
349
350 let is_transposed = if transpose { "_trans" } else { "" };
351 let transposed_arg = if transpose { ".trans" } else { "" };
352 let num = format!("x{factor}");
353
354 let ty = match elem.size() {
355 2 => "b16",
356 1 => "b8",
357 _ => unreachable!(),
358 };
359
360 format!(
362 r#"
363inline __device__ void
364__stmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({args}) {{
365 uint32 addr = static_cast<uint32>(__cvta_generic_to_shared(row_addr));
366 asm volatile("stmatrix.sync.aligned.m8n{width}.{num}{transposed_arg}.shared::cta.{ty}"
367 " {placeholder_addr}, {placeholders_regs};"
368 :: {param_addr}, {params_regs});
369 }}
370 "#
371 )
372}