Skip to main content

rlx_driver/
device.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Device selection — which backend to use.
17
18/// Target device for graph execution.
19///
20/// Each variant maps to a backend crate gated by a Cargo feature.
21/// Use `Device::is_available()` to check if the feature is enabled.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum Device {
24    // ── CPU ─────────────────────────────────────────────────
25    /// CPU with SIMD (NEON/AVX) + optional BLAS.
26    Cpu,
27
28    // ── Apple ───────────────────────────────────────────────
29    /// GPU via Apple Metal (Metal Performance Shaders).
30    Metal,
31    /// Apple MLX framework (unified memory GPU).
32    Mlx,
33    /// Apple Neural Engine.
34    Ane,
35
36    // ── NVIDIA ──────────────────────────────────────────────
37    /// NVIDIA GPU via native CUDA (cuBLAS, cuDNN).
38    Cuda,
39
40    // ── AMD ─────────────────────────────────────────────────
41    /// AMD GPU via ROCm/HIP.
42    Rocm,
43
44    // ── Google ──────────────────────────────────────────────
45    /// Google TPU via libtpu's PJRT plugin (no Python).
46    Tpu,
47
48    // ── Cross-platform GPU ──────────────────────────────────
49    /// Portable GPU via wgpu (Metal/Vulkan/DX12/WebGPU).
50    Gpu,
51    /// Vulkan compute shaders.
52    Vulkan,
53    /// OpenGL compute shaders (legacy).
54    OpenGl,
55    /// DirectX 12 compute (Windows).
56    DirectX,
57    /// WebGPU (WASM target).
58    WebGpu,
59}
60
61impl Device {
62    /// Human-readable name (no engine-layer info).
63    /// `is_available` / `available` live in rlx-runtime since they
64    /// consult the engine's backend registry — keeping them out of
65    /// the driver layer preserves the one-way dep direction.
66    pub fn name(self) -> &'static str {
67        match self {
68            Device::Cpu => "CPU",
69            Device::Metal => "Metal",
70            Device::Mlx => "MLX",
71            Device::Ane => "ANE",
72            Device::Cuda => "CUDA",
73            Device::Rocm => "ROCm",
74            Device::Tpu => "TPU",
75            Device::Gpu => "GPU (wgpu)",
76            Device::Vulkan => "Vulkan",
77            Device::OpenGl => "OpenGL",
78            Device::DirectX => "DirectX 12",
79            Device::WebGpu => "WebGPU",
80        }
81    }
82
83    /// All variant labels — convenience for callers that want to
84    /// enumerate without listing every variant manually. Pair
85    /// with `rlx_runtime::available_devices()` to filter.
86    pub fn all() -> &'static [Device] {
87        &[
88            Device::Cpu,
89            Device::Metal,
90            Device::Mlx,
91            Device::Ane,
92            Device::Cuda,
93            Device::Rocm,
94            Device::Tpu,
95            Device::Gpu,
96            Device::Vulkan,
97            Device::OpenGl,
98            Device::DirectX,
99            Device::WebGpu,
100        ]
101    }
102}
103
104impl std::fmt::Display for Device {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        write!(f, "{}", self.name())
107    }
108}
109
110/// Error returned by [`Device::from_str`] when the input doesn't match
111/// any known device alias.
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub struct DeviceFromStrError(pub String);
114
115impl std::fmt::Display for DeviceFromStrError {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(
118            f,
119            "unknown device '{}' (try: cpu, metal, mlx, ane, cuda, rocm, gpu, vulkan, opengl, directx, webgpu, tpu)",
120            self.0
121        )
122    }
123}
124
125impl std::error::Error for DeviceFromStrError {}
126
127impl std::str::FromStr for Device {
128    type Err = DeviceFromStrError;
129
130    fn from_str(s: &str) -> Result<Self, Self::Err> {
131        let key = s.trim().to_ascii_lowercase();
132        Ok(match key.as_str() {
133            "cpu" => Device::Cpu,
134            "metal" | "mps" | "mtl" => Device::Metal,
135            "mlx" => Device::Mlx,
136            "ane" | "neural-engine" => Device::Ane,
137            "cuda" | "nvidia" => Device::Cuda,
138            "rocm" | "hip" | "amd" => Device::Rocm,
139            "gpu" | "wgpu" => Device::Gpu,
140            "vulkan" | "vk" => Device::Vulkan,
141            "opengl" | "gl" => Device::OpenGl,
142            "directx" | "dx12" | "d3d12" => Device::DirectX,
143            "webgpu" => Device::WebGpu,
144            "tpu" => Device::Tpu,
145            _ => return Err(DeviceFromStrError(s.to_string())),
146        })
147    }
148}
149
150/// Per-family backend support filter.
151///
152/// Each model family declares which devices it can execute on (e.g.,
153/// SAM adds TPU on top of the standard set; some VLM crates exclude
154/// MLX until the vision tower lands). A single shared
155/// [`BackendSupport`] impl per family lets [`validate_device`] return
156/// uniform error messages instead of every model crate hand-rolling
157/// the same `match` ladder.
158pub trait BackendSupport {
159    /// Short stable family identifier (`"qwen3"`, `"llama32"`, `"sam"`).
160    fn family(&self) -> &'static str;
161
162    /// `true` if this family can execute on `device` today.
163    fn supports(&self, device: Device) -> bool;
164}
165
166/// Workspace-wide standard backend set: CPU, Metal, MLX, CUDA, ROCm, GPU.
167///
168/// New families default to this set via [`StandardBackends`] until they
169/// need to opt in/out. Mirrors `rlx_core::STANDARD_DEVICES`.
170pub const STANDARD_DEVICES: &[Device] = &[
171    Device::Cpu,
172    Device::Metal,
173    Device::Mlx,
174    Device::Cuda,
175    Device::Rocm,
176    Device::Gpu,
177];
178
179/// Default [`BackendSupport`] for families on the standard backend set.
180#[derive(Debug, Clone, Copy)]
181pub struct StandardBackends(pub &'static str);
182
183impl BackendSupport for StandardBackends {
184    fn family(&self) -> &'static str {
185        self.0
186    }
187    fn supports(&self, device: Device) -> bool {
188        STANDARD_DEVICES.contains(&device)
189    }
190}
191
192/// Validate that `device` is supported by `family`. Returns the same device
193/// on success; on failure, formats a uniform error string. Callers that
194/// need a typed error should use [`BackendSupport::supports`] directly.
195pub fn validate_device<S: BackendSupport>(support: &S, device: Device) -> Result<Device, String> {
196    if support.supports(device) {
197        Ok(device)
198    } else {
199        Err(format!(
200            "device {} is not supported by family `{}`",
201            device.name(),
202            support.family()
203        ))
204    }
205}
206
207#[cfg(test)]
208mod from_str_tests {
209    use super::*;
210    use std::str::FromStr;
211
212    #[test]
213    fn parse_basics() {
214        assert_eq!(Device::from_str("cpu").unwrap(), Device::Cpu);
215        assert_eq!(Device::from_str("CUDA").unwrap(), Device::Cuda);
216        assert_eq!(Device::from_str("mps").unwrap(), Device::Metal);
217        assert_eq!(Device::from_str("wgpu").unwrap(), Device::Gpu);
218        assert!(Device::from_str("nothing").is_err());
219    }
220
221    #[test]
222    fn standard_backends_set() {
223        let s = StandardBackends("qwen3");
224        assert!(s.supports(Device::Cpu));
225        assert!(s.supports(Device::Metal));
226        assert!(!s.supports(Device::Tpu));
227        assert!(validate_device(&s, Device::Tpu).is_err());
228    }
229}