use rlx_driver::Device;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParseDeviceError {
pub input: String,
pub message: String,
}
impl std::fmt::Display for ParseDeviceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for ParseDeviceError {}
pub fn parse_device(s: &str) -> Result<Device, ParseDeviceError> {
let key = s.trim().to_ascii_lowercase();
match key.as_str() {
"cpu" => Ok(Device::Cpu),
"metal" | "mtl" => Ok(Device::Metal),
"mlx" => Ok(Device::Mlx),
"ane" | "neural-engine" => Ok(Device::Ane),
"cuda" | "nvidia" => Ok(Device::Cuda),
"rocm" | "hip" | "amd" => Ok(Device::Rocm),
"gpu" | "wgpu" => Ok(Device::Gpu),
"vulkan" | "vk" => Ok(Device::Vulkan),
"opengl" | "gl" => Ok(Device::OpenGl),
"directx" | "dx12" | "d3d12" => Ok(Device::DirectX),
"webgpu" => Ok(Device::WebGpu),
"tpu" => Ok(Device::Tpu),
"" => Err(ParseDeviceError {
input: s.to_string(),
message: "empty device name".into(),
}),
other => Err(ParseDeviceError {
input: s.to_string(),
message: format!(
"unknown device '{other}' (try: cpu, metal, mlx, cuda, rocm, gpu, vulkan, tpu)"
),
}),
}
}
pub fn device_label(device: Device) -> &'static str {
match device {
Device::Cpu => "cpu",
Device::Metal => "metal",
Device::Mlx => "mlx",
Device::Ane => "ane",
Device::Cuda => "cuda",
Device::Rocm => "rocm",
Device::Gpu => "gpu",
Device::Vulkan => "vulkan",
Device::OpenGl => "opengl",
Device::DirectX => "directx",
Device::WebGpu => "webgpu",
Device::Tpu => "tpu",
}
}
pub fn parse_device_list(s: &str) -> Result<Vec<Device>, ParseDeviceError> {
let mut out = Vec::new();
for part in s.split([',', ';', ' ']) {
let part = part.trim();
if part.is_empty() {
continue;
}
let dev = parse_device(part)?;
if !out.contains(&dev) {
out.push(dev);
}
}
if out.is_empty() {
return Err(ParseDeviceError {
input: s.to_string(),
message: "device list is empty".into(),
});
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_aliases() {
assert_eq!(parse_device("CUDA").unwrap(), Device::Cuda);
assert_eq!(parse_device("wgpu").unwrap(), Device::Gpu);
assert_eq!(
parse_device_list("cpu, metal;mlx").unwrap(),
vec![Device::Cpu, Device::Metal, Device::Mlx,]
);
}
}