rlx_models_core/device_capabilities.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shared backend policy for RLX model crates.
17//!
18//! Every model family in this workspace targets the same seven execution
19//! backends. Call [`validate_standard_device`] at runner / loader build
20//! time; enable matching `rlx-runtime` features on the model crate
21//! (`metal`, `mlx`, `cuda`, `rocm`, `gpu`, `vulkan`, or `all-backends`).
22
23use anyhow::{Result, bail};
24use rlx_runtime::{Device, memory_estimate};
25
26/// Backends every model crate is expected to support when the matching
27/// `rlx-runtime` feature is enabled at build time.
28pub const STANDARD_DEVICES: &[Device] = &[
29 Device::Cpu,
30 Device::Metal,
31 Device::Mlx,
32 Device::Cuda,
33 Device::Rocm,
34 Device::Gpu,
35 Device::Vulkan,
36];
37
38/// CLI / help string for `--device`.
39pub const STANDARD_DEVICE_NAMES: &str = "cpu|metal|mps|mlx|cuda|rocm|hip|gpu|wgpu|vulkan";
40
41/// True when `device` is in [`STANDARD_DEVICES`].
42pub fn is_standard_device(device: Device) -> bool {
43 STANDARD_DEVICES.contains(&device)
44}
45
46/// Fail fast on exotic runtime devices (TPU, ANE, OpenGL, …).
47pub fn validate_standard_device(family: &str, device: Device) -> Result<()> {
48 if is_standard_device(device) {
49 Ok(())
50 } else {
51 bail!(
52 "{family}: device {device:?} is not supported \
53 (use {STANDARD_DEVICE_NAMES})"
54 )
55 }
56}
57
58/// `(free_bytes, total_bytes)` for TIDE MoE VRAM budget sizing.
59///
60/// Override with `RLX_CUDA_FREE_BYTES` / `RLX_CUDA_TOTAL_BYTES` or
61/// `RLX_DEVICE_FREE_BYTES` / `RLX_DEVICE_TOTAL_BYTES`. On Apple Silicon
62/// (Metal / MLX), falls back to unified memory when env vars are unset.
63pub fn device_memory_for_moe_offload(device: Device) -> Option<(usize, usize)> {
64 if let (Ok(free), Ok(total)) = (
65 std::env::var("RLX_CUDA_FREE_BYTES"),
66 std::env::var("RLX_CUDA_TOTAL_BYTES"),
67 ) {
68 if let (Ok(f), Ok(t)) = (free.parse(), total.parse()) {
69 return Some((f, t));
70 }
71 }
72 if let (Ok(free), Ok(total)) = (
73 std::env::var("RLX_DEVICE_FREE_BYTES"),
74 std::env::var("RLX_DEVICE_TOTAL_BYTES"),
75 ) {
76 if let (Ok(f), Ok(t)) = (free.parse(), total.parse()) {
77 return Some((f, t));
78 }
79 }
80 match device {
81 Device::Metal | Device::Mlx => memory_estimate::available_unified_memory().map(|t| (t, t)),
82 Device::Cuda | Device::Rocm | Device::Gpu | Device::Vulkan => {
83 memory_estimate::available_unified_memory().map(|t| (t, t))
84 }
85 _ => None,
86 }
87}
88
89/// SAM v1 also documents `tpu` on [`rlx_sam::Sam::from_safetensors_on`].
90pub fn validate_sam_device(family: &str, device: Device) -> Result<()> {
91 if device == Device::Tpu || is_standard_device(device) {
92 Ok(())
93 } else {
94 bail!(
95 "{family}: device {device:?} is not supported \
96 (use {STANDARD_DEVICE_NAMES} or tpu)"
97 )
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn standard_set_covers_cli_backends() {
107 for dev in STANDARD_DEVICES {
108 assert!(is_standard_device(*dev));
109 }
110 assert!(!is_standard_device(Device::Tpu));
111 }
112}