use libloading::Symbol;
use once_cell::sync::OnceCell;
use rattler_conda_types::Version;
use std::process::Command;
use std::{
mem::MaybeUninit,
os::raw::{c_int, c_uint, c_ulong},
str::FromStr,
};
pub fn cuda_version() -> Option<Version> {
static DETECTED_CUDA_VERSION: OnceCell<Option<Version>> = OnceCell::new();
DETECTED_CUDA_VERSION
.get_or_init(detect_cuda_version)
.clone()
}
pub fn detect_cuda_version() -> Option<Version> {
if cfg!(target_env = "musl") {
detect_cuda_version_via_nvidia_smi()
} else {
detect_cuda_version_via_nvml()
}
}
pub fn detect_cuda_version_via_nvml() -> Option<Version> {
let library = nvml_library_paths()
.iter()
.find_map(|path| unsafe { libloading::Library::new(*path).ok() })?;
let nvml_init: Symbol<'_, unsafe extern "C" fn() -> c_int> = unsafe {
library
.get(b"nvmlInit_v2\0")
.or_else(|_| library.get(b"nvmlInit\0"))
}
.ok()?;
let nvml_shutdown: Symbol<'_, unsafe extern "C" fn() -> c_int> =
unsafe { library.get(b"nvmlShutdown\0") }.ok()?;
let nvml_system_get_cuda_driver_version: Symbol<'_, unsafe extern "C" fn(*mut c_int) -> c_int> =
unsafe {
library
.get(b"nvmlSystemGetCudaDriverVersion_v2\0")
.or_else(|_| library.get(b"nvmlSystemGetCudaDriverVersion\0"))
}
.ok()?;
if unsafe { nvml_init() } != 0 {
return None;
}
let mut cuda_driver_version = MaybeUninit::uninit();
let result = unsafe { nvml_system_get_cuda_driver_version(cuda_driver_version.as_mut_ptr()) };
let _ = unsafe { nvml_shutdown() };
if result != 0 {
return None;
}
let version = unsafe { cuda_driver_version.assume_init() };
Version::from_str(&format!("{}.{}", version / 1000, (version % 1000) / 10)).ok()
}
fn nvml_library_paths() -> &'static [&'static str] {
#[cfg(target_os = "macos")]
static FILENAMES: &[&str] = &[
"libnvidia-ml.1.dylib", "libnvidia-ml.dylib",
"/usr/local/cuda/lib/libnvidia-ml.1.dylib",
"/usr/local/cuda/lib/libnvidia-ml.dylib",
];
#[cfg(target_os = "linux")]
static FILENAMES: &[&str] = &[
"libnvidia-ml.so.1", "libnvidia-ml.so",
"/usr/lib64/nvidia/libnvidia-ml.so.1", "/usr/lib64/nvidia/libnvidia-ml.so",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1", "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so",
"/usr/lib/wsl/lib/libnvidia-ml.so.1", "/usr/lib/wsl/lib/libnvidia-ml.so",
];
#[cfg(windows)]
static FILENAMES: &[&str] = &["nvml.dll"];
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
static FILENAMES: &[&str] = &[];
FILENAMES
}
pub fn detect_cuda_version_via_libcuda() -> Option<Version> {
let cuda_library = cuda_library_paths()
.iter()
.find_map(|path| unsafe { libloading::Library::new(*path).ok() })?;
let cu_init: Symbol<'_, unsafe extern "C" fn(c_uint) -> c_ulong> =
unsafe { cuda_library.get(b"cuInit\0") }.ok()?;
let cu_driver_get_version: Symbol<'_, unsafe extern "C" fn(*mut c_int) -> c_ulong> =
unsafe { cuda_library.get(b"cuDriverGetVersion\0") }.ok()?;
if unsafe { cu_init(0) } != 0 {
return None;
}
let mut version_int = MaybeUninit::uninit();
if unsafe { cu_driver_get_version(version_int.as_mut_ptr()) != 0 } {
return None;
}
let version = unsafe { version_int.assume_init() };
Version::from_str(&format!("{}.{}", version / 1000, (version % 1000) / 10)).ok()
}
fn cuda_library_paths() -> &'static [&'static str] {
#[cfg(target_os = "macos")]
static FILENAMES: &[&str] = &[
"libcuda.1.dylib", "libcuda.dylib",
"/usr/local/cuda/lib/libcuda.1.dylib",
"/usr/local/cuda/lib/libcuda.dylib",
];
#[cfg(target_os = "linux")]
static FILENAMES: &[&str] = &[
"libcuda.so.1", "libcuda.so",
"/usr/lib64/nvidia/libcuda.so.1", "/usr/lib64/nvidia/libcuda.so",
"/usr/lib/x86_64-linux-gnu/libcuda.so.1", "/usr/lib/x86_64-linux-gnu/libcuda.so",
"/usr/lib/wsl/lib/libcuda.so.1", "/usr/lib/wsl/lib/libcuda.so",
];
#[cfg(windows)]
static FILENAMES: &[&str] = &["nvcuda.dll"];
#[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
static FILENAMES: &[&str] = &[];
FILENAMES
}
fn detect_cuda_version_via_nvidia_smi() -> Option<Version> {
static CUDA_VERSION_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| {
regex::Regex::new("<cuda_version>(.*)<\\/cuda_version>").unwrap()
});
let nvidia_smi_output = Command::new("nvidia-smi")
.arg("--query")
.arg("-u")
.arg("-x")
.env_remove("CUDA_VISIBLE_DEVICES")
.output()
.ok()?;
let output = String::from_utf8_lossy(&nvidia_smi_output.stdout);
let version_match = CUDA_VERSION_RE.captures(&output)?;
let version_str = version_match.get(1)?.as_str();
Version::from_str(version_str).ok()
}
#[cfg(test)]
mod test {
use super::*;
#[test]
pub fn doesnt_crash() {
let version = detect_cuda_version_via_nvml();
println!("Cuda {version:?}");
}
#[test]
pub fn doesnt_crash_nvidia_smi() {
let version = detect_cuda_version_via_nvidia_smi();
println!("Cuda {version:?}");
}
}