Skip to main content

entrenar/monitor/gpu/
monitor.rs

1//! GPU monitor that collects metrics.
2//!
3//! Uses NVML when available (feature `nvml`), otherwise falls back to mock.
4
5use super::{GpuMetrics, GpuProcess};
6
7#[cfg(feature = "nvml")]
8use nvml_wrapper::{enum_wrappers::device::TemperatureSensor, Nvml};
9
10#[cfg(feature = "nvml")]
11use std::fs;
12
13/// GPU monitor that collects metrics
14///
15/// When compiled with `nvml` feature, uses real NVIDIA NVML for hardware metrics.
16/// Otherwise provides mock mode for testing.
17#[derive(Debug)]
18pub struct GpuMonitor {
19    /// Number of detected devices
20    num_devices: u32,
21    /// Mock mode for testing
22    mock_mode: bool,
23    /// Mock metrics generator
24    mock_metrics: Vec<GpuMetrics>,
25    /// NVML instance (when feature enabled)
26    #[cfg(feature = "nvml")]
27    nvml: Option<Nvml>,
28}
29
30impl GpuMonitor {
31    /// Create a new GPU monitor
32    ///
33    /// Attempts to initialize NVML if feature enabled, falls back gracefully.
34    #[cfg(feature = "nvml")]
35    pub fn new() -> Result<Self, String> {
36        match Nvml::init() {
37            Ok(nvml) => {
38                let num_devices = nvml.device_count().unwrap_or(0);
39                Ok(Self {
40                    num_devices,
41                    mock_mode: false,
42                    mock_metrics: Vec::new(),
43                    nvml: Some(nvml),
44                })
45            }
46            Err(e) => {
47                eprintln!("[GpuMonitor] NVML init failed: {e}, using mock mode");
48                Ok(Self { num_devices: 0, mock_mode: false, mock_metrics: Vec::new(), nvml: None })
49            }
50        }
51    }
52
53    /// Create a new GPU monitor (non-NVML fallback)
54    #[cfg(not(feature = "nvml"))]
55    pub fn new() -> Result<Self, String> {
56        // Without NVML feature, return empty (graceful degradation)
57        Ok(Self { num_devices: 0, mock_mode: false, mock_metrics: Vec::new() })
58    }
59
60    /// Create a mock GPU monitor for testing
61    pub fn mock(num_devices: u32) -> Self {
62        let mock_metrics = (0..num_devices).map(GpuMetrics::mock).collect();
63        Self {
64            num_devices,
65            mock_mode: true,
66            mock_metrics,
67            #[cfg(feature = "nvml")]
68            nvml: None,
69        }
70    }
71
72    /// Get number of detected devices
73    pub fn num_devices(&self) -> u32 {
74        self.num_devices
75    }
76
77    /// Check if running in mock mode
78    pub fn is_mock(&self) -> bool {
79        self.mock_mode
80    }
81
82    /// Sample current GPU metrics
83    #[cfg(feature = "nvml")]
84    pub fn sample(&self) -> Vec<GpuMetrics> {
85        contract_pre_sample!();
86        if self.mock_mode {
87            return self.mock_metrics.clone();
88        }
89
90        let Some(nvml) = &self.nvml else {
91            return Vec::new();
92        };
93
94        let mut metrics = Vec::with_capacity(self.num_devices as usize);
95
96        for i in 0..self.num_devices {
97            let Ok(device) = nvml.device_by_index(i) else {
98                continue;
99            };
100
101            let name = device.name().unwrap_or_else(|_err| format!("GPU {i}"));
102
103            // Utilization rates
104            let (utilization_percent, memory_utilization_percent) =
105                device.utilization_rates().map_or((0, 0), |rates| (rates.gpu, rates.memory));
106
107            // Memory info
108            let (memory_used_mb, memory_total_mb) = device
109                .memory_info()
110                .map_or((0, 0), |mem| (mem.used / (1024 * 1024), mem.total / (1024 * 1024)));
111
112            // Temperature
113            let temperature_celsius = device.temperature(TemperatureSensor::Gpu).unwrap_or(0);
114
115            // Power
116            let power_watts = device.power_usage().map_or(0.0, |mw| mw as f32 / 1000.0);
117            let power_limit_watts =
118                device.enforced_power_limit().map_or(0.0, |mw| mw as f32 / 1000.0);
119
120            // Clocks
121            let clock_mhz = device
122                .clock_info(nvml_wrapper::enum_wrappers::device::Clock::Graphics)
123                .unwrap_or(0);
124            let memory_clock_mhz =
125                device.clock_info(nvml_wrapper::enum_wrappers::device::Clock::Memory).unwrap_or(0);
126
127            // PCIe throughput
128            let pcie_tx_kbps = u64::from(
129                device
130                    .pcie_throughput(nvml_wrapper::enum_wrappers::device::PcieUtilCounter::Send)
131                    .unwrap_or(0),
132            );
133            let pcie_rx_kbps = u64::from(
134                device
135                    .pcie_throughput(nvml_wrapper::enum_wrappers::device::PcieUtilCounter::Receive)
136                    .unwrap_or(0),
137            );
138
139            // Fan speed (may not be available on all GPUs)
140            let fan_speed_percent = device.fan_speed(0).unwrap_or(0);
141
142            // Collect running compute processes
143            let processes = Self::collect_gpu_processes(&device);
144
145            metrics.push(GpuMetrics {
146                device_id: i,
147                name,
148                utilization_percent,
149                memory_used_mb,
150                memory_total_mb,
151                memory_utilization_percent,
152                temperature_celsius,
153                power_watts,
154                power_limit_watts,
155                clock_mhz,
156                memory_clock_mhz,
157                pcie_tx_kbps,
158                pcie_rx_kbps,
159                fan_speed_percent,
160                processes,
161            });
162        }
163
164        metrics
165    }
166
167    /// Sample current GPU metrics (non-NVML fallback)
168    #[cfg(not(feature = "nvml"))]
169    pub fn sample(&self) -> Vec<GpuMetrics> {
170        if self.mock_mode {
171            return self.mock_metrics.clone();
172        }
173
174        // Production implementation would query NVML here
175        Vec::new()
176    }
177
178    /// Sample with simulated variation (for testing)
179    pub fn sample_with_variation(&mut self, variation: f32) -> Vec<GpuMetrics> {
180        if !self.mock_mode {
181            return Vec::new();
182        }
183
184        self.mock_metrics
185            .iter()
186            .map(|base| {
187                let mut m = base.clone();
188                let var = (variation * 10.0) as i32;
189                m.utilization_percent = (m.utilization_percent as i32 + var).clamp(0, 100) as u32;
190                m.temperature_celsius =
191                    (m.temperature_celsius as i32 + var / 2).clamp(30, 100) as u32;
192                m.power_watts = (m.power_watts + variation * 20.0).clamp(0.0, m.power_limit_watts);
193                m
194            })
195            .collect()
196    }
197
198    /// Set mock metrics (for testing specific scenarios)
199    pub fn set_mock_metrics(&mut self, metrics: Vec<GpuMetrics>) {
200        self.mock_metrics = metrics;
201        self.num_devices = self.mock_metrics.len() as u32;
202        self.mock_mode = true;
203    }
204
205    /// Collect GPU processes from NVML and enrich with /proc data
206    #[cfg(feature = "nvml")]
207    fn collect_gpu_processes(device: &nvml_wrapper::Device<'_>) -> Vec<GpuProcess> {
208        use nvml_wrapper::enums::device::UsedGpuMemory;
209
210        let mut processes = Vec::new();
211
212        // Helper to extract memory from UsedGpuMemory enum
213        let extract_memory = |mem: UsedGpuMemory| -> u64 {
214            match mem {
215                UsedGpuMemory::Used(bytes) => bytes / (1024 * 1024),
216                UsedGpuMemory::Unavailable => 0,
217            }
218        };
219
220        // Get compute processes (CUDA apps)
221        if let Ok(compute_procs) = device.running_compute_processes() {
222            for proc in compute_procs {
223                let pid = proc.pid;
224                let gpu_memory_mb = extract_memory(proc.used_gpu_memory);
225
226                // Read /proc/PID/exe for full path
227                let exe_path = fs::read_link(format!("/proc/{pid}/exe"))
228                    .map_or_else(|_| format!("[pid {pid}]"), |p| p.to_string_lossy().to_string());
229
230                // Read /proc/PID/stat for CPU and memory
231                let (cpu_percent, rss_mb) = Self::read_proc_stats(pid);
232
233                processes.push(GpuProcess { pid, exe_path, gpu_memory_mb, cpu_percent, rss_mb });
234            }
235        }
236
237        // Also check graphics processes
238        if let Ok(graphics_procs) = device.running_graphics_processes() {
239            for proc in graphics_procs {
240                // Skip if already in compute list
241                if processes.iter().any(|p| p.pid == proc.pid) {
242                    continue;
243                }
244
245                let pid = proc.pid;
246                let gpu_memory_mb = extract_memory(proc.used_gpu_memory);
247
248                let exe_path = fs::read_link(format!("/proc/{pid}/exe"))
249                    .map_or_else(|_| format!("[pid {pid}]"), |p| p.to_string_lossy().to_string());
250
251                let (cpu_percent, rss_mb) = Self::read_proc_stats(pid);
252
253                processes.push(GpuProcess { pid, exe_path, gpu_memory_mb, cpu_percent, rss_mb });
254            }
255        }
256
257        processes
258    }
259
260    /// Read CPU% and RSS from /proc/PID/stat and /proc/PID/statm
261    #[cfg(feature = "nvml")]
262    fn read_proc_stats(pid: u32) -> (f32, u64) {
263        // Read RSS from /proc/PID/statm (second field, in pages)
264        let rss_mb = fs::read_to_string(format!("/proc/{pid}/statm"))
265            .ok()
266            .and_then(|s| s.split_whitespace().nth(1)?.parse::<u64>().ok())
267            .map_or(0, |pages| pages * 4096 / (1024 * 1024));
268
269        // CPU% would require sampling over time - approximate from /proc/PID/stat
270        // For now, read utime+stime and estimate based on uptime
271        let cpu_percent = fs::read_to_string(format!("/proc/{pid}/stat"))
272            .ok()
273            .and_then(|s| {
274                let fields: Vec<&str> = s.split_whitespace().collect();
275                if fields.len() > 14 {
276                    let utime: u64 = fields[13].parse().ok()?;
277                    let stime: u64 = fields[14].parse().ok()?;
278                    let total_ticks = utime + stime;
279                    // Rough approximation: assume 100 ticks/sec, sample over 1 sec
280                    // This is imprecise but gives an order of magnitude
281                    Some((total_ticks as f32 / 100.0).min(100.0))
282                } else {
283                    None
284                }
285            })
286            .unwrap_or(0.0);
287
288        (cpu_percent, rss_mb)
289    }
290
291    /// Collect GPU processes (non-NVML fallback)
292    #[cfg(not(feature = "nvml"))]
293    #[allow(dead_code)]
294    fn collect_gpu_processes(_device: &()) -> Vec<GpuProcess> {
295        Vec::new()
296    }
297}
298
299impl Default for GpuMonitor {
300    fn default() -> Self {
301        Self::new().unwrap_or_else(|_err| Self::mock(0))
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_gpu_monitor_new() {
311        let monitor = GpuMonitor::new();
312        assert!(monitor.is_ok());
313    }
314
315    #[test]
316    fn test_gpu_monitor_mock() {
317        let monitor = GpuMonitor::mock(2);
318        assert_eq!(monitor.num_devices(), 2);
319        assert!(monitor.is_mock());
320    }
321
322    #[test]
323    fn test_gpu_monitor_sample_mock() {
324        let monitor = GpuMonitor::mock(2);
325        let metrics = monitor.sample();
326        assert_eq!(metrics.len(), 2);
327        assert_eq!(metrics[0].device_id, 0);
328        assert_eq!(metrics[1].device_id, 1);
329    }
330
331    #[test]
332    fn test_gpu_monitor_sample_with_variation() {
333        let mut monitor = GpuMonitor::mock(1);
334        let base = monitor.sample()[0].utilization_percent;
335
336        let varied = monitor.sample_with_variation(1.0);
337        // Variation should change the value
338        assert!(varied[0].utilization_percent != base || base == 100 || base == 0);
339    }
340
341    #[test]
342    fn test_gpu_monitor_set_mock_metrics() {
343        let mut monitor = GpuMonitor::mock(0);
344        monitor.set_mock_metrics(vec![GpuMetrics {
345            device_id: 5,
346            utilization_percent: 99,
347            ..Default::default()
348        }]);
349
350        let metrics = monitor.sample();
351        assert_eq!(metrics.len(), 1);
352        assert_eq!(metrics[0].device_id, 5);
353        assert_eq!(metrics[0].utilization_percent, 99);
354    }
355
356    #[test]
357    fn test_gpu_monitor_default() {
358        let monitor = GpuMonitor::default();
359        // Should either be real or mock, but should work
360        let _ = monitor.num_devices();
361    }
362
363    #[test]
364    fn test_gpu_monitor_non_mock_sample() {
365        // Create non-mock monitor but call sample (should return empty or real data)
366        let monitor = GpuMonitor::new().expect("operation should succeed");
367        let metrics = monitor.sample();
368        // Just verify it doesn't crash - may be empty if no NVML
369        let _ = metrics;
370    }
371
372    #[test]
373    fn test_gpu_monitor_non_mock_sample_with_variation() {
374        let mut monitor = GpuMonitor::new().expect("operation should succeed");
375        // Non-mock mode should return empty
376        let metrics = monitor.sample_with_variation(1.0);
377        // Should be empty for non-mock mode
378        assert!(metrics.is_empty() || !monitor.is_mock());
379    }
380
381    #[cfg(feature = "nvml")]
382    #[test]
383    fn test_gpu_monitor_nvml_sample() {
384        // Test real NVML sampling if available
385        let monitor = GpuMonitor::new().expect("operation should succeed");
386        if monitor.num_devices() > 0 {
387            let metrics = monitor.sample();
388            assert!(!metrics.is_empty());
389            // Verify basic sanity
390            for m in &metrics {
391                assert!(m.utilization_percent <= 100);
392                assert!(m.temperature_celsius < 150);
393            }
394        }
395    }
396}