Skip to main content

infernum_arbiter/
gpu.rs

1//! GPU detection and information gathering.
2//!
3//! Detects available GPUs and their capabilities via:
4//! 1. nvidia-smi (NVIDIA GPUs)
5//! 2. rocm-smi (AMD GPUs)
6//! 3. System fallback (generic detection)
7
8use std::process::{Command, Stdio};
9use std::time::{Duration, Instant};
10use thiserror::Error;
11
12/// Errors from GPU detection.
13#[derive(Debug, Error)]
14pub enum GpuDetectionError {
15    /// No GPU detected.
16    #[error("No GPU detected")]
17    NoGpu,
18
19    /// Detection command failed.
20    #[error("Detection command failed: {0}")]
21    CommandFailed(String),
22
23    /// Failed to parse GPU information.
24    #[error("Failed to parse GPU info: {0}")]
25    ParseError(String),
26
27    /// Timeout during detection.
28    #[error("Detection timed out after {0:?}")]
29    Timeout(Duration),
30}
31
32/// Result type for GPU detection operations.
33pub type Result<T> = std::result::Result<T, GpuDetectionError>;
34
35/// GPU vendor.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum GpuVendor {
38    /// NVIDIA GPU.
39    Nvidia,
40    /// AMD GPU.
41    Amd,
42    /// Intel GPU.
43    Intel,
44    /// Apple Silicon.
45    Apple,
46    /// Unknown vendor.
47    Unknown,
48}
49
50/// Information about a detected GPU.
51#[derive(Debug, Clone)]
52pub struct GpuInfo {
53    /// Device index (0-based).
54    pub device_id: u32,
55
56    /// GPU vendor.
57    pub vendor: GpuVendor,
58
59    /// GPU name/model.
60    pub name: String,
61
62    /// Total VRAM in bytes.
63    pub vram_bytes: u64,
64
65    /// Free VRAM in bytes (at detection time).
66    pub vram_free_bytes: u64,
67
68    /// Compute capability (NVIDIA) or architecture info.
69    pub compute_capability: Option<String>,
70
71    /// Driver version.
72    pub driver_version: Option<String>,
73}
74
75impl GpuInfo {
76    /// Returns VRAM in gigabytes.
77    pub fn vram_gb(&self) -> f64 {
78        self.vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
79    }
80
81    /// Returns free VRAM in gigabytes.
82    pub fn vram_free_gb(&self) -> f64 {
83        self.vram_free_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
84    }
85
86    /// Returns VRAM utilization (0.0 - 1.0).
87    pub fn vram_utilization(&self) -> f32 {
88        if self.vram_bytes == 0 {
89            return 0.0;
90        }
91        let used = self.vram_bytes.saturating_sub(self.vram_free_bytes);
92        used as f32 / self.vram_bytes as f32
93    }
94}
95
96/// GPU detection result.
97#[derive(Debug, Clone)]
98pub struct GpuDetectionResult {
99    /// Detected GPUs.
100    pub gpus: Vec<GpuInfo>,
101
102    /// Total VRAM across all GPUs.
103    pub total_vram_bytes: u64,
104
105    /// Detection method used.
106    pub detection_method: DetectionMethod,
107}
108
109impl GpuDetectionResult {
110    /// Returns the primary (first) GPU.
111    pub fn primary(&self) -> Option<&GpuInfo> {
112        self.gpus.first()
113    }
114
115    /// Returns total VRAM in GB.
116    pub fn total_vram_gb(&self) -> f64 {
117        self.total_vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
118    }
119
120    /// Returns true if any GPU was detected.
121    pub fn has_gpu(&self) -> bool {
122        !self.gpus.is_empty()
123    }
124
125    /// Creates an empty result (no GPUs).
126    pub fn none() -> Self {
127        Self {
128            gpus: vec![],
129            total_vram_bytes: 0,
130            detection_method: DetectionMethod::None,
131        }
132    }
133}
134
135/// Method used for GPU detection.
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137pub enum DetectionMethod {
138    /// NVIDIA nvidia-smi.
139    NvidiaSmi,
140    /// AMD rocm-smi.
141    RocmSmi,
142    /// Apple Metal.
143    AppleMetal,
144    /// Generic system detection.
145    System,
146    /// No detection performed.
147    None,
148}
149
150/// GPU detector that tries multiple detection methods.
151pub struct GpuDetector {
152    /// Timeout for detection commands.
153    timeout: Duration,
154}
155
156impl Default for GpuDetector {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162impl GpuDetector {
163    /// Creates a new GPU detector with a 5-second timeout.
164    pub fn new() -> Self {
165        Self {
166            timeout: Duration::from_secs(5),
167        }
168    }
169
170    /// Creates a detector with custom timeout.
171    pub fn with_timeout(timeout: Duration) -> Self {
172        Self { timeout }
173    }
174
175    /// Runs a command with the configured timeout.
176    ///
177    /// Spawns the process and polls `try_wait` until it exits or the timeout
178    /// elapses. On timeout the child is killed and `GpuDetectionError::Timeout`
179    /// is returned.
180    fn run_with_timeout(&self, cmd: &mut Command) -> Result<std::process::Output> {
181        let mut child = cmd
182            .stdout(Stdio::piped())
183            .stderr(Stdio::piped())
184            .spawn()
185            .map_err(|e| GpuDetectionError::CommandFailed(e.to_string()))?;
186
187        let start = Instant::now();
188        loop {
189            match child.try_wait() {
190                Ok(Some(_)) => {
191                    // Child exited — collect stdout/stderr.
192                    return child
193                        .wait_with_output()
194                        .map_err(|e| GpuDetectionError::CommandFailed(e.to_string()));
195                },
196                Ok(None) => {
197                    if start.elapsed() >= self.timeout {
198                        let _ = child.kill();
199                        let _ = child.wait(); // Reap zombie.
200                        return Err(GpuDetectionError::Timeout(self.timeout));
201                    }
202                    std::thread::sleep(Duration::from_millis(50));
203                },
204                Err(e) => {
205                    return Err(GpuDetectionError::CommandFailed(e.to_string()));
206                },
207            }
208        }
209    }
210
211    /// Detects all available GPUs.
212    pub fn detect(&self) -> Result<GpuDetectionResult> {
213        // Try NVIDIA first (most common for ML)
214        if let Ok(result) = self.detect_nvidia() {
215            if result.has_gpu() {
216                return Ok(result);
217            }
218        }
219
220        // Try AMD
221        if let Ok(result) = self.detect_amd() {
222            if result.has_gpu() {
223                return Ok(result);
224            }
225        }
226
227        // Try Apple Metal on macOS
228        #[cfg(target_os = "macos")]
229        if let Ok(result) = self.detect_apple() {
230            if result.has_gpu() {
231                return Ok(result);
232            }
233        }
234
235        // No GPU found
236        Err(GpuDetectionError::NoGpu)
237    }
238
239    /// Detects GPUs with fallback to default config on failure.
240    pub fn detect_or_default(&self, default_vram_bytes: u64) -> GpuDetectionResult {
241        match self.detect() {
242            Ok(result) => result,
243            Err(_) => GpuDetectionResult {
244                gpus: vec![GpuInfo {
245                    device_id: 0,
246                    vendor: GpuVendor::Unknown,
247                    name: "Unknown GPU".to_string(),
248                    vram_bytes: default_vram_bytes,
249                    vram_free_bytes: default_vram_bytes,
250                    compute_capability: None,
251                    driver_version: None,
252                }],
253                total_vram_bytes: default_vram_bytes,
254                detection_method: DetectionMethod::None,
255            },
256        }
257    }
258
259    /// Detects NVIDIA GPUs using nvidia-smi.
260    fn detect_nvidia(&self) -> Result<GpuDetectionResult> {
261        let output = self.run_with_timeout(Command::new("nvidia-smi").args([
262            "--query-gpu=index,name,memory.total,memory.free,driver_version,compute_cap",
263            "--format=csv,noheader,nounits",
264        ]))?;
265
266        if !output.status.success() {
267            return Err(GpuDetectionError::CommandFailed(
268                String::from_utf8_lossy(&output.stderr).to_string(),
269            ));
270        }
271
272        let stdout = String::from_utf8_lossy(&output.stdout);
273        let mut gpus = Vec::new();
274        let mut total_vram = 0u64;
275
276        for line in stdout.lines() {
277            if line.trim().is_empty() {
278                continue;
279            }
280
281            let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
282            if parts.len() < 4 {
283                continue;
284            }
285
286            let device_id = parts[0]
287                .parse::<u32>()
288                .map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
289
290            let name = parts[1].to_string();
291
292            // nvidia-smi reports memory in MiB
293            let vram_mib = parts[2]
294                .parse::<u64>()
295                .map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
296            let vram_bytes = vram_mib * 1024 * 1024;
297
298            let vram_free_mib = parts[3]
299                .parse::<u64>()
300                .map_err(|e| GpuDetectionError::ParseError(e.to_string()))?;
301            let vram_free_bytes = vram_free_mib * 1024 * 1024;
302
303            let driver_version = parts.get(4).map(|s| s.to_string());
304            let compute_capability = parts.get(5).map(|s| s.to_string());
305
306            total_vram += vram_bytes;
307
308            gpus.push(GpuInfo {
309                device_id,
310                vendor: GpuVendor::Nvidia,
311                name,
312                vram_bytes,
313                vram_free_bytes,
314                compute_capability,
315                driver_version,
316            });
317        }
318
319        Ok(GpuDetectionResult {
320            gpus,
321            total_vram_bytes: total_vram,
322            detection_method: DetectionMethod::NvidiaSmi,
323        })
324    }
325
326    /// Detects AMD GPUs using rocm-smi.
327    fn detect_amd(&self) -> Result<GpuDetectionResult> {
328        let output = self.run_with_timeout(Command::new("rocm-smi").args([
329            "--showmeminfo",
330            "vram",
331            "--json",
332        ]))?;
333
334        if !output.status.success() {
335            return Err(GpuDetectionError::CommandFailed(
336                String::from_utf8_lossy(&output.stderr).to_string(),
337            ));
338        }
339
340        // ROCm SMI outputs JSON, but for now we'll return a simple fallback
341        // Full implementation would parse the JSON
342        let stdout = String::from_utf8_lossy(&output.stdout);
343
344        // Basic parsing - look for memory values
345        let mut gpus = Vec::new();
346        let mut total_vram = 0u64;
347
348        // Simplified: if rocm-smi succeeded, assume we have at least one AMD GPU
349        // A real implementation would parse the JSON properly
350        if stdout.contains("card") || stdout.contains("GPU") {
351            gpus.push(GpuInfo {
352                device_id: 0,
353                vendor: GpuVendor::Amd,
354                name: "AMD GPU".to_string(),
355                vram_bytes: 16 * 1024 * 1024 * 1024, // Default 16GB
356                vram_free_bytes: 16 * 1024 * 1024 * 1024,
357                compute_capability: None,
358                driver_version: None,
359            });
360            total_vram = 16 * 1024 * 1024 * 1024;
361        }
362
363        Ok(GpuDetectionResult {
364            gpus,
365            total_vram_bytes: total_vram,
366            detection_method: DetectionMethod::RocmSmi,
367        })
368    }
369
370    /// Detects Apple Metal GPUs on macOS.
371    #[cfg(target_os = "macos")]
372    fn detect_apple(&self) -> Result<GpuDetectionResult> {
373        // Use system_profiler to get GPU info
374        let output = self.run_with_timeout(
375            Command::new("system_profiler").args(["SPDisplaysDataType", "-json"]),
376        )?;
377
378        if !output.status.success() {
379            return Err(GpuDetectionError::CommandFailed(
380                String::from_utf8_lossy(&output.stderr).to_string(),
381            ));
382        }
383
384        // For Apple Silicon, unified memory is shared
385        // We'll estimate GPU portion as ~75% of total RAM
386        let sysctl_output =
387            self.run_with_timeout(Command::new("sysctl").args(["-n", "hw.memsize"]))?;
388
389        let total_ram = String::from_utf8_lossy(&sysctl_output.stdout)
390            .trim()
391            .parse::<u64>()
392            .unwrap_or(16 * 1024 * 1024 * 1024);
393
394        // Assume 75% of unified memory available for GPU
395        let gpu_memory = (total_ram as f64 * 0.75) as u64;
396
397        Ok(GpuDetectionResult {
398            gpus: vec![GpuInfo {
399                device_id: 0,
400                vendor: GpuVendor::Apple,
401                name: "Apple Silicon GPU".to_string(),
402                vram_bytes: gpu_memory,
403                vram_free_bytes: gpu_memory,
404                compute_capability: None,
405                driver_version: None,
406            }],
407            total_vram_bytes: gpu_memory,
408            detection_method: DetectionMethod::AppleMetal,
409        })
410    }
411}
412
413// ============================================================================
414// Tests
415// ============================================================================
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_gpu_info_vram_gb() {
423        let info = GpuInfo {
424            device_id: 0,
425            vendor: GpuVendor::Nvidia,
426            name: "Test GPU".to_string(),
427            vram_bytes: 24 * 1024 * 1024 * 1024, // 24 GB
428            vram_free_bytes: 20 * 1024 * 1024 * 1024,
429            compute_capability: None,
430            driver_version: None,
431        };
432
433        assert!((info.vram_gb() - 24.0).abs() < 0.01);
434        assert!((info.vram_free_gb() - 20.0).abs() < 0.01);
435    }
436
437    #[test]
438    fn test_gpu_info_utilization() {
439        let info = GpuInfo {
440            device_id: 0,
441            vendor: GpuVendor::Nvidia,
442            name: "Test GPU".to_string(),
443            vram_bytes: 10 * 1024 * 1024 * 1024, // 10 GB total
444            vram_free_bytes: 4 * 1024 * 1024 * 1024, // 4 GB free = 6 GB used
445            compute_capability: None,
446            driver_version: None,
447        };
448
449        // 6/10 = 60% utilization
450        assert!((info.vram_utilization() - 0.6).abs() < 0.01);
451    }
452
453    #[test]
454    fn test_gpu_info_utilization_zero_vram() {
455        let info = GpuInfo {
456            device_id: 0,
457            vendor: GpuVendor::Unknown,
458            name: "Test GPU".to_string(),
459            vram_bytes: 0,
460            vram_free_bytes: 0,
461            compute_capability: None,
462            driver_version: None,
463        };
464
465        // Should handle zero gracefully
466        assert_eq!(info.vram_utilization(), 0.0);
467    }
468
469    #[test]
470    fn test_detection_result_primary() {
471        let result = GpuDetectionResult {
472            gpus: vec![
473                GpuInfo {
474                    device_id: 0,
475                    vendor: GpuVendor::Nvidia,
476                    name: "GPU 0".to_string(),
477                    vram_bytes: 24 * 1024 * 1024 * 1024,
478                    vram_free_bytes: 24 * 1024 * 1024 * 1024,
479                    compute_capability: Some("8.9".to_string()),
480                    driver_version: None,
481                },
482                GpuInfo {
483                    device_id: 1,
484                    vendor: GpuVendor::Nvidia,
485                    name: "GPU 1".to_string(),
486                    vram_bytes: 24 * 1024 * 1024 * 1024,
487                    vram_free_bytes: 24 * 1024 * 1024 * 1024,
488                    compute_capability: Some("8.9".to_string()),
489                    driver_version: None,
490                },
491            ],
492            total_vram_bytes: 48 * 1024 * 1024 * 1024,
493            detection_method: DetectionMethod::NvidiaSmi,
494        };
495
496        assert!(result.has_gpu());
497        assert_eq!(result.primary().map(|g| g.device_id), Some(0));
498        assert!((result.total_vram_gb() - 48.0).abs() < 0.01);
499    }
500
501    #[test]
502    fn test_detection_result_none() {
503        let result = GpuDetectionResult::none();
504
505        assert!(!result.has_gpu());
506        assert!(result.primary().is_none());
507        assert_eq!(result.total_vram_bytes, 0);
508    }
509
510    #[test]
511    fn test_detector_fallback_on_failure() {
512        let detector = GpuDetector::new();
513        let default_vram = 8 * 1024 * 1024 * 1024; // 8 GB
514
515        let result = detector.detect_or_default(default_vram);
516
517        // Should always return something, even if detection fails
518        assert!(!result.gpus.is_empty());
519
520        // If we're on a machine without a GPU, it should return the default
521        if result.detection_method == DetectionMethod::None {
522            assert_eq!(result.total_vram_bytes, default_vram);
523        }
524    }
525
526    #[test]
527    fn test_detector_nvidia_parsing() {
528        // Test parsing of nvidia-smi output format
529        let sample_line = "0, NVIDIA GeForce RTX 4090, 24564, 23000, 545.23.08, 8.9";
530        let parts: Vec<&str> = sample_line.split(',').map(|s| s.trim()).collect();
531
532        assert_eq!(parts[0], "0");
533        assert_eq!(parts[1], "NVIDIA GeForce RTX 4090");
534        assert_eq!(parts[2].parse::<u64>().ok(), Some(24564)); // MiB
535        assert_eq!(parts[3].parse::<u64>().ok(), Some(23000)); // MiB free
536        assert_eq!(parts[4], "545.23.08");
537        assert_eq!(parts[5], "8.9");
538    }
539
540    #[test]
541    fn test_gpu_vendor_equality() {
542        assert_eq!(GpuVendor::Nvidia, GpuVendor::Nvidia);
543        assert_ne!(GpuVendor::Nvidia, GpuVendor::Amd);
544    }
545
546    #[test]
547    fn test_detector_with_timeout() {
548        let detector = GpuDetector::with_timeout(Duration::from_secs(10));
549        assert_eq!(detector.timeout, Duration::from_secs(10));
550    }
551
552    // Integration test that actually calls nvidia-smi (if available)
553    #[test]
554    fn test_nvidia_detection_real() {
555        let detector = GpuDetector::new();
556
557        // This test may pass or fail depending on whether nvidia-smi is available
558        match detector.detect_nvidia() {
559            Ok(result) => {
560                // If NVIDIA GPUs are found, verify the data makes sense
561                for gpu in &result.gpus {
562                    assert_eq!(gpu.vendor, GpuVendor::Nvidia);
563                    assert!(gpu.vram_bytes > 0);
564                    assert!(gpu.vram_free_bytes <= gpu.vram_bytes);
565                    assert!(!gpu.name.is_empty());
566                }
567                assert_eq!(result.detection_method, DetectionMethod::NvidiaSmi);
568            },
569            Err(GpuDetectionError::CommandFailed(_)) => {
570                // nvidia-smi not available - acceptable in CI
571            },
572            Err(e) => {
573                panic!("Unexpected error: {}", e);
574            },
575        }
576    }
577
578    #[test]
579    fn test_detect_all_graceful() {
580        let detector = GpuDetector::new();
581
582        // detect() should either succeed or return NoGpu error
583        match detector.detect() {
584            Ok(result) => {
585                assert!(result.has_gpu());
586                assert!(result.total_vram_bytes > 0);
587            },
588            Err(GpuDetectionError::NoGpu) => {
589                // Expected on machines without GPU
590            },
591            Err(e) => {
592                // Other errors should not occur during normal operation
593                panic!("Unexpected detection error: {}", e);
594            },
595        }
596    }
597}