use libloading::{Library, Symbol};
use once_cell::sync::OnceCell;
use rattler_conda_types::Version;
use std::process::Command;
use std::{
mem::MaybeUninit,
os::raw::{c_int, c_uint, c_ulong},
str::FromStr,
};
pub(crate) fn is_valid_cuda_version_format(s: &str) -> bool {
let mut parts = s.split('.');
match (parts.next(), parts.next(), parts.next()) {
(Some(major), Some(minor), None) => {
!major.is_empty()
&& major.chars().all(|c| c.is_ascii_digit())
&& !minor.is_empty()
&& minor.chars().all(|c| c.is_ascii_digit())
}
_ => false,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CudaArchInfo {
pub major: u32,
pub minor: u32,
}
#[derive(Debug, Clone)]
pub struct CudaInfo {
pub version: Option<Version>,
pub arch_info: Option<CudaArchInfo>,
}
pub fn cuda_info() -> &'static CudaInfo {
static DETECTED_CUDA_INFO: OnceCell<CudaInfo> = OnceCell::new();
DETECTED_CUDA_INFO.get_or_init(detect_cuda_info)
}
pub fn cuda_version() -> Option<Version> {
cuda_info().version.clone()
}
pub fn cuda_arch() -> Option<CudaArchInfo> {
cuda_info().arch_info.clone()
}
fn detect_cuda_info() -> CudaInfo {
if cfg!(target_env = "musl") {
CudaInfo {
version: detect_cuda_version_via_nvidia_smi(),
arch_info: None,
}
} else {
detect_cuda_info_via_libcuda()
}
}
fn detect_cuda_info_via_libcuda() -> CudaInfo {
let cuda_library = match cuda_library_paths()
.iter()
.find_map(|path| unsafe { Library::new(*path).ok() })
{
Some(lib) => lib,
None => {
return CudaInfo {
version: None,
arch_info: None,
};
}
};
let cu_init: Symbol<'_, unsafe extern "C" fn(c_uint) -> c_ulong> =
match unsafe { cuda_library.get(b"cuInit\0") } {
Ok(init) => init,
Err(_) => {
return CudaInfo {
version: None,
arch_info: None,
};
}
};
if unsafe { cu_init(0) } != 0 {
return CudaInfo {
version: None,
arch_info: None,
};
}
let version = detect_cuda_version_from_library(&cuda_library);
let arch_info = detect_cuda_arch_from_library(&cuda_library);
CudaInfo { version, arch_info }
}
fn detect_cuda_version_from_library(cuda_library: &Library) -> Option<Version> {
let cu_driver_get_version: Symbol<'_, unsafe extern "C" fn(*mut c_int) -> c_ulong> =
unsafe { cuda_library.get(b"cuDriverGetVersion\0") }.ok()?;
let mut version_int = MaybeUninit::uninit();
if unsafe { cu_driver_get_version(version_int.as_mut_ptr()) != 0 } {
return None;
}
let version = unsafe { version_int.assume_init() };
Version::from_str(&format!("{}.{}", version / 1000, (version % 1000) / 10)).ok()
}
fn detect_cuda_arch_from_library(cuda_library: &Library) -> Option<CudaArchInfo> {
const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: c_int = 75;
const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: c_int = 76;
let cu_device_get_count: Symbol<'_, unsafe extern "C" fn(*mut c_int) -> c_ulong> =
unsafe { cuda_library.get(b"cuDeviceGetCount\0") }.ok()?;
let cu_device_get: Symbol<'_, unsafe extern "C" fn(*mut c_int, c_int) -> c_ulong> =
unsafe { cuda_library.get(b"cuDeviceGet\0") }.ok()?;
let cu_device_get_attribute: Symbol<
'_,
unsafe extern "C" fn(*mut c_int, c_int, c_int) -> c_ulong,
> = unsafe { cuda_library.get(b"cuDeviceGetAttribute\0") }.ok()?;
let mut device_count = MaybeUninit::uninit();
if unsafe { cu_device_get_count(device_count.as_mut_ptr()) } != 0 {
return None;
}
let device_count = unsafe { device_count.assume_init() };
if device_count == 0 {
return None;
}
let mut min_arch: Option<CudaArchInfo> = None;
for device_idx in 0..device_count {
let mut device = MaybeUninit::uninit();
if unsafe { cu_device_get(device.as_mut_ptr(), device_idx) } != 0 {
continue;
}
let device = unsafe { device.assume_init() };
let mut cc_major = MaybeUninit::uninit();
if unsafe {
cu_device_get_attribute(
cc_major.as_mut_ptr(),
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
device,
)
} != 0
{
continue;
}
let cc_major = unsafe { cc_major.assume_init() } as u32;
let mut cc_minor = MaybeUninit::uninit();
if unsafe {
cu_device_get_attribute(
cc_minor.as_mut_ptr(),
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
device,
)
} != 0
{
continue;
}
let cc_minor = unsafe { cc_minor.assume_init() } as u32;
let is_new_minimum = min_arch.as_ref().is_none_or(|min| {
cc_major < min.major || (cc_major == min.major && cc_minor < min.minor)
});
if is_new_minimum {
min_arch = Some(CudaArchInfo {
major: cc_major,
minor: cc_minor,
});
}
}
min_arch
}
pub fn detect_cuda_version() -> Option<Version> {
if cfg!(target_env = "musl") {
detect_cuda_version_via_nvidia_smi()
} else {
detect_cuda_version_via_nvml()
}
}
pub fn detect_cuda_version_via_nvml() -> Option<Version> {
let library = nvml_library_paths()
.iter()
.find_map(|path| unsafe { libloading::Library::new(*path).ok() })?;
let nvml_init: Symbol<'_, unsafe extern "C" fn() -> c_int> = unsafe {
library
.get(b"nvmlInit_v2\0")
.or_else(|_| library.get(b"nvmlInit\0"))
}
.ok()?;
let nvml_shutdown: Symbol<'_, unsafe extern "C" fn() -> c_int> =
unsafe { library.get(b"nvmlShutdown\0") }.ok()?;
let nvml_system_get_cuda_driver_version: Symbol<'_, unsafe extern "C" fn(*mut c_int) -> c_int> =
unsafe {
library
.get(b"nvmlSystemGetCudaDriverVersion_v2\0")
.or_else(|_| library.get(b"nvmlSystemGetCudaDriverVersion\0"))
}
.ok()?;
if unsafe { nvml_init() } != 0 {
return None;
}
let mut cuda_driver_version = MaybeUninit::uninit();
let result = unsafe { nvml_system_get_cuda_driver_version(cuda_driver_version.as_mut_ptr()) };
let _ = unsafe { nvml_shutdown() };
if result != 0 {
return None;
}
let version = unsafe { cuda_driver_version.assume_init() };
Version::from_str(&format!("{}.{}", version / 1000, (version % 1000) / 10)).ok()
}
fn nvml_library_paths() -> &'static [&'static str] {
#[cfg(target_os = "macos")]
static FILENAMES: &[&str] = &[
"libnvidia-ml.1.dylib", "libnvidia-ml.dylib",
"/usr/local/cuda/lib/libnvidia-ml.1.dylib",
"/usr/local/cuda/lib/libnvidia-ml.dylib",
];
#[cfg(target_os = "linux")]
static FILENAMES: &[&str] = &[
"libnvidia-ml.so.1", "libnvidia-ml.so",
"/usr/lib64/nvidia/libnvidia-ml.so.1", "/usr/lib64/nvidia/libnvidia-ml.so",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1", "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so",
"/usr/lib/wsl/lib/libnvidia-ml.so.1", "/usr/lib/wsl/lib/libnvidia-ml.so",
];
#[cfg(windows)]
static FILENAMES: &[&str] = &["nvml.dll"];
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
static FILENAMES: &[&str] = &[];
FILENAMES
}
pub fn detect_cuda_version_via_libcuda() -> Option<Version> {
let cuda_library = cuda_library_paths()
.iter()
.find_map(|path| unsafe { libloading::Library::new(*path).ok() })?;
let cu_init: Symbol<'_, unsafe extern "C" fn(c_uint) -> c_ulong> =
unsafe { cuda_library.get(b"cuInit\0") }.ok()?;
let cu_driver_get_version: Symbol<'_, unsafe extern "C" fn(*mut c_int) -> c_ulong> =
unsafe { cuda_library.get(b"cuDriverGetVersion\0") }.ok()?;
if unsafe { cu_init(0) } != 0 {
return None;
}
let mut version_int = MaybeUninit::uninit();
if unsafe { cu_driver_get_version(version_int.as_mut_ptr()) != 0 } {
return None;
}
let version = unsafe { version_int.assume_init() };
Version::from_str(&format!("{}.{}", version / 1000, (version % 1000) / 10)).ok()
}
fn cuda_library_paths() -> &'static [&'static str] {
#[cfg(target_os = "macos")]
static FILENAMES: &[&str] = &[
"libcuda.1.dylib", "libcuda.dylib",
"/usr/local/cuda/lib/libcuda.1.dylib",
"/usr/local/cuda/lib/libcuda.dylib",
];
#[cfg(target_os = "linux")]
static FILENAMES: &[&str] = &[
"libcuda.so.1", "libcuda.so",
"/usr/lib64/nvidia/libcuda.so.1", "/usr/lib64/nvidia/libcuda.so",
"/usr/lib/x86_64-linux-gnu/libcuda.so.1", "/usr/lib/x86_64-linux-gnu/libcuda.so",
"/usr/lib/wsl/lib/libcuda.so.1", "/usr/lib/wsl/lib/libcuda.so",
];
#[cfg(windows)]
static FILENAMES: &[&str] = &["nvcuda.dll"];
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
static FILENAMES: &[&str] = &[];
FILENAMES
}
fn detect_cuda_version_via_nvidia_smi() -> Option<Version> {
static CUDA_VERSION_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| {
regex::Regex::new("<cuda_version>(.*)<\\/cuda_version>").unwrap()
});
let nvidia_smi_output = Command::new("nvidia-smi")
.arg("--query")
.arg("-u")
.arg("-x")
.env_remove("CUDA_VISIBLE_DEVICES")
.output()
.ok()?;
let output = String::from_utf8_lossy(&nvidia_smi_output.stdout);
let version_match = CUDA_VERSION_RE.captures(&output)?;
let version_str = version_match.get(1)?.as_str();
Version::from_str(version_str).ok()
}
#[cfg(test)]
mod test {
use super::*;
#[test]
pub fn doesnt_crash() {
let version = detect_cuda_version_via_nvml();
println!("Cuda {version:?}");
}
#[test]
pub fn doesnt_crash_nvidia_smi() {
let version = detect_cuda_version_via_nvidia_smi();
println!("Cuda {version:?}");
}
#[test]
pub fn test_cuda_info() {
let info = cuda_info();
println!("CUDA Info: {info:?}");
if let Some(ref arch) = info.arch_info {
println!(" Compute capability: {}.{}", arch.major, arch.minor);
}
}
#[test]
pub fn test_cuda_arch() {
let arch = cuda_arch();
println!("CUDA Arch: {arch:?}");
}
#[test]
fn test_is_valid_cuda_version_format() {
assert!(is_valid_cuda_version_format("8.6"));
assert!(is_valid_cuda_version_format("7.5"));
assert!(is_valid_cuda_version_format("10.2"));
assert!(is_valid_cuda_version_format("0.0"));
assert!(is_valid_cuda_version_format("12.0"));
assert!(!is_valid_cuda_version_format("8"));
assert!(!is_valid_cuda_version_format("8.6.1"));
assert!(!is_valid_cuda_version_format("8.6.1.0"));
assert!(!is_valid_cuda_version_format(""));
assert!(!is_valid_cuda_version_format(".6"));
assert!(!is_valid_cuda_version_format("8."));
assert!(!is_valid_cuda_version_format("."));
assert!(!is_valid_cuda_version_format("8.6a"));
assert!(!is_valid_cuda_version_format("a.6"));
assert!(!is_valid_cuda_version_format("8.b"));
assert!(!is_valid_cuda_version_format("eight.six"));
assert!(!is_valid_cuda_version_format("8-6"));
assert!(!is_valid_cuda_version_format("8_6"));
}
}