#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Device {
Cpu,
Metal,
Mlx,
Ane,
Cuda,
Rocm,
Tpu,
Gpu,
Vulkan,
OpenGl,
DirectX,
WebGpu,
}
impl Device {
pub fn name(self) -> &'static str {
match self {
Device::Cpu => "CPU",
Device::Metal => "Metal",
Device::Mlx => "MLX",
Device::Ane => "ANE",
Device::Cuda => "CUDA",
Device::Rocm => "ROCm",
Device::Tpu => "TPU",
Device::Gpu => "GPU (wgpu)",
Device::Vulkan => "Vulkan",
Device::OpenGl => "OpenGL",
Device::DirectX => "DirectX 12",
Device::WebGpu => "WebGPU",
}
}
pub fn all() -> &'static [Device] {
&[
Device::Cpu,
Device::Metal,
Device::Mlx,
Device::Ane,
Device::Cuda,
Device::Rocm,
Device::Tpu,
Device::Gpu,
Device::Vulkan,
Device::OpenGl,
Device::DirectX,
Device::WebGpu,
]
}
}
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeviceFromStrError(pub String);
impl std::fmt::Display for DeviceFromStrError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"unknown device '{}' (try: cpu, metal, mlx, ane, cuda, rocm, gpu, vulkan, opengl, directx, webgpu, tpu)",
self.0
)
}
}
impl std::error::Error for DeviceFromStrError {}
impl std::str::FromStr for Device {
type Err = DeviceFromStrError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let key = s.trim().to_ascii_lowercase();
Ok(match key.as_str() {
"cpu" => Device::Cpu,
"metal" | "mps" | "mtl" => Device::Metal,
"mlx" => Device::Mlx,
"ane" | "neural-engine" => Device::Ane,
"cuda" | "nvidia" => Device::Cuda,
"rocm" | "hip" | "amd" => Device::Rocm,
"gpu" | "wgpu" => Device::Gpu,
"vulkan" | "vk" => Device::Vulkan,
"opengl" | "gl" => Device::OpenGl,
"directx" | "dx12" | "d3d12" => Device::DirectX,
"webgpu" => Device::WebGpu,
"tpu" => Device::Tpu,
_ => return Err(DeviceFromStrError(s.to_string())),
})
}
}
pub trait BackendSupport {
fn family(&self) -> &'static str;
fn supports(&self, device: Device) -> bool;
}
pub const STANDARD_DEVICES: &[Device] = &[
Device::Cpu,
Device::Metal,
Device::Mlx,
Device::Cuda,
Device::Rocm,
Device::Gpu,
];
#[derive(Debug, Clone, Copy)]
pub struct StandardBackends(pub &'static str);
impl BackendSupport for StandardBackends {
fn family(&self) -> &'static str {
self.0
}
fn supports(&self, device: Device) -> bool {
STANDARD_DEVICES.contains(&device)
}
}
pub fn validate_device<S: BackendSupport>(support: &S, device: Device) -> Result<Device, String> {
if support.supports(device) {
Ok(device)
} else {
Err(format!(
"device {} is not supported by family `{}`",
device.name(),
support.family()
))
}
}
#[cfg(test)]
mod from_str_tests {
use super::*;
use std::str::FromStr;
#[test]
fn parse_basics() {
assert_eq!(Device::from_str("cpu").unwrap(), Device::Cpu);
assert_eq!(Device::from_str("CUDA").unwrap(), Device::Cuda);
assert_eq!(Device::from_str("mps").unwrap(), Device::Metal);
assert_eq!(Device::from_str("wgpu").unwrap(), Device::Gpu);
assert!(Device::from_str("nothing").is_err());
}
#[test]
fn standard_backends_set() {
let s = StandardBackends("qwen3");
assert!(s.supports(Device::Cpu));
assert!(s.supports(Device::Metal));
assert!(!s.supports(Device::Tpu));
assert!(validate_device(&s, Device::Tpu).is_err());
}
}