use anyhow::{Result, bail};
use rlx_core::{validate_sam_device, validate_standard_device};
use rlx_runtime::Device;
pub fn parse_device(s: &str) -> Result<Device> {
Ok(match s {
"cpu" => Device::Cpu,
"metal" | "mps" => Device::Metal,
"mlx" => Device::Mlx,
"cuda" => Device::Cuda,
"rocm" | "hip" => Device::Rocm,
"gpu" | "wgpu" => Device::Gpu,
"vulkan" => Device::Vulkan,
other => bail!("unknown device {other} (cpu|metal|mps|mlx|cuda|rocm|hip|gpu|wgpu|vulkan)"),
})
}
pub fn parse_standard_device(family: &str, s: &str) -> Result<Device> {
let d = parse_device(s)?;
validate_standard_device(family, d)?;
Ok(d)
}
pub fn parse_llama32_device(s: &str) -> Result<Device> {
parse_standard_device("llama32", s)
}
pub fn parse_gemma_device(s: &str) -> Result<Device> {
parse_standard_device("gemma", s)
}
pub fn parse_qwen35_device(s: &str) -> Result<Device> {
parse_standard_device("qwen35", s)
}
pub fn parse_llada2_device(s: &str) -> Result<Device> {
parse_standard_device("llada2", s)
}
pub fn parse_sam_device(family: &str, s: &str) -> Result<Device> {
let d = match s {
"tpu" => Device::Tpu,
other => parse_device(other)?,
};
validate_sam_device(family, d)?;
Ok(d)
}