cubecl_cpp/cuda/ptx/
mma.rs

1use crate::{
2    Dialect,
3    shared::{Component, Elem, FP4Kind, FP6Kind, FP8Kind, Variable},
4};
5
6#[allow(clippy::too_many_arguments)]
7pub fn mma_template<D: Dialect>(
8    a_elem: Elem<D>,
9    b_elem: Elem<D>,
10    cd_elem: Elem<D>,
11    k: u32,
12    n_a_registers: usize,
13    n_b_registers: usize,
14    n_c_registers: usize,
15    n_d_registers: usize,
16) -> String {
17    let a_ty = mma_ty(a_elem);
18    let b_ty = mma_ty(b_elem);
19    let cd_ty = mma_ty(cd_elem);
20
21    let ab_arg_ty = match a_elem {
22        Elem::F32 => &format!("{}", Elem::<D>::F32),
23        _ => &format!("{}", Elem::<D>::U32),
24    };
25    let cd_arg_ty = match cd_elem {
26        Elem::F32 => &format!("{}", Elem::<D>::F32),
27        _ => &format!("{}", Elem::<D>::U32),
28    };
29
30    let args_a = (0..n_a_registers).map(|i| format!("{ab_arg_ty} const &reg_a_{i}"));
31    let args_b = (0..n_b_registers).map(|i| format!("{ab_arg_ty} const &reg_b_{i}"));
32    let args_c = (0..n_c_registers).map(|i| format!("{cd_arg_ty} const &reg_c_{i}"));
33    let args_d = (0..n_d_registers).map(|i| format!("{cd_arg_ty} &reg_d_{i}"));
34    let args = args_a
35        .chain(args_b)
36        .chain(args_c)
37        .chain(args_d)
38        .collect::<Vec<_>>()
39        .join(", ");
40
41    let kind = if is_fp6_fp4(a_elem) || is_fp6_fp4(b_elem) {
42        ".kind::f8f6f4"
43    } else {
44        ""
45    };
46
47    let mut idx = 0usize;
48
49    let placeholders_d = comma_separated((0..n_d_registers).map(|_| placeholder(&mut idx)));
50    let placeholders_d = format!("{{{placeholders_d}}}");
51
52    let placeholders_a = comma_separated((0..n_a_registers).map(|_| placeholder(&mut idx)));
53    let placeholders_a = format!("{{{placeholders_a}}}");
54
55    let placeholders_b = comma_separated((0..n_b_registers).map(|_| placeholder(&mut idx)));
56    let placeholders_b = format!("{{{placeholders_b}}}");
57
58    let placeholders_c = comma_separated((0..n_c_registers).map(|_| placeholder(&mut idx)));
59    let placeholders_c = format!("{{{placeholders_c}}}");
60
61    let params_out =
62        comma_separated((0..n_d_registers).map(|i| as_reg(&format!("reg_d_{i}"), cd_elem, true)));
63    let params_a = (0..n_a_registers).map(|i| as_reg(&format!("reg_a_{i}"), a_elem, false));
64    let params_b = (0..n_b_registers).map(|i| as_reg(&format!("reg_b_{i}"), b_elem, false));
65    let params_c = (0..n_c_registers).map(|i| as_reg(&format!("reg_c_{i}"), cd_elem, false));
66    let params_in = comma_separated(params_a.chain(params_b).chain(params_c));
67
68    format!(
69        r#"
70inline __device__ void
71__mma_m16n8k{k}_{a_elem}_{b_elem}_{cd_elem}({args}) {{
72  asm volatile("mma.sync.aligned.m16n8k{k}.row.col{kind}.{cd_ty}.{a_ty}.{b_ty}.{cd_ty}"
73               " {placeholders_d}, {placeholders_a}, {placeholders_b}, {placeholders_c};"
74               : {params_out}
75               : {params_in});
76    }}
77    "#
78    )
79}
80
81fn is_fp6_fp4<D: Dialect>(elem: Elem<D>) -> bool {
82    matches!(elem, Elem::<D>::FP4(_) | Elem::<D>::FP6(_))
83}
84
85#[allow(clippy::too_many_arguments)]
86pub fn mma_scaled_template<D: Dialect>(
87    a_elem: Elem<D>,
88    b_elem: Elem<D>,
89    cd_elem: Elem<D>,
90    k: u32,
91    n_a_registers: usize,
92    n_b_registers: usize,
93    n_c_registers: usize,
94    n_d_registers: usize,
95    scales_elem: Elem<D>,
96    scales_factor: u32,
97) -> String {
98    let a_ty = mma_ty(a_elem);
99    let b_ty = mma_ty(b_elem);
100    let cd_ty = mma_ty(cd_elem);
101    // Needs custom mapping because of the ignored sign bit
102    let s_ty = match scales_elem {
103        Elem::FP8(FP8Kind::UE8M0) => "ue8m0",
104        Elem::FP8(FP8Kind::E4M3) => "ue4m3",
105        _ => panic!("Unsupported scales type"),
106    };
107
108    let kind = match scales_factor {
109        1 => "mxf8f6f4",
110        2 | 4 => "mxf4nvf4",
111        _ => panic!("Unsupported scales factor"),
112    };
113
114    let ab_arg_ty = match a_elem {
115        Elem::F32 => &format!("{}", Elem::<D>::F32),
116        _ => &format!("{}", Elem::<D>::U32),
117    };
118    let cd_arg_ty = match cd_elem {
119        Elem::F32 => &format!("{}", Elem::<D>::F32),
120        _ => &format!("{}", Elem::<D>::U32),
121    };
122
123    // Note: Scaled MMA actually requires float registers for C/D, unlike normal MMA
124    let args_a = (0..n_a_registers).map(|i| format!("{ab_arg_ty} const &reg_a_{i}"));
125    let args_b = (0..n_b_registers).map(|i| format!("{ab_arg_ty} const &reg_b_{i}"));
126    let args_c = (0..n_c_registers).map(|i| format!("{cd_arg_ty} const &reg_c_{i}"));
127    let args_d = (0..n_d_registers).map(|i| format!("{cd_arg_ty} &reg_d_{i}"));
128    let args = args_a
129        .chain(args_b)
130        .chain(args_c)
131        .chain(args_d)
132        .collect::<Vec<_>>()
133        .join(", ");
134
135    let mut idx = 0usize;
136
137    let placeholders_d = comma_separated((0..n_d_registers).map(|_| placeholder(&mut idx)));
138    let placeholders_d = format!("{{{placeholders_d}}}");
139
140    let placeholders_a = comma_separated((0..n_a_registers).map(|_| placeholder(&mut idx)));
141    let placeholders_a = format!("{{{placeholders_a}}}");
142
143    let placeholders_b = comma_separated((0..n_b_registers).map(|_| placeholder(&mut idx)));
144    let placeholders_b = format!("{{{placeholders_b}}}");
145
146    let placeholders_c = comma_separated((0..n_c_registers).map(|_| placeholder(&mut idx)));
147    let placeholders_c = format!("{{{placeholders_c}}}");
148
149    let placeholder_scales_a = format!(
150        "{{{}}}, {{{}, {}}}",
151        placeholder(&mut idx),
152        placeholder(&mut idx),
153        placeholder(&mut idx)
154    );
155    let placeholder_scales_b = format!(
156        "{{{}}}, {{{}, {}}}",
157        placeholder(&mut idx),
158        placeholder(&mut idx),
159        placeholder(&mut idx)
160    );
161
162    let params_out =
163        comma_separated((0..n_d_registers).map(|i| as_reg(&format!("reg_d_{i}"), cd_elem, true)));
164    let params_a = (0..n_a_registers).map(|i| as_reg(&format!("reg_a_{i}"), a_elem, false));
165    let params_b = (0..n_b_registers).map(|i| as_reg(&format!("reg_b_{i}"), b_elem, false));
166    let params_c = (0..n_c_registers).map(|i| as_reg(&format!("reg_c_{i}"), cd_elem, false));
167    let params_in = comma_separated(params_a.chain(params_b).chain(params_c));
168
169    format!(
170        r#"
171inline __device__ void
172__mma_scaled_{scales_factor}x_m16n8k{k}_{a_elem}_{b_elem}_{cd_elem}({args}, uint32 const &scales_a, uint32 const &scales_b) {{
173    static constexpr uint16 tidA = 0;
174    static constexpr uint16 bidA = 0;
175    static constexpr uint16 tidB = 0;
176    static constexpr uint16 bidB = 0;
177
178    asm volatile("mma.sync.aligned.kind::{kind}.block_scale.scale_vec::{scales_factor}X.m16n8k{k}.row.col.{cd_ty}.{a_ty}.{b_ty}.{cd_ty}.{s_ty} "
179               "{placeholders_d}, {placeholders_a}, {placeholders_b}, {placeholders_c}, {placeholder_scales_a}, {placeholder_scales_b};"
180               : {params_out}
181               : {params_in}, "r"(scales_a), "h"(bidA), "h"(tidA), "r"(scales_b), "h"(bidB), "h"(tidB));
182    }}
183    "#
184    )
185}
186
187pub(crate) fn comma_separated(it: impl IntoIterator<Item = String>) -> String {
188    it.into_iter().collect::<Vec<_>>().join(", ")
189}
190
191fn placeholder(idx: &mut usize) -> String {
192    let placeholder = format!("%{idx}");
193    *idx += 1;
194    placeholder
195}
196
197fn as_reg<D: Dialect>(ident: &str, ty: Elem<D>, output: bool) -> String {
198    let ty = match ty {
199        Elem::F32 => "f",
200        Elem::F64 => "d",
201        Elem::U64 => "l",
202        _ => "r",
203    };
204    if output {
205        format!(r#""={ty}"({ident})"#)
206    } else {
207        format!(r#""{ty}"({ident})"#)
208    }
209}
210
211fn mma_ty<D: Dialect>(elem: Elem<D>) -> &'static str {
212    match elem {
213        Elem::TF32 => "tf32",
214        Elem::F32 => "f32",
215        Elem::F64 => "f64",
216        Elem::F16 => "f16",
217        Elem::BF16 => "bf16",
218        Elem::FP4(FP4Kind::E2M1) => "e2m1",
219        // For packed MMA this will always exist as fp4x2, since 4-bit values can't exist
220        Elem::FP4x2(FP4Kind::E2M1) => "e2m1",
221        Elem::FP6(FP6Kind::E2M3) => "e2m3",
222        Elem::FP6(FP6Kind::E3M2) => "e3m2",
223        Elem::FP8(FP8Kind::E4M3) => "e4m3",
224        Elem::FP8(FP8Kind::E5M2) => "e5m2",
225        Elem::FP8(FP8Kind::UE8M0) => "ue8m0",
226        Elem::I8 => "s8",
227        Elem::I16 => "s16",
228        Elem::I32 => "s32",
229        Elem::I64 => "s64",
230        Elem::U8 => "u8",
231        Elem::U16 => "u16",
232        Elem::U32 => "u32",
233        Elem::U64 => "u64",
234        Elem::Bool => "b1",
235        other => panic!("{other} not supported for MMA"),
236    }
237}
238
239pub fn ldmatrix_call<D: Dialect>(
240    output: &Variable<D>,
241    buffer: &Variable<D>,
242    offset: &Variable<D>,
243    line_size: &Option<u32>,
244    factor: &u32,
245    transpose: &bool,
246) -> String {
247    let elem = output.elem();
248    let width = 16 / output.elem().size();
249    let is_transposed = if *transpose { "_trans" } else { "" };
250    let regs =
251        comma_separated((0..*factor).map(|i| format!("reinterpret_cast<uint32&>({output}[{i}])")));
252    let buffer = if let Some(line_size) = *line_size {
253        let mut item = buffer.item();
254        item.vectorization = line_size as usize;
255        format!("reinterpret_cast<{item}*>({})", buffer.fmt_ptr())
256    } else {
257        buffer.fmt_ptr()
258    };
259
260    format!("__ldmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({regs}, {buffer} + {offset});\n")
261}
262
263pub fn ldmatrix_template<D: Dialect>(elem: Elem<D>, factor: u32, transpose: bool) -> String {
264    let width = 16 / elem.size();
265    let arg_ty = Elem::<D>::U32;
266
267    let args_regs = (0..factor).map(|i| format!("{arg_ty} &reg_{i}"));
268    let arg_addr = ["void const *row_addr".to_string()];
269    let args = args_regs.chain(arg_addr).collect::<Vec<_>>().join(", ");
270
271    let mut idx = 0usize;
272
273    let placeholders_regs = comma_separated((0..factor).map(|_| placeholder(&mut idx)));
274    let placeholders_regs = format!("{{{placeholders_regs}}}");
275
276    let placeholder_addr = format!("[{}]", placeholder(&mut idx));
277
278    let params_regs = comma_separated((0..factor).map(|i| format!(r#""=r"(reg_{i})"#)));
279    let param_addr = r#""r"(addr)"#;
280
281    let is_transposed = if transpose { "_trans" } else { "" };
282    let transposed_arg = if transpose { ".trans" } else { "" };
283    let num = format!("x{factor}");
284
285    let ty = match elem.size() {
286        2 => "b16",
287        1 => "b8",
288        _ => unreachable!(),
289    };
290
291    format!(
292        r#"
293inline __device__ void
294__ldmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({args}) {{
295  uint32 addr = static_cast<uint32>(__cvta_generic_to_shared(row_addr));
296  asm volatile("ldmatrix.sync.aligned.m8n{width}.{num}{transposed_arg}.shared::cta.{ty}"
297               " {placeholders_regs}, {placeholder_addr};"
298               : {params_regs}
299               : {param_addr});
300    }}
301    "#
302    )
303}
304
305pub fn stmatrix_call<D: Dialect>(
306    registers: &Variable<D>,
307    buffer: &Variable<D>,
308    offset: &Variable<D>,
309    line_size: &Option<u32>,
310    factor: &u32,
311    transpose: &bool,
312) -> String {
313    let elem = registers.elem();
314    let width = 16 / registers.elem().size();
315    let is_transposed = if *transpose { "_trans" } else { "" };
316    let regs = comma_separated(
317        (0..*factor).map(|i| format!("reinterpret_cast<uint32&>({registers}[{i}])")),
318    );
319    let buffer = if let Some(line_size) = *line_size {
320        let mut item = buffer.item();
321        item.vectorization = line_size as usize;
322        format!("reinterpret_cast<{item}*>({})", buffer.fmt_ptr())
323    } else {
324        buffer.fmt_ptr()
325    };
326
327    format!("__stmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({regs}, {buffer} + {offset});\n")
328}
329
330pub fn stmatrix_template<D: Dialect>(elem: Elem<D>, factor: u32, transpose: bool) -> String {
331    let width = 16 / elem.size();
332    let arg_ty = Elem::<D>::U32;
333
334    let args_regs = (0..factor).map(|i| format!("{arg_ty} const &reg_{i}"));
335    let arg_addr = ["void *row_addr".to_string()];
336    let args = args_regs.chain(arg_addr).collect::<Vec<_>>().join(", ");
337
338    let mut idx = 0usize;
339
340    let placeholder_addr = format!("[{}]", placeholder(&mut idx));
341
342    let placeholders_regs = comma_separated((0..factor).map(|_| placeholder(&mut idx)));
343    let placeholders_regs = format!("{{{placeholders_regs}}}");
344
345    let params_regs = comma_separated((0..factor).map(|i| format!(r#""r"(reg_{i})"#)));
346    let param_addr = r#""r"(addr)"#;
347
348    let is_transposed = if transpose { "_trans" } else { "" };
349    let transposed_arg = if transpose { ".trans" } else { "" };
350    let num = format!("x{factor}");
351
352    let ty = match elem.size() {
353        2 => "b16",
354        1 => "b8",
355        _ => unreachable!(),
356    };
357
358    // Note: smem technically an input
359    format!(
360        r#"
361inline __device__ void
362__stmatrix_m{width}n8_{elem}_{factor}x{is_transposed}({args}) {{
363  uint32 addr = static_cast<uint32>(__cvta_generic_to_shared(row_addr));
364  asm volatile("stmatrix.sync.aligned.m8n{width}.{num}{transposed_arg}.shared::cta.{ty}"
365               " {placeholder_addr}, {placeholders_regs};"
366               :: {param_addr}, {params_regs});
367    }}
368    "#
369    )
370}