rlx_cli/device.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::Result;
17use rlx_core::{validate_sam_device, validate_standard_device};
18use rlx_runtime::Device;
19use std::str::FromStr;
20
21/// Parse a device alias. Delegates to the upstream `FromStr for Device`
22/// impl in `rlx-driver`, which knows every alias (`mps`, `wgpu`, `hip`,
23/// …) and returns a uniform error string. Kept as a wrapper that
24/// surfaces `anyhow::Result` so existing callers don't change.
25pub fn parse_device(s: &str) -> Result<Device> {
26 Device::from_str(s).map_err(|e| anyhow::anyhow!("{e}"))
27}
28
29/// Parse a CLI device name and enforce the workspace-wide standard backend set.
30pub fn parse_standard_device(family: &str, s: &str) -> Result<Device> {
31 let d = parse_device(s)?;
32 validate_standard_device(family, d)?;
33 Ok(d)
34}
35
36pub fn parse_llama32_device(s: &str) -> Result<Device> {
37 parse_standard_device("llama32", s)
38}
39
40pub fn parse_gemma_device(s: &str) -> Result<Device> {
41 parse_standard_device("gemma", s)
42}
43
44pub fn parse_qwen35_device(s: &str) -> Result<Device> {
45 parse_standard_device("qwen35", s)
46}
47
48pub fn parse_llada2_device(s: &str) -> Result<Device> {
49 parse_standard_device("llada2", s)
50}
51
52/// Parse a CLI device name and enforce the SAM backend set (standard + `tpu`).
53pub fn parse_sam_device(family: &str, s: &str) -> Result<Device> {
54 let d = match s {
55 "tpu" => Device::Tpu,
56 other => parse_device(other)?,
57 };
58 validate_sam_device(family, d)?;
59 Ok(d)
60}