use std::process::{Command, Stdio};
use std::time::{Duration, Instant};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum GpuDetectionError {
#[error("No GPU detected")]
NoGpu,
#[error("Detection command failed: {0}")]
CommandFailed(String),
#[error("Failed to parse GPU info: {0}")]
ParseError(String),
#[error("Detection timed out after {0:?}")]
Timeout(Duration),
}
pub type Result<T> = std::result::Result<T, GpuDetectionError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GpuVendor {
Nvidia,
Amd,
Intel,
Apple,
Unknown,
}
#[derive(Debug, Clone)]
pub struct GpuInfo {
pub device_id: u32,
pub vendor: GpuVendor,
pub name: String,
pub vram_bytes: u64,
pub vram_free_bytes: u64,
pub compute_capability: Option<String>,
pub driver_version: Option<String>,
}
impl GpuInfo {
pub fn vram_gb(&self) -> f64 {
self.vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
}
pub fn vram_free_gb(&self) -> f64 {
self.vram_free_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
}
pub fn vram_utilization(&self) -> f32 {
if self.vram_bytes == 0 {
return 0.0;
}
let used = self.vram_bytes.saturating_sub(self.vram_free_bytes);
used as f32 / self.vram_bytes as f32
}
}
#[derive(Debug, Clone)]
pub struct GpuDetectionResult {
pub gpus: Vec<GpuInfo>,
pub total_vram_bytes: u64,
pub detection_method: DetectionMethod,
}
impl GpuDetectionResult {
pub fn primary(&self) -> Option<&GpuInfo> {
self.gpus.first()
}
pub fn total_vram_gb(&self) -> f64 {
self.total_vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
}
pub fn has_gpu(&self) -> bool {
!self.gpus.is_empty()
}
pub fn none() -> Self {
Self {
gpus: vec![],
total_vram_bytes: 0,
detection_method: DetectionMethod::None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DetectionMethod {
NvidiaSmi,
RocmSmi,
AppleMetal,
System,
None,
}
pub struct GpuDetector {
timeout: Duration,
}
impl Default for GpuDetector {
fn default() -> Self {
Self::new()
}
}
impl GpuDetector {
pub fn new() -> Self {
Self {
timeout: Duration::from_secs(5),
}
}
pub fn with_timeout(timeout: Duration) -> Self {
Self { timeout }
}
fn run_with_timeout(&self, cmd: &mut Command) -> Result<std::process::Output> {
let mut child = cmd
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| GpuDetectionError::CommandFailed(e.to_string()))?;
let start = Instant::now();
loop {
match child.try_wait() {
Ok(Some(_)) => {
return child
.wait_with_output()
.map_err(|e| GpuDetectionError::CommandFailed(e.to_string()));
},
Ok(None) => {
if start.elapsed() >= self.timeout {
let _ = child.kill();
let _ = child.wait(); return Err(GpuDetectionError::Timeout(self.timeout));
}
std::thread::sleep(Duration::from_millis(50));
},
Err(e) => {
return Err(GpuDetectionError::CommandFailed(e.to_string()));
},
}
}
}
pub fn detect(&self) -> Result<GpuDetectionResult> {
if let Ok(result) = self.detect_nvidia() {
if result.has_gpu() {
return Ok(result);
}
}
if let Ok(result) = self.detect_amd() {
if result.has_gpu() {
return Ok(result);
}
}
#[cfg(target_os = "macos")]
if let Ok(result) = self.detect_apple() {
if result.has_gpu() {
return Ok(result);
}
}
Err(GpuDetectionError::NoGpu)
}
pub fn detect_or_default(&self, default_vram_bytes: u64) -> GpuDetectionResult {
match self.detect() {
Ok(result) => result,
Err(_) => GpuDetectionResult {
gpus: vec![GpuInfo {
device_id: 0,
vendor: GpuVendor::Unknown,
name: "Unknown GPU".to_string(),
vram_bytes: default_vram_bytes,
vram_free_bytes: default_vram_bytes,
compute_capability: None,
driver_version: None,
}],
total_vram_bytes: default_vram_bytes,
detection_method: DetectionMethod::None,
},
}
}
fn detect_nvidia(&self) -> Result<GpuDetectionResult> {
let output = self.run_with_timeout(Command::new("nvidia-smi").args([
"--query-gpu=index,name,memory.total,memory.free,driver_version,compute_cap",
"--format=csv,noheader,nounits",
]))?;
if !output.status.success() {
return Err(GpuDetectionError::CommandFailed(
String::from_utf8_lossy(&output.stderr).to_string(),
));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let mut gpus = Vec::new();
let mut total_vram = 0u64;
for line in stdout.lines() {
if line.trim().is_empty() {
continue;
}
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
if parts.len() < 4 {
continue;
}
let device_id = parts[0]
.parse::<u32>()
.map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
let name = parts[1].to_string();
let vram_mib = parts[2]
.parse::<u64>()
.map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
let vram_bytes = vram_mib * 1024 * 1024;
let vram_free_mib = parts[3]
.parse::<u64>()
.map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
let vram_free_bytes = vram_free_mib * 1024 * 1024;
let driver_version = parts.get(4).map(|s| s.to_string());
let compute_capability = parts.get(5).map(|s| s.to_string());
total_vram += vram_bytes;
gpus.push(GpuInfo {
device_id,
vendor: GpuVendor::Nvidia,
name,
vram_bytes,
vram_free_bytes,
compute_capability,
driver_version,
});
}
Ok(GpuDetectionResult {
gpus,
total_vram_bytes: total_vram,
detection_method: DetectionMethod::NvidiaSmi,
})
}
fn detect_amd(&self) -> Result<GpuDetectionResult> {
let output = self.run_with_timeout(Command::new("rocm-smi").args([
"--showmeminfo",
"vram",
"--json",
]))?;
if !output.status.success() {
return Err(GpuDetectionError::CommandFailed(
String::from_utf8_lossy(&output.stderr).to_string(),
));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let mut gpus = Vec::new();
let mut total_vram = 0u64;
if stdout.contains("card") || stdout.contains("GPU") {
gpus.push(GpuInfo {
device_id: 0,
vendor: GpuVendor::Amd,
name: "AMD GPU".to_string(),
vram_bytes: 16 * 1024 * 1024 * 1024, vram_free_bytes: 16 * 1024 * 1024 * 1024,
compute_capability: None,
driver_version: None,
});
total_vram = 16 * 1024 * 1024 * 1024;
}
Ok(GpuDetectionResult {
gpus,
total_vram_bytes: total_vram,
detection_method: DetectionMethod::RocmSmi,
})
}
#[cfg(target_os = "macos")]
fn detect_apple(&self) -> Result<GpuDetectionResult> {
let output = self.run_with_timeout(
Command::new("system_profiler").args(["SPDisplaysDataType", "-json"]),
)?;
if !output.status.success() {
return Err(GpuDetectionError::CommandFailed(
String::from_utf8_lossy(&output.stderr).to_string(),
));
}
let sysctl_output =
self.run_with_timeout(Command::new("sysctl").args(["-n", "hw.memsize"]))?;
let total_ram = String::from_utf8_lossy(&sysctl_output.stdout)
.trim()
.parse::<u64>()
.unwrap_or(16 * 1024 * 1024 * 1024);
let gpu_memory = (total_ram as f64 * 0.75) as u64;
Ok(GpuDetectionResult {
gpus: vec![GpuInfo {
device_id: 0,
vendor: GpuVendor::Apple,
name: "Apple Silicon GPU".to_string(),
vram_bytes: gpu_memory,
vram_free_bytes: gpu_memory,
compute_capability: None,
driver_version: None,
}],
total_vram_bytes: gpu_memory,
detection_method: DetectionMethod::AppleMetal,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_info_vram_gb() {
let info = GpuInfo {
device_id: 0,
vendor: GpuVendor::Nvidia,
name: "Test GPU".to_string(),
vram_bytes: 24 * 1024 * 1024 * 1024, vram_free_bytes: 20 * 1024 * 1024 * 1024,
compute_capability: None,
driver_version: None,
};
assert!((info.vram_gb() - 24.0).abs() < 0.01);
assert!((info.vram_free_gb() - 20.0).abs() < 0.01);
}
#[test]
fn test_gpu_info_utilization() {
let info = GpuInfo {
device_id: 0,
vendor: GpuVendor::Nvidia,
name: "Test GPU".to_string(),
vram_bytes: 10 * 1024 * 1024 * 1024, vram_free_bytes: 4 * 1024 * 1024 * 1024, compute_capability: None,
driver_version: None,
};
assert!((info.vram_utilization() - 0.6).abs() < 0.01);
}
#[test]
fn test_gpu_info_utilization_zero_vram() {
let info = GpuInfo {
device_id: 0,
vendor: GpuVendor::Unknown,
name: "Test GPU".to_string(),
vram_bytes: 0,
vram_free_bytes: 0,
compute_capability: None,
driver_version: None,
};
assert_eq!(info.vram_utilization(), 0.0);
}
#[test]
fn test_detection_result_primary() {
let result = GpuDetectionResult {
gpus: vec![
GpuInfo {
device_id: 0,
vendor: GpuVendor::Nvidia,
name: "GPU 0".to_string(),
vram_bytes: 24 * 1024 * 1024 * 1024,
vram_free_bytes: 24 * 1024 * 1024 * 1024,
compute_capability: Some("8.9".to_string()),
driver_version: None,
},
GpuInfo {
device_id: 1,
vendor: GpuVendor::Nvidia,
name: "GPU 1".to_string(),
vram_bytes: 24 * 1024 * 1024 * 1024,
vram_free_bytes: 24 * 1024 * 1024 * 1024,
compute_capability: Some("8.9".to_string()),
driver_version: None,
},
],
total_vram_bytes: 48 * 1024 * 1024 * 1024,
detection_method: DetectionMethod::NvidiaSmi,
};
assert!(result.has_gpu());
assert_eq!(result.primary().map(|g| g.device_id), Some(0));
assert!((result.total_vram_gb() - 48.0).abs() < 0.01);
}
#[test]
fn test_detection_result_none() {
let result = GpuDetectionResult::none();
assert!(!result.has_gpu());
assert!(result.primary().is_none());
assert_eq!(result.total_vram_bytes, 0);
}
#[test]
fn test_detector_fallback_on_failure() {
let detector = GpuDetector::new();
let default_vram = 8 * 1024 * 1024 * 1024;
let result = detector.detect_or_default(default_vram);
assert!(!result.gpus.is_empty());
if result.detection_method == DetectionMethod::None {
assert_eq!(result.total_vram_bytes, default_vram);
}
}
#[test]
fn test_detector_nvidia_parsing() {
let sample_line = "0, NVIDIA GeForce RTX 4090, 24564, 23000, 545.23.08, 8.9";
let parts: Vec<&str> = sample_line.split(',').map(|s| s.trim()).collect();
assert_eq!(parts[0], "0");
assert_eq!(parts[1], "NVIDIA GeForce RTX 4090");
assert_eq!(parts[2].parse::<u64>().ok(), Some(24564)); assert_eq!(parts[3].parse::<u64>().ok(), Some(23000)); assert_eq!(parts[4], "545.23.08");
assert_eq!(parts[5], "8.9");
}
#[test]
fn test_gpu_vendor_equality() {
assert_eq!(GpuVendor::Nvidia, GpuVendor::Nvidia);
assert_ne!(GpuVendor::Nvidia, GpuVendor::Amd);
}
#[test]
fn test_detector_with_timeout() {
let detector = GpuDetector::with_timeout(Duration::from_secs(10));
assert_eq!(detector.timeout, Duration::from_secs(10));
}
#[test]
fn test_nvidia_detection_real() {
let detector = GpuDetector::new();
match detector.detect_nvidia() {
Ok(result) => {
for gpu in &result.gpus {
assert_eq!(gpu.vendor, GpuVendor::Nvidia);
assert!(gpu.vram_bytes > 0);
assert!(gpu.vram_free_bytes <= gpu.vram_bytes);
assert!(!gpu.name.is_empty());
}
assert_eq!(result.detection_method, DetectionMethod::NvidiaSmi);
},
Err(GpuDetectionError::CommandFailed(_)) => {
},
Err(e) => {
panic!("Unexpected error: {}", e);
},
}
}
#[test]
fn test_detect_all_graceful() {
let detector = GpuDetector::new();
match detector.detect() {
Ok(result) => {
assert!(result.has_gpu());
assert!(result.total_vram_bytes > 0);
},
Err(GpuDetectionError::NoGpu) => {
},
Err(e) => {
panic!("Unexpected detection error: {}", e);
},
}
}
}