Skip to main content

oxicuda_ptx/
arch.rs

1//! NVIDIA GPU architecture definitions and capability queries.
2//!
3//! This module provides [`SmVersion`] to identify target architectures from
4//! Turing (sm_75) through Blackwell (sm_120), and [`ArchCapabilities`] for
5//! querying hardware features such as tensor core availability, async copy
6//! support, and maximum thread counts.
7
8use std::fmt;
9
10/// NVIDIA GPU Streaming Multiprocessor version.
11///
12/// Each variant corresponds to a CUDA compute capability and determines
13/// the PTX ISA version, available instructions, and hardware features.
14///
15/// # Ordering
16///
17/// `SmVersion` derives `Ord` so that newer architectures compare greater
18/// than older ones: `Sm80 > Sm75`.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20pub enum SmVersion {
21    /// Turing (compute capability 7.5).
22    Sm75,
23    /// Ampere (compute capability 8.0).
24    Sm80,
25    /// Ampere `GA10x` (compute capability 8.6).
26    Sm86,
27    /// Ada Lovelace (compute capability 8.9).
28    Sm89,
29    /// Hopper (compute capability 9.0).
30    Sm90,
31    /// Hopper with accelerated features (compute capability 9.0a).
32    Sm90a,
33    /// Blackwell (compute capability 10.0).
34    Sm100,
35    /// Blackwell B200 / next-gen (compute capability 12.0).
36    Sm120,
37}
38
39impl SmVersion {
40    /// Returns the PTX target string (e.g. `"sm_80"`, `"sm_90a"`).
41    #[must_use]
42    pub const fn as_ptx_str(self) -> &'static str {
43        match self {
44            Self::Sm75 => "sm_75",
45            Self::Sm80 => "sm_80",
46            Self::Sm86 => "sm_86",
47            Self::Sm89 => "sm_89",
48            Self::Sm90 => "sm_90",
49            Self::Sm90a => "sm_90a",
50            Self::Sm100 => "sm_100",
51            Self::Sm120 => "sm_120",
52        }
53    }
54
55    /// Returns the PTX ISA version string appropriate for this architecture.
56    ///
57    /// The PTX version determines which instructions and features are available.
58    /// Later architectures require higher PTX versions.
59    #[must_use]
60    pub const fn ptx_version(self) -> &'static str {
61        match self {
62            Self::Sm75 => "6.4",
63            Self::Sm80 => "7.0",
64            Self::Sm86 => "7.1",
65            Self::Sm89 => "7.8",
66            Self::Sm90 | Self::Sm90a => "8.0",
67            Self::Sm100 => "8.5",
68            Self::Sm120 => "8.7",
69        }
70    }
71
72    /// Returns the PTX ISA version as a `(major, minor)` pair.
73    ///
74    /// This is useful for programmatic version comparisons rather than
75    /// string parsing.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use oxicuda_ptx::arch::SmVersion;
81    ///
82    /// assert_eq!(SmVersion::Sm80.ptx_isa_version(), (7, 0));
83    /// assert_eq!(SmVersion::Sm90.ptx_isa_version(), (8, 0));
84    /// assert_eq!(SmVersion::Sm120.ptx_isa_version(), (8, 7));
85    /// ```
86    #[must_use]
87    pub const fn ptx_isa_version(self) -> (u32, u32) {
88        match self {
89            Self::Sm75 => (6, 4),
90            Self::Sm80 => (7, 0),
91            Self::Sm86 => (7, 1),
92            Self::Sm89 => (7, 8),
93            Self::Sm90 | Self::Sm90a => (8, 0),
94            Self::Sm100 => (8, 5),
95            Self::Sm120 => (8, 7),
96        }
97    }
98
99    /// Returns the architecture capabilities for this SM version.
100    #[must_use]
101    pub const fn capabilities(self) -> ArchCapabilities {
102        ArchCapabilities::for_sm(self)
103    }
104
105    /// Converts a CUDA compute capability pair to an `SmVersion`.
106    ///
107    /// Returns `None` if the compute capability is not recognized.
108    ///
109    /// # Examples
110    ///
111    /// ```
112    /// use oxicuda_ptx::arch::SmVersion;
113    ///
114    /// assert_eq!(SmVersion::from_compute_capability(8, 0), Some(SmVersion::Sm80));
115    /// assert_eq!(SmVersion::from_compute_capability(7, 5), Some(SmVersion::Sm75));
116    /// assert_eq!(SmVersion::from_compute_capability(6, 0), None);
117    /// ```
118    #[must_use]
119    pub const fn from_compute_capability(major: i32, minor: i32) -> Option<Self> {
120        match (major, minor) {
121            (7, 5) => Some(Self::Sm75),
122            (8, 0) => Some(Self::Sm80),
123            (8, 6) => Some(Self::Sm86),
124            (8, 9) => Some(Self::Sm89),
125            (9, 0) => Some(Self::Sm90),
126            (10, 0) => Some(Self::Sm100),
127            (12, 0) => Some(Self::Sm120),
128            _ => None,
129        }
130    }
131
132    /// Returns the maximum number of threads per block for this architecture.
133    #[must_use]
134    pub const fn max_threads_per_block(self) -> u32 {
135        1024
136    }
137
138    /// Returns the maximum number of threads per SM for this architecture.
139    #[must_use]
140    pub const fn max_threads_per_sm(self) -> u32 {
141        match self {
142            Self::Sm75 => 1024,
143            Self::Sm89 => 1536,
144            Self::Sm80 | Self::Sm86 | Self::Sm90 | Self::Sm90a | Self::Sm100 | Self::Sm120 => 2048,
145        }
146    }
147
148    /// Returns the warp size for this architecture (always 32).
149    #[must_use]
150    pub const fn warp_size(self) -> u32 {
151        32
152    }
153
154    /// Returns the maximum shared memory per block in bytes.
155    #[must_use]
156    pub const fn max_shared_mem_per_block(self) -> u32 {
157        match self {
158            Self::Sm75 => 65536,
159            Self::Sm80 | Self::Sm86 => 163_840,
160            Self::Sm89 => 101_376,
161            Self::Sm90 | Self::Sm90a | Self::Sm100 | Self::Sm120 => 232_448,
162        }
163    }
164}
165
166impl fmt::Display for SmVersion {
167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168        f.write_str(self.as_ptx_str())
169    }
170}
171
172/// Hardware capabilities for a specific GPU architecture.
173///
174/// Query this struct to determine whether a given feature (tensor cores,
175/// async copy, TMA, etc.) is available on the target architecture before
176/// emitting instructions that require it.
177#[allow(clippy::struct_excessive_bools)]
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub struct ArchCapabilities {
180    /// Whether the architecture supports `mma.sync` tensor core instructions.
181    pub has_tensor_cores: bool,
182    /// Whether `cp.async` (asynchronous global-to-shared copy) is supported.
183    pub has_cp_async: bool,
184    /// Whether `ldmatrix` (warp-cooperative shared memory load) is supported.
185    pub has_ldmatrix: bool,
186    /// Whether `mma.sync.aligned.m16n8k16` (Ampere MMA shapes) is supported.
187    pub has_ampere_mma: bool,
188    /// Whether WGMMA (warp-group MMA, Hopper) instructions are supported.
189    pub has_wgmma: bool,
190    /// Whether TMA (Tensor Memory Accelerator, Hopper) is supported.
191    pub has_tma: bool,
192    /// Whether FP8 (E4M3/E5M2) data types are supported.
193    pub has_fp8: bool,
194    /// Whether FP6/FP4 narrow floating-point types are supported (Blackwell).
195    pub has_fp6_fp4: bool,
196    /// Whether dynamic shared memory (`extern __shared__`) is supported.
197    pub has_dynamic_smem: bool,
198    /// Whether `bar.sync` with named barriers is supported.
199    pub has_named_barriers: bool,
200    /// Whether `fence.mbarrier` and related cluster barriers are supported.
201    pub has_cluster_barriers: bool,
202    /// Whether `stmatrix` (store matrix to shared memory) is supported (SM >= 90).
203    pub has_stmatrix: bool,
204    /// Whether `redux.sync` (warp-level reduction) is supported (SM >= 80).
205    pub has_redux: bool,
206    /// Whether `elect.sync` (warp leader election) is supported (SM >= 90).
207    pub has_elect_one: bool,
208    /// Whether `griddepcontrol` (grid dependency control) is supported (SM >= 90).
209    pub has_griddepcontrol: bool,
210    /// Whether `setmaxnreg` (set max register count) is supported (SM >= 90).
211    pub has_setmaxnreg: bool,
212    /// Whether bulk async copy operations are supported (SM >= 90).
213    pub has_bulk_copy: bool,
214    /// Whether SM 120 (Rubin) specific features are available.
215    pub has_sm120_features: bool,
216}
217
218impl ArchCapabilities {
219    /// Returns the capabilities for the given SM version.
220    #[must_use]
221    #[allow(clippy::too_many_lines)]
222    pub const fn for_sm(sm: SmVersion) -> Self {
223        match sm {
224            SmVersion::Sm75 => Self {
225                has_tensor_cores: true,
226                has_cp_async: false,
227                has_ldmatrix: true,
228                has_ampere_mma: false,
229                has_wgmma: false,
230                has_tma: false,
231                has_fp8: false,
232                has_fp6_fp4: false,
233                has_dynamic_smem: true,
234                has_named_barriers: true,
235                has_cluster_barriers: false,
236                has_stmatrix: false,
237                has_redux: false,
238                has_elect_one: false,
239                has_griddepcontrol: false,
240                has_setmaxnreg: false,
241                has_bulk_copy: false,
242                has_sm120_features: false,
243            },
244            SmVersion::Sm80 | SmVersion::Sm86 => Self {
245                has_tensor_cores: true,
246                has_cp_async: true,
247                has_ldmatrix: true,
248                has_ampere_mma: true,
249                has_wgmma: false,
250                has_tma: false,
251                has_fp8: false,
252                has_fp6_fp4: false,
253                has_dynamic_smem: true,
254                has_named_barriers: true,
255                has_cluster_barriers: false,
256                has_stmatrix: false,
257                has_redux: true,
258                has_elect_one: false,
259                has_griddepcontrol: false,
260                has_setmaxnreg: false,
261                has_bulk_copy: false,
262                has_sm120_features: false,
263            },
264            SmVersion::Sm89 => Self {
265                has_tensor_cores: true,
266                has_cp_async: true,
267                has_ldmatrix: true,
268                has_ampere_mma: true,
269                has_wgmma: false,
270                has_tma: false,
271                has_fp8: true,
272                has_fp6_fp4: false,
273                has_dynamic_smem: true,
274                has_named_barriers: true,
275                has_cluster_barriers: false,
276                has_stmatrix: false,
277                has_redux: true,
278                has_elect_one: false,
279                has_griddepcontrol: false,
280                has_setmaxnreg: false,
281                has_bulk_copy: false,
282                has_sm120_features: false,
283            },
284            SmVersion::Sm90 | SmVersion::Sm90a => Self {
285                has_tensor_cores: true,
286                has_cp_async: true,
287                has_ldmatrix: true,
288                has_ampere_mma: true,
289                has_wgmma: true,
290                has_tma: true,
291                has_fp8: true,
292                has_fp6_fp4: false,
293                has_dynamic_smem: true,
294                has_named_barriers: true,
295                has_cluster_barriers: true,
296                has_stmatrix: true,
297                has_redux: true,
298                has_elect_one: true,
299                has_griddepcontrol: true,
300                has_setmaxnreg: true,
301                has_bulk_copy: true,
302                has_sm120_features: false,
303            },
304            SmVersion::Sm100 => Self {
305                has_tensor_cores: true,
306                has_cp_async: true,
307                has_ldmatrix: true,
308                has_ampere_mma: true,
309                has_wgmma: true,
310                has_tma: true,
311                has_fp8: true,
312                has_fp6_fp4: true,
313                has_dynamic_smem: true,
314                has_named_barriers: true,
315                has_cluster_barriers: true,
316                has_stmatrix: true,
317                has_redux: true,
318                has_elect_one: true,
319                has_griddepcontrol: true,
320                has_setmaxnreg: true,
321                has_bulk_copy: true,
322                has_sm120_features: false,
323            },
324            SmVersion::Sm120 => Self {
325                has_tensor_cores: true,
326                has_cp_async: true,
327                has_ldmatrix: true,
328                has_ampere_mma: true,
329                has_wgmma: true,
330                has_tma: true,
331                has_fp8: true,
332                has_fp6_fp4: true,
333                has_dynamic_smem: true,
334                has_named_barriers: true,
335                has_cluster_barriers: true,
336                has_stmatrix: true,
337                has_redux: true,
338                has_elect_one: true,
339                has_griddepcontrol: true,
340                has_setmaxnreg: true,
341                has_bulk_copy: true,
342                has_sm120_features: true,
343            },
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn sm_version_ordering() {
354        assert!(SmVersion::Sm80 > SmVersion::Sm75);
355        assert!(SmVersion::Sm90a > SmVersion::Sm90);
356        assert!(SmVersion::Sm120 > SmVersion::Sm100);
357    }
358
359    #[test]
360    fn ptx_version_strings() {
361        assert_eq!(SmVersion::Sm75.ptx_version(), "6.4");
362        assert_eq!(SmVersion::Sm80.ptx_version(), "7.0");
363        assert_eq!(SmVersion::Sm86.ptx_version(), "7.1");
364        assert_eq!(SmVersion::Sm90.ptx_version(), "8.0");
365        assert_eq!(SmVersion::Sm100.ptx_version(), "8.5");
366        assert_eq!(SmVersion::Sm120.ptx_version(), "8.7");
367    }
368
369    #[test]
370    fn from_compute_capability_valid() {
371        assert_eq!(
372            SmVersion::from_compute_capability(7, 5),
373            Some(SmVersion::Sm75)
374        );
375        assert_eq!(
376            SmVersion::from_compute_capability(8, 0),
377            Some(SmVersion::Sm80)
378        );
379        assert_eq!(
380            SmVersion::from_compute_capability(9, 0),
381            Some(SmVersion::Sm90)
382        );
383    }
384
385    #[test]
386    fn from_compute_capability_unknown() {
387        assert_eq!(SmVersion::from_compute_capability(6, 0), None);
388        assert_eq!(SmVersion::from_compute_capability(5, 2), None);
389    }
390
391    #[test]
392    fn capabilities_turing() {
393        let caps = SmVersion::Sm75.capabilities();
394        assert!(caps.has_tensor_cores);
395        assert!(!caps.has_cp_async);
396        assert!(!caps.has_ampere_mma);
397        assert!(!caps.has_wgmma);
398    }
399
400    #[test]
401    fn capabilities_ampere() {
402        let caps = SmVersion::Sm80.capabilities();
403        assert!(caps.has_tensor_cores);
404        assert!(caps.has_cp_async);
405        assert!(caps.has_ampere_mma);
406        assert!(!caps.has_wgmma);
407        assert!(!caps.has_fp8);
408    }
409
410    #[test]
411    fn capabilities_hopper() {
412        let caps = SmVersion::Sm90a.capabilities();
413        assert!(caps.has_wgmma);
414        assert!(caps.has_tma);
415        assert!(caps.has_fp8);
416        assert!(!caps.has_fp6_fp4);
417        assert!(caps.has_cluster_barriers);
418    }
419
420    #[test]
421    fn capabilities_blackwell() {
422        let caps = SmVersion::Sm100.capabilities();
423        assert!(caps.has_fp6_fp4);
424        assert!(caps.has_wgmma);
425        assert!(caps.has_tma);
426    }
427
428    #[test]
429    fn display_sm_version() {
430        assert_eq!(format!("{}", SmVersion::Sm80), "sm_80");
431        assert_eq!(format!("{}", SmVersion::Sm90a), "sm_90a");
432    }
433
434    #[test]
435    fn shared_memory_limits() {
436        assert_eq!(SmVersion::Sm75.max_shared_mem_per_block(), 65536);
437        assert_eq!(SmVersion::Sm80.max_shared_mem_per_block(), 163_840);
438        assert_eq!(SmVersion::Sm90.max_shared_mem_per_block(), 232_448);
439    }
440
441    #[test]
442    fn ptx_isa_version_all_sm() {
443        assert_eq!(SmVersion::Sm75.ptx_isa_version(), (6, 4));
444        assert_eq!(SmVersion::Sm80.ptx_isa_version(), (7, 0));
445        assert_eq!(SmVersion::Sm86.ptx_isa_version(), (7, 1));
446        assert_eq!(SmVersion::Sm89.ptx_isa_version(), (7, 8));
447        assert_eq!(SmVersion::Sm90.ptx_isa_version(), (8, 0));
448        assert_eq!(SmVersion::Sm90a.ptx_isa_version(), (8, 0));
449        assert_eq!(SmVersion::Sm100.ptx_isa_version(), (8, 5));
450        assert_eq!(SmVersion::Sm120.ptx_isa_version(), (8, 7));
451    }
452
453    #[test]
454    fn capabilities_new_fields_turing() {
455        let caps = SmVersion::Sm75.capabilities();
456        assert!(!caps.has_redux);
457        assert!(!caps.has_stmatrix);
458        assert!(!caps.has_elect_one);
459        assert!(!caps.has_griddepcontrol);
460        assert!(!caps.has_setmaxnreg);
461        assert!(!caps.has_bulk_copy);
462        assert!(!caps.has_sm120_features);
463    }
464
465    #[test]
466    fn capabilities_new_fields_ampere() {
467        let caps = SmVersion::Sm80.capabilities();
468        assert!(caps.has_redux);
469        assert!(!caps.has_stmatrix);
470        assert!(!caps.has_elect_one);
471        assert!(!caps.has_griddepcontrol);
472        assert!(!caps.has_sm120_features);
473    }
474
475    #[test]
476    fn capabilities_new_fields_hopper() {
477        let caps = SmVersion::Sm90.capabilities();
478        assert!(caps.has_redux);
479        assert!(caps.has_stmatrix);
480        assert!(caps.has_elect_one);
481        assert!(caps.has_griddepcontrol);
482        assert!(caps.has_setmaxnreg);
483        assert!(caps.has_bulk_copy);
484        assert!(!caps.has_sm120_features);
485    }
486
487    #[test]
488    fn capabilities_new_fields_sm120() {
489        let caps = SmVersion::Sm120.capabilities();
490        assert!(caps.has_redux);
491        assert!(caps.has_stmatrix);
492        assert!(caps.has_elect_one);
493        assert!(caps.has_griddepcontrol);
494        assert!(caps.has_setmaxnreg);
495        assert!(caps.has_bulk_copy);
496        assert!(caps.has_sm120_features);
497    }
498}