use core::fmt;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[repr(transparent)]
pub struct CudaVersion(u32);
impl CudaVersion {
pub const CUDA_11_4: Self = Self::from_major_minor(11, 4);
pub const CUDA_11_8: Self = Self::from_major_minor(11, 8);
pub const CUDA_12_0: Self = Self::from_major_minor(12, 0);
pub const CUDA_12_3: Self = Self::from_major_minor(12, 3);
pub const CUDA_12_6: Self = Self::from_major_minor(12, 6);
pub const CUDA_12_8: Self = Self::from_major_minor(12, 8);
pub const CUDA_13_0: Self = Self::from_major_minor(13, 0);
pub const FLOOR: Self = Self::CUDA_11_4;
#[inline]
pub const fn from_major_minor(major: u32, minor: u32) -> Self {
Self(major * 1000 + minor * 10)
}
#[inline]
pub const fn from_raw(raw: u32) -> Self {
Self(raw)
}
#[inline]
pub const fn raw(self) -> u32 {
self.0
}
#[inline]
pub const fn major(self) -> u32 {
self.0 / 1000
}
#[inline]
pub const fn minor(self) -> u32 {
(self.0 % 1000) / 10
}
#[inline]
pub const fn at_least(self, major: u32, minor: u32) -> bool {
self.0 >= major * 1000 + minor * 10
}
}
impl fmt::Display for CudaVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CUDA {}.{}", self.major(), self.minor())
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[non_exhaustive]
pub enum Feature {
StreamOrderedAllocator,
VirtualMemoryManagement,
LibraryManagement,
GreenContexts,
MulticastObjects,
GraphConditionalNodes,
GraphSwitchNodes,
CudaLaunchKernelEx,
NvJitLink,
TensorMapObjects,
RuntimeLogBuffer,
CudaInitDevice,
RuntimeGreenContexts,
}
impl Feature {
pub const fn required_version(self) -> CudaVersion {
match self {
Feature::StreamOrderedAllocator => CudaVersion::from_major_minor(11, 2),
Feature::VirtualMemoryManagement => CudaVersion::from_major_minor(10, 2),
Feature::LibraryManagement => CudaVersion::CUDA_12_0,
Feature::GreenContexts => CudaVersion::CUDA_12_0,
Feature::MulticastObjects => CudaVersion::CUDA_12_0,
Feature::GraphConditionalNodes => CudaVersion::CUDA_12_3,
Feature::GraphSwitchNodes => CudaVersion::CUDA_12_8,
Feature::CudaLaunchKernelEx => CudaVersion::CUDA_12_0,
Feature::NvJitLink => CudaVersion::CUDA_12_0,
Feature::TensorMapObjects => CudaVersion::from_major_minor(11, 8),
Feature::RuntimeLogBuffer => CudaVersion::CUDA_12_0,
Feature::CudaInitDevice => CudaVersion::CUDA_12_0,
Feature::RuntimeGreenContexts => CudaVersion::from_major_minor(13, 1),
}
}
}
#[inline]
pub const fn supports(version: CudaVersion, feature: Feature) -> bool {
version.raw() >= feature.required_version().raw()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_decode() {
let v = CudaVersion::from_major_minor(12, 6);
assert_eq!(v.major(), 12);
assert_eq!(v.minor(), 6);
assert_eq!(v.raw(), 12060);
}
#[test]
fn ordering_is_by_version() {
assert!(CudaVersion::CUDA_11_4 < CudaVersion::CUDA_12_0);
assert!(CudaVersion::CUDA_12_0 < CudaVersion::CUDA_13_0);
}
#[test]
fn at_least() {
assert!(CudaVersion::CUDA_12_6.at_least(12, 0));
assert!(!CudaVersion::CUDA_11_4.at_least(12, 0));
}
#[test]
fn feature_gating() {
assert!(supports(CudaVersion::CUDA_12_0, Feature::GreenContexts));
assert!(!supports(CudaVersion::CUDA_11_8, Feature::GreenContexts));
assert!(supports(CudaVersion::CUDA_12_8, Feature::GraphSwitchNodes));
assert!(!supports(CudaVersion::CUDA_12_6, Feature::GraphSwitchNodes));
}
#[test]
fn floor_is_consistent() {
assert_eq!(CudaVersion::FLOOR, CudaVersion::CUDA_11_4);
}
}