1use core::fmt;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum CutlassDtype {
17 F32,
19 F64,
21 F16,
23 Bf16,
25 F8E4m3,
27 F8E5m2,
29 F4E2m1,
31 I8,
33 I32,
35 U8,
37}
38
39impl CutlassDtype {
40 pub fn as_cutlass_type(self) -> &'static str {
43 match self {
44 CutlassDtype::F32 => "float",
45 CutlassDtype::F64 => "double",
46 CutlassDtype::F16 => "cutlass::half_t",
47 CutlassDtype::Bf16 => "cutlass::bfloat16_t",
48 CutlassDtype::F8E4m3 => "cutlass::float_e4m3_t",
49 CutlassDtype::F8E5m2 => "cutlass::float_e5m2_t",
50 CutlassDtype::F4E2m1 => "cutlass::float_e2m1_t",
51 CutlassDtype::I8 => "int8_t",
52 CutlassDtype::I32 => "int32_t",
53 CutlassDtype::U8 => "uint8_t",
54 }
55 }
56
57 pub fn short_name(self) -> &'static str {
59 match self {
60 CutlassDtype::F32 => "f32",
61 CutlassDtype::F64 => "f64",
62 CutlassDtype::F16 => "f16",
63 CutlassDtype::Bf16 => "bf16",
64 CutlassDtype::F8E4m3 => "f8e4m3",
65 CutlassDtype::F8E5m2 => "f8e5m2",
66 CutlassDtype::F4E2m1 => "f4e2m1",
67 CutlassDtype::I8 => "i8",
68 CutlassDtype::I32 => "i32",
69 CutlassDtype::U8 => "u8",
70 }
71 }
72
73 pub fn size_bits(self) -> u32 {
75 match self {
76 CutlassDtype::F64 => 64,
77 CutlassDtype::F32 | CutlassDtype::I32 => 32,
78 CutlassDtype::F16 | CutlassDtype::Bf16 => 16,
79 CutlassDtype::F8E4m3 | CutlassDtype::F8E5m2 | CutlassDtype::I8 | CutlassDtype::U8 => 8,
80 CutlassDtype::F4E2m1 => 4,
81 }
82 }
83}
84
85impl fmt::Display for CutlassDtype {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 f.write_str(self.short_name())
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
97pub enum SmArch {
98 Sm80,
99 Sm86,
100 Sm89,
101 Sm90,
102 Sm90a,
103 Sm100,
104 Sm120,
105}
106
107impl SmArch {
108 pub fn nvrtc_flag(self) -> &'static str {
109 match self {
110 SmArch::Sm80 => "--gpu-architecture=compute_80",
111 SmArch::Sm86 => "--gpu-architecture=compute_86",
112 SmArch::Sm89 => "--gpu-architecture=compute_89",
113 SmArch::Sm90 => "--gpu-architecture=compute_90",
114 SmArch::Sm90a => "--gpu-architecture=compute_90a",
115 SmArch::Sm100 => "--gpu-architecture=compute_100",
116 SmArch::Sm120 => "--gpu-architecture=compute_120",
117 }
118 }
119
120 pub fn short_name(self) -> &'static str {
121 match self {
122 SmArch::Sm80 => "sm_80",
123 SmArch::Sm86 => "sm_86",
124 SmArch::Sm89 => "sm_89",
125 SmArch::Sm90 => "sm_90",
126 SmArch::Sm90a => "sm_90a",
127 SmArch::Sm100 => "sm_100",
128 SmArch::Sm120 => "sm_120",
129 }
130 }
131
132 pub fn supports_fp8(self) -> bool {
133 matches!(
134 self,
135 SmArch::Sm89 | SmArch::Sm90 | SmArch::Sm90a | SmArch::Sm100 | SmArch::Sm120
136 )
137 }
138
139 pub fn supports_fp4(self) -> bool {
140 matches!(self, SmArch::Sm100 | SmArch::Sm120)
141 }
142
143 pub fn supports_persistent_kernels(self) -> bool {
144 matches!(
145 self,
146 SmArch::Sm90 | SmArch::Sm90a | SmArch::Sm100 | SmArch::Sm120
147 )
148 }
149}
150
151impl fmt::Display for SmArch {
152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153 f.write_str(self.short_name())
154 }
155}
156
157pub trait GemmSupported: Copy + Send + Sync + 'static {
164 const DTYPE: CutlassDtype;
165}
166
167impl GemmSupported for f32 {
168 const DTYPE: CutlassDtype = CutlassDtype::F32;
169}
170impl GemmSupported for f64 {
171 const DTYPE: CutlassDtype = CutlassDtype::F64;
172}
173impl GemmSupported for i8 {
174 const DTYPE: CutlassDtype = CutlassDtype::I8;
175}
176impl GemmSupported for i32 {
177 const DTYPE: CutlassDtype = CutlassDtype::I32;
178}
179impl GemmSupported for u8 {
180 const DTYPE: CutlassDtype = CutlassDtype::U8;
181}
182
183#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
188#[repr(transparent)]
189pub struct F16(pub u16);
190impl GemmSupported for F16 {
191 const DTYPE: CutlassDtype = CutlassDtype::F16;
192}
193
194#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
196#[repr(transparent)]
197pub struct Bf16(pub u16);
198impl GemmSupported for Bf16 {
199 const DTYPE: CutlassDtype = CutlassDtype::Bf16;
200}
201
202#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
204#[repr(transparent)]
205pub struct F8E4m3(pub u8);
206impl GemmSupported for F8E4m3 {
207 const DTYPE: CutlassDtype = CutlassDtype::F8E4m3;
208}
209
210#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
212#[repr(transparent)]
213pub struct F8E5m2(pub u8);
214impl GemmSupported for F8E5m2 {
215 const DTYPE: CutlassDtype = CutlassDtype::F8E5m2;
216}
217
218#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
221#[repr(transparent)]
222pub struct F4E2m1(pub u8);
223impl GemmSupported for F4E2m1 {
224 const DTYPE: CutlassDtype = CutlassDtype::F4E2m1;
225}
226
227pub fn is_supported_for(dtype: CutlassDtype, arch: SmArch) -> bool {
231 match dtype {
232 CutlassDtype::F8E4m3 | CutlassDtype::F8E5m2 => arch.supports_fp8(),
233 CutlassDtype::F4E2m1 => arch.supports_fp4(),
234 _ => true,
235 }
236}
237
238pub fn is_fp8_supported(arch: SmArch) -> bool {
240 arch.supports_fp8()
241}
242
243pub fn is_fp4_supported(arch: SmArch) -> bool {
245 arch.supports_fp4()
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn arch_capability_predicates() {
254 assert!(!SmArch::Sm80.supports_fp8());
255 assert!(SmArch::Sm89.supports_fp8());
256 assert!(SmArch::Sm90a.supports_fp8());
257 assert!(SmArch::Sm100.supports_fp4());
258 assert!(!SmArch::Sm89.supports_fp4());
259 assert!(SmArch::Sm90a.supports_persistent_kernels());
260 assert!(!SmArch::Sm80.supports_persistent_kernels());
261 }
262
263 #[test]
264 fn dtype_short_names_unique() {
265 let all = [
266 CutlassDtype::F32,
267 CutlassDtype::F64,
268 CutlassDtype::F16,
269 CutlassDtype::Bf16,
270 CutlassDtype::F8E4m3,
271 CutlassDtype::F8E5m2,
272 CutlassDtype::F4E2m1,
273 CutlassDtype::I8,
274 CutlassDtype::I32,
275 CutlassDtype::U8,
276 ];
277 let mut seen: Vec<&'static str> = Vec::new();
278 for dt in all {
279 assert!(!seen.contains(&dt.short_name()));
280 seen.push(dt.short_name());
281 }
282 }
283
284 #[test]
285 fn is_supported_for_matrix() {
286 assert!(is_supported_for(CutlassDtype::F32, SmArch::Sm80));
287 assert!(!is_supported_for(CutlassDtype::F8E4m3, SmArch::Sm80));
288 assert!(is_supported_for(CutlassDtype::F8E4m3, SmArch::Sm90a));
289 assert!(!is_supported_for(CutlassDtype::F4E2m1, SmArch::Sm89));
290 assert!(is_supported_for(CutlassDtype::F4E2m1, SmArch::Sm100));
291 }
292}