cubecl_cpp/cuda/ptx/
mod.rs

1use crate::{
2    Dialect,
3    shared::{Elem, FP4Kind, FP6Kind, FP8Kind},
4};
5
6pub const TMA_LOAD_IM2COL: &str = include_str!("tma_load_im2col.cuh");
7
8#[allow(clippy::too_many_arguments)]
9pub fn mma_template<D: Dialect>(
10    a_elem: Elem<D>,
11    b_elem: Elem<D>,
12    cd_elem: Elem<D>,
13    k: u32,
14    n_a_registers: usize,
15    n_b_registers: usize,
16    n_c_registers: usize,
17    n_d_registers: usize,
18) -> String {
19    let a_ty = mma_ty(a_elem);
20    let b_ty = mma_ty(b_elem);
21    let cd_ty = mma_ty(cd_elem);
22
23    let ab_arg_ty = match a_elem {
24        Elem::F32 => &format!("{}", Elem::<D>::F32),
25        _ => &format!("{}", Elem::<D>::U32),
26    };
27    let cd_arg_ty = match cd_elem {
28        Elem::F32 => &format!("{}", Elem::<D>::F32),
29        _ => &format!("{}", Elem::<D>::U32),
30    };
31
32    let args_a = (0..n_a_registers).map(|i| format!("{ab_arg_ty} const &reg_a_{i}"));
33    let args_b = (0..n_b_registers).map(|i| format!("{ab_arg_ty} const &reg_b_{i}"));
34    let args_c = (0..n_c_registers).map(|i| format!("{cd_arg_ty} const &reg_c_{i}"));
35    let args_d = (0..n_d_registers).map(|i| format!("{cd_arg_ty} &reg_d_{i}"));
36    let args = args_a
37        .chain(args_b)
38        .chain(args_c)
39        .chain(args_d)
40        .collect::<Vec<_>>()
41        .join(", ");
42
43    let kind = if is_fp6_fp4(a_elem) || is_fp6_fp4(b_elem) {
44        ".kind::f8f6f4"
45    } else {
46        ""
47    };
48
49    let mut idx = 0usize;
50
51    let placeholders_d = comma_separated((0..n_d_registers).map(|_| placeholder(&mut idx)));
52    let placeholders_d = format!("{{{placeholders_d}}}");
53
54    let placeholders_a = comma_separated((0..n_a_registers).map(|_| placeholder(&mut idx)));
55    let placeholders_a = format!("{{{placeholders_a}}}");
56
57    let placeholders_b = comma_separated((0..n_b_registers).map(|_| placeholder(&mut idx)));
58    let placeholders_b = format!("{{{placeholders_b}}}");
59
60    let placeholders_c = comma_separated((0..n_c_registers).map(|_| placeholder(&mut idx)));
61    let placeholders_c = format!("{{{placeholders_c}}}");
62
63    let params_out =
64        comma_separated((0..n_d_registers).map(|i| as_reg(&format!("reg_d_{i}"), cd_elem, true)));
65    let params_a = (0..n_a_registers).map(|i| as_reg(&format!("reg_a_{i}"), a_elem, false));
66    let params_b = (0..n_b_registers).map(|i| as_reg(&format!("reg_b_{i}"), b_elem, false));
67    let params_c = (0..n_c_registers).map(|i| as_reg(&format!("reg_c_{i}"), cd_elem, false));
68    let params_in = comma_separated(params_a.chain(params_b).chain(params_c));
69
70    format!(
71        r#"
72inline __device__ void
73__mma_m16n8k{k}_{a_elem}_{b_elem}_{cd_elem}({args}) {{
74  asm volatile("mma.sync.aligned.m16n8k{k}.row.col{kind}.{cd_ty}.{a_ty}.{b_ty}.{cd_ty}"
75               " {placeholders_d}, {placeholders_a}, {placeholders_b}, {placeholders_c};"
76               : {params_out}
77               : {params_in});
78    }}
79    "#
80    )
81}
82
83fn is_fp6_fp4<D: Dialect>(elem: Elem<D>) -> bool {
84    matches!(elem, Elem::<D>::FP4(_) | Elem::<D>::FP6(_))
85}
86
87#[allow(clippy::too_many_arguments)]
88pub fn mma_scaled_template<D: Dialect>(
89    a_elem: Elem<D>,
90    b_elem: Elem<D>,
91    cd_elem: Elem<D>,
92    k: u32,
93    n_a_registers: usize,
94    n_b_registers: usize,
95    n_c_registers: usize,
96    n_d_registers: usize,
97    scales_elem: Elem<D>,
98    scales_factor: u32,
99) -> String {
100    let a_ty = mma_ty(a_elem);
101    let b_ty = mma_ty(b_elem);
102    let cd_ty = mma_ty(cd_elem);
103    // Needs custom mapping because of the ignored sign bit
104    let s_ty = match scales_elem {
105        Elem::FP8(FP8Kind::UE8M0) => "ue8m0",
106        Elem::FP8(FP8Kind::E4M3) => "ue4m3",
107        _ => panic!("Unsupported scales type"),
108    };
109
110    let kind = match scales_factor {
111        1 => "mxf8f6f4",
112        2 | 4 => "mxf4nvf4",
113        _ => panic!("Unsupported scales factor"),
114    };
115
116    let ab_arg_ty = match a_elem {
117        Elem::F32 => &format!("{}", Elem::<D>::F32),
118        _ => &format!("{}", Elem::<D>::U32),
119    };
120    let cd_arg_ty = match cd_elem {
121        Elem::F32 => &format!("{}", Elem::<D>::F32),
122        _ => &format!("{}", Elem::<D>::U32),
123    };
124
125    // Note: Scaled MMA actually requires float registers for C/D, unlike normal MMA
126    let args_a = (0..n_a_registers).map(|i| format!("{ab_arg_ty} const &reg_a_{i}"));
127    let args_b = (0..n_b_registers).map(|i| format!("{ab_arg_ty} const &reg_b_{i}"));
128    let args_c = (0..n_c_registers).map(|i| format!("{cd_arg_ty} const &reg_c_{i}"));
129    let args_d = (0..n_d_registers).map(|i| format!("{cd_arg_ty} &reg_d_{i}"));
130    let args = args_a
131        .chain(args_b)
132        .chain(args_c)
133        .chain(args_d)
134        .collect::<Vec<_>>()
135        .join(", ");
136
137    let mut idx = 0usize;
138
139    let placeholders_d = comma_separated((0..n_d_registers).map(|_| placeholder(&mut idx)));
140    let placeholders_d = format!("{{{placeholders_d}}}");
141
142    let placeholders_a = comma_separated((0..n_a_registers).map(|_| placeholder(&mut idx)));
143    let placeholders_a = format!("{{{placeholders_a}}}");
144
145    let placeholders_b = comma_separated((0..n_b_registers).map(|_| placeholder(&mut idx)));
146    let placeholders_b = format!("{{{placeholders_b}}}");
147
148    let placeholders_c = comma_separated((0..n_c_registers).map(|_| placeholder(&mut idx)));
149    let placeholders_c = format!("{{{placeholders_c}}}");
150
151    let placeholder_scales_a = format!(
152        "{{{}}}, {{{}, {}}}",
153        placeholder(&mut idx),
154        placeholder(&mut idx),
155        placeholder(&mut idx)
156    );
157    let placeholder_scales_b = format!(
158        "{{{}}}, {{{}, {}}}",
159        placeholder(&mut idx),
160        placeholder(&mut idx),
161        placeholder(&mut idx)
162    );
163
164    let params_out =
165        comma_separated((0..n_d_registers).map(|i| as_reg(&format!("reg_d_{i}"), cd_elem, true)));
166    let params_a = (0..n_a_registers).map(|i| as_reg(&format!("reg_a_{i}"), a_elem, false));
167    let params_b = (0..n_b_registers).map(|i| as_reg(&format!("reg_b_{i}"), b_elem, false));
168    let params_c = (0..n_c_registers).map(|i| as_reg(&format!("reg_c_{i}"), cd_elem, false));
169    let params_in = comma_separated(params_a.chain(params_b).chain(params_c));
170
171    format!(
172        r#"
173inline __device__ void
174__mma_scaled_{scales_factor}x_m16n8k{k}_{a_elem}_{b_elem}_{cd_elem}({args}, uint32 const &scales_a, uint32 const &scales_b) {{
175    static constexpr uint16 tidA = 0;
176    static constexpr uint16 bidA = 0;
177    static constexpr uint16 tidB = 0;
178    static constexpr uint16 bidB = 0;
179
180    asm volatile("mma.sync.aligned.kind::{kind}.block_scale.scale_vec::{scales_factor}X.m16n8k{k}.row.col.{cd_ty}.{a_ty}.{b_ty}.{cd_ty}.{s_ty} "
181               "{placeholders_d}, {placeholders_a}, {placeholders_b}, {placeholders_c}, {placeholder_scales_a}, {placeholder_scales_b};"
182               : {params_out}
183               : {params_in}, "r"(scales_a), "h"(bidA), "h"(tidA), "r"(scales_b), "h"(bidB), "h"(tidB));
184    }}
185    "#
186    )
187}
188
189pub(crate) fn comma_separated(it: impl IntoIterator<Item = String>) -> String {
190    it.into_iter().collect::<Vec<_>>().join(", ")
191}
192
193fn placeholder(idx: &mut usize) -> String {
194    let placeholder = format!("%{idx}");
195    *idx += 1;
196    placeholder
197}
198
199fn as_reg<D: Dialect>(ident: &str, ty: Elem<D>, output: bool) -> String {
200    let ty = match ty {
201        Elem::F32 => "f",
202        Elem::F64 => "d",
203        Elem::U64 => "l",
204        _ => "r",
205    };
206    if output {
207        format!(r#""={ty}"({ident})"#)
208    } else {
209        format!(r#""{ty}"({ident})"#)
210    }
211}
212
213fn mma_ty<D: Dialect>(elem: Elem<D>) -> &'static str {
214    match elem {
215        Elem::TF32 => "tf32",
216        Elem::F32 => "f32",
217        Elem::F64 => "f64",
218        Elem::F16 => "f16",
219        Elem::BF16 => "bf16",
220        Elem::FP4(FP4Kind::E2M1) => "e2m1",
221        // For packed MMA this will always exist as fp4x2, since 4-bit values can't exist
222        Elem::FP4x2(FP4Kind::E2M1) => "e2m1",
223        Elem::FP6(FP6Kind::E2M3) => "e2m3",
224        Elem::FP6(FP6Kind::E3M2) => "e3m2",
225        Elem::FP8(FP8Kind::E4M3) => "e4m3",
226        Elem::FP8(FP8Kind::E5M2) => "e5m2",
227        Elem::FP8(FP8Kind::UE8M0) => "ue8m0",
228        Elem::I8 => "s8",
229        Elem::I16 => "s16",
230        Elem::I32 => "s32",
231        Elem::I64 => "s64",
232        Elem::U8 => "u8",
233        Elem::U16 => "u16",
234        Elem::U32 => "u32",
235        Elem::U64 => "u64",
236        Elem::Bool => "b1",
237        other => panic!("{other} not supported for MMA"),
238    }
239}