oximedia_ml/device.rs
1//! Execution device abstraction.
2//!
3//! [`DeviceType`] is the user-facing handle for selecting which backend
4//! an [`OnnxModel`](crate::model::OnnxModel) should run on. Backends
5//! are feature-gated at build time; runtime availability is probed via
6//! [`DeviceType::is_available`] so callers can ask for
7//! [`DeviceType::auto`] without needing `cfg!` plumbing of their own.
8//!
9//! ## Probe cascade
10//!
11//! [`DeviceType::auto`] walks the following order, first success wins:
12//!
13//! 1. **CUDA** — `oxionnx_cuda::CudaContext::try_new()`
14//! (requires the `cuda` feature).
15//! 2. **DirectML** — `oxionnx_directml::DirectMLContext::try_new()`
16//! (requires the `directml` feature; always `None` off Windows).
17//! 3. **WebGPU** — `oxionnx_gpu::GpuContext::try_new()`
18//! (requires the `webgpu` feature).
19//! 4. **CPU** — always available.
20//!
21//! Every probe is wrapped in `std::panic::catch_unwind`, so a misbehaving
22//! foreign driver can never unwind through our call stack. The result of
23//! `auto()` is memoised in a `OnceLock`: calling it twice does not re-init
24//! the CUDA driver / wgpu adapter.
25//!
26//! ## Capability introspection
27//!
28//! [`DeviceCapabilities`] carries a richer description (device name,
29//! memory, compute capability, dtype support) and is produced by
30//! [`DeviceCapabilities::probe`] for a specific [`DeviceType`], or
31//! [`DeviceCapabilities::probe_all`] for every compiled-in backend at
32//! once.
33//!
34//! ## Example
35//!
36//! ```
37//! use oximedia_ml::{DeviceCapabilities, DeviceType};
38//!
39//! // Pick the strongest available backend (always succeeds — falls
40//! // back to CPU if nothing else is compiled in / usable).
41//! let device = DeviceType::auto();
42//! assert!(device.is_available());
43//!
44//! // Full capability report for the selected device.
45//! let caps = DeviceCapabilities::best_available();
46//! assert_eq!(caps.device_type, device);
47//! ```
48
49use core::fmt;
50#[cfg(any(feature = "cuda", feature = "webgpu", feature = "directml"))]
51use std::panic::AssertUnwindSafe;
52use std::sync::OnceLock;
53
54/// Execution backend for ML inference.
55#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
56#[cfg_attr(feature = "serde", derive(serde::Serialize))]
57#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
58pub enum DeviceType {
59 /// Pure-Rust CPU execution (always available).
60 Cpu,
61 /// NVIDIA CUDA via `oxionnx-cuda` (feature `cuda`).
62 Cuda,
63 /// WebGPU / wgpu compute backend via `oxionnx-gpu` (feature `webgpu`).
64 WebGpu,
65 /// Microsoft DirectML (feature `directml`, Windows-only at runtime).
66 DirectMl,
67 /// Apple CoreML. Reserved variant; no `coreml` feature exists yet, so
68 /// this device is never reported as available and `auto()` never
69 /// selects it. It exists so API consumers can exhaustively match on
70 /// `DeviceType` without having to guess whether CoreML will be added
71 /// later.
72 CoreMl,
73}
74
75/// Memoised result of [`DeviceType::auto`].
76///
77/// The probe cascade has observable side effects (CUDA driver
78/// initialisation, wgpu adapter enumeration), and some of those mutate
79/// thread-local state. Caching the first result makes subsequent calls
80/// cheap and avoids re-running those side effects.
81static AUTO_CACHE: OnceLock<DeviceType> = OnceLock::new();
82
83impl DeviceType {
84 /// Return the preferred device available in this build, in the order
85 /// **CUDA → DirectML → WebGPU → CPU**. Always succeeds because CPU is
86 /// unconditionally available.
87 ///
88 /// The result is memoised for the lifetime of the process.
89 #[must_use]
90 pub fn auto() -> Self {
91 *AUTO_CACHE.get_or_init(|| {
92 if Self::Cuda.is_available() {
93 return Self::Cuda;
94 }
95 if Self::DirectMl.is_available() {
96 return Self::DirectMl;
97 }
98 if Self::WebGpu.is_available() {
99 return Self::WebGpu;
100 }
101 Self::Cpu
102 })
103 }
104
105 /// Report whether this device is usable in the current build /
106 /// runtime environment. A device may be compiled in (feature-gated)
107 /// yet still unavailable at runtime — e.g. no GPU detected.
108 #[must_use]
109 pub fn is_available(self) -> bool {
110 match self {
111 Self::Cpu => true,
112 Self::Cuda => cuda_available(),
113 Self::WebGpu => webgpu_available(),
114 Self::DirectMl => directml_available(),
115 Self::CoreMl => false,
116 }
117 }
118
119 /// Short canonical name matching the feature flag / CLI spelling.
120 ///
121 /// Retained for backward compatibility with existing call sites;
122 /// [`DeviceType::display_name`] is preferred for human-facing output.
123 #[must_use]
124 pub fn name(self) -> &'static str {
125 match self {
126 Self::Cpu => "cpu",
127 Self::Cuda => "cuda",
128 Self::WebGpu => "webgpu",
129 Self::DirectMl => "directml",
130 Self::CoreMl => "coreml",
131 }
132 }
133
134 /// Human-facing label — identical to [`Self::name`] for now, but
135 /// conceptually distinct so downstream UIs can swap in a friendlier
136 /// string later without affecting programmatic lookups.
137 #[must_use]
138 pub fn display_name(self) -> &'static str {
139 match self {
140 Self::Cpu => "CPU",
141 Self::Cuda => "CUDA",
142 Self::WebGpu => "WebGPU",
143 Self::DirectMl => "DirectML",
144 Self::CoreMl => "CoreML",
145 }
146 }
147
148 /// Run the full cascade probe and return the richer
149 /// [`DeviceCapabilities`] record for this device.
150 #[must_use]
151 pub fn probe_caps(self) -> DeviceCapabilities {
152 DeviceCapabilities::probe(self)
153 }
154
155 /// Return every [`DeviceType`] whose backend is currently usable.
156 ///
157 /// The returned vector always contains [`DeviceType::Cpu`] and is
158 /// ordered by the probe cascade (CPU last).
159 #[must_use]
160 pub fn list_available() -> Vec<Self> {
161 let mut out = Vec::with_capacity(4);
162 if Self::Cuda.is_available() {
163 out.push(Self::Cuda);
164 }
165 if Self::DirectMl.is_available() {
166 out.push(Self::DirectMl);
167 }
168 if Self::WebGpu.is_available() {
169 out.push(Self::WebGpu);
170 }
171 // CoreMl is never currently available.
172 out.push(Self::Cpu);
173 out
174 }
175
176 /// Every variant in enum declaration order.
177 ///
178 /// Used by [`DeviceCapabilities::probe_all`] and the test suite.
179 #[must_use]
180 pub const fn all_variants() -> [Self; 5] {
181 [
182 Self::Cpu,
183 Self::Cuda,
184 Self::WebGpu,
185 Self::DirectMl,
186 Self::CoreMl,
187 ]
188 }
189}
190
191impl Default for DeviceType {
192 fn default() -> Self {
193 Self::Cpu
194 }
195}
196
197impl fmt::Display for DeviceType {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 f.write_str(self.name())
200 }
201}
202
203/// Rich capability description for a single [`DeviceType`].
204///
205/// Produced by [`DeviceCapabilities::probe`]. Fields are populated
206/// best-effort — anything the backend does not expose is left as `None` /
207/// `false`, which lets callers use a single code path regardless of how
208/// much telemetry the driver provides.
209#[derive(Clone, Debug, Default, PartialEq, Eq)]
210#[cfg_attr(feature = "serde", derive(serde::Serialize))]
211pub struct DeviceCapabilities {
212 /// Which device this record describes.
213 pub device_type: DeviceType,
214 /// Whether the device is currently available for inference.
215 pub is_available: bool,
216 /// Human-facing device name (e.g. "CPU (x86_64)", "NVIDIA GPU via CUDA").
217 pub device_name: String,
218 /// Total device memory in bytes, if known.
219 pub memory_total_bytes: Option<u64>,
220 /// Free device memory in bytes, if known.
221 pub memory_free_bytes: Option<u64>,
222 /// Compute capability string (e.g. "8.6" for Ampere). `None` for CPU
223 /// / WebGPU / DirectML / CoreML.
224 pub compute_capability: Option<String>,
225 /// Whether FP16 (half-precision float) is supported.
226 pub supports_fp16: bool,
227 /// Whether BF16 (bfloat16) is supported.
228 pub supports_bf16: bool,
229 /// Whether INT8 quantised inference is supported.
230 pub supports_int8: bool,
231}
232
233impl DeviceCapabilities {
234 /// Probe a specific device and describe its capabilities.
235 ///
236 /// Always returns a record; unavailable devices get `is_available =
237 /// false` with the rest populated from static knowledge.
238 #[must_use]
239 pub fn probe(device: DeviceType) -> Self {
240 match device {
241 DeviceType::Cpu => Self {
242 device_type: DeviceType::Cpu,
243 is_available: true,
244 device_name: cpu_device_name(),
245 memory_total_bytes: None,
246 memory_free_bytes: None,
247 compute_capability: None,
248 supports_fp16: false,
249 supports_bf16: false,
250 supports_int8: true,
251 },
252 DeviceType::Cuda => {
253 let live = cuda_available();
254 Self {
255 device_type: DeviceType::Cuda,
256 is_available: live,
257 device_name: if live {
258 "NVIDIA GPU via CUDA".to_string()
259 } else {
260 "CUDA (unavailable)".to_string()
261 },
262 memory_total_bytes: None,
263 memory_free_bytes: None,
264 compute_capability: None,
265 supports_fp16: live,
266 supports_bf16: live,
267 supports_int8: live,
268 }
269 }
270 DeviceType::WebGpu => {
271 let live = webgpu_available();
272 Self {
273 device_type: DeviceType::WebGpu,
274 is_available: live,
275 device_name: if live {
276 "GPU via wgpu".to_string()
277 } else {
278 "WebGPU (unavailable)".to_string()
279 },
280 memory_total_bytes: None,
281 memory_free_bytes: None,
282 compute_capability: None,
283 supports_fp16: false,
284 supports_bf16: false,
285 supports_int8: false,
286 }
287 }
288 DeviceType::DirectMl => {
289 let live = directml_available();
290 Self {
291 device_type: DeviceType::DirectMl,
292 is_available: live,
293 device_name: if live {
294 "GPU via DirectML".to_string()
295 } else {
296 "DirectML (unavailable)".to_string()
297 },
298 memory_total_bytes: None,
299 memory_free_bytes: None,
300 compute_capability: None,
301 supports_fp16: live,
302 supports_bf16: false,
303 supports_int8: live,
304 }
305 }
306 DeviceType::CoreMl => Self {
307 device_type: DeviceType::CoreMl,
308 is_available: false,
309 device_name: "CoreML (not yet supported)".to_string(),
310 memory_total_bytes: None,
311 memory_free_bytes: None,
312 compute_capability: None,
313 supports_fp16: false,
314 supports_bf16: false,
315 supports_int8: false,
316 },
317 }
318 }
319
320 /// Probe every [`DeviceType`] variant and return a record per device.
321 ///
322 /// The resulting vector has exactly [`DeviceType::all_variants`] entries
323 /// and is ordered to match that array.
324 #[must_use]
325 pub fn probe_all() -> Vec<Self> {
326 DeviceType::all_variants()
327 .iter()
328 .copied()
329 .map(Self::probe)
330 .collect()
331 }
332
333 /// Return capabilities for the best currently-available device.
334 ///
335 /// Equivalent to `DeviceCapabilities::probe(DeviceType::auto())`, but
336 /// exposed as a named constructor for callers that only need the
337 /// capability record.
338 #[must_use]
339 pub fn best_available() -> Self {
340 Self::probe(DeviceType::auto())
341 }
342}
343
344impl fmt::Display for DeviceCapabilities {
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 write!(
347 f,
348 "{} [{}]",
349 self.device_name,
350 if self.is_available {
351 "available"
352 } else {
353 "unavailable"
354 }
355 )
356 }
357}
358
359// ---------------------------------------------------------------------------
360// Probe primitives
361// ---------------------------------------------------------------------------
362
363/// Run a probe closure under `catch_unwind` so a foreign panic cannot
364/// unwind through our caller. A panicking probe is treated as
365/// "unavailable".
366///
367/// Only compiled when at least one foreign probe is enabled, otherwise
368/// every `*_available()` function is a constant `false` and this helper
369/// would be dead code.
370#[cfg(any(feature = "cuda", feature = "webgpu", feature = "directml"))]
371fn safe_probe<F: FnOnce() -> bool>(probe: F) -> bool {
372 std::panic::catch_unwind(AssertUnwindSafe(probe)).unwrap_or(false)
373}
374
375#[cfg(feature = "cuda")]
376fn cuda_available() -> bool {
377 safe_probe(|| oxionnx::cuda::CudaContext::try_new().is_some())
378}
379
380#[cfg(not(feature = "cuda"))]
381fn cuda_available() -> bool {
382 false
383}
384
385#[cfg(feature = "webgpu")]
386fn webgpu_available() -> bool {
387 safe_probe(|| oxionnx::gpu::GpuContext::try_new().is_some())
388}
389
390#[cfg(not(feature = "webgpu"))]
391fn webgpu_available() -> bool {
392 false
393}
394
395#[cfg(feature = "directml")]
396fn directml_available() -> bool {
397 safe_probe(|| oxionnx::directml::DirectMLContext::try_new().is_some())
398}
399
400#[cfg(not(feature = "directml"))]
401fn directml_available() -> bool {
402 false
403}
404
405/// Best-effort CPU description string — architecture and pointer width.
406fn cpu_device_name() -> String {
407 format!(
408 "CPU ({}-{})",
409 std::env::consts::ARCH,
410 core::mem::size_of::<usize>() * 8
411 )
412}
413
414// ---------------------------------------------------------------------------
415// Tests (pure unit tests — heavier synthetic tests live under tests/)
416// ---------------------------------------------------------------------------
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn cpu_always_available() {
424 assert!(DeviceType::Cpu.is_available());
425 assert_eq!(DeviceType::Cpu.name(), "cpu");
426 }
427
428 #[test]
429 fn auto_returns_available_device() {
430 let device = DeviceType::auto();
431 assert!(device.is_available());
432 }
433
434 #[test]
435 fn default_is_cpu() {
436 assert_eq!(DeviceType::default(), DeviceType::Cpu);
437 }
438
439 #[test]
440 fn display_matches_name() {
441 assert_eq!(format!("{}", DeviceType::Cpu), "cpu");
442 assert_eq!(format!("{}", DeviceType::Cuda), "cuda");
443 assert_eq!(format!("{}", DeviceType::WebGpu), "webgpu");
444 assert_eq!(format!("{}", DeviceType::DirectMl), "directml");
445 assert_eq!(format!("{}", DeviceType::CoreMl), "coreml");
446 }
447
448 #[test]
449 fn display_names_are_stable() {
450 assert_eq!(DeviceType::Cpu.display_name(), "CPU");
451 assert_eq!(DeviceType::Cuda.display_name(), "CUDA");
452 assert_eq!(DeviceType::WebGpu.display_name(), "WebGPU");
453 assert_eq!(DeviceType::DirectMl.display_name(), "DirectML");
454 assert_eq!(DeviceType::CoreMl.display_name(), "CoreML");
455 }
456
457 #[test]
458 fn coreml_never_available() {
459 assert!(!DeviceType::CoreMl.is_available());
460 }
461
462 #[test]
463 fn all_variants_has_five_entries() {
464 assert_eq!(DeviceType::all_variants().len(), 5);
465 }
466
467 #[test]
468 fn capabilities_cpu_is_available() {
469 let caps = DeviceCapabilities::probe(DeviceType::Cpu);
470 assert!(caps.is_available);
471 assert!(caps.supports_int8);
472 assert_eq!(caps.device_type, DeviceType::Cpu);
473 }
474}