use anyhow::{Result, bail};
use rlx_runtime::{Device, memory_estimate};
pub const STANDARD_DEVICES: &[Device] = &[
Device::Cpu,
Device::Metal,
Device::Mlx,
Device::Cuda,
Device::Rocm,
Device::Gpu,
Device::Vulkan,
];
pub const STANDARD_DEVICE_NAMES: &str = "cpu|metal|mps|mlx|cuda|rocm|hip|gpu|wgpu|vulkan";
pub fn is_standard_device(device: Device) -> bool {
STANDARD_DEVICES.contains(&device)
}
pub fn validate_standard_device(family: &str, device: Device) -> Result<()> {
if is_standard_device(device) {
Ok(())
} else {
bail!(
"{family}: device {device:?} is not supported \
(use {STANDARD_DEVICE_NAMES})"
)
}
}
pub fn device_memory_for_moe_offload(device: Device) -> Option<(usize, usize)> {
if let (Ok(free), Ok(total)) = (
std::env::var("RLX_CUDA_FREE_BYTES"),
std::env::var("RLX_CUDA_TOTAL_BYTES"),
) {
if let (Ok(f), Ok(t)) = (free.parse(), total.parse()) {
return Some((f, t));
}
}
if let (Ok(free), Ok(total)) = (
std::env::var("RLX_DEVICE_FREE_BYTES"),
std::env::var("RLX_DEVICE_TOTAL_BYTES"),
) {
if let (Ok(f), Ok(t)) = (free.parse(), total.parse()) {
return Some((f, t));
}
}
match device {
Device::Metal | Device::Mlx => memory_estimate::available_unified_memory().map(|t| (t, t)),
Device::Cuda | Device::Rocm | Device::Gpu | Device::Vulkan => {
memory_estimate::available_unified_memory().map(|t| (t, t))
}
_ => None,
}
}
pub fn validate_sam_device(family: &str, device: Device) -> Result<()> {
if device == Device::Tpu || is_standard_device(device) {
Ok(())
} else {
bail!(
"{family}: device {device:?} is not supported \
(use {STANDARD_DEVICE_NAMES} or tpu)"
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn standard_set_covers_cli_backends() {
for dev in STANDARD_DEVICES {
assert!(is_standard_device(*dev));
}
assert!(!is_standard_device(Device::Tpu));
}
}