Skip to main content

wave_runtime/
device.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU detection for the WAVE runtime.
5//!
6//! Probes the system for available GPU hardware and returns a `Device` with
7//! vendor and name information. Falls back to the WAVE emulator when no
8//! supported GPU is found.
9
10use crate::error::RuntimeError;
11use std::fmt;
12use std::process::Command;
13
14/// GPU vendor classification.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum GpuVendor {
17    /// Apple GPU via Metal.
18    Apple,
19    /// NVIDIA GPU via CUDA/PTX.
20    Nvidia,
21    /// AMD GPU via `ROCm`/HIP.
22    Amd,
23    /// Intel GPU via oneAPI/SYCL.
24    Intel,
25    /// Software emulator fallback.
26    Emulator,
27}
28
29impl fmt::Display for GpuVendor {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            Self::Apple => write!(f, "Metal"),
33            Self::Nvidia => write!(f, "CUDA"),
34            Self::Amd => write!(f, "ROCm"),
35            Self::Intel => write!(f, "SYCL"),
36            Self::Emulator => write!(f, "Emulator"),
37        }
38    }
39}
40
41/// Detected GPU device.
42#[derive(Debug, Clone)]
43pub struct Device {
44    /// Hardware vendor.
45    pub vendor: GpuVendor,
46    /// Human-readable device name.
47    pub name: String,
48}
49
50impl fmt::Display for Device {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        write!(f, "{} ({})", self.name, self.vendor)
53    }
54}
55
56/// Detect the best available GPU on the system.
57///
58/// Checks in order: Metal (macOS), CUDA (nvidia-smi), `ROCm` (rocminfo),
59/// oneAPI (sycl-ls). Falls back to the WAVE emulator if no GPU is found.
60///
61/// # Errors
62///
63/// Returns `RuntimeError::Device` if detection encounters an unrecoverable error.
64pub fn detect_gpu() -> Result<Device, RuntimeError> {
65    if cfg!(target_os = "macos") {
66        return Ok(Device {
67            vendor: GpuVendor::Apple,
68            name: "Apple GPU (Metal)".into(),
69        });
70    }
71
72    if let Some(dev) = detect_nvidia() {
73        return Ok(dev);
74    }
75
76    if let Some(dev) = detect_amd() {
77        return Ok(dev);
78    }
79
80    if let Some(dev) = detect_intel() {
81        return Ok(dev);
82    }
83
84    Ok(Device {
85        vendor: GpuVendor::Emulator,
86        name: "WAVE Emulator (no GPU)".into(),
87    })
88}
89
90fn detect_nvidia() -> Option<Device> {
91    let output = Command::new("nvidia-smi")
92        .arg("--query-gpu=name")
93        .arg("--format=csv,noheader")
94        .output()
95        .ok()?;
96
97    if !output.status.success() {
98        return None;
99    }
100
101    let name = String::from_utf8_lossy(&output.stdout)
102        .lines()
103        .next()
104        .unwrap_or("NVIDIA GPU")
105        .trim()
106        .to_string();
107
108    Some(Device {
109        vendor: GpuVendor::Nvidia,
110        name: format!("{name} (CUDA)"),
111    })
112}
113
114fn detect_amd() -> Option<Device> {
115    let output = Command::new("rocminfo").output().ok()?;
116
117    if !output.status.success() {
118        return None;
119    }
120
121    let stdout = String::from_utf8_lossy(&output.stdout);
122    if !stdout.contains("gfx") {
123        return None;
124    }
125
126    let name = stdout
127        .lines()
128        .find(|l| l.contains("Marketing Name"))
129        .and_then(|l| l.split(':').nth(1))
130        .map_or_else(|| "AMD GPU".into(), |s| s.trim().to_string());
131
132    Some(Device {
133        vendor: GpuVendor::Amd,
134        name: format!("{name} (ROCm)"),
135    })
136}
137
138fn detect_intel() -> Option<Device> {
139    let output = Command::new("sycl-ls").output().ok()?;
140
141    if !output.status.success() {
142        return None;
143    }
144
145    let stdout = String::from_utf8_lossy(&output.stdout);
146    if !stdout.contains("level_zero:gpu") && !stdout.contains("opencl:gpu") {
147        return None;
148    }
149
150    Some(Device {
151        vendor: GpuVendor::Intel,
152        name: "Intel GPU (SYCL)".into(),
153    })
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_detect_gpu_returns_device() {
162        let device = detect_gpu().unwrap();
163        assert!(!device.name.is_empty());
164    }
165
166    #[test]
167    fn test_device_display() {
168        let device = Device {
169            vendor: GpuVendor::Emulator,
170            name: "WAVE Emulator (no GPU)".into(),
171        };
172        assert_eq!(device.to_string(), "WAVE Emulator (no GPU) (Emulator)");
173    }
174}