use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum SmVersion {
Sm75,
Sm80,
Sm86,
Sm89,
Sm90,
Sm90a,
Sm100,
Sm120,
}
impl SmVersion {
#[must_use]
pub const fn as_ptx_str(self) -> &'static str {
match self {
Self::Sm75 => "sm_75",
Self::Sm80 => "sm_80",
Self::Sm86 => "sm_86",
Self::Sm89 => "sm_89",
Self::Sm90 => "sm_90",
Self::Sm90a => "sm_90a",
Self::Sm100 => "sm_100",
Self::Sm120 => "sm_120",
}
}
#[must_use]
pub const fn ptx_version(self) -> &'static str {
match self {
Self::Sm75 => "6.4",
Self::Sm80 => "7.0",
Self::Sm86 => "7.1",
Self::Sm89 => "7.8",
Self::Sm90 | Self::Sm90a => "8.0",
Self::Sm100 => "8.5",
Self::Sm120 => "8.7",
}
}
#[must_use]
pub const fn ptx_isa_version(self) -> (u32, u32) {
match self {
Self::Sm75 => (6, 4),
Self::Sm80 => (7, 0),
Self::Sm86 => (7, 1),
Self::Sm89 => (7, 8),
Self::Sm90 | Self::Sm90a => (8, 0),
Self::Sm100 => (8, 5),
Self::Sm120 => (8, 7),
}
}
#[must_use]
pub const fn capabilities(self) -> ArchCapabilities {
ArchCapabilities::for_sm(self)
}
#[must_use]
pub const fn from_compute_capability(major: i32, minor: i32) -> Option<Self> {
match (major, minor) {
(7, 5) => Some(Self::Sm75),
(8, 0) => Some(Self::Sm80),
(8, 6) => Some(Self::Sm86),
(8, 9) => Some(Self::Sm89),
(9, 0) => Some(Self::Sm90),
(10, 0) => Some(Self::Sm100),
(12, 0) => Some(Self::Sm120),
_ => None,
}
}
#[must_use]
pub const fn max_threads_per_block(self) -> u32 {
1024
}
#[must_use]
pub const fn max_threads_per_sm(self) -> u32 {
match self {
Self::Sm75 => 1024,
Self::Sm89 => 1536,
Self::Sm80 | Self::Sm86 | Self::Sm90 | Self::Sm90a | Self::Sm100 | Self::Sm120 => 2048,
}
}
#[must_use]
pub const fn warp_size(self) -> u32 {
32
}
#[must_use]
pub const fn max_shared_mem_per_block(self) -> u32 {
match self {
Self::Sm75 => 65536,
Self::Sm80 | Self::Sm86 => 163_840,
Self::Sm89 => 101_376,
Self::Sm90 | Self::Sm90a | Self::Sm100 | Self::Sm120 => 232_448,
}
}
}
impl fmt::Display for SmVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_ptx_str())
}
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ArchCapabilities {
pub has_tensor_cores: bool,
pub has_cp_async: bool,
pub has_ldmatrix: bool,
pub has_ampere_mma: bool,
pub has_wgmma: bool,
pub has_tma: bool,
pub has_fp8: bool,
pub has_fp6_fp4: bool,
pub has_dynamic_smem: bool,
pub has_named_barriers: bool,
pub has_cluster_barriers: bool,
pub has_stmatrix: bool,
pub has_redux: bool,
pub has_elect_one: bool,
pub has_griddepcontrol: bool,
pub has_setmaxnreg: bool,
pub has_bulk_copy: bool,
pub has_sm120_features: bool,
}
impl ArchCapabilities {
#[must_use]
#[allow(clippy::too_many_lines)]
pub const fn for_sm(sm: SmVersion) -> Self {
match sm {
SmVersion::Sm75 => Self {
has_tensor_cores: true,
has_cp_async: false,
has_ldmatrix: true,
has_ampere_mma: false,
has_wgmma: false,
has_tma: false,
has_fp8: false,
has_fp6_fp4: false,
has_dynamic_smem: true,
has_named_barriers: true,
has_cluster_barriers: false,
has_stmatrix: false,
has_redux: false,
has_elect_one: false,
has_griddepcontrol: false,
has_setmaxnreg: false,
has_bulk_copy: false,
has_sm120_features: false,
},
SmVersion::Sm80 | SmVersion::Sm86 => Self {
has_tensor_cores: true,
has_cp_async: true,
has_ldmatrix: true,
has_ampere_mma: true,
has_wgmma: false,
has_tma: false,
has_fp8: false,
has_fp6_fp4: false,
has_dynamic_smem: true,
has_named_barriers: true,
has_cluster_barriers: false,
has_stmatrix: false,
has_redux: true,
has_elect_one: false,
has_griddepcontrol: false,
has_setmaxnreg: false,
has_bulk_copy: false,
has_sm120_features: false,
},
SmVersion::Sm89 => Self {
has_tensor_cores: true,
has_cp_async: true,
has_ldmatrix: true,
has_ampere_mma: true,
has_wgmma: false,
has_tma: false,
has_fp8: true,
has_fp6_fp4: false,
has_dynamic_smem: true,
has_named_barriers: true,
has_cluster_barriers: false,
has_stmatrix: false,
has_redux: true,
has_elect_one: false,
has_griddepcontrol: false,
has_setmaxnreg: false,
has_bulk_copy: false,
has_sm120_features: false,
},
SmVersion::Sm90 | SmVersion::Sm90a => Self {
has_tensor_cores: true,
has_cp_async: true,
has_ldmatrix: true,
has_ampere_mma: true,
has_wgmma: true,
has_tma: true,
has_fp8: true,
has_fp6_fp4: false,
has_dynamic_smem: true,
has_named_barriers: true,
has_cluster_barriers: true,
has_stmatrix: true,
has_redux: true,
has_elect_one: true,
has_griddepcontrol: true,
has_setmaxnreg: true,
has_bulk_copy: true,
has_sm120_features: false,
},
SmVersion::Sm100 => Self {
has_tensor_cores: true,
has_cp_async: true,
has_ldmatrix: true,
has_ampere_mma: true,
has_wgmma: true,
has_tma: true,
has_fp8: true,
has_fp6_fp4: true,
has_dynamic_smem: true,
has_named_barriers: true,
has_cluster_barriers: true,
has_stmatrix: true,
has_redux: true,
has_elect_one: true,
has_griddepcontrol: true,
has_setmaxnreg: true,
has_bulk_copy: true,
has_sm120_features: false,
},
SmVersion::Sm120 => Self {
has_tensor_cores: true,
has_cp_async: true,
has_ldmatrix: true,
has_ampere_mma: true,
has_wgmma: true,
has_tma: true,
has_fp8: true,
has_fp6_fp4: true,
has_dynamic_smem: true,
has_named_barriers: true,
has_cluster_barriers: true,
has_stmatrix: true,
has_redux: true,
has_elect_one: true,
has_griddepcontrol: true,
has_setmaxnreg: true,
has_bulk_copy: true,
has_sm120_features: true,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sm_version_ordering() {
assert!(SmVersion::Sm80 > SmVersion::Sm75);
assert!(SmVersion::Sm90a > SmVersion::Sm90);
assert!(SmVersion::Sm120 > SmVersion::Sm100);
}
#[test]
fn ptx_version_strings() {
assert_eq!(SmVersion::Sm75.ptx_version(), "6.4");
assert_eq!(SmVersion::Sm80.ptx_version(), "7.0");
assert_eq!(SmVersion::Sm86.ptx_version(), "7.1");
assert_eq!(SmVersion::Sm90.ptx_version(), "8.0");
assert_eq!(SmVersion::Sm100.ptx_version(), "8.5");
assert_eq!(SmVersion::Sm120.ptx_version(), "8.7");
}
#[test]
fn from_compute_capability_valid() {
assert_eq!(
SmVersion::from_compute_capability(7, 5),
Some(SmVersion::Sm75)
);
assert_eq!(
SmVersion::from_compute_capability(8, 0),
Some(SmVersion::Sm80)
);
assert_eq!(
SmVersion::from_compute_capability(9, 0),
Some(SmVersion::Sm90)
);
}
#[test]
fn from_compute_capability_unknown() {
assert_eq!(SmVersion::from_compute_capability(6, 0), None);
assert_eq!(SmVersion::from_compute_capability(5, 2), None);
}
#[test]
fn capabilities_turing() {
let caps = SmVersion::Sm75.capabilities();
assert!(caps.has_tensor_cores);
assert!(!caps.has_cp_async);
assert!(!caps.has_ampere_mma);
assert!(!caps.has_wgmma);
}
#[test]
fn capabilities_ampere() {
let caps = SmVersion::Sm80.capabilities();
assert!(caps.has_tensor_cores);
assert!(caps.has_cp_async);
assert!(caps.has_ampere_mma);
assert!(!caps.has_wgmma);
assert!(!caps.has_fp8);
}
#[test]
fn capabilities_hopper() {
let caps = SmVersion::Sm90a.capabilities();
assert!(caps.has_wgmma);
assert!(caps.has_tma);
assert!(caps.has_fp8);
assert!(!caps.has_fp6_fp4);
assert!(caps.has_cluster_barriers);
}
#[test]
fn capabilities_blackwell() {
let caps = SmVersion::Sm100.capabilities();
assert!(caps.has_fp6_fp4);
assert!(caps.has_wgmma);
assert!(caps.has_tma);
}
#[test]
fn display_sm_version() {
assert_eq!(format!("{}", SmVersion::Sm80), "sm_80");
assert_eq!(format!("{}", SmVersion::Sm90a), "sm_90a");
}
#[test]
fn shared_memory_limits() {
assert_eq!(SmVersion::Sm75.max_shared_mem_per_block(), 65536);
assert_eq!(SmVersion::Sm80.max_shared_mem_per_block(), 163_840);
assert_eq!(SmVersion::Sm90.max_shared_mem_per_block(), 232_448);
}
#[test]
fn ptx_isa_version_all_sm() {
assert_eq!(SmVersion::Sm75.ptx_isa_version(), (6, 4));
assert_eq!(SmVersion::Sm80.ptx_isa_version(), (7, 0));
assert_eq!(SmVersion::Sm86.ptx_isa_version(), (7, 1));
assert_eq!(SmVersion::Sm89.ptx_isa_version(), (7, 8));
assert_eq!(SmVersion::Sm90.ptx_isa_version(), (8, 0));
assert_eq!(SmVersion::Sm90a.ptx_isa_version(), (8, 0));
assert_eq!(SmVersion::Sm100.ptx_isa_version(), (8, 5));
assert_eq!(SmVersion::Sm120.ptx_isa_version(), (8, 7));
}
#[test]
fn capabilities_new_fields_turing() {
let caps = SmVersion::Sm75.capabilities();
assert!(!caps.has_redux);
assert!(!caps.has_stmatrix);
assert!(!caps.has_elect_one);
assert!(!caps.has_griddepcontrol);
assert!(!caps.has_setmaxnreg);
assert!(!caps.has_bulk_copy);
assert!(!caps.has_sm120_features);
}
#[test]
fn capabilities_new_fields_ampere() {
let caps = SmVersion::Sm80.capabilities();
assert!(caps.has_redux);
assert!(!caps.has_stmatrix);
assert!(!caps.has_elect_one);
assert!(!caps.has_griddepcontrol);
assert!(!caps.has_sm120_features);
}
#[test]
fn capabilities_new_fields_hopper() {
let caps = SmVersion::Sm90.capabilities();
assert!(caps.has_redux);
assert!(caps.has_stmatrix);
assert!(caps.has_elect_one);
assert!(caps.has_griddepcontrol);
assert!(caps.has_setmaxnreg);
assert!(caps.has_bulk_copy);
assert!(!caps.has_sm120_features);
}
#[test]
fn capabilities_new_fields_sm120() {
let caps = SmVersion::Sm120.capabilities();
assert!(caps.has_redux);
assert!(caps.has_stmatrix);
assert!(caps.has_elect_one);
assert!(caps.has_griddepcontrol);
assert!(caps.has_setmaxnreg);
assert!(caps.has_bulk_copy);
assert!(caps.has_sm120_features);
}
}