1use std::fmt;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub enum SmVersion {
21 Sm75,
23 Sm80,
25 Sm86,
27 Sm89,
29 Sm90,
31 Sm90a,
33 Sm100,
35 Sm120,
37}
38
39impl SmVersion {
40 #[must_use]
42 pub const fn as_ptx_str(self) -> &'static str {
43 match self {
44 Self::Sm75 => "sm_75",
45 Self::Sm80 => "sm_80",
46 Self::Sm86 => "sm_86",
47 Self::Sm89 => "sm_89",
48 Self::Sm90 => "sm_90",
49 Self::Sm90a => "sm_90a",
50 Self::Sm100 => "sm_100",
51 Self::Sm120 => "sm_120",
52 }
53 }
54
55 #[must_use]
60 pub const fn ptx_version(self) -> &'static str {
61 match self {
62 Self::Sm75 => "6.4",
63 Self::Sm80 => "7.0",
64 Self::Sm86 => "7.1",
65 Self::Sm89 => "7.8",
66 Self::Sm90 | Self::Sm90a => "8.0",
67 Self::Sm100 => "8.5",
68 Self::Sm120 => "8.7",
69 }
70 }
71
72 #[must_use]
87 pub const fn ptx_isa_version(self) -> (u32, u32) {
88 match self {
89 Self::Sm75 => (6, 4),
90 Self::Sm80 => (7, 0),
91 Self::Sm86 => (7, 1),
92 Self::Sm89 => (7, 8),
93 Self::Sm90 | Self::Sm90a => (8, 0),
94 Self::Sm100 => (8, 5),
95 Self::Sm120 => (8, 7),
96 }
97 }
98
99 #[must_use]
101 pub const fn capabilities(self) -> ArchCapabilities {
102 ArchCapabilities::for_sm(self)
103 }
104
105 #[must_use]
119 pub const fn from_compute_capability(major: i32, minor: i32) -> Option<Self> {
120 match (major, minor) {
121 (7, 5) => Some(Self::Sm75),
122 (8, 0) => Some(Self::Sm80),
123 (8, 6) => Some(Self::Sm86),
124 (8, 9) => Some(Self::Sm89),
125 (9, 0) => Some(Self::Sm90),
126 (10, 0) => Some(Self::Sm100),
127 (12, 0) => Some(Self::Sm120),
128 _ => None,
129 }
130 }
131
132 #[must_use]
134 pub const fn max_threads_per_block(self) -> u32 {
135 1024
136 }
137
138 #[must_use]
140 pub const fn max_threads_per_sm(self) -> u32 {
141 match self {
142 Self::Sm75 => 1024,
143 Self::Sm89 => 1536,
144 Self::Sm80 | Self::Sm86 | Self::Sm90 | Self::Sm90a | Self::Sm100 | Self::Sm120 => 2048,
145 }
146 }
147
148 #[must_use]
150 pub const fn warp_size(self) -> u32 {
151 32
152 }
153
154 #[must_use]
156 pub const fn max_shared_mem_per_block(self) -> u32 {
157 match self {
158 Self::Sm75 => 65536,
159 Self::Sm80 | Self::Sm86 => 163_840,
160 Self::Sm89 => 101_376,
161 Self::Sm90 | Self::Sm90a | Self::Sm100 | Self::Sm120 => 232_448,
162 }
163 }
164}
165
166impl fmt::Display for SmVersion {
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 f.write_str(self.as_ptx_str())
169 }
170}
171
172#[allow(clippy::struct_excessive_bools)]
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub struct ArchCapabilities {
180 pub has_tensor_cores: bool,
182 pub has_cp_async: bool,
184 pub has_ldmatrix: bool,
186 pub has_ampere_mma: bool,
188 pub has_wgmma: bool,
190 pub has_tma: bool,
192 pub has_fp8: bool,
194 pub has_fp6_fp4: bool,
196 pub has_dynamic_smem: bool,
198 pub has_named_barriers: bool,
200 pub has_cluster_barriers: bool,
202 pub has_stmatrix: bool,
204 pub has_redux: bool,
206 pub has_elect_one: bool,
208 pub has_griddepcontrol: bool,
210 pub has_setmaxnreg: bool,
212 pub has_bulk_copy: bool,
214 pub has_sm120_features: bool,
216}
217
218impl ArchCapabilities {
219 #[must_use]
221 #[allow(clippy::too_many_lines)]
222 pub const fn for_sm(sm: SmVersion) -> Self {
223 match sm {
224 SmVersion::Sm75 => Self {
225 has_tensor_cores: true,
226 has_cp_async: false,
227 has_ldmatrix: true,
228 has_ampere_mma: false,
229 has_wgmma: false,
230 has_tma: false,
231 has_fp8: false,
232 has_fp6_fp4: false,
233 has_dynamic_smem: true,
234 has_named_barriers: true,
235 has_cluster_barriers: false,
236 has_stmatrix: false,
237 has_redux: false,
238 has_elect_one: false,
239 has_griddepcontrol: false,
240 has_setmaxnreg: false,
241 has_bulk_copy: false,
242 has_sm120_features: false,
243 },
244 SmVersion::Sm80 | SmVersion::Sm86 => Self {
245 has_tensor_cores: true,
246 has_cp_async: true,
247 has_ldmatrix: true,
248 has_ampere_mma: true,
249 has_wgmma: false,
250 has_tma: false,
251 has_fp8: false,
252 has_fp6_fp4: false,
253 has_dynamic_smem: true,
254 has_named_barriers: true,
255 has_cluster_barriers: false,
256 has_stmatrix: false,
257 has_redux: true,
258 has_elect_one: false,
259 has_griddepcontrol: false,
260 has_setmaxnreg: false,
261 has_bulk_copy: false,
262 has_sm120_features: false,
263 },
264 SmVersion::Sm89 => Self {
265 has_tensor_cores: true,
266 has_cp_async: true,
267 has_ldmatrix: true,
268 has_ampere_mma: true,
269 has_wgmma: false,
270 has_tma: false,
271 has_fp8: true,
272 has_fp6_fp4: false,
273 has_dynamic_smem: true,
274 has_named_barriers: true,
275 has_cluster_barriers: false,
276 has_stmatrix: false,
277 has_redux: true,
278 has_elect_one: false,
279 has_griddepcontrol: false,
280 has_setmaxnreg: false,
281 has_bulk_copy: false,
282 has_sm120_features: false,
283 },
284 SmVersion::Sm90 | SmVersion::Sm90a => Self {
285 has_tensor_cores: true,
286 has_cp_async: true,
287 has_ldmatrix: true,
288 has_ampere_mma: true,
289 has_wgmma: true,
290 has_tma: true,
291 has_fp8: true,
292 has_fp6_fp4: false,
293 has_dynamic_smem: true,
294 has_named_barriers: true,
295 has_cluster_barriers: true,
296 has_stmatrix: true,
297 has_redux: true,
298 has_elect_one: true,
299 has_griddepcontrol: true,
300 has_setmaxnreg: true,
301 has_bulk_copy: true,
302 has_sm120_features: false,
303 },
304 SmVersion::Sm100 => Self {
305 has_tensor_cores: true,
306 has_cp_async: true,
307 has_ldmatrix: true,
308 has_ampere_mma: true,
309 has_wgmma: true,
310 has_tma: true,
311 has_fp8: true,
312 has_fp6_fp4: true,
313 has_dynamic_smem: true,
314 has_named_barriers: true,
315 has_cluster_barriers: true,
316 has_stmatrix: true,
317 has_redux: true,
318 has_elect_one: true,
319 has_griddepcontrol: true,
320 has_setmaxnreg: true,
321 has_bulk_copy: true,
322 has_sm120_features: false,
323 },
324 SmVersion::Sm120 => Self {
325 has_tensor_cores: true,
326 has_cp_async: true,
327 has_ldmatrix: true,
328 has_ampere_mma: true,
329 has_wgmma: true,
330 has_tma: true,
331 has_fp8: true,
332 has_fp6_fp4: true,
333 has_dynamic_smem: true,
334 has_named_barriers: true,
335 has_cluster_barriers: true,
336 has_stmatrix: true,
337 has_redux: true,
338 has_elect_one: true,
339 has_griddepcontrol: true,
340 has_setmaxnreg: true,
341 has_bulk_copy: true,
342 has_sm120_features: true,
343 },
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn sm_version_ordering() {
354 assert!(SmVersion::Sm80 > SmVersion::Sm75);
355 assert!(SmVersion::Sm90a > SmVersion::Sm90);
356 assert!(SmVersion::Sm120 > SmVersion::Sm100);
357 }
358
359 #[test]
360 fn ptx_version_strings() {
361 assert_eq!(SmVersion::Sm75.ptx_version(), "6.4");
362 assert_eq!(SmVersion::Sm80.ptx_version(), "7.0");
363 assert_eq!(SmVersion::Sm86.ptx_version(), "7.1");
364 assert_eq!(SmVersion::Sm90.ptx_version(), "8.0");
365 assert_eq!(SmVersion::Sm100.ptx_version(), "8.5");
366 assert_eq!(SmVersion::Sm120.ptx_version(), "8.7");
367 }
368
369 #[test]
370 fn from_compute_capability_valid() {
371 assert_eq!(
372 SmVersion::from_compute_capability(7, 5),
373 Some(SmVersion::Sm75)
374 );
375 assert_eq!(
376 SmVersion::from_compute_capability(8, 0),
377 Some(SmVersion::Sm80)
378 );
379 assert_eq!(
380 SmVersion::from_compute_capability(9, 0),
381 Some(SmVersion::Sm90)
382 );
383 }
384
385 #[test]
386 fn from_compute_capability_unknown() {
387 assert_eq!(SmVersion::from_compute_capability(6, 0), None);
388 assert_eq!(SmVersion::from_compute_capability(5, 2), None);
389 }
390
391 #[test]
392 fn capabilities_turing() {
393 let caps = SmVersion::Sm75.capabilities();
394 assert!(caps.has_tensor_cores);
395 assert!(!caps.has_cp_async);
396 assert!(!caps.has_ampere_mma);
397 assert!(!caps.has_wgmma);
398 }
399
400 #[test]
401 fn capabilities_ampere() {
402 let caps = SmVersion::Sm80.capabilities();
403 assert!(caps.has_tensor_cores);
404 assert!(caps.has_cp_async);
405 assert!(caps.has_ampere_mma);
406 assert!(!caps.has_wgmma);
407 assert!(!caps.has_fp8);
408 }
409
410 #[test]
411 fn capabilities_hopper() {
412 let caps = SmVersion::Sm90a.capabilities();
413 assert!(caps.has_wgmma);
414 assert!(caps.has_tma);
415 assert!(caps.has_fp8);
416 assert!(!caps.has_fp6_fp4);
417 assert!(caps.has_cluster_barriers);
418 }
419
420 #[test]
421 fn capabilities_blackwell() {
422 let caps = SmVersion::Sm100.capabilities();
423 assert!(caps.has_fp6_fp4);
424 assert!(caps.has_wgmma);
425 assert!(caps.has_tma);
426 }
427
428 #[test]
429 fn display_sm_version() {
430 assert_eq!(format!("{}", SmVersion::Sm80), "sm_80");
431 assert_eq!(format!("{}", SmVersion::Sm90a), "sm_90a");
432 }
433
434 #[test]
435 fn shared_memory_limits() {
436 assert_eq!(SmVersion::Sm75.max_shared_mem_per_block(), 65536);
437 assert_eq!(SmVersion::Sm80.max_shared_mem_per_block(), 163_840);
438 assert_eq!(SmVersion::Sm90.max_shared_mem_per_block(), 232_448);
439 }
440
441 #[test]
442 fn ptx_isa_version_all_sm() {
443 assert_eq!(SmVersion::Sm75.ptx_isa_version(), (6, 4));
444 assert_eq!(SmVersion::Sm80.ptx_isa_version(), (7, 0));
445 assert_eq!(SmVersion::Sm86.ptx_isa_version(), (7, 1));
446 assert_eq!(SmVersion::Sm89.ptx_isa_version(), (7, 8));
447 assert_eq!(SmVersion::Sm90.ptx_isa_version(), (8, 0));
448 assert_eq!(SmVersion::Sm90a.ptx_isa_version(), (8, 0));
449 assert_eq!(SmVersion::Sm100.ptx_isa_version(), (8, 5));
450 assert_eq!(SmVersion::Sm120.ptx_isa_version(), (8, 7));
451 }
452
453 #[test]
454 fn capabilities_new_fields_turing() {
455 let caps = SmVersion::Sm75.capabilities();
456 assert!(!caps.has_redux);
457 assert!(!caps.has_stmatrix);
458 assert!(!caps.has_elect_one);
459 assert!(!caps.has_griddepcontrol);
460 assert!(!caps.has_setmaxnreg);
461 assert!(!caps.has_bulk_copy);
462 assert!(!caps.has_sm120_features);
463 }
464
465 #[test]
466 fn capabilities_new_fields_ampere() {
467 let caps = SmVersion::Sm80.capabilities();
468 assert!(caps.has_redux);
469 assert!(!caps.has_stmatrix);
470 assert!(!caps.has_elect_one);
471 assert!(!caps.has_griddepcontrol);
472 assert!(!caps.has_sm120_features);
473 }
474
475 #[test]
476 fn capabilities_new_fields_hopper() {
477 let caps = SmVersion::Sm90.capabilities();
478 assert!(caps.has_redux);
479 assert!(caps.has_stmatrix);
480 assert!(caps.has_elect_one);
481 assert!(caps.has_griddepcontrol);
482 assert!(caps.has_setmaxnreg);
483 assert!(caps.has_bulk_copy);
484 assert!(!caps.has_sm120_features);
485 }
486
487 #[test]
488 fn capabilities_new_fields_sm120() {
489 let caps = SmVersion::Sm120.capabilities();
490 assert!(caps.has_redux);
491 assert!(caps.has_stmatrix);
492 assert!(caps.has_elect_one);
493 assert!(caps.has_griddepcontrol);
494 assert!(caps.has_setmaxnreg);
495 assert!(caps.has_bulk_copy);
496 assert!(caps.has_sm120_features);
497 }
498}