Skip to main content

rlx_models_core/
device_capabilities.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shared backend policy for RLX model crates.
17//!
18//! Every model family in this workspace targets the same seven execution
19//! backends. Call [`validate_standard_device`] at runner / loader build
20//! time; enable matching `rlx-runtime` features on the model crate
21//! (`metal`, `mlx`, `cuda`, `rocm`, `gpu`, `vulkan`, or `all-backends`).
22
23use anyhow::{Result, bail};
24use rlx_runtime::{Device, memory_estimate};
25
26/// Backends every model crate is expected to support when the matching
27/// `rlx-runtime` feature is enabled at build time.
28pub const STANDARD_DEVICES: &[Device] = &[
29    Device::Cpu,
30    Device::Metal,
31    Device::Mlx,
32    Device::Cuda,
33    Device::Rocm,
34    Device::Gpu,
35    Device::Vulkan,
36];
37
38/// CLI / help string for `--device`.
39pub const STANDARD_DEVICE_NAMES: &str = "cpu|metal|mps|mlx|cuda|rocm|hip|gpu|wgpu|vulkan";
40
41/// True when `device` is in [`STANDARD_DEVICES`].
42pub fn is_standard_device(device: Device) -> bool {
43    STANDARD_DEVICES.contains(&device)
44}
45
46/// Fail fast on exotic runtime devices (TPU, ANE, OpenGL, …).
47pub fn validate_standard_device(family: &str, device: Device) -> Result<()> {
48    if is_standard_device(device) {
49        Ok(())
50    } else {
51        bail!(
52            "{family}: device {device:?} is not supported \
53             (use {STANDARD_DEVICE_NAMES})"
54        )
55    }
56}
57
58/// `(free_bytes, total_bytes)` for TIDE MoE VRAM budget sizing.
59///
60/// Override with `RLX_CUDA_FREE_BYTES` / `RLX_CUDA_TOTAL_BYTES` or
61/// `RLX_DEVICE_FREE_BYTES` / `RLX_DEVICE_TOTAL_BYTES`. On Apple Silicon
62/// (Metal / MLX), falls back to unified memory when env vars are unset.
63pub fn device_memory_for_moe_offload(device: Device) -> Option<(usize, usize)> {
64    if let (Ok(free), Ok(total)) = (
65        std::env::var("RLX_CUDA_FREE_BYTES"),
66        std::env::var("RLX_CUDA_TOTAL_BYTES"),
67    ) {
68        if let (Ok(f), Ok(t)) = (free.parse(), total.parse()) {
69            return Some((f, t));
70        }
71    }
72    if let (Ok(free), Ok(total)) = (
73        std::env::var("RLX_DEVICE_FREE_BYTES"),
74        std::env::var("RLX_DEVICE_TOTAL_BYTES"),
75    ) {
76        if let (Ok(f), Ok(t)) = (free.parse(), total.parse()) {
77            return Some((f, t));
78        }
79    }
80    match device {
81        Device::Metal | Device::Mlx => memory_estimate::available_unified_memory().map(|t| (t, t)),
82        Device::Cuda | Device::Rocm | Device::Gpu | Device::Vulkan => {
83            memory_estimate::available_unified_memory().map(|t| (t, t))
84        }
85        _ => None,
86    }
87}
88
89/// SAM v1 also documents `tpu` on [`rlx_sam::Sam::from_safetensors_on`].
90pub fn validate_sam_device(family: &str, device: Device) -> Result<()> {
91    if device == Device::Tpu || is_standard_device(device) {
92        Ok(())
93    } else {
94        bail!(
95            "{family}: device {device:?} is not supported \
96             (use {STANDARD_DEVICE_NAMES} or tpu)"
97        )
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn standard_set_covers_cli_backends() {
107        for dev in STANDARD_DEVICES {
108            assert!(is_standard_device(*dev));
109        }
110        assert!(!is_standard_device(Device::Tpu));
111    }
112}