cortenforge_tools/
gpu_probe.rs

1use serde::Serialize;
2
3#[derive(Debug, Clone, Serialize)]
4pub struct GpuStatus {
5    pub available: bool,
6    pub utilization: Option<f64>,
7    pub mem_used_mb: Option<u64>,
8    pub vendor: Option<String>,
9    pub device_name: Option<String>,
10}
11
12impl GpuStatus {
13    pub fn unavailable() -> Self {
14        Self {
15            available: false,
16            utilization: None,
17            mem_used_mb: None,
18            vendor: None,
19            device_name: None,
20        }
21    }
22}
23
24pub trait GpuProbe {
25    fn status(&self) -> GpuStatus;
26}
27
28pub struct MacGpuProbe;
29
30#[cfg(feature = "gpu_nvidia")]
31pub struct NvidiaGpuProbe;
32
33pub struct FallbackGpuProbe;
34
35#[cfg(target_os = "linux")]
36pub struct LinuxGpuProbe;
37
38#[cfg(target_os = "windows")]
39pub struct WindowsGpuProbe;
40
41#[cfg(all(target_os = "windows", feature = "gpu_amd_windows"))]
42pub struct AmdWindowsProbe;
43
44impl GpuProbe for MacGpuProbe {
45    fn status(&self) -> GpuStatus {
46        #[cfg(target_os = "macos")]
47        {
48            // TODO: replace with real Metal/GPU stats. For now report availability with zero utilization.
49            return GpuStatus {
50                available: true,
51                utilization: Some(0.0),
52                mem_used_mb: None,
53                vendor: None,
54                device_name: None,
55            };
56        }
57
58        #[cfg(not(target_os = "macos"))]
59        {
60            return GpuStatus::unavailable();
61        }
62    }
63}
64
65#[cfg(feature = "gpu_nvidia")]
66impl GpuProbe for NvidiaGpuProbe {
67    fn status(&self) -> GpuStatus {
68        use nvml_wrapper::Nvml;
69
70        let nvml = match Nvml::init() {
71            Ok(nvml) => nvml,
72            Err(_) => return GpuStatus::unavailable(),
73        };
74        let device = match nvml.device_by_index(0) {
75            Ok(device) => device,
76            Err(_) => return GpuStatus::unavailable(),
77        };
78
79        let utilization = device
80            .utilization_rates()
81            .ok()
82            .map(|rates| rates.gpu as f64);
83        let mem_used_mb = device
84            .memory_info()
85            .ok()
86            .map(|mem| mem.used / (1024 * 1024));
87
88        GpuStatus {
89            available: true,
90            utilization,
91            mem_used_mb,
92            vendor: Some("NVIDIA".to_string()),
93            device_name: device.name().ok(),
94        }
95    }
96}
97
98impl GpuProbe for FallbackGpuProbe {
99    fn status(&self) -> GpuStatus {
100        GpuStatus::unavailable()
101    }
102}
103
104#[cfg(target_os = "linux")]
105impl GpuProbe for LinuxGpuProbe {
106    fn status(&self) -> GpuStatus {
107        #[cfg(feature = "gpu_nvidia")]
108        {
109            let status = NvidiaGpuProbe.status();
110            if status.available {
111                return status;
112            }
113        }
114
115        if let Some(status) = nvidia_smi_status() {
116            return status;
117        }
118        if let Some(status) = amd_status() {
119            return status;
120        }
121        if let Some(status) = intel_status() {
122            return status;
123        }
124
125        GpuStatus::unavailable()
126    }
127}
128
129#[cfg(target_os = "windows")]
130impl GpuProbe for WindowsGpuProbe {
131    fn status(&self) -> GpuStatus {
132        #[cfg(feature = "gpu_nvidia")]
133        {
134            return NvidiaGpuProbe.status();
135        }
136
137        #[cfg(all(not(feature = "gpu_nvidia"), feature = "gpu_amd_windows"))]
138        {
139            return AmdWindowsProbe.status();
140        }
141
142        #[cfg(all(not(feature = "gpu_nvidia"), not(feature = "gpu_amd_windows")))]
143        {
144            return GpuStatus::unavailable();
145        }
146    }
147}
148
149#[cfg(all(target_os = "windows", feature = "gpu_amd_windows"))]
150impl GpuProbe for AmdWindowsProbe {
151    fn status(&self) -> GpuStatus {
152        // TODO: replace with ADLX/WMI implementation when we add a Windows AMD backend.
153        GpuStatus::unavailable()
154    }
155}
156
157pub fn platform_probe() -> Box<dyn GpuProbe> {
158    #[cfg(target_os = "windows")]
159    {
160        return Box::new(WindowsGpuProbe);
161    }
162
163    #[cfg(target_os = "linux")]
164    {
165        return Box::new(LinuxGpuProbe);
166    }
167
168    #[cfg(all(feature = "gpu_nvidia", not(target_os = "windows"), not(target_os = "linux")))]
169    {
170        return Box::new(NvidiaGpuProbe);
171    }
172
173    #[cfg(target_os = "macos")]
174    {
175        return Box::new(MacGpuProbe);
176    }
177
178    #[cfg(all(
179        not(target_os = "macos"),
180        not(target_os = "windows"),
181        not(target_os = "linux"),
182        not(feature = "gpu_nvidia")
183    ))]
184    {
185        return Box::new(FallbackGpuProbe);
186    }
187}
188
189#[cfg(target_os = "linux")]
190fn nvidia_smi_status() -> Option<GpuStatus> {
191    use std::process::Command;
192
193    let output = Command::new("nvidia-smi")
194        .args([
195            "--query-gpu=name,utilization.gpu,memory.used",
196            "--format=csv,noheader,nounits",
197        ])
198        .output()
199        .ok()?;
200    if !output.status.success() {
201        return None;
202    }
203    let line = String::from_utf8_lossy(&output.stdout);
204    let mut parts = line.lines().next()?.split(',');
205    let name = parts.next()?.trim().to_string();
206    let utilization = parts.next().and_then(|v| v.trim().parse::<f64>().ok());
207    let mem_used_mb = parts.next().and_then(|v| v.trim().parse::<u64>().ok());
208
209    Some(GpuStatus {
210        available: true,
211        utilization,
212        mem_used_mb,
213        vendor: Some("NVIDIA".to_string()),
214        device_name: Some(name),
215    })
216}
217
218#[cfg(target_os = "linux")]
219fn amd_status() -> Option<GpuStatus> {
220    use std::process::Command;
221
222    if let Ok(output) = Command::new("rocm-smi")
223        .args(["--showuse", "--json"])
224        .output()
225    {
226        if output.status.success() {
227            if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
228                if let Some(pct) = val
229                    .get("card")
230                    .and_then(|c| c.get(0))
231                    .and_then(|c| c.get("GPU use (%)"))
232                    .and_then(|v| v.as_f64())
233                {
234                    return Some(GpuStatus {
235                        available: true,
236                        utilization: Some(pct),
237                        mem_used_mb: amd_mem_used_mb(),
238                        vendor: Some("AMD".to_string()),
239                        device_name: None,
240                    });
241                }
242            }
243        }
244    }
245
246    let output = Command::new("rocm-smi")
247        .arg("--showuse")
248        .output()
249        .or_else(|_| Command::new("radeontop").arg("--help").output())
250        .ok()?;
251    if !output.status.success() {
252        return None;
253    }
254    let text = String::from_utf8_lossy(&output.stdout);
255    for line in text.lines() {
256        if line.to_ascii_lowercase().contains("gpu use") {
257            if let Some(pct) = line
258                .split('%')
259                .next()
260                .and_then(|s| s.split_whitespace().last())
261                .and_then(|n| n.parse::<f64>().ok())
262            {
263                return Some(GpuStatus {
264                    available: true,
265                    utilization: Some(pct),
266                    mem_used_mb: amd_mem_used_mb(),
267                    vendor: Some("AMD".to_string()),
268                    device_name: None,
269                });
270            }
271        }
272    }
273    None
274}
275
276#[cfg(target_os = "linux")]
277fn amd_mem_used_mb() -> Option<u64> {
278    use std::process::Command;
279
280    if let Ok(output) = Command::new("rocm-smi")
281        .args(["--showmeminfo", "vram", "--json"])
282        .output()
283    {
284        if output.status.success() {
285            if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
286                if let Some(used) = val
287                    .get("card")
288                    .and_then(|c| c.get(0))
289                    .and_then(|c| c.get("vram"))
290                    .and_then(|v| v.get("used (B)"))
291                    .and_then(|v| v.as_u64())
292                {
293                    return Some(used / (1024 * 1024));
294                }
295            }
296        }
297    }
298    None
299}
300
301#[cfg(target_os = "linux")]
302fn intel_status() -> Option<GpuStatus> {
303    use std::process::Command;
304
305    let output = Command::new("intel_gpu_top").arg("--json").output().ok()?;
306    if !output.status.success() {
307        return None;
308    }
309    let text = String::from_utf8_lossy(&output.stdout);
310    let mem_used_mb = intel_mem_json(&output.stdout).or_else(|| intel_mem_text(&text));
311    for line in text.lines() {
312        if line.to_ascii_lowercase().contains("render/3d") {
313            if let Some(pct) = line
314                .split('%')
315                .next()
316                .and_then(|s| s.split_whitespace().last())
317                .and_then(|n| n.parse::<f64>().ok())
318            {
319                return Some(GpuStatus {
320                    available: true,
321                    utilization: Some(pct),
322                    mem_used_mb,
323                    vendor: Some("Intel".to_string()),
324                    device_name: None,
325                });
326            }
327        }
328    }
329    None
330}
331
332#[cfg(target_os = "linux")]
333fn intel_mem_json(data: &[u8]) -> Option<u64> {
334    let val: serde_json::Value = serde_json::from_slice(data).ok()?;
335    find_mem_value(&val)
336}
337
338#[cfg(target_os = "linux")]
339fn find_mem_value(val: &serde_json::Value) -> Option<u64> {
340    match val {
341        serde_json::Value::Number(n) => n.as_u64(),
342        serde_json::Value::Object(map) => {
343            for (k, v) in map {
344                let key = k.to_ascii_lowercase();
345                if key.contains("mem") {
346                    if let Some(n) = v.as_u64() {
347                        if n > 10_000 {
348                            return Some(n / (1024 * 1024));
349                        }
350                        return Some(n);
351                    }
352                    if let Some(f) = v.as_f64() {
353                        if f > 10_000.0 {
354                            return Some((f / 1024.0 / 1024.0) as u64);
355                        }
356                        return Some(f as u64);
357                    }
358                }
359                if let Some(found) = find_mem_value(v) {
360                    return Some(found);
361                }
362            }
363            None
364        }
365        serde_json::Value::Array(arr) => {
366            for v in arr {
367                if let Some(found) = find_mem_value(v) {
368                    return Some(found);
369                }
370            }
371            None
372        }
373        _ => None,
374    }
375}
376
377#[cfg(target_os = "linux")]
378fn intel_mem_text(text: &str) -> Option<u64> {
379    for line in text.lines() {
380        let lower = line.to_ascii_lowercase();
381        if lower.contains("mem") {
382            if let Some(num) = line
383                .split_whitespace()
384                .map(|w| w.trim_end_matches(['%', 'm', 'M', 'b', 'B']))
385                .filter_map(|w| w.parse::<f64>().ok())
386                .next_back()
387            {
388                if num > 10_000.0 {
389                    return Some((num / 1024.0 / 1024.0) as u64);
390                }
391                return Some(num as u64);
392            }
393        }
394    }
395    None
396}
397
398pub fn to_json_value(status: &GpuStatus) -> serde_json::Value {
399    serde_json::to_value(status).unwrap_or_else(|_| serde_json::Value::Null)
400}
401
402pub fn write_status_json(status: &GpuStatus) -> Result<(), serde_json::Error> {
403    serde_json::to_writer(std::io::stdout(), status)
404}
405
406#[cfg(test)]
407mod tests {
408    use super::{to_json_value, FallbackGpuProbe, GpuProbe, GpuStatus};
409
410    struct MockProbe {
411        status: GpuStatus,
412    }
413
414    impl GpuProbe for MockProbe {
415        fn status(&self) -> GpuStatus {
416            self.status.clone()
417        }
418    }
419
420    #[test]
421    fn status_serializes_with_expected_keys() {
422        let status = GpuStatus::unavailable();
423        let value = to_json_value(&status);
424        let obj = value.as_object().expect("status must be a JSON object");
425        for key in ["available", "utilization", "mem_used_mb", "vendor", "device_name"] {
426            assert!(obj.contains_key(key), "missing key: {key}");
427        }
428    }
429
430    #[test]
431    fn fallback_reports_unavailable() {
432        let status = FallbackGpuProbe.status();
433        assert!(!status.available);
434        assert!(status.utilization.is_none());
435        assert!(status.mem_used_mb.is_none());
436    }
437
438    #[test]
439    fn mock_probe_serializes_full_status() {
440        let probe = MockProbe {
441            status: GpuStatus {
442                available: true,
443                utilization: Some(55.5),
444                mem_used_mb: Some(2048),
445                vendor: Some("MockGPU".to_string()),
446                device_name: Some("MockDevice".to_string()),
447            },
448        };
449
450        let value = to_json_value(&probe.status());
451        let obj = value.as_object().expect("status must be a JSON object");
452        assert_eq!(obj.get("available").and_then(|v| v.as_bool()), Some(true));
453        assert_eq!(
454            obj.get("utilization").and_then(|v| v.as_f64()),
455            Some(55.5)
456        );
457        assert_eq!(
458            obj.get("mem_used_mb").and_then(|v| v.as_u64()),
459            Some(2048)
460        );
461        assert_eq!(
462            obj.get("vendor").and_then(|v| v.as_str()),
463            Some("MockGPU")
464        );
465        assert_eq!(
466            obj.get("device_name").and_then(|v| v.as_str()),
467            Some("MockDevice")
468        );
469    }
470}