1use crate::error::RuntimeError;
11use std::fmt;
12use std::process::Command;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum GpuVendor {
17 Apple,
19 Nvidia,
21 Amd,
23 Intel,
25 Emulator,
27}
28
29impl fmt::Display for GpuVendor {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Self::Apple => write!(f, "Metal"),
33 Self::Nvidia => write!(f, "CUDA"),
34 Self::Amd => write!(f, "ROCm"),
35 Self::Intel => write!(f, "SYCL"),
36 Self::Emulator => write!(f, "Emulator"),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct Device {
44 pub vendor: GpuVendor,
46 pub name: String,
48}
49
50impl fmt::Display for Device {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 write!(f, "{} ({})", self.name, self.vendor)
53 }
54}
55
56pub fn detect_gpu() -> Result<Device, RuntimeError> {
65 if cfg!(target_os = "macos") {
66 return Ok(Device {
67 vendor: GpuVendor::Apple,
68 name: "Apple GPU (Metal)".into(),
69 });
70 }
71
72 if let Some(dev) = detect_nvidia() {
73 return Ok(dev);
74 }
75
76 if let Some(dev) = detect_amd() {
77 return Ok(dev);
78 }
79
80 if let Some(dev) = detect_intel() {
81 return Ok(dev);
82 }
83
84 Ok(Device {
85 vendor: GpuVendor::Emulator,
86 name: "WAVE Emulator (no GPU)".into(),
87 })
88}
89
90fn detect_nvidia() -> Option<Device> {
91 let output = Command::new("nvidia-smi")
92 .arg("--query-gpu=name")
93 .arg("--format=csv,noheader")
94 .output()
95 .ok()?;
96
97 if !output.status.success() {
98 return None;
99 }
100
101 let name = String::from_utf8_lossy(&output.stdout)
102 .lines()
103 .next()
104 .unwrap_or("NVIDIA GPU")
105 .trim()
106 .to_string();
107
108 Some(Device {
109 vendor: GpuVendor::Nvidia,
110 name: format!("{name} (CUDA)"),
111 })
112}
113
114fn detect_amd() -> Option<Device> {
115 let output = Command::new("rocminfo").output().ok()?;
116
117 if !output.status.success() {
118 return None;
119 }
120
121 let stdout = String::from_utf8_lossy(&output.stdout);
122 if !stdout.contains("gfx") {
123 return None;
124 }
125
126 let name = stdout
127 .lines()
128 .find(|l| l.contains("Marketing Name"))
129 .and_then(|l| l.split(':').nth(1))
130 .map_or_else(|| "AMD GPU".into(), |s| s.trim().to_string());
131
132 Some(Device {
133 vendor: GpuVendor::Amd,
134 name: format!("{name} (ROCm)"),
135 })
136}
137
138fn detect_intel() -> Option<Device> {
139 let output = Command::new("sycl-ls").output().ok()?;
140
141 if !output.status.success() {
142 return None;
143 }
144
145 let stdout = String::from_utf8_lossy(&output.stdout);
146 if !stdout.contains("level_zero:gpu") && !stdout.contains("opencl:gpu") {
147 return None;
148 }
149
150 Some(Device {
151 vendor: GpuVendor::Intel,
152 name: "Intel GPU (SYCL)".into(),
153 })
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn test_detect_gpu_returns_device() {
162 let device = detect_gpu().unwrap();
163 assert!(!device.name.is_empty());
164 }
165
166 #[test]
167 fn test_device_display() {
168 let device = Device {
169 vendor: GpuVendor::Emulator,
170 name: "WAVE Emulator (no GPU)".into(),
171 };
172 assert_eq!(device.to_string(), "WAVE Emulator (no GPU) (Emulator)");
173 }
174}