use std::ffi::{CStr, c_char, c_int, c_uint, c_void};
use std::ptr;
use super::types::{GpuDevice, GpuVendor};
type CUresult = c_int;
type CUdevice = c_int;
type FnCuInit = unsafe extern "C" fn(c_uint) -> CUresult;
type FnCuDeviceGetCount = unsafe extern "C" fn(*mut c_int) -> CUresult;
type FnCuDeviceGet = unsafe extern "C" fn(*mut CUdevice, c_int) -> CUresult;
type FnCuDeviceGetName = unsafe extern "C" fn(*mut c_char, c_int, CUdevice) -> CUresult;
pub(super) fn detect_nvidia() -> Vec<GpuDevice> {
let lib = unsafe { libloading::Library::new("libcuda.so") }
.or_else(|_| unsafe { libloading::Library::new("libcuda.so.1") })
.or_else(|_| unsafe { libloading::Library::new("nvcuda.dll") });
let Ok(lib) = lib else { return Vec::new() };
unsafe {
let cu_init: libloading::Symbol<FnCuInit> = match lib.get(b"cuInit") {
Ok(f) => f,
Err(_) => return Vec::new(),
};
if cu_init(0) != 0 {
return Vec::new();
}
let cu_device_get_count: libloading::Symbol<FnCuDeviceGetCount> =
match lib.get(b"cuDeviceGetCount") {
Ok(f) => f,
Err(_) => return Vec::new(),
};
let mut count: c_int = 0;
if cu_device_get_count(&mut count) != 0 || count <= 0 {
return Vec::new();
}
let cu_device_get: libloading::Symbol<FnCuDeviceGet> = match lib.get(b"cuDeviceGet") {
Ok(f) => f,
Err(_) => return Vec::new(),
};
let cu_device_get_name: libloading::Symbol<FnCuDeviceGetName> =
match lib.get(b"cuDeviceGetName") {
Ok(f) => f,
Err(_) => return Vec::new(),
};
let mut devices = Vec::with_capacity(count as usize);
for ordinal in 0..count {
let mut dev: CUdevice = 0;
if cu_device_get(&mut dev, ordinal) != 0 {
continue;
}
let mut name_buf = [0i8; 256];
let name = if cu_device_get_name(
name_buf.as_mut_ptr() as *mut c_char,
name_buf.len() as c_int,
dev,
) == 0
{
CStr::from_ptr(name_buf.as_ptr() as *const c_char)
.to_string_lossy()
.into_owned()
} else {
format!("NVIDIA GPU {ordinal}")
};
let nvml_lookup = nvidia_nvml_lookup(ordinal as u32);
let generation = nvidia_generation_from_name(&name);
devices.push(GpuDevice {
vendor: GpuVendor::Nvidia,
name,
index: ordinal as u32,
vendor_index: ordinal as u32,
generation,
pci_id: nvml_lookup.pci_id,
vram_mib: nvml_lookup.vram_mib,
serial: nvml_lookup.serial,
host_pci_address: nvml_lookup.host_pci_address,
vendor_id_hex: "0x10de".into(),
});
}
let _ = ptr::null::<c_void>();
devices
}
}
pub(super) fn init_nvml_with_fallback(
) -> Result<nvml_wrapper::Nvml, nvml_wrapper::error::NvmlError> {
match nvml_wrapper::Nvml::init() {
Ok(n) => Ok(n),
Err(_) => nvml_wrapper::Nvml::builder()
.lib_path(std::ffi::OsStr::new("libnvidia-ml.so.1"))
.init(),
}
}
#[derive(Debug, Clone, Default)]
struct NvmlLookup {
pci_id: String,
vram_mib: u64,
serial: Option<String>,
host_pci_address: String,
}
fn nvidia_nvml_lookup(ordinal: u32) -> NvmlLookup {
let nvml = match init_nvml_with_fallback() {
Ok(n) => n,
Err(_) => return NvmlLookup::default(),
};
let device = match nvml.device_by_index(ordinal) {
Ok(d) => d,
Err(_) => return NvmlLookup::default(),
};
let (pci_id, host_pci_address) = match device.pci_info() {
Ok(p) => {
let id = format!(
"0x{:04x}:0x{:04x}",
p.pci_device_id >> 16,
p.pci_device_id & 0xFFFF
);
let bus = p
.bus_id
.trim_start_matches('0')
.trim_start_matches(':')
.to_string();
let host_pci = if bus.is_empty() {
p.bus_id.clone()
} else {
bus
};
(id, host_pci)
}
Err(_) => (String::new(), String::new()),
};
let vram_mib = match device.memory_info() {
Ok(m) => m.total / 1024 / 1024,
Err(_) => 0,
};
let serial = match device.serial() {
Ok(s) => {
let trimmed = s.trim();
if trimmed.is_empty() || trimmed == "0" {
None
} else {
Some(trimmed.to_string())
}
}
Err(e) => {
tracing::debug!(error = %e, ordinal, "nvml serial unavailable");
None
}
};
NvmlLookup {
pci_id,
vram_mib,
serial,
host_pci_address,
}
}
fn nvidia_generation_from_name(name: &str) -> String {
let n = name.to_lowercase();
if n.contains("rtx 50")
|| n.contains("5050")
|| n.contains("5060")
|| n.contains("5070")
|| n.contains("5080")
|| n.contains("5090")
|| n.contains("b100")
|| n.contains("b200")
|| n.contains("gb200")
{
return "Blackwell".into();
}
if n.contains("h100") || n.contains("h200") {
return "Hopper".into();
}
if n.contains("rtx 40")
|| n.contains("4060")
|| n.contains("4070")
|| n.contains("4080")
|| n.contains("4090")
|| n.contains("ada")
|| n.contains("l4")
|| n.contains("l40")
{
return "Ada Lovelace".into();
}
if n.contains("rtx 30")
|| n.contains("3050")
|| n.contains("3060")
|| n.contains("3070")
|| n.contains("3080")
|| n.contains("3090")
|| n.contains("a10")
|| n.contains("a100")
|| n.contains("ampere")
{
return "Ampere".into();
}
if n.contains("rtx 20")
|| n.contains("2060")
|| n.contains("2070")
|| n.contains("2080")
|| n.contains(" t4")
|| n.contains("turing")
{
return "Turing".into();
}
if n.contains("gtx 10")
|| n.contains("1050")
|| n.contains("1060")
|| n.contains("1070")
|| n.contains("1080")
|| n.contains("p100")
|| n.contains("p40")
|| n.contains("pascal")
{
return "Pascal".into();
}
"Unknown".into()
}