Skip to main content

rlx_runtime/
device_parse.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! String identifiers for [`rlx_driver::Device`] — config files, CLI, env vars.
17
18use rlx_driver::Device;
19
20/// Failed to parse a device name.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct ParseDeviceError {
23    pub input: String,
24    pub message: String,
25}
26
27impl std::fmt::Display for ParseDeviceError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "{}", self.message)
30    }
31}
32
33impl std::error::Error for ParseDeviceError {}
34
35/// Lower-case Cargo feature names and common aliases → [`Device`].
36pub fn parse_device(s: &str) -> Result<Device, ParseDeviceError> {
37    let key = s.trim().to_ascii_lowercase();
38    match key.as_str() {
39        "cpu" => Ok(Device::Cpu),
40        "metal" | "mtl" => Ok(Device::Metal),
41        "mlx" => Ok(Device::Mlx),
42        "ane" | "neural-engine" => Ok(Device::Ane),
43        "cuda" | "nvidia" => Ok(Device::Cuda),
44        "rocm" | "hip" | "amd" => Ok(Device::Rocm),
45        "gpu" | "wgpu" => Ok(Device::Gpu),
46        "vulkan" | "vk" => Ok(Device::Vulkan),
47        "opengl" | "gl" => Ok(Device::OpenGl),
48        "directx" | "dx12" | "d3d12" => Ok(Device::DirectX),
49        "webgpu" => Ok(Device::WebGpu),
50        "tpu" => Ok(Device::Tpu),
51        "" => Err(ParseDeviceError {
52            input: s.to_string(),
53            message: "empty device name".into(),
54        }),
55        other => Err(ParseDeviceError {
56            input: s.to_string(),
57            message: format!(
58                "unknown device '{other}' (try: cpu, metal, mlx, cuda, rocm, gpu, vulkan, tpu)"
59            ),
60        }),
61    }
62}
63
64/// Stable lower-case label for `device` (matches Cargo feature names).
65pub fn device_label(device: Device) -> &'static str {
66    match device {
67        Device::Cpu => "cpu",
68        Device::Metal => "metal",
69        Device::Mlx => "mlx",
70        Device::Ane => "ane",
71        Device::Cuda => "cuda",
72        Device::Rocm => "rocm",
73        Device::Gpu => "gpu",
74        Device::Vulkan => "vulkan",
75        Device::OpenGl => "opengl",
76        Device::DirectX => "directx",
77        Device::WebGpu => "webgpu",
78        Device::Tpu => "tpu",
79    }
80}
81
82/// Parse comma/semicolon/whitespace-separated device lists (`RLX_DEVICES=cpu,metal`).
83pub fn parse_device_list(s: &str) -> Result<Vec<Device>, ParseDeviceError> {
84    let mut out = Vec::new();
85    for part in s.split([',', ';', ' ']) {
86        let part = part.trim();
87        if part.is_empty() {
88            continue;
89        }
90        let dev = parse_device(part)?;
91        if !out.contains(&dev) {
92            out.push(dev);
93        }
94    }
95    if out.is_empty() {
96        return Err(ParseDeviceError {
97            input: s.to_string(),
98            message: "device list is empty".into(),
99        });
100    }
101    Ok(out)
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn parse_aliases() {
110        assert_eq!(parse_device("CUDA").unwrap(), Device::Cuda);
111        assert_eq!(parse_device("wgpu").unwrap(), Device::Gpu);
112        assert_eq!(
113            parse_device_list("cpu, metal;mlx").unwrap(),
114            vec![Device::Cpu, Device::Metal, Device::Mlx,]
115        );
116    }
117}