1use serde::Serialize;
2
3#[derive(Debug, Clone, Serialize)]
4pub struct GpuStatus {
5 pub available: bool,
6 pub utilization: Option<f64>,
7 pub mem_used_mb: Option<u64>,
8 pub vendor: Option<String>,
9 pub device_name: Option<String>,
10}
11
12impl GpuStatus {
13 pub fn unavailable() -> Self {
14 Self {
15 available: false,
16 utilization: None,
17 mem_used_mb: None,
18 vendor: None,
19 device_name: None,
20 }
21 }
22}
23
24pub trait GpuProbe {
25 fn status(&self) -> GpuStatus;
26}
27
28pub struct MacGpuProbe;
29
30#[cfg(feature = "gpu_nvidia")]
31pub struct NvidiaGpuProbe;
32
33pub struct FallbackGpuProbe;
34
35#[cfg(target_os = "linux")]
36pub struct LinuxGpuProbe;
37
38#[cfg(target_os = "windows")]
39pub struct WindowsGpuProbe;
40
41#[cfg(all(target_os = "windows", feature = "gpu_amd_windows"))]
42pub struct AmdWindowsProbe;
43
44impl GpuProbe for MacGpuProbe {
45 fn status(&self) -> GpuStatus {
46 #[cfg(target_os = "macos")]
47 {
48 return GpuStatus {
50 available: true,
51 utilization: Some(0.0),
52 mem_used_mb: None,
53 vendor: None,
54 device_name: None,
55 };
56 }
57
58 #[cfg(not(target_os = "macos"))]
59 {
60 return GpuStatus::unavailable();
61 }
62 }
63}
64
65#[cfg(feature = "gpu_nvidia")]
66impl GpuProbe for NvidiaGpuProbe {
67 fn status(&self) -> GpuStatus {
68 use nvml_wrapper::Nvml;
69
70 let nvml = match Nvml::init() {
71 Ok(nvml) => nvml,
72 Err(_) => return GpuStatus::unavailable(),
73 };
74 let device = match nvml.device_by_index(0) {
75 Ok(device) => device,
76 Err(_) => return GpuStatus::unavailable(),
77 };
78
79 let utilization = device
80 .utilization_rates()
81 .ok()
82 .map(|rates| rates.gpu as f64);
83 let mem_used_mb = device
84 .memory_info()
85 .ok()
86 .map(|mem| mem.used / (1024 * 1024));
87
88 GpuStatus {
89 available: true,
90 utilization,
91 mem_used_mb,
92 vendor: Some("NVIDIA".to_string()),
93 device_name: device.name().ok(),
94 }
95 }
96}
97
98impl GpuProbe for FallbackGpuProbe {
99 fn status(&self) -> GpuStatus {
100 GpuStatus::unavailable()
101 }
102}
103
104#[cfg(target_os = "linux")]
105impl GpuProbe for LinuxGpuProbe {
106 fn status(&self) -> GpuStatus {
107 #[cfg(feature = "gpu_nvidia")]
108 {
109 let status = NvidiaGpuProbe.status();
110 if status.available {
111 return status;
112 }
113 }
114
115 if let Some(status) = nvidia_smi_status() {
116 return status;
117 }
118 if let Some(status) = amd_status() {
119 return status;
120 }
121 if let Some(status) = intel_status() {
122 return status;
123 }
124
125 GpuStatus::unavailable()
126 }
127}
128
129#[cfg(target_os = "windows")]
130impl GpuProbe for WindowsGpuProbe {
131 fn status(&self) -> GpuStatus {
132 #[cfg(feature = "gpu_nvidia")]
133 {
134 return NvidiaGpuProbe.status();
135 }
136
137 #[cfg(all(not(feature = "gpu_nvidia"), feature = "gpu_amd_windows"))]
138 {
139 return AmdWindowsProbe.status();
140 }
141
142 #[cfg(all(not(feature = "gpu_nvidia"), not(feature = "gpu_amd_windows")))]
143 {
144 return GpuStatus::unavailable();
145 }
146 }
147}
148
149#[cfg(all(target_os = "windows", feature = "gpu_amd_windows"))]
150impl GpuProbe for AmdWindowsProbe {
151 fn status(&self) -> GpuStatus {
152 GpuStatus::unavailable()
154 }
155}
156
157pub fn platform_probe() -> Box<dyn GpuProbe> {
158 #[cfg(target_os = "windows")]
159 {
160 return Box::new(WindowsGpuProbe);
161 }
162
163 #[cfg(target_os = "linux")]
164 {
165 return Box::new(LinuxGpuProbe);
166 }
167
168 #[cfg(all(feature = "gpu_nvidia", not(target_os = "windows"), not(target_os = "linux")))]
169 {
170 return Box::new(NvidiaGpuProbe);
171 }
172
173 #[cfg(target_os = "macos")]
174 {
175 return Box::new(MacGpuProbe);
176 }
177
178 #[cfg(all(
179 not(target_os = "macos"),
180 not(target_os = "windows"),
181 not(target_os = "linux"),
182 not(feature = "gpu_nvidia")
183 ))]
184 {
185 return Box::new(FallbackGpuProbe);
186 }
187}
188
189#[cfg(target_os = "linux")]
190fn nvidia_smi_status() -> Option<GpuStatus> {
191 use std::process::Command;
192
193 let output = Command::new("nvidia-smi")
194 .args([
195 "--query-gpu=name,utilization.gpu,memory.used",
196 "--format=csv,noheader,nounits",
197 ])
198 .output()
199 .ok()?;
200 if !output.status.success() {
201 return None;
202 }
203 let line = String::from_utf8_lossy(&output.stdout);
204 let mut parts = line.lines().next()?.split(',');
205 let name = parts.next()?.trim().to_string();
206 let utilization = parts.next().and_then(|v| v.trim().parse::<f64>().ok());
207 let mem_used_mb = parts.next().and_then(|v| v.trim().parse::<u64>().ok());
208
209 Some(GpuStatus {
210 available: true,
211 utilization,
212 mem_used_mb,
213 vendor: Some("NVIDIA".to_string()),
214 device_name: Some(name),
215 })
216}
217
218#[cfg(target_os = "linux")]
219fn amd_status() -> Option<GpuStatus> {
220 use std::process::Command;
221
222 if let Ok(output) = Command::new("rocm-smi")
223 .args(["--showuse", "--json"])
224 .output()
225 {
226 if output.status.success() {
227 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
228 if let Some(pct) = val
229 .get("card")
230 .and_then(|c| c.get(0))
231 .and_then(|c| c.get("GPU use (%)"))
232 .and_then(|v| v.as_f64())
233 {
234 return Some(GpuStatus {
235 available: true,
236 utilization: Some(pct),
237 mem_used_mb: amd_mem_used_mb(),
238 vendor: Some("AMD".to_string()),
239 device_name: None,
240 });
241 }
242 }
243 }
244 }
245
246 let output = Command::new("rocm-smi")
247 .arg("--showuse")
248 .output()
249 .or_else(|_| Command::new("radeontop").arg("--help").output())
250 .ok()?;
251 if !output.status.success() {
252 return None;
253 }
254 let text = String::from_utf8_lossy(&output.stdout);
255 for line in text.lines() {
256 if line.to_ascii_lowercase().contains("gpu use") {
257 if let Some(pct) = line
258 .split('%')
259 .next()
260 .and_then(|s| s.split_whitespace().last())
261 .and_then(|n| n.parse::<f64>().ok())
262 {
263 return Some(GpuStatus {
264 available: true,
265 utilization: Some(pct),
266 mem_used_mb: amd_mem_used_mb(),
267 vendor: Some("AMD".to_string()),
268 device_name: None,
269 });
270 }
271 }
272 }
273 None
274}
275
276#[cfg(target_os = "linux")]
277fn amd_mem_used_mb() -> Option<u64> {
278 use std::process::Command;
279
280 if let Ok(output) = Command::new("rocm-smi")
281 .args(["--showmeminfo", "vram", "--json"])
282 .output()
283 {
284 if output.status.success() {
285 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
286 if let Some(used) = val
287 .get("card")
288 .and_then(|c| c.get(0))
289 .and_then(|c| c.get("vram"))
290 .and_then(|v| v.get("used (B)"))
291 .and_then(|v| v.as_u64())
292 {
293 return Some(used / (1024 * 1024));
294 }
295 }
296 }
297 }
298 None
299}
300
301#[cfg(target_os = "linux")]
302fn intel_status() -> Option<GpuStatus> {
303 use std::process::Command;
304
305 let output = Command::new("intel_gpu_top").arg("--json").output().ok()?;
306 if !output.status.success() {
307 return None;
308 }
309 let text = String::from_utf8_lossy(&output.stdout);
310 let mem_used_mb = intel_mem_json(&output.stdout).or_else(|| intel_mem_text(&text));
311 for line in text.lines() {
312 if line.to_ascii_lowercase().contains("render/3d") {
313 if let Some(pct) = line
314 .split('%')
315 .next()
316 .and_then(|s| s.split_whitespace().last())
317 .and_then(|n| n.parse::<f64>().ok())
318 {
319 return Some(GpuStatus {
320 available: true,
321 utilization: Some(pct),
322 mem_used_mb,
323 vendor: Some("Intel".to_string()),
324 device_name: None,
325 });
326 }
327 }
328 }
329 None
330}
331
332#[cfg(target_os = "linux")]
333fn intel_mem_json(data: &[u8]) -> Option<u64> {
334 let val: serde_json::Value = serde_json::from_slice(data).ok()?;
335 find_mem_value(&val)
336}
337
338#[cfg(target_os = "linux")]
339fn find_mem_value(val: &serde_json::Value) -> Option<u64> {
340 match val {
341 serde_json::Value::Number(n) => n.as_u64(),
342 serde_json::Value::Object(map) => {
343 for (k, v) in map {
344 let key = k.to_ascii_lowercase();
345 if key.contains("mem") {
346 if let Some(n) = v.as_u64() {
347 if n > 10_000 {
348 return Some(n / (1024 * 1024));
349 }
350 return Some(n);
351 }
352 if let Some(f) = v.as_f64() {
353 if f > 10_000.0 {
354 return Some((f / 1024.0 / 1024.0) as u64);
355 }
356 return Some(f as u64);
357 }
358 }
359 if let Some(found) = find_mem_value(v) {
360 return Some(found);
361 }
362 }
363 None
364 }
365 serde_json::Value::Array(arr) => {
366 for v in arr {
367 if let Some(found) = find_mem_value(v) {
368 return Some(found);
369 }
370 }
371 None
372 }
373 _ => None,
374 }
375}
376
377#[cfg(target_os = "linux")]
378fn intel_mem_text(text: &str) -> Option<u64> {
379 for line in text.lines() {
380 let lower = line.to_ascii_lowercase();
381 if lower.contains("mem") {
382 if let Some(num) = line
383 .split_whitespace()
384 .map(|w| w.trim_end_matches(['%', 'm', 'M', 'b', 'B']))
385 .filter_map(|w| w.parse::<f64>().ok())
386 .next_back()
387 {
388 if num > 10_000.0 {
389 return Some((num / 1024.0 / 1024.0) as u64);
390 }
391 return Some(num as u64);
392 }
393 }
394 }
395 None
396}
397
398pub fn to_json_value(status: &GpuStatus) -> serde_json::Value {
399 serde_json::to_value(status).unwrap_or_else(|_| serde_json::Value::Null)
400}
401
402pub fn write_status_json(status: &GpuStatus) -> Result<(), serde_json::Error> {
403 serde_json::to_writer(std::io::stdout(), status)
404}
405
406#[cfg(test)]
407mod tests {
408 use super::{to_json_value, FallbackGpuProbe, GpuProbe, GpuStatus};
409
410 struct MockProbe {
411 status: GpuStatus,
412 }
413
414 impl GpuProbe for MockProbe {
415 fn status(&self) -> GpuStatus {
416 self.status.clone()
417 }
418 }
419
420 #[test]
421 fn status_serializes_with_expected_keys() {
422 let status = GpuStatus::unavailable();
423 let value = to_json_value(&status);
424 let obj = value.as_object().expect("status must be a JSON object");
425 for key in ["available", "utilization", "mem_used_mb", "vendor", "device_name"] {
426 assert!(obj.contains_key(key), "missing key: {key}");
427 }
428 }
429
430 #[test]
431 fn fallback_reports_unavailable() {
432 let status = FallbackGpuProbe.status();
433 assert!(!status.available);
434 assert!(status.utilization.is_none());
435 assert!(status.mem_used_mb.is_none());
436 }
437
438 #[test]
439 fn mock_probe_serializes_full_status() {
440 let probe = MockProbe {
441 status: GpuStatus {
442 available: true,
443 utilization: Some(55.5),
444 mem_used_mb: Some(2048),
445 vendor: Some("MockGPU".to_string()),
446 device_name: Some("MockDevice".to_string()),
447 },
448 };
449
450 let value = to_json_value(&probe.status());
451 let obj = value.as_object().expect("status must be a JSON object");
452 assert_eq!(obj.get("available").and_then(|v| v.as_bool()), Some(true));
453 assert_eq!(
454 obj.get("utilization").and_then(|v| v.as_f64()),
455 Some(55.5)
456 );
457 assert_eq!(
458 obj.get("mem_used_mb").and_then(|v| v.as_u64()),
459 Some(2048)
460 );
461 assert_eq!(
462 obj.get("vendor").and_then(|v| v.as_str()),
463 Some("MockGPU")
464 );
465 assert_eq!(
466 obj.get("device_name").and_then(|v| v.as_str()),
467 Some("MockDevice")
468 );
469 }
470}