cortenforge_tools/
gpu_probe.rs

1use serde::Serialize;
2
3#[derive(Debug, Clone, Serialize)]
4pub struct GpuStatus {
5    pub available: bool,
6    pub utilization: Option<f64>,
7    pub mem_used_mb: Option<u64>,
8    pub vendor: Option<String>,
9    pub device_name: Option<String>,
10}
11
12impl GpuStatus {
13    pub fn unavailable() -> Self {
14        Self {
15            available: false,
16            utilization: None,
17            mem_used_mb: None,
18            vendor: None,
19            device_name: None,
20        }
21    }
22}
23
24pub trait GpuProbe {
25    fn status(&self) -> GpuStatus;
26}
27
28pub struct MacGpuProbe;
29
30#[cfg(feature = "gpu-nvidia")]
31pub struct NvidiaGpuProbe;
32
33pub struct FallbackGpuProbe;
34
35#[cfg(target_os = "linux")]
36pub struct LinuxGpuProbe;
37
38#[cfg(target_os = "windows")]
39pub struct WindowsGpuProbe;
40
41#[cfg(all(target_os = "windows", feature = "gpu_amd_windows"))]
42pub struct AmdWindowsProbe;
43
44impl GpuProbe for MacGpuProbe {
45    fn status(&self) -> GpuStatus {
46        #[cfg(target_os = "macos")]
47        {
48            // TODO: replace with real Metal/GPU stats. For now report availability with zero utilization.
49            GpuStatus {
50                available: true,
51                utilization: Some(0.0),
52                mem_used_mb: None,
53                vendor: None,
54                device_name: None,
55            }
56        }
57
58        #[cfg(not(target_os = "macos"))]
59        {
60            GpuStatus::unavailable()
61        }
62    }
63}
64
65#[cfg(feature = "gpu-nvidia")]
66impl GpuProbe for NvidiaGpuProbe {
67    fn status(&self) -> GpuStatus {
68        use nvml_wrapper::Nvml;
69
70        let nvml = match Nvml::init() {
71            Ok(nvml) => nvml,
72            Err(_) => return GpuStatus::unavailable(),
73        };
74        let device = match nvml.device_by_index(0) {
75            Ok(device) => device,
76            Err(_) => return GpuStatus::unavailable(),
77        };
78
79        let utilization = device
80            .utilization_rates()
81            .ok()
82            .map(|rates| rates.gpu as f64);
83        let mem_used_mb = device
84            .memory_info()
85            .ok()
86            .map(|mem| mem.used / (1024 * 1024));
87
88        GpuStatus {
89            available: true,
90            utilization,
91            mem_used_mb,
92            vendor: Some("NVIDIA".to_string()),
93            device_name: device.name().ok(),
94        }
95    }
96}
97
98impl GpuProbe for FallbackGpuProbe {
99    fn status(&self) -> GpuStatus {
100        GpuStatus::unavailable()
101    }
102}
103
104#[cfg(target_os = "linux")]
105impl GpuProbe for LinuxGpuProbe {
106    fn status(&self) -> GpuStatus {
107        #[cfg(feature = "gpu-nvidia")]
108        {
109            let status = NvidiaGpuProbe.status();
110            if status.available {
111                return status;
112            }
113        }
114
115        if let Some(status) = nvidia_smi_status() {
116            return status;
117        }
118        if let Some(status) = amd_status() {
119            return status;
120        }
121        if let Some(status) = intel_status() {
122            return status;
123        }
124
125        GpuStatus::unavailable()
126    }
127}
128
129#[cfg(target_os = "windows")]
130impl GpuProbe for WindowsGpuProbe {
131    fn status(&self) -> GpuStatus {
132        #[cfg(feature = "gpu-nvidia")]
133        {
134            return NvidiaGpuProbe.status();
135        }
136
137        #[cfg(all(not(feature = "gpu-nvidia"), feature = "gpu_amd_windows"))]
138        {
139            return AmdWindowsProbe.status();
140        }
141
142        #[cfg(all(not(feature = "gpu-nvidia"), not(feature = "gpu_amd_windows")))]
143        {
144            return GpuStatus::unavailable();
145        }
146    }
147}
148
149#[cfg(all(target_os = "windows", feature = "gpu_amd_windows"))]
150impl GpuProbe for AmdWindowsProbe {
151    fn status(&self) -> GpuStatus {
152        // TODO: replace with ADLX/WMI implementation when we add a Windows AMD backend.
153        GpuStatus::unavailable()
154    }
155}
156
157pub fn platform_probe() -> Box<dyn GpuProbe> {
158    #[cfg(target_os = "windows")]
159    {
160        return Box::new(WindowsGpuProbe);
161    }
162
163    #[cfg(target_os = "linux")]
164    {
165        Box::new(LinuxGpuProbe)
166    }
167
168    #[cfg(all(
169        feature = "gpu-nvidia",
170        not(target_os = "windows"),
171        not(target_os = "linux"),
172        not(target_os = "macos")
173    ))]
174    {
175        return Box::new(NvidiaGpuProbe);
176    }
177
178    #[cfg(target_os = "macos")]
179    {
180        Box::new(MacGpuProbe)
181    }
182
183    #[cfg(all(
184        not(target_os = "macos"),
185        not(target_os = "windows"),
186        not(target_os = "linux"),
187        not(feature = "gpu-nvidia")
188    ))]
189    {
190        return Box::new(FallbackGpuProbe);
191    }
192}
193
194#[cfg(target_os = "linux")]
195fn nvidia_smi_status() -> Option<GpuStatus> {
196    use std::process::Command;
197
198    let output = Command::new("nvidia-smi")
199        .args([
200            "--query-gpu=name,utilization.gpu,memory.used",
201            "--format=csv,noheader,nounits",
202        ])
203        .output()
204        .ok()?;
205    if !output.status.success() {
206        return None;
207    }
208    let line = String::from_utf8_lossy(&output.stdout);
209    let mut parts = line.lines().next()?.split(',');
210    let name = parts.next()?.trim().to_string();
211    let utilization = parts.next().and_then(|v| v.trim().parse::<f64>().ok());
212    let mem_used_mb = parts.next().and_then(|v| v.trim().parse::<u64>().ok());
213
214    Some(GpuStatus {
215        available: true,
216        utilization,
217        mem_used_mb,
218        vendor: Some("NVIDIA".to_string()),
219        device_name: Some(name),
220    })
221}
222
223#[cfg(target_os = "linux")]
224fn amd_status() -> Option<GpuStatus> {
225    use std::process::Command;
226
227    if let Ok(output) = Command::new("rocm-smi")
228        .args(["--showuse", "--json"])
229        .output()
230    {
231        if output.status.success() {
232            if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
233                if let Some(pct) = val
234                    .get("card")
235                    .and_then(|c| c.get(0))
236                    .and_then(|c| c.get("GPU use (%)"))
237                    .and_then(|v| v.as_f64())
238                {
239                    return Some(GpuStatus {
240                        available: true,
241                        utilization: Some(pct),
242                        mem_used_mb: amd_mem_used_mb(),
243                        vendor: Some("AMD".to_string()),
244                        device_name: None,
245                    });
246                }
247            }
248        }
249    }
250
251    let output = Command::new("rocm-smi")
252        .arg("--showuse")
253        .output()
254        .or_else(|_| Command::new("radeontop").arg("--help").output())
255        .ok()?;
256    if !output.status.success() {
257        return None;
258    }
259    let text = String::from_utf8_lossy(&output.stdout);
260    for line in text.lines() {
261        if line.to_ascii_lowercase().contains("gpu use") {
262            if let Some(pct) = line
263                .split('%')
264                .next()
265                .and_then(|s| s.split_whitespace().last())
266                .and_then(|n| n.parse::<f64>().ok())
267            {
268                return Some(GpuStatus {
269                    available: true,
270                    utilization: Some(pct),
271                    mem_used_mb: amd_mem_used_mb(),
272                    vendor: Some("AMD".to_string()),
273                    device_name: None,
274                });
275            }
276        }
277    }
278    None
279}
280
281#[cfg(target_os = "linux")]
282fn amd_mem_used_mb() -> Option<u64> {
283    use std::process::Command;
284
285    if let Ok(output) = Command::new("rocm-smi")
286        .args(["--showmeminfo", "vram", "--json"])
287        .output()
288    {
289        if output.status.success() {
290            if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
291                if let Some(used) = val
292                    .get("card")
293                    .and_then(|c| c.get(0))
294                    .and_then(|c| c.get("vram"))
295                    .and_then(|v| v.get("used (B)"))
296                    .and_then(|v| v.as_u64())
297                {
298                    return Some(used / (1024 * 1024));
299                }
300            }
301        }
302    }
303    None
304}
305
306#[cfg(target_os = "linux")]
307fn intel_status() -> Option<GpuStatus> {
308    use std::process::Command;
309
310    let output = Command::new("intel_gpu_top").arg("--json").output().ok()?;
311    if !output.status.success() {
312        return None;
313    }
314    let text = String::from_utf8_lossy(&output.stdout);
315    let mem_used_mb = intel_mem_json(&output.stdout).or_else(|| intel_mem_text(&text));
316    for line in text.lines() {
317        if line.to_ascii_lowercase().contains("render/3d") {
318            if let Some(pct) = line
319                .split('%')
320                .next()
321                .and_then(|s| s.split_whitespace().last())
322                .and_then(|n| n.parse::<f64>().ok())
323            {
324                return Some(GpuStatus {
325                    available: true,
326                    utilization: Some(pct),
327                    mem_used_mb,
328                    vendor: Some("Intel".to_string()),
329                    device_name: None,
330                });
331            }
332        }
333    }
334    None
335}
336
337#[cfg(target_os = "linux")]
338fn intel_mem_json(data: &[u8]) -> Option<u64> {
339    let val: serde_json::Value = serde_json::from_slice(data).ok()?;
340    find_mem_value(&val)
341}
342
343#[cfg(target_os = "linux")]
344fn find_mem_value(val: &serde_json::Value) -> Option<u64> {
345    match val {
346        serde_json::Value::Number(n) => n.as_u64(),
347        serde_json::Value::Object(map) => {
348            for (k, v) in map {
349                let key = k.to_ascii_lowercase();
350                if key.contains("mem") {
351                    if let Some(n) = v.as_u64() {
352                        if n > 10_000 {
353                            return Some(n / (1024 * 1024));
354                        }
355                        return Some(n);
356                    }
357                    if let Some(f) = v.as_f64() {
358                        if f > 10_000.0 {
359                            return Some((f / 1024.0 / 1024.0) as u64);
360                        }
361                        return Some(f as u64);
362                    }
363                }
364                if let Some(found) = find_mem_value(v) {
365                    return Some(found);
366                }
367            }
368            None
369        }
370        serde_json::Value::Array(arr) => {
371            for v in arr {
372                if let Some(found) = find_mem_value(v) {
373                    return Some(found);
374                }
375            }
376            None
377        }
378        _ => None,
379    }
380}
381
382#[cfg(target_os = "linux")]
383fn intel_mem_text(text: &str) -> Option<u64> {
384    for line in text.lines() {
385        let lower = line.to_ascii_lowercase();
386        if lower.contains("mem") {
387            if let Some(num) = line
388                .split_whitespace()
389                .map(|w| w.trim_end_matches(['%', 'm', 'M', 'b', 'B']))
390                .filter_map(|w| w.parse::<f64>().ok())
391                .next_back()
392            {
393                if num > 10_000.0 {
394                    return Some((num / 1024.0 / 1024.0) as u64);
395                }
396                return Some(num as u64);
397            }
398        }
399    }
400    None
401}
402
403pub fn to_json_value(status: &GpuStatus) -> serde_json::Value {
404    serde_json::to_value(status).unwrap_or(serde_json::Value::Null)
405}
406
407pub fn write_status_json(status: &GpuStatus) -> Result<(), serde_json::Error> {
408    serde_json::to_writer(std::io::stdout(), status)
409}
410
411#[cfg(test)]
412mod tests {
413    use super::{to_json_value, FallbackGpuProbe, GpuProbe, GpuStatus};
414
415    struct MockProbe {
416        status: GpuStatus,
417    }
418
419    impl GpuProbe for MockProbe {
420        fn status(&self) -> GpuStatus {
421            self.status.clone()
422        }
423    }
424
425    #[test]
426    fn status_serializes_with_expected_keys() {
427        let status = GpuStatus::unavailable();
428        let value = to_json_value(&status);
429        let obj = value.as_object().expect("status must be a JSON object");
430        for key in [
431            "available",
432            "utilization",
433            "mem_used_mb",
434            "vendor",
435            "device_name",
436        ] {
437            assert!(obj.contains_key(key), "missing key: {key}");
438        }
439    }
440
441    #[test]
442    fn fallback_reports_unavailable() {
443        let status = FallbackGpuProbe.status();
444        assert!(!status.available);
445        assert!(status.utilization.is_none());
446        assert!(status.mem_used_mb.is_none());
447    }
448
449    #[test]
450    fn mock_probe_serializes_full_status() {
451        let probe = MockProbe {
452            status: GpuStatus {
453                available: true,
454                utilization: Some(55.5),
455                mem_used_mb: Some(2048),
456                vendor: Some("MockGPU".to_string()),
457                device_name: Some("MockDevice".to_string()),
458            },
459        };
460
461        let value = to_json_value(&probe.status());
462        let obj = value.as_object().expect("status must be a JSON object");
463        assert_eq!(obj.get("available").and_then(|v| v.as_bool()), Some(true));
464        assert_eq!(obj.get("utilization").and_then(|v| v.as_f64()), Some(55.5));
465        assert_eq!(obj.get("mem_used_mb").and_then(|v| v.as_u64()), Some(2048));
466        assert_eq!(obj.get("vendor").and_then(|v| v.as_str()), Some("MockGPU"));
467        assert_eq!(
468            obj.get("device_name").and_then(|v| v.as_str()),
469            Some("MockDevice")
470        );
471    }
472}