use core::fmt;
#[cfg(any(feature = "cuda", feature = "webgpu", feature = "directml"))]
use std::panic::AssertUnwindSafe;
use std::sync::OnceLock;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
pub enum DeviceType {
Cpu,
Cuda,
WebGpu,
DirectMl,
CoreMl,
}
static AUTO_CACHE: OnceLock<DeviceType> = OnceLock::new();
impl DeviceType {
#[must_use]
pub fn auto() -> Self {
*AUTO_CACHE.get_or_init(|| {
if Self::Cuda.is_available() {
return Self::Cuda;
}
if Self::DirectMl.is_available() {
return Self::DirectMl;
}
if Self::WebGpu.is_available() {
return Self::WebGpu;
}
Self::Cpu
})
}
#[must_use]
pub fn is_available(self) -> bool {
match self {
Self::Cpu => true,
Self::Cuda => cuda_available(),
Self::WebGpu => webgpu_available(),
Self::DirectMl => directml_available(),
Self::CoreMl => false,
}
}
#[must_use]
pub fn name(self) -> &'static str {
match self {
Self::Cpu => "cpu",
Self::Cuda => "cuda",
Self::WebGpu => "webgpu",
Self::DirectMl => "directml",
Self::CoreMl => "coreml",
}
}
#[must_use]
pub fn display_name(self) -> &'static str {
match self {
Self::Cpu => "CPU",
Self::Cuda => "CUDA",
Self::WebGpu => "WebGPU",
Self::DirectMl => "DirectML",
Self::CoreMl => "CoreML",
}
}
#[must_use]
pub fn probe_caps(self) -> DeviceCapabilities {
DeviceCapabilities::probe(self)
}
#[must_use]
pub fn list_available() -> Vec<Self> {
let mut out = Vec::with_capacity(4);
if Self::Cuda.is_available() {
out.push(Self::Cuda);
}
if Self::DirectMl.is_available() {
out.push(Self::DirectMl);
}
if Self::WebGpu.is_available() {
out.push(Self::WebGpu);
}
out.push(Self::Cpu);
out
}
#[must_use]
pub const fn all_variants() -> [Self; 5] {
[
Self::Cpu,
Self::Cuda,
Self::WebGpu,
Self::DirectMl,
Self::CoreMl,
]
}
}
impl Default for DeviceType {
fn default() -> Self {
Self::Cpu
}
}
impl fmt::Display for DeviceType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct DeviceCapabilities {
pub device_type: DeviceType,
pub is_available: bool,
pub device_name: String,
pub memory_total_bytes: Option<u64>,
pub memory_free_bytes: Option<u64>,
pub compute_capability: Option<String>,
pub supports_fp16: bool,
pub supports_bf16: bool,
pub supports_int8: bool,
}
impl DeviceCapabilities {
#[must_use]
pub fn probe(device: DeviceType) -> Self {
match device {
DeviceType::Cpu => Self {
device_type: DeviceType::Cpu,
is_available: true,
device_name: cpu_device_name(),
memory_total_bytes: None,
memory_free_bytes: None,
compute_capability: None,
supports_fp16: false,
supports_bf16: false,
supports_int8: true,
},
DeviceType::Cuda => {
let live = cuda_available();
Self {
device_type: DeviceType::Cuda,
is_available: live,
device_name: if live {
"NVIDIA GPU via CUDA".to_string()
} else {
"CUDA (unavailable)".to_string()
},
memory_total_bytes: None,
memory_free_bytes: None,
compute_capability: None,
supports_fp16: live,
supports_bf16: live,
supports_int8: live,
}
}
DeviceType::WebGpu => {
let live = webgpu_available();
Self {
device_type: DeviceType::WebGpu,
is_available: live,
device_name: if live {
"GPU via wgpu".to_string()
} else {
"WebGPU (unavailable)".to_string()
},
memory_total_bytes: None,
memory_free_bytes: None,
compute_capability: None,
supports_fp16: false,
supports_bf16: false,
supports_int8: false,
}
}
DeviceType::DirectMl => {
let live = directml_available();
Self {
device_type: DeviceType::DirectMl,
is_available: live,
device_name: if live {
"GPU via DirectML".to_string()
} else {
"DirectML (unavailable)".to_string()
},
memory_total_bytes: None,
memory_free_bytes: None,
compute_capability: None,
supports_fp16: live,
supports_bf16: false,
supports_int8: live,
}
}
DeviceType::CoreMl => Self {
device_type: DeviceType::CoreMl,
is_available: false,
device_name: "CoreML (not yet supported)".to_string(),
memory_total_bytes: None,
memory_free_bytes: None,
compute_capability: None,
supports_fp16: false,
supports_bf16: false,
supports_int8: false,
},
}
}
#[must_use]
pub fn probe_all() -> Vec<Self> {
DeviceType::all_variants()
.iter()
.copied()
.map(Self::probe)
.collect()
}
#[must_use]
pub fn best_available() -> Self {
Self::probe(DeviceType::auto())
}
}
impl fmt::Display for DeviceCapabilities {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} [{}]",
self.device_name,
if self.is_available {
"available"
} else {
"unavailable"
}
)
}
}
#[cfg(any(feature = "cuda", feature = "webgpu", feature = "directml"))]
fn safe_probe<F: FnOnce() -> bool>(probe: F) -> bool {
std::panic::catch_unwind(AssertUnwindSafe(probe)).unwrap_or(false)
}
#[cfg(feature = "cuda")]
fn cuda_available() -> bool {
safe_probe(|| oxionnx::cuda::CudaContext::try_new().is_some())
}
#[cfg(not(feature = "cuda"))]
fn cuda_available() -> bool {
false
}
#[cfg(feature = "webgpu")]
fn webgpu_available() -> bool {
safe_probe(|| oxionnx::gpu::GpuContext::try_new().is_some())
}
#[cfg(not(feature = "webgpu"))]
fn webgpu_available() -> bool {
false
}
#[cfg(feature = "directml")]
fn directml_available() -> bool {
safe_probe(|| oxionnx::directml::DirectMLContext::try_new().is_some())
}
#[cfg(not(feature = "directml"))]
fn directml_available() -> bool {
false
}
fn cpu_device_name() -> String {
format!(
"CPU ({}-{})",
std::env::consts::ARCH,
core::mem::size_of::<usize>() * 8
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_always_available() {
assert!(DeviceType::Cpu.is_available());
assert_eq!(DeviceType::Cpu.name(), "cpu");
}
#[test]
fn auto_returns_available_device() {
let device = DeviceType::auto();
assert!(device.is_available());
}
#[test]
fn default_is_cpu() {
assert_eq!(DeviceType::default(), DeviceType::Cpu);
}
#[test]
fn display_matches_name() {
assert_eq!(format!("{}", DeviceType::Cpu), "cpu");
assert_eq!(format!("{}", DeviceType::Cuda), "cuda");
assert_eq!(format!("{}", DeviceType::WebGpu), "webgpu");
assert_eq!(format!("{}", DeviceType::DirectMl), "directml");
assert_eq!(format!("{}", DeviceType::CoreMl), "coreml");
}
#[test]
fn display_names_are_stable() {
assert_eq!(DeviceType::Cpu.display_name(), "CPU");
assert_eq!(DeviceType::Cuda.display_name(), "CUDA");
assert_eq!(DeviceType::WebGpu.display_name(), "WebGPU");
assert_eq!(DeviceType::DirectMl.display_name(), "DirectML");
assert_eq!(DeviceType::CoreMl.display_name(), "CoreML");
}
#[test]
fn coreml_never_available() {
assert!(!DeviceType::CoreMl.is_available());
}
#[test]
fn all_variants_has_five_entries() {
assert_eq!(DeviceType::all_variants().len(), 5);
}
#[test]
fn capabilities_cpu_is_available() {
let caps = DeviceCapabilities::probe(DeviceType::Cpu);
assert!(caps.is_available);
assert!(caps.supports_int8);
assert_eq!(caps.device_type, DeviceType::Cpu);
}
}