1use serde::Serialize;
2
3#[derive(Debug, Clone, Serialize)]
4pub struct GpuStatus {
5 pub available: bool,
6 pub utilization: Option<f64>,
7 pub mem_used_mb: Option<u64>,
8 pub vendor: Option<String>,
9 pub device_name: Option<String>,
10}
11
12impl GpuStatus {
13 pub fn unavailable() -> Self {
14 Self {
15 available: false,
16 utilization: None,
17 mem_used_mb: None,
18 vendor: None,
19 device_name: None,
20 }
21 }
22}
23
24pub trait GpuProbe {
25 fn status(&self) -> GpuStatus;
26}
27
28pub struct MacGpuProbe;
29
30#[cfg(feature = "gpu-nvidia")]
31pub struct NvidiaGpuProbe;
32
33pub struct FallbackGpuProbe;
34
35#[cfg(target_os = "linux")]
36pub struct LinuxGpuProbe;
37
38#[cfg(target_os = "windows")]
39pub struct WindowsGpuProbe;
40
41#[cfg(all(target_os = "windows", feature = "gpu_amd_windows"))]
42pub struct AmdWindowsProbe;
43
44impl GpuProbe for MacGpuProbe {
45 fn status(&self) -> GpuStatus {
46 #[cfg(target_os = "macos")]
47 {
48 GpuStatus {
50 available: true,
51 utilization: Some(0.0),
52 mem_used_mb: None,
53 vendor: None,
54 device_name: None,
55 }
56 }
57
58 #[cfg(not(target_os = "macos"))]
59 {
60 GpuStatus::unavailable()
61 }
62 }
63}
64
65#[cfg(feature = "gpu-nvidia")]
66impl GpuProbe for NvidiaGpuProbe {
67 fn status(&self) -> GpuStatus {
68 use nvml_wrapper::Nvml;
69
70 let nvml = match Nvml::init() {
71 Ok(nvml) => nvml,
72 Err(_) => return GpuStatus::unavailable(),
73 };
74 let device = match nvml.device_by_index(0) {
75 Ok(device) => device,
76 Err(_) => return GpuStatus::unavailable(),
77 };
78
79 let utilization = device
80 .utilization_rates()
81 .ok()
82 .map(|rates| rates.gpu as f64);
83 let mem_used_mb = device
84 .memory_info()
85 .ok()
86 .map(|mem| mem.used / (1024 * 1024));
87
88 GpuStatus {
89 available: true,
90 utilization,
91 mem_used_mb,
92 vendor: Some("NVIDIA".to_string()),
93 device_name: device.name().ok(),
94 }
95 }
96}
97
98impl GpuProbe for FallbackGpuProbe {
99 fn status(&self) -> GpuStatus {
100 GpuStatus::unavailable()
101 }
102}
103
104#[cfg(target_os = "linux")]
105impl GpuProbe for LinuxGpuProbe {
106 fn status(&self) -> GpuStatus {
107 #[cfg(feature = "gpu-nvidia")]
108 {
109 let status = NvidiaGpuProbe.status();
110 if status.available {
111 return status;
112 }
113 }
114
115 if let Some(status) = nvidia_smi_status() {
116 return status;
117 }
118 if let Some(status) = amd_status() {
119 return status;
120 }
121 if let Some(status) = intel_status() {
122 return status;
123 }
124
125 GpuStatus::unavailable()
126 }
127}
128
129#[cfg(target_os = "windows")]
130impl GpuProbe for WindowsGpuProbe {
131 fn status(&self) -> GpuStatus {
132 #[cfg(feature = "gpu-nvidia")]
133 {
134 return NvidiaGpuProbe.status();
135 }
136
137 #[cfg(all(not(feature = "gpu-nvidia"), feature = "gpu_amd_windows"))]
138 {
139 return AmdWindowsProbe.status();
140 }
141
142 #[cfg(all(not(feature = "gpu-nvidia"), not(feature = "gpu_amd_windows")))]
143 {
144 return GpuStatus::unavailable();
145 }
146 }
147}
148
149#[cfg(all(target_os = "windows", feature = "gpu_amd_windows"))]
150impl GpuProbe for AmdWindowsProbe {
151 fn status(&self) -> GpuStatus {
152 GpuStatus::unavailable()
154 }
155}
156
157pub fn platform_probe() -> Box<dyn GpuProbe> {
158 #[cfg(target_os = "windows")]
159 {
160 return Box::new(WindowsGpuProbe);
161 }
162
163 #[cfg(target_os = "linux")]
164 {
165 Box::new(LinuxGpuProbe)
166 }
167
168 #[cfg(all(
169 feature = "gpu-nvidia",
170 not(target_os = "windows"),
171 not(target_os = "linux"),
172 not(target_os = "macos")
173 ))]
174 {
175 return Box::new(NvidiaGpuProbe);
176 }
177
178 #[cfg(target_os = "macos")]
179 {
180 Box::new(MacGpuProbe)
181 }
182
183 #[cfg(all(
184 not(target_os = "macos"),
185 not(target_os = "windows"),
186 not(target_os = "linux"),
187 not(feature = "gpu-nvidia")
188 ))]
189 {
190 return Box::new(FallbackGpuProbe);
191 }
192}
193
194#[cfg(target_os = "linux")]
195fn nvidia_smi_status() -> Option<GpuStatus> {
196 use std::process::Command;
197
198 let output = Command::new("nvidia-smi")
199 .args([
200 "--query-gpu=name,utilization.gpu,memory.used",
201 "--format=csv,noheader,nounits",
202 ])
203 .output()
204 .ok()?;
205 if !output.status.success() {
206 return None;
207 }
208 let line = String::from_utf8_lossy(&output.stdout);
209 let mut parts = line.lines().next()?.split(',');
210 let name = parts.next()?.trim().to_string();
211 let utilization = parts.next().and_then(|v| v.trim().parse::<f64>().ok());
212 let mem_used_mb = parts.next().and_then(|v| v.trim().parse::<u64>().ok());
213
214 Some(GpuStatus {
215 available: true,
216 utilization,
217 mem_used_mb,
218 vendor: Some("NVIDIA".to_string()),
219 device_name: Some(name),
220 })
221}
222
223#[cfg(target_os = "linux")]
224fn amd_status() -> Option<GpuStatus> {
225 use std::process::Command;
226
227 if let Ok(output) = Command::new("rocm-smi")
228 .args(["--showuse", "--json"])
229 .output()
230 {
231 if output.status.success() {
232 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
233 if let Some(pct) = val
234 .get("card")
235 .and_then(|c| c.get(0))
236 .and_then(|c| c.get("GPU use (%)"))
237 .and_then(|v| v.as_f64())
238 {
239 return Some(GpuStatus {
240 available: true,
241 utilization: Some(pct),
242 mem_used_mb: amd_mem_used_mb(),
243 vendor: Some("AMD".to_string()),
244 device_name: None,
245 });
246 }
247 }
248 }
249 }
250
251 let output = Command::new("rocm-smi")
252 .arg("--showuse")
253 .output()
254 .or_else(|_| Command::new("radeontop").arg("--help").output())
255 .ok()?;
256 if !output.status.success() {
257 return None;
258 }
259 let text = String::from_utf8_lossy(&output.stdout);
260 for line in text.lines() {
261 if line.to_ascii_lowercase().contains("gpu use") {
262 if let Some(pct) = line
263 .split('%')
264 .next()
265 .and_then(|s| s.split_whitespace().last())
266 .and_then(|n| n.parse::<f64>().ok())
267 {
268 return Some(GpuStatus {
269 available: true,
270 utilization: Some(pct),
271 mem_used_mb: amd_mem_used_mb(),
272 vendor: Some("AMD".to_string()),
273 device_name: None,
274 });
275 }
276 }
277 }
278 None
279}
280
281#[cfg(target_os = "linux")]
282fn amd_mem_used_mb() -> Option<u64> {
283 use std::process::Command;
284
285 if let Ok(output) = Command::new("rocm-smi")
286 .args(["--showmeminfo", "vram", "--json"])
287 .output()
288 {
289 if output.status.success() {
290 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&output.stdout) {
291 if let Some(used) = val
292 .get("card")
293 .and_then(|c| c.get(0))
294 .and_then(|c| c.get("vram"))
295 .and_then(|v| v.get("used (B)"))
296 .and_then(|v| v.as_u64())
297 {
298 return Some(used / (1024 * 1024));
299 }
300 }
301 }
302 }
303 None
304}
305
306#[cfg(target_os = "linux")]
307fn intel_status() -> Option<GpuStatus> {
308 use std::process::Command;
309
310 let output = Command::new("intel_gpu_top").arg("--json").output().ok()?;
311 if !output.status.success() {
312 return None;
313 }
314 let text = String::from_utf8_lossy(&output.stdout);
315 let mem_used_mb = intel_mem_json(&output.stdout).or_else(|| intel_mem_text(&text));
316 for line in text.lines() {
317 if line.to_ascii_lowercase().contains("render/3d") {
318 if let Some(pct) = line
319 .split('%')
320 .next()
321 .and_then(|s| s.split_whitespace().last())
322 .and_then(|n| n.parse::<f64>().ok())
323 {
324 return Some(GpuStatus {
325 available: true,
326 utilization: Some(pct),
327 mem_used_mb,
328 vendor: Some("Intel".to_string()),
329 device_name: None,
330 });
331 }
332 }
333 }
334 None
335}
336
337#[cfg(target_os = "linux")]
338fn intel_mem_json(data: &[u8]) -> Option<u64> {
339 let val: serde_json::Value = serde_json::from_slice(data).ok()?;
340 find_mem_value(&val)
341}
342
343#[cfg(target_os = "linux")]
344fn find_mem_value(val: &serde_json::Value) -> Option<u64> {
345 match val {
346 serde_json::Value::Number(n) => n.as_u64(),
347 serde_json::Value::Object(map) => {
348 for (k, v) in map {
349 let key = k.to_ascii_lowercase();
350 if key.contains("mem") {
351 if let Some(n) = v.as_u64() {
352 if n > 10_000 {
353 return Some(n / (1024 * 1024));
354 }
355 return Some(n);
356 }
357 if let Some(f) = v.as_f64() {
358 if f > 10_000.0 {
359 return Some((f / 1024.0 / 1024.0) as u64);
360 }
361 return Some(f as u64);
362 }
363 }
364 if let Some(found) = find_mem_value(v) {
365 return Some(found);
366 }
367 }
368 None
369 }
370 serde_json::Value::Array(arr) => {
371 for v in arr {
372 if let Some(found) = find_mem_value(v) {
373 return Some(found);
374 }
375 }
376 None
377 }
378 _ => None,
379 }
380}
381
382#[cfg(target_os = "linux")]
383fn intel_mem_text(text: &str) -> Option<u64> {
384 for line in text.lines() {
385 let lower = line.to_ascii_lowercase();
386 if lower.contains("mem") {
387 if let Some(num) = line
388 .split_whitespace()
389 .map(|w| w.trim_end_matches(['%', 'm', 'M', 'b', 'B']))
390 .filter_map(|w| w.parse::<f64>().ok())
391 .next_back()
392 {
393 if num > 10_000.0 {
394 return Some((num / 1024.0 / 1024.0) as u64);
395 }
396 return Some(num as u64);
397 }
398 }
399 }
400 None
401}
402
403pub fn to_json_value(status: &GpuStatus) -> serde_json::Value {
404 serde_json::to_value(status).unwrap_or(serde_json::Value::Null)
405}
406
407pub fn write_status_json(status: &GpuStatus) -> Result<(), serde_json::Error> {
408 serde_json::to_writer(std::io::stdout(), status)
409}
410
411#[cfg(test)]
412mod tests {
413 use super::{to_json_value, FallbackGpuProbe, GpuProbe, GpuStatus};
414
415 struct MockProbe {
416 status: GpuStatus,
417 }
418
419 impl GpuProbe for MockProbe {
420 fn status(&self) -> GpuStatus {
421 self.status.clone()
422 }
423 }
424
425 #[test]
426 fn status_serializes_with_expected_keys() {
427 let status = GpuStatus::unavailable();
428 let value = to_json_value(&status);
429 let obj = value.as_object().expect("status must be a JSON object");
430 for key in [
431 "available",
432 "utilization",
433 "mem_used_mb",
434 "vendor",
435 "device_name",
436 ] {
437 assert!(obj.contains_key(key), "missing key: {key}");
438 }
439 }
440
441 #[test]
442 fn fallback_reports_unavailable() {
443 let status = FallbackGpuProbe.status();
444 assert!(!status.available);
445 assert!(status.utilization.is_none());
446 assert!(status.mem_used_mb.is_none());
447 }
448
449 #[test]
450 fn mock_probe_serializes_full_status() {
451 let probe = MockProbe {
452 status: GpuStatus {
453 available: true,
454 utilization: Some(55.5),
455 mem_used_mb: Some(2048),
456 vendor: Some("MockGPU".to_string()),
457 device_name: Some("MockDevice".to_string()),
458 },
459 };
460
461 let value = to_json_value(&probe.status());
462 let obj = value.as_object().expect("status must be a JSON object");
463 assert_eq!(obj.get("available").and_then(|v| v.as_bool()), Some(true));
464 assert_eq!(obj.get("utilization").and_then(|v| v.as_f64()), Some(55.5));
465 assert_eq!(obj.get("mem_used_mb").and_then(|v| v.as_u64()), Some(2048));
466 assert_eq!(obj.get("vendor").and_then(|v| v.as_str()), Some("MockGPU"));
467 assert_eq!(
468 obj.get("device_name").and_then(|v| v.as_str()),
469 Some("MockDevice")
470 );
471 }
472}