Skip to main content

oximedia_ml/
device.rs

1//! Execution device abstraction.
2//!
3//! [`DeviceType`] is the user-facing handle for selecting which backend
4//! an [`OnnxModel`](crate::model::OnnxModel) should run on. Backends
5//! are feature-gated at build time; runtime availability is probed via
6//! [`DeviceType::is_available`] so callers can ask for
7//! [`DeviceType::auto`] without needing `cfg!` plumbing of their own.
8//!
9//! ## Probe cascade
10//!
11//! [`DeviceType::auto`] walks the following order, first success wins:
12//!
13//! 1. **CUDA** — `oxionnx_cuda::CudaContext::try_new()`
14//!    (requires the `cuda` feature).
15//! 2. **DirectML** — `oxionnx_directml::DirectMLContext::try_new()`
16//!    (requires the `directml` feature; always `None` off Windows).
17//! 3. **WebGPU** — `oxionnx_gpu::GpuContext::try_new()`
18//!    (requires the `webgpu` feature).
19//! 4. **CPU** — always available.
20//!
21//! Every probe is wrapped in `std::panic::catch_unwind`, so a misbehaving
22//! foreign driver can never unwind through our call stack. The result of
23//! `auto()` is memoised in a `OnceLock`: calling it twice does not re-init
24//! the CUDA driver / wgpu adapter.
25//!
26//! ## Capability introspection
27//!
28//! [`DeviceCapabilities`] carries a richer description (device name,
29//! memory, compute capability, dtype support) and is produced by
30//! [`DeviceCapabilities::probe`] for a specific [`DeviceType`], or
31//! [`DeviceCapabilities::probe_all`] for every compiled-in backend at
32//! once.
33//!
34//! ## Example
35//!
36//! ```
37//! use oximedia_ml::{DeviceCapabilities, DeviceType};
38//!
39//! // Pick the strongest available backend (always succeeds — falls
40//! // back to CPU if nothing else is compiled in / usable).
41//! let device = DeviceType::auto();
42//! assert!(device.is_available());
43//!
44//! // Full capability report for the selected device.
45//! let caps = DeviceCapabilities::best_available();
46//! assert_eq!(caps.device_type, device);
47//! ```
48
49use core::fmt;
50#[cfg(any(feature = "cuda", feature = "webgpu", feature = "directml"))]
51use std::panic::AssertUnwindSafe;
52use std::sync::OnceLock;
53
54/// Execution backend for ML inference.
55#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
56#[cfg_attr(feature = "serde", derive(serde::Serialize))]
57#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
58pub enum DeviceType {
59    /// Pure-Rust CPU execution (always available).
60    Cpu,
61    /// NVIDIA CUDA via `oxionnx-cuda` (feature `cuda`).
62    Cuda,
63    /// WebGPU / wgpu compute backend via `oxionnx-gpu` (feature `webgpu`).
64    WebGpu,
65    /// Microsoft DirectML (feature `directml`, Windows-only at runtime).
66    DirectMl,
67    /// Apple CoreML. Reserved variant; no `coreml` feature exists yet, so
68    /// this device is never reported as available and `auto()` never
69    /// selects it. It exists so API consumers can exhaustively match on
70    /// `DeviceType` without having to guess whether CoreML will be added
71    /// later.
72    CoreMl,
73}
74
75/// Memoised result of [`DeviceType::auto`].
76///
77/// The probe cascade has observable side effects (CUDA driver
78/// initialisation, wgpu adapter enumeration), and some of those mutate
79/// thread-local state. Caching the first result makes subsequent calls
80/// cheap and avoids re-running those side effects.
81static AUTO_CACHE: OnceLock<DeviceType> = OnceLock::new();
82
83impl DeviceType {
84    /// Return the preferred device available in this build, in the order
85    /// **CUDA → DirectML → WebGPU → CPU**. Always succeeds because CPU is
86    /// unconditionally available.
87    ///
88    /// The result is memoised for the lifetime of the process.
89    #[must_use]
90    pub fn auto() -> Self {
91        *AUTO_CACHE.get_or_init(|| {
92            if Self::Cuda.is_available() {
93                return Self::Cuda;
94            }
95            if Self::DirectMl.is_available() {
96                return Self::DirectMl;
97            }
98            if Self::WebGpu.is_available() {
99                return Self::WebGpu;
100            }
101            Self::Cpu
102        })
103    }
104
105    /// Report whether this device is usable in the current build /
106    /// runtime environment. A device may be compiled in (feature-gated)
107    /// yet still unavailable at runtime — e.g. no GPU detected.
108    #[must_use]
109    pub fn is_available(self) -> bool {
110        match self {
111            Self::Cpu => true,
112            Self::Cuda => cuda_available(),
113            Self::WebGpu => webgpu_available(),
114            Self::DirectMl => directml_available(),
115            Self::CoreMl => false,
116        }
117    }
118
119    /// Short canonical name matching the feature flag / CLI spelling.
120    ///
121    /// Retained for backward compatibility with existing call sites;
122    /// [`DeviceType::display_name`] is preferred for human-facing output.
123    #[must_use]
124    pub fn name(self) -> &'static str {
125        match self {
126            Self::Cpu => "cpu",
127            Self::Cuda => "cuda",
128            Self::WebGpu => "webgpu",
129            Self::DirectMl => "directml",
130            Self::CoreMl => "coreml",
131        }
132    }
133
134    /// Human-facing label — identical to [`Self::name`] for now, but
135    /// conceptually distinct so downstream UIs can swap in a friendlier
136    /// string later without affecting programmatic lookups.
137    #[must_use]
138    pub fn display_name(self) -> &'static str {
139        match self {
140            Self::Cpu => "CPU",
141            Self::Cuda => "CUDA",
142            Self::WebGpu => "WebGPU",
143            Self::DirectMl => "DirectML",
144            Self::CoreMl => "CoreML",
145        }
146    }
147
148    /// Run the full cascade probe and return the richer
149    /// [`DeviceCapabilities`] record for this device.
150    #[must_use]
151    pub fn probe_caps(self) -> DeviceCapabilities {
152        DeviceCapabilities::probe(self)
153    }
154
155    /// Return every [`DeviceType`] whose backend is currently usable.
156    ///
157    /// The returned vector always contains [`DeviceType::Cpu`] and is
158    /// ordered by the probe cascade (CPU last).
159    #[must_use]
160    pub fn list_available() -> Vec<Self> {
161        let mut out = Vec::with_capacity(4);
162        if Self::Cuda.is_available() {
163            out.push(Self::Cuda);
164        }
165        if Self::DirectMl.is_available() {
166            out.push(Self::DirectMl);
167        }
168        if Self::WebGpu.is_available() {
169            out.push(Self::WebGpu);
170        }
171        // CoreMl is never currently available.
172        out.push(Self::Cpu);
173        out
174    }
175
176    /// Every variant in enum declaration order.
177    ///
178    /// Used by [`DeviceCapabilities::probe_all`] and the test suite.
179    #[must_use]
180    pub const fn all_variants() -> [Self; 5] {
181        [
182            Self::Cpu,
183            Self::Cuda,
184            Self::WebGpu,
185            Self::DirectMl,
186            Self::CoreMl,
187        ]
188    }
189}
190
191impl Default for DeviceType {
192    fn default() -> Self {
193        Self::Cpu
194    }
195}
196
197impl fmt::Display for DeviceType {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        f.write_str(self.name())
200    }
201}
202
203/// Rich capability description for a single [`DeviceType`].
204///
205/// Produced by [`DeviceCapabilities::probe`]. Fields are populated
206/// best-effort — anything the backend does not expose is left as `None` /
207/// `false`, which lets callers use a single code path regardless of how
208/// much telemetry the driver provides.
209#[derive(Clone, Debug, Default, PartialEq, Eq)]
210#[cfg_attr(feature = "serde", derive(serde::Serialize))]
211pub struct DeviceCapabilities {
212    /// Which device this record describes.
213    pub device_type: DeviceType,
214    /// Whether the device is currently available for inference.
215    pub is_available: bool,
216    /// Human-facing device name (e.g. "CPU (x86_64)", "NVIDIA GPU via CUDA").
217    pub device_name: String,
218    /// Total device memory in bytes, if known.
219    pub memory_total_bytes: Option<u64>,
220    /// Free device memory in bytes, if known.
221    pub memory_free_bytes: Option<u64>,
222    /// Compute capability string (e.g. "8.6" for Ampere). `None` for CPU
223    /// / WebGPU / DirectML / CoreML.
224    pub compute_capability: Option<String>,
225    /// Whether FP16 (half-precision float) is supported.
226    pub supports_fp16: bool,
227    /// Whether BF16 (bfloat16) is supported.
228    pub supports_bf16: bool,
229    /// Whether INT8 quantised inference is supported.
230    pub supports_int8: bool,
231}
232
233impl DeviceCapabilities {
234    /// Probe a specific device and describe its capabilities.
235    ///
236    /// Always returns a record; unavailable devices get `is_available =
237    /// false` with the rest populated from static knowledge.
238    #[must_use]
239    pub fn probe(device: DeviceType) -> Self {
240        match device {
241            DeviceType::Cpu => Self {
242                device_type: DeviceType::Cpu,
243                is_available: true,
244                device_name: cpu_device_name(),
245                memory_total_bytes: None,
246                memory_free_bytes: None,
247                compute_capability: None,
248                supports_fp16: false,
249                supports_bf16: false,
250                supports_int8: true,
251            },
252            DeviceType::Cuda => {
253                let live = cuda_available();
254                Self {
255                    device_type: DeviceType::Cuda,
256                    is_available: live,
257                    device_name: if live {
258                        "NVIDIA GPU via CUDA".to_string()
259                    } else {
260                        "CUDA (unavailable)".to_string()
261                    },
262                    memory_total_bytes: None,
263                    memory_free_bytes: None,
264                    compute_capability: None,
265                    supports_fp16: live,
266                    supports_bf16: live,
267                    supports_int8: live,
268                }
269            }
270            DeviceType::WebGpu => {
271                let live = webgpu_available();
272                Self {
273                    device_type: DeviceType::WebGpu,
274                    is_available: live,
275                    device_name: if live {
276                        "GPU via wgpu".to_string()
277                    } else {
278                        "WebGPU (unavailable)".to_string()
279                    },
280                    memory_total_bytes: None,
281                    memory_free_bytes: None,
282                    compute_capability: None,
283                    supports_fp16: false,
284                    supports_bf16: false,
285                    supports_int8: false,
286                }
287            }
288            DeviceType::DirectMl => {
289                let live = directml_available();
290                Self {
291                    device_type: DeviceType::DirectMl,
292                    is_available: live,
293                    device_name: if live {
294                        "GPU via DirectML".to_string()
295                    } else {
296                        "DirectML (unavailable)".to_string()
297                    },
298                    memory_total_bytes: None,
299                    memory_free_bytes: None,
300                    compute_capability: None,
301                    supports_fp16: live,
302                    supports_bf16: false,
303                    supports_int8: live,
304                }
305            }
306            DeviceType::CoreMl => Self {
307                device_type: DeviceType::CoreMl,
308                is_available: false,
309                device_name: "CoreML (not yet supported)".to_string(),
310                memory_total_bytes: None,
311                memory_free_bytes: None,
312                compute_capability: None,
313                supports_fp16: false,
314                supports_bf16: false,
315                supports_int8: false,
316            },
317        }
318    }
319
320    /// Probe every [`DeviceType`] variant and return a record per device.
321    ///
322    /// The resulting vector has exactly [`DeviceType::all_variants`] entries
323    /// and is ordered to match that array.
324    #[must_use]
325    pub fn probe_all() -> Vec<Self> {
326        DeviceType::all_variants()
327            .iter()
328            .copied()
329            .map(Self::probe)
330            .collect()
331    }
332
333    /// Return capabilities for the best currently-available device.
334    ///
335    /// Equivalent to `DeviceCapabilities::probe(DeviceType::auto())`, but
336    /// exposed as a named constructor for callers that only need the
337    /// capability record.
338    #[must_use]
339    pub fn best_available() -> Self {
340        Self::probe(DeviceType::auto())
341    }
342}
343
344impl fmt::Display for DeviceCapabilities {
345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346        write!(
347            f,
348            "{} [{}]",
349            self.device_name,
350            if self.is_available {
351                "available"
352            } else {
353                "unavailable"
354            }
355        )
356    }
357}
358
359// ---------------------------------------------------------------------------
360// Probe primitives
361// ---------------------------------------------------------------------------
362
363/// Run a probe closure under `catch_unwind` so a foreign panic cannot
364/// unwind through our caller. A panicking probe is treated as
365/// "unavailable".
366///
367/// Only compiled when at least one foreign probe is enabled, otherwise
368/// every `*_available()` function is a constant `false` and this helper
369/// would be dead code.
370#[cfg(any(feature = "cuda", feature = "webgpu", feature = "directml"))]
371fn safe_probe<F: FnOnce() -> bool>(probe: F) -> bool {
372    std::panic::catch_unwind(AssertUnwindSafe(probe)).unwrap_or(false)
373}
374
375#[cfg(feature = "cuda")]
376fn cuda_available() -> bool {
377    safe_probe(|| oxionnx::cuda::CudaContext::try_new().is_some())
378}
379
380#[cfg(not(feature = "cuda"))]
381fn cuda_available() -> bool {
382    false
383}
384
385#[cfg(feature = "webgpu")]
386fn webgpu_available() -> bool {
387    safe_probe(|| oxionnx::gpu::GpuContext::try_new().is_some())
388}
389
390#[cfg(not(feature = "webgpu"))]
391fn webgpu_available() -> bool {
392    false
393}
394
395#[cfg(feature = "directml")]
396fn directml_available() -> bool {
397    safe_probe(|| oxionnx::directml::DirectMLContext::try_new().is_some())
398}
399
400#[cfg(not(feature = "directml"))]
401fn directml_available() -> bool {
402    false
403}
404
405/// Best-effort CPU description string — architecture and pointer width.
406fn cpu_device_name() -> String {
407    format!(
408        "CPU ({}-{})",
409        std::env::consts::ARCH,
410        core::mem::size_of::<usize>() * 8
411    )
412}
413
414// ---------------------------------------------------------------------------
415// Tests (pure unit tests — heavier synthetic tests live under tests/)
416// ---------------------------------------------------------------------------
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn cpu_always_available() {
424        assert!(DeviceType::Cpu.is_available());
425        assert_eq!(DeviceType::Cpu.name(), "cpu");
426    }
427
428    #[test]
429    fn auto_returns_available_device() {
430        let device = DeviceType::auto();
431        assert!(device.is_available());
432    }
433
434    #[test]
435    fn default_is_cpu() {
436        assert_eq!(DeviceType::default(), DeviceType::Cpu);
437    }
438
439    #[test]
440    fn display_matches_name() {
441        assert_eq!(format!("{}", DeviceType::Cpu), "cpu");
442        assert_eq!(format!("{}", DeviceType::Cuda), "cuda");
443        assert_eq!(format!("{}", DeviceType::WebGpu), "webgpu");
444        assert_eq!(format!("{}", DeviceType::DirectMl), "directml");
445        assert_eq!(format!("{}", DeviceType::CoreMl), "coreml");
446    }
447
448    #[test]
449    fn display_names_are_stable() {
450        assert_eq!(DeviceType::Cpu.display_name(), "CPU");
451        assert_eq!(DeviceType::Cuda.display_name(), "CUDA");
452        assert_eq!(DeviceType::WebGpu.display_name(), "WebGPU");
453        assert_eq!(DeviceType::DirectMl.display_name(), "DirectML");
454        assert_eq!(DeviceType::CoreMl.display_name(), "CoreML");
455    }
456
457    #[test]
458    fn coreml_never_available() {
459        assert!(!DeviceType::CoreMl.is_available());
460    }
461
462    #[test]
463    fn all_variants_has_five_entries() {
464        assert_eq!(DeviceType::all_variants().len(), 5);
465    }
466
467    #[test]
468    fn capabilities_cpu_is_available() {
469        let caps = DeviceCapabilities::probe(DeviceType::Cpu);
470        assert!(caps.is_available);
471        assert!(caps.supports_int8);
472        assert_eq!(caps.device_type, DeviceType::Cpu);
473    }
474}