cortenforge_tools/
gpu_probe.rs

1use serde::Serialize;
2
3#[derive(Debug, Clone, Serialize)]
4pub struct GpuStatus {
5    pub available: bool,
6    pub utilization: Option<f64>,
7    pub mem_used_mb: Option<u64>,
8    pub vendor: Option<String>,
9    pub device_name: Option<String>,
10}
11
12impl GpuStatus {
13    pub fn unavailable() -> Self {
14        Self {
15            available: false,
16            utilization: None,
17            mem_used_mb: None,
18            vendor: None,
19            device_name: None,
20        }
21    }
22}
23
24pub trait GpuProbe {
25    fn status(&self) -> GpuStatus;
26}
27
28pub struct MacGpuProbe;
29
30#[cfg(feature = "gpu-nvidia")]
31pub struct NvidiaGpuProbe;
32
33pub struct FallbackGpuProbe;
34
35#[cfg(target_os = "linux")]
36pub struct LinuxGpuProbe;
37
38#[cfg(target_os = "windows")]
39pub struct WindowsGpuProbe;
40
41#[cfg(all(target_os = "windows", feature = "gpu-windows"))]
42pub struct AmdWindowsProbe;
43
44impl GpuProbe for MacGpuProbe {
45    fn status(&self) -> GpuStatus {
46        #[cfg(target_os = "macos")]
47        {
48            // TODO(#19): Implement Metal-based GPU stats for macOS
49            GpuStatus {
50                available: true,
51                utilization: Some(0.0),
52                mem_used_mb: None,
53                vendor: None,
54                device_name: None,
55            }
56        }
57
58        #[cfg(not(target_os = "macos"))]
59        {
60            GpuStatus::unavailable()
61        }
62    }
63}
64
65#[cfg(feature = "gpu-nvidia")]
66impl GpuProbe for NvidiaGpuProbe {
67    fn status(&self) -> GpuStatus {
68        use nvml_wrapper::Nvml;
69
70        let nvml = match Nvml::init() {
71            Ok(nvml) => nvml,
72            Err(_) => return GpuStatus::unavailable(),
73        };
74        let device = match nvml.device_by_index(0) {
75            Ok(device) => device,
76            Err(_) => return GpuStatus::unavailable(),
77        };
78
79        let utilization = device
80            .utilization_rates()
81            .ok()
82            .map(|rates| rates.gpu as f64);
83        let mem_used_mb = device
84            .memory_info()
85            .ok()
86            .map(|mem| mem.used / (1024 * 1024));
87
88        GpuStatus {
89            available: true,
90            utilization,
91            mem_used_mb,
92            vendor: Some("NVIDIA".to_string()),
93            device_name: device.name().ok(),
94        }
95    }
96}
97
98impl GpuProbe for FallbackGpuProbe {
99    fn status(&self) -> GpuStatus {
100        GpuStatus::unavailable()
101    }
102}
103
104#[cfg(target_os = "linux")]
105impl GpuProbe for LinuxGpuProbe {
106    fn status(&self) -> GpuStatus {
107        #[cfg(feature = "gpu-nvidia")]
108        {
109            let status = NvidiaGpuProbe.status();
110            if status.available {
111                return status;
112            }
113        }
114
115        if let Some(status) = nvidia_smi_status() {
116            return status;
117        }
118        if let Some(status) = amd_status() {
119            return status;
120        }
121        if let Some(status) = intel_status() {
122            return status;
123        }
124
125        GpuStatus::unavailable()
126    }
127}
128
129#[cfg(target_os = "windows")]
130impl GpuProbe for WindowsGpuProbe {
131    fn status(&self) -> GpuStatus {
132        #[cfg(feature = "gpu-nvidia")]
133        {
134            return NvidiaGpuProbe.status();
135        }
136
137        #[cfg(all(not(feature = "gpu-nvidia"), feature = "gpu-windows"))]
138        {
139            return AmdWindowsProbe.status();
140        }
141
142        #[cfg(all(not(feature = "gpu-nvidia"), not(feature = "gpu-windows")))]
143        {
144            return GpuStatus::unavailable();
145        }
146    }
147}
148
149#[cfg(all(target_os = "windows", feature = "gpu-windows"))]
150impl GpuProbe for AmdWindowsProbe {
151    fn status(&self) -> GpuStatus {
152        // TODO(#20): Implement WMI/DirectX GPU stats for Windows AMD
153        if let Some(status) = windows_wmi_status() {
154            return status;
155        }
156        GpuStatus::unavailable()
157    }
158}
159
160#[cfg(target_os = "windows")]
161fn windows_wmi_status() -> Option<GpuStatus> {
162    None
163}
164
165pub fn platform_probe() -> Box<dyn GpuProbe> {
166    #[cfg(target_os = "windows")]
167    {
168        return Box::new(WindowsGpuProbe);
169    }
170
171    #[cfg(target_os = "linux")]
172    {
173        Box::new(LinuxGpuProbe)
174    }
175
176    #[cfg(all(
177        feature = "gpu-nvidia",
178        not(target_os = "windows"),
179        not(target_os = "linux"),
180        not(target_os = "macos")
181    ))]
182    {
183        return Box::new(NvidiaGpuProbe);
184    }
185
186    #[cfg(target_os = "macos")]
187    {
188        Box::new(MacGpuProbe)
189    }
190
191    #[cfg(all(
192        not(target_os = "macos"),
193        not(target_os = "windows"),
194        not(target_os = "linux"),
195        not(feature = "gpu-nvidia")
196    ))]
197    {
198        return Box::new(FallbackGpuProbe);
199    }
200}
201
202#[cfg(target_os = "linux")]
203fn nvidia_smi_status() -> Option<GpuStatus> {
204    use std::process::Command;
205
206    let output = Command::new("nvidia-smi")
207        .args([
208            "--query-gpu=name,utilization.gpu,memory.used",
209            "--format=csv,noheader,nounits",
210        ])
211        .output()
212        .ok()?;
213    if !output.status.success() {
214        return None;
215    }
216    let line = String::from_utf8_lossy(&output.stdout);
217    let mut parts = line.lines().next()?.split(',');
218    let name = parts.next()?.trim().to_string();
219    let utilization = parts.next().and_then(|v| v.trim().parse::<f64>().ok());
220    let mem_used_mb = parts.next().and_then(|v| v.trim().parse::<u64>().ok());
221
222    Some(GpuStatus {
223        available: true,
224        utilization,
225        mem_used_mb,
226        vendor: Some("NVIDIA".to_string()),
227        device_name: Some(name),
228    })
229}
230
231#[cfg(target_os = "linux")]
232fn amd_status() -> Option<GpuStatus> {
233    use std::process::Command;
234
235    if let Ok(output) = Command::new("rocm-smi")
236        .args(["--showuse", "--json"])
237        .output()
238    {
239        if output.status.success() {
240            if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
241                if let Some(pct) = val
242                    .get("card")
243                    .and_then(|c| c.get(0))
244                    .and_then(|c| c.get("GPU use (%)"))
245                    .and_then(|v| v.as_f64())
246                {
247                    return Some(GpuStatus {
248                        available: true,
249                        utilization: Some(pct),
250                        mem_used_mb: amd_mem_used_mb(),
251                        vendor: Some("AMD".to_string()),
252                        device_name: None,
253                    });
254                }
255            }
256        }
257    }
258
259    let output = Command::new("rocm-smi")
260        .arg("--showuse")
261        .output()
262        .or_else(|_| Command::new("radeontop").arg("--help").output())
263        .ok()?;
264    if !output.status.success() {
265        return None;
266    }
267    let text = String::from_utf8_lossy(&output.stdout);
268    for line in text.lines() {
269        if line.to_ascii_lowercase().contains("gpu use") {
270            if let Some(pct) = line
271                .split('%')
272                .next()
273                .and_then(|s| s.split_whitespace().last())
274                .and_then(|n| n.parse::<f64>().ok())
275            {
276                return Some(GpuStatus {
277                    available: true,
278                    utilization: Some(pct),
279                    mem_used_mb: amd_mem_used_mb(),
280                    vendor: Some("AMD".to_string()),
281                    device_name: None,
282                });
283            }
284        }
285    }
286    None
287}
288
289#[cfg(target_os = "linux")]
290fn amd_mem_used_mb() -> Option<u64> {
291    use std::process::Command;
292
293    if let Ok(output) = Command::new("rocm-smi")
294        .args(["--showmeminfo", "vram", "--json"])
295        .output()
296    {
297        if output.status.success() {
298            if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
299                if let Some(used) = val
300                    .get("card")
301                    .and_then(|c| c.get(0))
302                    .and_then(|c| c.get("vram"))
303                    .and_then(|v| v.get("used (B)"))
304                    .and_then(|v| v.as_u64())
305                {
306                    return Some(used / (1024 * 1024));
307                }
308            }
309        }
310    }
311    None
312}
313
314#[cfg(target_os = "linux")]
315fn intel_status() -> Option<GpuStatus> {
316    use std::process::Command;
317
318    let output = Command::new("intel_gpu_top").arg("--json").output().ok()?;
319    if !output.status.success() {
320        return None;
321    }
322    let text = String::from_utf8_lossy(&output.stdout);
323    let mem_used_mb = intel_mem_json(&output.stdout).or_else(|| intel_mem_text(&text));
324    for line in text.lines() {
325        if line.to_ascii_lowercase().contains("render/3d") {
326            if let Some(pct) = line
327                .split('%')
328                .next()
329                .and_then(|s| s.split_whitespace().last())
330                .and_then(|n| n.parse::<f64>().ok())
331            {
332                return Some(GpuStatus {
333                    available: true,
334                    utilization: Some(pct),
335                    mem_used_mb,
336                    vendor: Some("Intel".to_string()),
337                    device_name: None,
338                });
339            }
340        }
341    }
342    None
343}
344
345#[cfg(target_os = "linux")]
346fn intel_mem_json(data: &[u8]) -> Option<u64> {
347    let val: serde_json::Value = serde_json::from_slice(data).ok()?;
348    find_mem_value(&val)
349}
350
351#[cfg(target_os = "linux")]
352fn find_mem_value(val: &serde_json::Value) -> Option<u64> {
353    match val {
354        serde_json::Value::Number(n) => n.as_u64(),
355        serde_json::Value::Object(map) => {
356            for (k, v) in map {
357                let key = k.to_ascii_lowercase();
358                if key.contains("mem") {
359                    if let Some(n) = v.as_u64() {
360                        if n > 10_000 {
361                            return Some(n / (1024 * 1024));
362                        }
363                        return Some(n);
364                    }
365                    if let Some(f) = v.as_f64() {
366                        if f > 10_000.0 {
367                            return Some((f / 1024.0 / 1024.0) as u64);
368                        }
369                        return Some(f as u64);
370                    }
371                }
372                if let Some(found) = find_mem_value(v) {
373                    return Some(found);
374                }
375            }
376            None
377        }
378        serde_json::Value::Array(arr) => {
379            for v in arr {
380                if let Some(found) = find_mem_value(v) {
381                    return Some(found);
382                }
383            }
384            None
385        }
386        _ => None,
387    }
388}
389
390#[cfg(target_os = "linux")]
391fn intel_mem_text(text: &str) -> Option<u64> {
392    for line in text.lines() {
393        let lower = line.to_ascii_lowercase();
394        if lower.contains("mem") {
395            if let Some(num) = line
396                .split_whitespace()
397                .map(|w| w.trim_end_matches(['%', 'm', 'M', 'b', 'B']))
398                .filter_map(|w| w.parse::<f64>().ok())
399                .next_back()
400            {
401                if num > 10_000.0 {
402                    return Some((num / 1024.0 / 1024.0) as u64);
403                }
404                return Some(num as u64);
405            }
406        }
407    }
408    None
409}
410
411pub fn to_json_value(status: &GpuStatus) -> serde_json::Value {
412    serde_json::to_value(status).unwrap_or(serde_json::Value::Null)
413}
414
415pub fn write_status_json(status: &GpuStatus) -> Result<(), serde_json::Error> {
416    serde_json::to_writer(std::io::stdout(), status)
417}
418
419#[cfg(test)]
420mod tests {
421    use super::{to_json_value, FallbackGpuProbe, GpuProbe, GpuStatus};
422
423    struct MockProbe {
424        status: GpuStatus,
425    }
426
427    impl GpuProbe for MockProbe {
428        fn status(&self) -> GpuStatus {
429            self.status.clone()
430        }
431    }
432
433    #[test]
434    fn status_serializes_with_expected_keys() {
435        let status = GpuStatus::unavailable();
436        let value = to_json_value(&status);
437        let obj = value.as_object().expect("status must be a JSON object");
438        for key in [
439            "available",
440            "utilization",
441            "mem_used_mb",
442            "vendor",
443            "device_name",
444        ] {
445            assert!(obj.contains_key(key), "missing key: {key}");
446        }
447    }
448
449    #[test]
450    fn fallback_reports_unavailable() {
451        let status = FallbackGpuProbe.status();
452        assert!(!status.available);
453        assert!(status.utilization.is_none());
454        assert!(status.mem_used_mb.is_none());
455    }
456
457    #[test]
458    fn mock_probe_serializes_full_status() {
459        let probe = MockProbe {
460            status: GpuStatus {
461                available: true,
462                utilization: Some(55.5),
463                mem_used_mb: Some(2048),
464                vendor: Some("MockGPU".to_string()),
465                device_name: Some("MockDevice".to_string()),
466            },
467        };
468
469        let value = to_json_value(&probe.status());
470        let obj = value.as_object().expect("status must be a JSON object");
471        assert_eq!(obj.get("available").and_then(|v| v.as_bool()), Some(true));
472        assert_eq!(obj.get("utilization").and_then(|v| v.as_f64()), Some(55.5));
473        assert_eq!(obj.get("mem_used_mb").and_then(|v| v.as_u64()), Some(2048));
474        assert_eq!(obj.get("vendor").and_then(|v| v.as_str()), Some("MockGPU"));
475        assert_eq!(
476            obj.get("device_name").and_then(|v| v.as_str()),
477            Some("MockDevice")
478        );
479    }
480}