1use serde::Serialize;
2
3#[derive(Debug, Clone, Serialize)]
4pub struct GpuStatus {
5 pub available: bool,
6 pub utilization: Option<f64>,
7 pub mem_used_mb: Option<u64>,
8 pub vendor: Option<String>,
9 pub device_name: Option<String>,
10}
11
12impl GpuStatus {
13 pub fn unavailable() -> Self {
14 Self {
15 available: false,
16 utilization: None,
17 mem_used_mb: None,
18 vendor: None,
19 device_name: None,
20 }
21 }
22}
23
24pub trait GpuProbe {
25 fn status(&self) -> GpuStatus;
26}
27
28pub struct MacGpuProbe;
29
30#[cfg(feature = "gpu-nvidia")]
31pub struct NvidiaGpuProbe;
32
33pub struct FallbackGpuProbe;
34
35#[cfg(target_os = "linux")]
36pub struct LinuxGpuProbe;
37
38#[cfg(target_os = "windows")]
39pub struct WindowsGpuProbe;
40
41#[cfg(all(target_os = "windows", feature = "gpu-windows"))]
42pub struct AmdWindowsProbe;
43
44impl GpuProbe for MacGpuProbe {
45 fn status(&self) -> GpuStatus {
46 #[cfg(target_os = "macos")]
47 {
48 GpuStatus {
50 available: true,
51 utilization: Some(0.0),
52 mem_used_mb: None,
53 vendor: None,
54 device_name: None,
55 }
56 }
57
58 #[cfg(not(target_os = "macos"))]
59 {
60 GpuStatus::unavailable()
61 }
62 }
63}
64
65#[cfg(feature = "gpu-nvidia")]
66impl GpuProbe for NvidiaGpuProbe {
67 fn status(&self) -> GpuStatus {
68 use nvml_wrapper::Nvml;
69
70 let nvml = match Nvml::init() {
71 Ok(nvml) => nvml,
72 Err(_) => return GpuStatus::unavailable(),
73 };
74 let device = match nvml.device_by_index(0) {
75 Ok(device) => device,
76 Err(_) => return GpuStatus::unavailable(),
77 };
78
79 let utilization = device
80 .utilization_rates()
81 .ok()
82 .map(|rates| rates.gpu as f64);
83 let mem_used_mb = device
84 .memory_info()
85 .ok()
86 .map(|mem| mem.used / (1024 * 1024));
87
88 GpuStatus {
89 available: true,
90 utilization,
91 mem_used_mb,
92 vendor: Some("NVIDIA".to_string()),
93 device_name: device.name().ok(),
94 }
95 }
96}
97
98impl GpuProbe for FallbackGpuProbe {
99 fn status(&self) -> GpuStatus {
100 GpuStatus::unavailable()
101 }
102}
103
104#[cfg(target_os = "linux")]
105impl GpuProbe for LinuxGpuProbe {
106 fn status(&self) -> GpuStatus {
107 #[cfg(feature = "gpu-nvidia")]
108 {
109 let status = NvidiaGpuProbe.status();
110 if status.available {
111 return status;
112 }
113 }
114
115 if let Some(status) = nvidia_smi_status() {
116 return status;
117 }
118 if let Some(status) = amd_status() {
119 return status;
120 }
121 if let Some(status) = intel_status() {
122 return status;
123 }
124
125 GpuStatus::unavailable()
126 }
127}
128
129#[cfg(target_os = "windows")]
130impl GpuProbe for WindowsGpuProbe {
131 fn status(&self) -> GpuStatus {
132 #[cfg(feature = "gpu-nvidia")]
133 {
134 return NvidiaGpuProbe.status();
135 }
136
137 #[cfg(all(not(feature = "gpu-nvidia"), feature = "gpu-windows"))]
138 {
139 return AmdWindowsProbe.status();
140 }
141
142 #[cfg(all(not(feature = "gpu-nvidia"), not(feature = "gpu-windows")))]
143 {
144 return GpuStatus::unavailable();
145 }
146 }
147}
148
149#[cfg(all(target_os = "windows", feature = "gpu-windows"))]
150impl GpuProbe for AmdWindowsProbe {
151 fn status(&self) -> GpuStatus {
152 if let Some(status) = windows_wmi_status() {
154 return status;
155 }
156 GpuStatus::unavailable()
157 }
158}
159
160#[cfg(target_os = "windows")]
161fn windows_wmi_status() -> Option<GpuStatus> {
162 None
163}
164
165pub fn platform_probe() -> Box<dyn GpuProbe> {
166 #[cfg(target_os = "windows")]
167 {
168 return Box::new(WindowsGpuProbe);
169 }
170
171 #[cfg(target_os = "linux")]
172 {
173 Box::new(LinuxGpuProbe)
174 }
175
176 #[cfg(all(
177 feature = "gpu-nvidia",
178 not(target_os = "windows"),
179 not(target_os = "linux"),
180 not(target_os = "macos")
181 ))]
182 {
183 return Box::new(NvidiaGpuProbe);
184 }
185
186 #[cfg(target_os = "macos")]
187 {
188 Box::new(MacGpuProbe)
189 }
190
191 #[cfg(all(
192 not(target_os = "macos"),
193 not(target_os = "windows"),
194 not(target_os = "linux"),
195 not(feature = "gpu-nvidia")
196 ))]
197 {
198 return Box::new(FallbackGpuProbe);
199 }
200}
201
202#[cfg(target_os = "linux")]
203fn nvidia_smi_status() -> Option<GpuStatus> {
204 use std::process::Command;
205
206 let output = Command::new("nvidia-smi")
207 .args([
208 "--query-gpu=name,utilization.gpu,memory.used",
209 "--format=csv,noheader,nounits",
210 ])
211 .output()
212 .ok()?;
213 if !output.status.success() {
214 return None;
215 }
216 let line = String::from_utf8_lossy(&output.stdout);
217 let mut parts = line.lines().next()?.split(',');
218 let name = parts.next()?.trim().to_string();
219 let utilization = parts.next().and_then(|v| v.trim().parse::<f64>().ok());
220 let mem_used_mb = parts.next().and_then(|v| v.trim().parse::<u64>().ok());
221
222 Some(GpuStatus {
223 available: true,
224 utilization,
225 mem_used_mb,
226 vendor: Some("NVIDIA".to_string()),
227 device_name: Some(name),
228 })
229}
230
231#[cfg(target_os = "linux")]
232fn amd_status() -> Option<GpuStatus> {
233 use std::process::Command;
234
235 if let Ok(output) = Command::new("rocm-smi")
236 .args(["--showuse", "--json"])
237 .output()
238 {
239 if output.status.success() {
240 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
241 if let Some(pct) = val
242 .get("card")
243 .and_then(|c| c.get(0))
244 .and_then(|c| c.get("GPU use (%)"))
245 .and_then(|v| v.as_f64())
246 {
247 return Some(GpuStatus {
248 available: true,
249 utilization: Some(pct),
250 mem_used_mb: amd_mem_used_mb(),
251 vendor: Some("AMD".to_string()),
252 device_name: None,
253 });
254 }
255 }
256 }
257 }
258
259 let output = Command::new("rocm-smi")
260 .arg("--showuse")
261 .output()
262 .or_else(|_| Command::new("radeontop").arg("--help").output())
263 .ok()?;
264 if !output.status.success() {
265 return None;
266 }
267 let text = String::from_utf8_lossy(&output.stdout);
268 for line in text.lines() {
269 if line.to_ascii_lowercase().contains("gpu use") {
270 if let Some(pct) = line
271 .split('%')
272 .next()
273 .and_then(|s| s.split_whitespace().last())
274 .and_then(|n| n.parse::<f64>().ok())
275 {
276 return Some(GpuStatus {
277 available: true,
278 utilization: Some(pct),
279 mem_used_mb: amd_mem_used_mb(),
280 vendor: Some("AMD".to_string()),
281 device_name: None,
282 });
283 }
284 }
285 }
286 None
287}
288
289#[cfg(target_os = "linux")]
290fn amd_mem_used_mb() -> Option<u64> {
291 use std::process::Command;
292
293 if let Ok(output) = Command::new("rocm-smi")
294 .args(["--showmeminfo", "vram", "--json"])
295 .output()
296 {
297 if output.status.success() {
298 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
299 if let Some(used) = val
300 .get("card")
301 .and_then(|c| c.get(0))
302 .and_then(|c| c.get("vram"))
303 .and_then(|v| v.get("used (B)"))
304 .and_then(|v| v.as_u64())
305 {
306 return Some(used / (1024 * 1024));
307 }
308 }
309 }
310 }
311 None
312}
313
314#[cfg(target_os = "linux")]
315fn intel_status() -> Option<GpuStatus> {
316 use std::process::Command;
317
318 let output = Command::new("intel_gpu_top").arg("--json").output().ok()?;
319 if !output.status.success() {
320 return None;
321 }
322 let text = String::from_utf8_lossy(&output.stdout);
323 let mem_used_mb = intel_mem_json(&output.stdout).or_else(|| intel_mem_text(&text));
324 for line in text.lines() {
325 if line.to_ascii_lowercase().contains("render/3d") {
326 if let Some(pct) = line
327 .split('%')
328 .next()
329 .and_then(|s| s.split_whitespace().last())
330 .and_then(|n| n.parse::<f64>().ok())
331 {
332 return Some(GpuStatus {
333 available: true,
334 utilization: Some(pct),
335 mem_used_mb,
336 vendor: Some("Intel".to_string()),
337 device_name: None,
338 });
339 }
340 }
341 }
342 None
343}
344
345#[cfg(target_os = "linux")]
346fn intel_mem_json(data: &[u8]) -> Option<u64> {
347 let val: serde_json::Value = serde_json::from_slice(data).ok()?;
348 find_mem_value(&val)
349}
350
351#[cfg(target_os = "linux")]
352fn find_mem_value(val: &serde_json::Value) -> Option<u64> {
353 match val {
354 serde_json::Value::Number(n) => n.as_u64(),
355 serde_json::Value::Object(map) => {
356 for (k, v) in map {
357 let key = k.to_ascii_lowercase();
358 if key.contains("mem") {
359 if let Some(n) = v.as_u64() {
360 if n > 10_000 {
361 return Some(n / (1024 * 1024));
362 }
363 return Some(n);
364 }
365 if let Some(f) = v.as_f64() {
366 if f > 10_000.0 {
367 return Some((f / 1024.0 / 1024.0) as u64);
368 }
369 return Some(f as u64);
370 }
371 }
372 if let Some(found) = find_mem_value(v) {
373 return Some(found);
374 }
375 }
376 None
377 }
378 serde_json::Value::Array(arr) => {
379 for v in arr {
380 if let Some(found) = find_mem_value(v) {
381 return Some(found);
382 }
383 }
384 None
385 }
386 _ => None,
387 }
388}
389
390#[cfg(target_os = "linux")]
391fn intel_mem_text(text: &str) -> Option<u64> {
392 for line in text.lines() {
393 let lower = line.to_ascii_lowercase();
394 if lower.contains("mem") {
395 if let Some(num) = line
396 .split_whitespace()
397 .map(|w| w.trim_end_matches(['%', 'm', 'M', 'b', 'B']))
398 .filter_map(|w| w.parse::<f64>().ok())
399 .next_back()
400 {
401 if num > 10_000.0 {
402 return Some((num / 1024.0 / 1024.0) as u64);
403 }
404 return Some(num as u64);
405 }
406 }
407 }
408 None
409}
410
411pub fn to_json_value(status: &GpuStatus) -> serde_json::Value {
412 serde_json::to_value(status).unwrap_or(serde_json::Value::Null)
413}
414
415pub fn write_status_json(status: &GpuStatus) -> Result<(), serde_json::Error> {
416 serde_json::to_writer(std::io::stdout(), status)
417}
418
419#[cfg(test)]
420mod tests {
421 use super::{to_json_value, FallbackGpuProbe, GpuProbe, GpuStatus};
422
423 struct MockProbe {
424 status: GpuStatus,
425 }
426
427 impl GpuProbe for MockProbe {
428 fn status(&self) -> GpuStatus {
429 self.status.clone()
430 }
431 }
432
433 #[test]
434 fn status_serializes_with_expected_keys() {
435 let status = GpuStatus::unavailable();
436 let value = to_json_value(&status);
437 let obj = value.as_object().expect("status must be a JSON object");
438 for key in [
439 "available",
440 "utilization",
441 "mem_used_mb",
442 "vendor",
443 "device_name",
444 ] {
445 assert!(obj.contains_key(key), "missing key: {key}");
446 }
447 }
448
449 #[test]
450 fn fallback_reports_unavailable() {
451 let status = FallbackGpuProbe.status();
452 assert!(!status.available);
453 assert!(status.utilization.is_none());
454 assert!(status.mem_used_mb.is_none());
455 }
456
457 #[test]
458 fn mock_probe_serializes_full_status() {
459 let probe = MockProbe {
460 status: GpuStatus {
461 available: true,
462 utilization: Some(55.5),
463 mem_used_mb: Some(2048),
464 vendor: Some("MockGPU".to_string()),
465 device_name: Some("MockDevice".to_string()),
466 },
467 };
468
469 let value = to_json_value(&probe.status());
470 let obj = value.as_object().expect("status must be a JSON object");
471 assert_eq!(obj.get("available").and_then(|v| v.as_bool()), Some(true));
472 assert_eq!(obj.get("utilization").and_then(|v| v.as_f64()), Some(55.5));
473 assert_eq!(obj.get("mem_used_mb").and_then(|v| v.as_u64()), Some(2048));
474 assert_eq!(obj.get("vendor").and_then(|v| v.as_str()), Some("MockGPU"));
475 assert_eq!(
476 obj.get("device_name").and_then(|v| v.as_str()),
477 Some("MockDevice")
478 );
479 }
480}