Skip to main content

rlx_cli/
device.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::{Result, bail};
17use rlx_core::{validate_sam_device, validate_standard_device};
18use rlx_runtime::Device;
19
20pub fn parse_device(s: &str) -> Result<Device> {
21    Ok(match s {
22        "cpu" => Device::Cpu,
23        "metal" | "mps" => Device::Metal,
24        "mlx" => Device::Mlx,
25        "cuda" => Device::Cuda,
26        "rocm" | "hip" => Device::Rocm,
27        "gpu" | "wgpu" => Device::Gpu,
28        "vulkan" => Device::Vulkan,
29        other => bail!("unknown device {other} (cpu|metal|mps|mlx|cuda|rocm|hip|gpu|wgpu|vulkan)"),
30    })
31}
32
33/// Parse a CLI device name and enforce the workspace-wide standard backend set.
34pub fn parse_standard_device(family: &str, s: &str) -> Result<Device> {
35    let d = parse_device(s)?;
36    validate_standard_device(family, d)?;
37    Ok(d)
38}
39
40pub fn parse_llama32_device(s: &str) -> Result<Device> {
41    parse_standard_device("llama32", s)
42}
43
44pub fn parse_gemma_device(s: &str) -> Result<Device> {
45    parse_standard_device("gemma", s)
46}
47
48pub fn parse_qwen35_device(s: &str) -> Result<Device> {
49    parse_standard_device("qwen35", s)
50}
51
52pub fn parse_llada2_device(s: &str) -> Result<Device> {
53    parse_standard_device("llada2", s)
54}
55
56/// Parse a CLI device name and enforce the SAM backend set (standard + `tpu`).
57pub fn parse_sam_device(family: &str, s: &str) -> Result<Device> {
58    let d = match s {
59        "tpu" => Device::Tpu,
60        other => parse_device(other)?,
61    };
62    validate_sam_device(family, d)?;
63    Ok(d)
64}