rlx_runtime/
device_parse.rs1use rlx_driver::Device;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct ParseDeviceError {
23 pub input: String,
24 pub message: String,
25}
26
27impl std::fmt::Display for ParseDeviceError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "{}", self.message)
30 }
31}
32
33impl std::error::Error for ParseDeviceError {}
34
35pub fn parse_device(s: &str) -> Result<Device, ParseDeviceError> {
37 let key = s.trim().to_ascii_lowercase();
38 match key.as_str() {
39 "cpu" => Ok(Device::Cpu),
40 "metal" | "mtl" => Ok(Device::Metal),
41 "mlx" => Ok(Device::Mlx),
42 "ane" | "neural-engine" => Ok(Device::Ane),
43 "cuda" | "nvidia" => Ok(Device::Cuda),
44 "rocm" | "hip" | "amd" => Ok(Device::Rocm),
45 "gpu" | "wgpu" => Ok(Device::Gpu),
46 "vulkan" | "vk" => Ok(Device::Vulkan),
47 "opengl" | "gl" => Ok(Device::OpenGl),
48 "directx" | "dx12" | "d3d12" => Ok(Device::DirectX),
49 "webgpu" => Ok(Device::WebGpu),
50 "tpu" => Ok(Device::Tpu),
51 "" => Err(ParseDeviceError {
52 input: s.to_string(),
53 message: "empty device name".into(),
54 }),
55 other => Err(ParseDeviceError {
56 input: s.to_string(),
57 message: format!(
58 "unknown device '{other}' (try: cpu, metal, mlx, cuda, rocm, gpu, vulkan, tpu)"
59 ),
60 }),
61 }
62}
63
64pub fn device_label(device: Device) -> &'static str {
66 match device {
67 Device::Cpu => "cpu",
68 Device::Metal => "metal",
69 Device::Mlx => "mlx",
70 Device::Ane => "ane",
71 Device::Cuda => "cuda",
72 Device::Rocm => "rocm",
73 Device::Gpu => "gpu",
74 Device::Vulkan => "vulkan",
75 Device::OpenGl => "opengl",
76 Device::DirectX => "directx",
77 Device::WebGpu => "webgpu",
78 Device::Tpu => "tpu",
79 }
80}
81
82pub fn parse_device_list(s: &str) -> Result<Vec<Device>, ParseDeviceError> {
84 let mut out = Vec::new();
85 for part in s.split([',', ';', ' ']) {
86 let part = part.trim();
87 if part.is_empty() {
88 continue;
89 }
90 let dev = parse_device(part)?;
91 if !out.contains(&dev) {
92 out.push(dev);
93 }
94 }
95 if out.is_empty() {
96 return Err(ParseDeviceError {
97 input: s.to_string(),
98 message: "device list is empty".into(),
99 });
100 }
101 Ok(out)
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn parse_aliases() {
110 assert_eq!(parse_device("CUDA").unwrap(), Device::Cuda);
111 assert_eq!(parse_device("wgpu").unwrap(), Device::Gpu);
112 assert_eq!(
113 parse_device_list("cpu, metal;mlx").unwrap(),
114 vec![Device::Cpu, Device::Metal, Device::Mlx,]
115 );
116 }
117}