1use std::process::{Command, Stdio};
9use std::time::{Duration, Instant};
10use thiserror::Error;
11
12#[derive(Debug, Error)]
14pub enum GpuDetectionError {
15 #[error("No GPU detected")]
17 NoGpu,
18
19 #[error("Detection command failed: {0}")]
21 CommandFailed(String),
22
23 #[error("Failed to parse GPU info: {0}")]
25 ParseError(String),
26
27 #[error("Detection timed out after {0:?}")]
29 Timeout(Duration),
30}
31
32pub type Result<T> = std::result::Result<T, GpuDetectionError>;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum GpuVendor {
38 Nvidia,
40 Amd,
42 Intel,
44 Apple,
46 Unknown,
48}
49
50#[derive(Debug, Clone)]
52pub struct GpuInfo {
53 pub device_id: u32,
55
56 pub vendor: GpuVendor,
58
59 pub name: String,
61
62 pub vram_bytes: u64,
64
65 pub vram_free_bytes: u64,
67
68 pub compute_capability: Option<String>,
70
71 pub driver_version: Option<String>,
73}
74
75impl GpuInfo {
76 pub fn vram_gb(&self) -> f64 {
78 self.vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
79 }
80
81 pub fn vram_free_gb(&self) -> f64 {
83 self.vram_free_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
84 }
85
86 pub fn vram_utilization(&self) -> f32 {
88 if self.vram_bytes == 0 {
89 return 0.0;
90 }
91 let used = self.vram_bytes.saturating_sub(self.vram_free_bytes);
92 used as f32 / self.vram_bytes as f32
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct GpuDetectionResult {
99 pub gpus: Vec<GpuInfo>,
101
102 pub total_vram_bytes: u64,
104
105 pub detection_method: DetectionMethod,
107}
108
109impl GpuDetectionResult {
110 pub fn primary(&self) -> Option<&GpuInfo> {
112 self.gpus.first()
113 }
114
115 pub fn total_vram_gb(&self) -> f64 {
117 self.total_vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
118 }
119
120 pub fn has_gpu(&self) -> bool {
122 !self.gpus.is_empty()
123 }
124
125 pub fn none() -> Self {
127 Self {
128 gpus: vec![],
129 total_vram_bytes: 0,
130 detection_method: DetectionMethod::None,
131 }
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137pub enum DetectionMethod {
138 NvidiaSmi,
140 RocmSmi,
142 AppleMetal,
144 System,
146 None,
148}
149
150pub struct GpuDetector {
152 timeout: Duration,
154}
155
156impl Default for GpuDetector {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162impl GpuDetector {
163 pub fn new() -> Self {
165 Self {
166 timeout: Duration::from_secs(5),
167 }
168 }
169
170 pub fn with_timeout(timeout: Duration) -> Self {
172 Self { timeout }
173 }
174
175 fn run_with_timeout(&self, cmd: &mut Command) -> Result<std::process::Output> {
181 let mut child = cmd
182 .stdout(Stdio::piped())
183 .stderr(Stdio::piped())
184 .spawn()
185 .map_err(|e| GpuDetectionError::CommandFailed(e.to_string()))?;
186
187 let start = Instant::now();
188 loop {
189 match child.try_wait() {
190 Ok(Some(_)) => {
191 return child
193 .wait_with_output()
194 .map_err(|e| GpuDetectionError::CommandFailed(e.to_string()));
195 },
196 Ok(None) => {
197 if start.elapsed() >= self.timeout {
198 let _ = child.kill();
199 let _ = child.wait(); return Err(GpuDetectionError::Timeout(self.timeout));
201 }
202 std::thread::sleep(Duration::from_millis(50));
203 },
204 Err(e) => {
205 return Err(GpuDetectionError::CommandFailed(e.to_string()));
206 },
207 }
208 }
209 }
210
211 pub fn detect(&self) -> Result<GpuDetectionResult> {
213 if let Ok(result) = self.detect_nvidia() {
215 if result.has_gpu() {
216 return Ok(result);
217 }
218 }
219
220 if let Ok(result) = self.detect_amd() {
222 if result.has_gpu() {
223 return Ok(result);
224 }
225 }
226
227 #[cfg(target_os = "macos")]
229 if let Ok(result) = self.detect_apple() {
230 if result.has_gpu() {
231 return Ok(result);
232 }
233 }
234
235 Err(GpuDetectionError::NoGpu)
237 }
238
239 pub fn detect_or_default(&self, default_vram_bytes: u64) -> GpuDetectionResult {
241 match self.detect() {
242 Ok(result) => result,
243 Err(_) => GpuDetectionResult {
244 gpus: vec![GpuInfo {
245 device_id: 0,
246 vendor: GpuVendor::Unknown,
247 name: "Unknown GPU".to_string(),
248 vram_bytes: default_vram_bytes,
249 vram_free_bytes: default_vram_bytes,
250 compute_capability: None,
251 driver_version: None,
252 }],
253 total_vram_bytes: default_vram_bytes,
254 detection_method: DetectionMethod::None,
255 },
256 }
257 }
258
259 fn detect_nvidia(&self) -> Result<GpuDetectionResult> {
261 let output = self.run_with_timeout(Command::new("nvidia-smi").args([
262 "--query-gpu=index,name,memory.total,memory.free,driver_version,compute_cap",
263 "--format=csv,noheader,nounits",
264 ]))?;
265
266 if !output.status.success() {
267 return Err(GpuDetectionError::CommandFailed(
268 String::from_utf8_lossy(&output.stderr).to_string(),
269 ));
270 }
271
272 let stdout = String::from_utf8_lossy(&output.stdout);
273 let mut gpus = Vec::new();
274 let mut total_vram = 0u64;
275
276 for line in stdout.lines() {
277 if line.trim().is_empty() {
278 continue;
279 }
280
281 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
282 if parts.len() < 4 {
283 continue;
284 }
285
286 let device_id = parts[0]
287 .parse::<u32>()
288 .map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
289
290 let name = parts[1].to_string();
291
292 let vram_mib = parts[2]
294 .parse::<u64>()
295 .map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
296 let vram_bytes = vram_mib * 1024 * 1024;
297
298 let vram_free_mib = parts[3]
299 .parse::<u64>()
300 .map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
301 let vram_free_bytes = vram_free_mib * 1024 * 1024;
302
303 let driver_version = parts.get(4).map(|s| s.to_string());
304 let compute_capability = parts.get(5).map(|s| s.to_string());
305
306 total_vram += vram_bytes;
307
308 gpus.push(GpuInfo {
309 device_id,
310 vendor: GpuVendor::Nvidia,
311 name,
312 vram_bytes,
313 vram_free_bytes,
314 compute_capability,
315 driver_version,
316 });
317 }
318
319 Ok(GpuDetectionResult {
320 gpus,
321 total_vram_bytes: total_vram,
322 detection_method: DetectionMethod::NvidiaSmi,
323 })
324 }
325
326 fn detect_amd(&self) -> Result<GpuDetectionResult> {
328 let output = self.run_with_timeout(Command::new("rocm-smi").args([
329 "--showmeminfo",
330 "vram",
331 "--json",
332 ]))?;
333
334 if !output.status.success() {
335 return Err(GpuDetectionError::CommandFailed(
336 String::from_utf8_lossy(&output.stderr).to_string(),
337 ));
338 }
339
340 let stdout = String::from_utf8_lossy(&output.stdout);
343
344 let mut gpus = Vec::new();
346 let mut total_vram = 0u64;
347
348 if stdout.contains("card") || stdout.contains("GPU") {
351 gpus.push(GpuInfo {
352 device_id: 0,
353 vendor: GpuVendor::Amd,
354 name: "AMD GPU".to_string(),
355 vram_bytes: 16 * 1024 * 1024 * 1024, vram_free_bytes: 16 * 1024 * 1024 * 1024,
357 compute_capability: None,
358 driver_version: None,
359 });
360 total_vram = 16 * 1024 * 1024 * 1024;
361 }
362
363 Ok(GpuDetectionResult {
364 gpus,
365 total_vram_bytes: total_vram,
366 detection_method: DetectionMethod::RocmSmi,
367 })
368 }
369
370 #[cfg(target_os = "macos")]
372 fn detect_apple(&self) -> Result<GpuDetectionResult> {
373 let output = self.run_with_timeout(
375 Command::new("system_profiler").args(["SPDisplaysDataType", "-json"]),
376 )?;
377
378 if !output.status.success() {
379 return Err(GpuDetectionError::CommandFailed(
380 String::from_utf8_lossy(&output.stderr).to_string(),
381 ));
382 }
383
384 let sysctl_output =
387 self.run_with_timeout(Command::new("sysctl").args(["-n", "hw.memsize"]))?;
388
389 let total_ram = String::from_utf8_lossy(&sysctl_output.stdout)
390 .trim()
391 .parse::<u64>()
392 .unwrap_or(16 * 1024 * 1024 * 1024);
393
394 let gpu_memory = (total_ram as f64 * 0.75) as u64;
396
397 Ok(GpuDetectionResult {
398 gpus: vec![GpuInfo {
399 device_id: 0,
400 vendor: GpuVendor::Apple,
401 name: "Apple Silicon GPU".to_string(),
402 vram_bytes: gpu_memory,
403 vram_free_bytes: gpu_memory,
404 compute_capability: None,
405 driver_version: None,
406 }],
407 total_vram_bytes: gpu_memory,
408 detection_method: DetectionMethod::AppleMetal,
409 })
410 }
411}
412
413#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_gpu_info_vram_gb() {
423 let info = GpuInfo {
424 device_id: 0,
425 vendor: GpuVendor::Nvidia,
426 name: "Test GPU".to_string(),
427 vram_bytes: 24 * 1024 * 1024 * 1024, vram_free_bytes: 20 * 1024 * 1024 * 1024,
429 compute_capability: None,
430 driver_version: None,
431 };
432
433 assert!((info.vram_gb() - 24.0).abs() < 0.01);
434 assert!((info.vram_free_gb() - 20.0).abs() < 0.01);
435 }
436
437 #[test]
438 fn test_gpu_info_utilization() {
439 let info = GpuInfo {
440 device_id: 0,
441 vendor: GpuVendor::Nvidia,
442 name: "Test GPU".to_string(),
443 vram_bytes: 10 * 1024 * 1024 * 1024, vram_free_bytes: 4 * 1024 * 1024 * 1024, compute_capability: None,
446 driver_version: None,
447 };
448
449 assert!((info.vram_utilization() - 0.6).abs() < 0.01);
451 }
452
453 #[test]
454 fn test_gpu_info_utilization_zero_vram() {
455 let info = GpuInfo {
456 device_id: 0,
457 vendor: GpuVendor::Unknown,
458 name: "Test GPU".to_string(),
459 vram_bytes: 0,
460 vram_free_bytes: 0,
461 compute_capability: None,
462 driver_version: None,
463 };
464
465 assert_eq!(info.vram_utilization(), 0.0);
467 }
468
469 #[test]
470 fn test_detection_result_primary() {
471 let result = GpuDetectionResult {
472 gpus: vec![
473 GpuInfo {
474 device_id: 0,
475 vendor: GpuVendor::Nvidia,
476 name: "GPU 0".to_string(),
477 vram_bytes: 24 * 1024 * 1024 * 1024,
478 vram_free_bytes: 24 * 1024 * 1024 * 1024,
479 compute_capability: Some("8.9".to_string()),
480 driver_version: None,
481 },
482 GpuInfo {
483 device_id: 1,
484 vendor: GpuVendor::Nvidia,
485 name: "GPU 1".to_string(),
486 vram_bytes: 24 * 1024 * 1024 * 1024,
487 vram_free_bytes: 24 * 1024 * 1024 * 1024,
488 compute_capability: Some("8.9".to_string()),
489 driver_version: None,
490 },
491 ],
492 total_vram_bytes: 48 * 1024 * 1024 * 1024,
493 detection_method: DetectionMethod::NvidiaSmi,
494 };
495
496 assert!(result.has_gpu());
497 assert_eq!(result.primary().map(|g| g.device_id), Some(0));
498 assert!((result.total_vram_gb() - 48.0).abs() < 0.01);
499 }
500
501 #[test]
502 fn test_detection_result_none() {
503 let result = GpuDetectionResult::none();
504
505 assert!(!result.has_gpu());
506 assert!(result.primary().is_none());
507 assert_eq!(result.total_vram_bytes, 0);
508 }
509
510 #[test]
511 fn test_detector_fallback_on_failure() {
512 let detector = GpuDetector::new();
513 let default_vram = 8 * 1024 * 1024 * 1024; let result = detector.detect_or_default(default_vram);
516
517 assert!(!result.gpus.is_empty());
519
520 if result.detection_method == DetectionMethod::None {
522 assert_eq!(result.total_vram_bytes, default_vram);
523 }
524 }
525
526 #[test]
527 fn test_detector_nvidia_parsing() {
528 let sample_line = "0, NVIDIA GeForce RTX 4090, 24564, 23000, 545.23.08, 8.9";
530 let parts: Vec<&str> = sample_line.split(',').map(|s| s.trim()).collect();
531
532 assert_eq!(parts[0], "0");
533 assert_eq!(parts[1], "NVIDIA GeForce RTX 4090");
534 assert_eq!(parts[2].parse::<u64>().ok(), Some(24564)); assert_eq!(parts[3].parse::<u64>().ok(), Some(23000)); assert_eq!(parts[4], "545.23.08");
537 assert_eq!(parts[5], "8.9");
538 }
539
540 #[test]
541 fn test_gpu_vendor_equality() {
542 assert_eq!(GpuVendor::Nvidia, GpuVendor::Nvidia);
543 assert_ne!(GpuVendor::Nvidia, GpuVendor::Amd);
544 }
545
546 #[test]
547 fn test_detector_with_timeout() {
548 let detector = GpuDetector::with_timeout(Duration::from_secs(10));
549 assert_eq!(detector.timeout, Duration::from_secs(10));
550 }
551
552 #[test]
554 fn test_nvidia_detection_real() {
555 let detector = GpuDetector::new();
556
557 match detector.detect_nvidia() {
559 Ok(result) => {
560 for gpu in &result.gpus {
562 assert_eq!(gpu.vendor, GpuVendor::Nvidia);
563 assert!(gpu.vram_bytes > 0);
564 assert!(gpu.vram_free_bytes <= gpu.vram_bytes);
565 assert!(!gpu.name.is_empty());
566 }
567 assert_eq!(result.detection_method, DetectionMethod::NvidiaSmi);
568 },
569 Err(GpuDetectionError::CommandFailed(_)) => {
570 },
572 Err(e) => {
573 panic!("Unexpected error: {}", e);
574 },
575 }
576 }
577
578 #[test]
579 fn test_detect_all_graceful() {
580 let detector = GpuDetector::new();
581
582 match detector.detect() {
584 Ok(result) => {
585 assert!(result.has_gpu());
586 assert!(result.total_vram_bytes > 0);
587 },
588 Err(GpuDetectionError::NoGpu) => {
589 },
591 Err(e) => {
592 panic!("Unexpected detection error: {}", e);
594 },
595 }
596 }
597}