use rlx::Device;
pub fn parse_device(s: &str) -> anyhow::Result<Device> {
let key = s.trim().to_ascii_lowercase();
match key.as_str() {
"cpu" => Ok(Device::Cpu),
"metal" | "mtl" => Ok(Device::Metal),
"mlx" => Ok(Device::Mlx),
"gpu" | "wgpu" => Ok(Device::Gpu),
"vulkan" | "vk" => Ok(Device::Vulkan),
"cuda" | "nvidia" => Ok(Device::Cuda),
"rocm" | "hip" | "amd" => Ok(Device::Rocm),
"tpu" => Ok(Device::Tpu),
"" => anyhow::bail!("empty device name (try: cpu, metal, mlx, gpu, cuda)"),
other => {
anyhow::bail!("unknown device '{other}' — try: cpu, metal, mlx, gpu, cuda, rocm, tpu")
}
}
}
pub fn recommended_features(device: Device) -> &'static str {
match device {
Device::Cpu => "rlx-engine",
Device::Metal => "rlx-engine,rlx-metal",
Device::Mlx => "rlx-engine,rlx-mlx",
Device::Gpu | Device::Vulkan | Device::WebGpu => "rlx-engine,rlx-gpu",
Device::Cuda => "rlx-engine,rlx-cuda",
Device::Rocm => "rlx-engine,rlx-rocm",
Device::Tpu => "rlx-engine,rlx-tpu",
Device::Ane | Device::OpenGl | Device::DirectX => "rlx-engine",
}
}
pub fn brainjepa_features(device: Device) -> &'static str {
recommended_features(device)
}
pub fn feature_enabled_in_build(device: Device) -> bool {
match device {
Device::Cpu => cfg!(feature = "rlx-cpu"),
Device::Metal => cfg!(feature = "rlx-metal"),
Device::Mlx => cfg!(feature = "rlx-mlx"),
Device::Gpu | Device::Vulkan | Device::WebGpu => cfg!(feature = "rlx-gpu"),
Device::Cuda => cfg!(feature = "rlx-cuda"),
Device::Rocm => cfg!(feature = "rlx-rocm"),
Device::Tpu => cfg!(feature = "rlx-tpu"),
_ => false,
}
}
pub fn runtime_available(device: Device) -> bool {
rlx::runtime::is_available(device)
}
pub fn available_devices() -> Vec<Device> {
rlx::runtime::available_devices()
}
fn device_user_name(device: Device) -> &'static str {
match device {
Device::Cpu => "cpu",
Device::Metal => "metal",
Device::Mlx => "mlx",
Device::Gpu => "gpu",
Device::Vulkan => "vulkan",
Device::Cuda => "cuda",
Device::Rocm => "rocm",
Device::Tpu => "tpu",
_ => "unknown",
}
}
fn mlx_extra_note() -> Option<&'static str> {
if cfg!(feature = "rlx-mlx") {
None
} else {
Some(
"MLX needs rlx with the mlx feature enabled:\n\
cargo build --release --features rlx-engine,rlx-mlx --bin infer",
)
}
}
fn runtime_hint(device: Device) -> &'static str {
match device {
Device::Metal => "Requires macOS with Metal support.",
Device::Mlx => "Requires macOS with MLX (Apple Silicon).",
Device::Gpu | Device::Vulkan | Device::WebGpu => {
"Requires a wgpu adapter (Metal on macOS, Vulkan on Linux, DX12 on Windows)."
}
Device::Cuda => "Requires an NVIDIA GPU with CUDA drivers installed.",
Device::Rocm => "Requires an AMD GPU with ROCm installed.",
Device::Tpu => "Requires a TPU runtime (libtpu / GCP TPU).",
_ => "Check that the platform driver for this backend is installed.",
}
}
fn format_available_list(devices: &[Device]) -> String {
if devices.is_empty() {
return " (none — rebuild with e.g. `--features rlx-engine,rlx-metal --bin infer`)".into();
}
devices
.iter()
.map(|d| device_user_name(*d))
.collect::<Vec<_>>()
.join(", ")
}
pub fn prepare_device(device: Device) {
match device {
Device::Metal => {
if std::env::var_os("RLX_METAL_UNFUSE_REGIONS").is_none()
&& std::env::var_os("RLX_METAL_NO_FUSION").is_none()
{
unsafe { std::env::set_var("RLX_METAL_UNFUSE_REGIONS", "1") };
}
if std::env::var_os("RLX_METAL_SGEMM_MPS").is_none() {
unsafe { std::env::set_var("RLX_METAL_SGEMM_MPS", "1") };
}
}
Device::Mlx => {
if std::env::var_os("RLX_MLX_MODE").is_none() {
unsafe { std::env::set_var("RLX_MLX_MODE", "compiled") };
}
}
Device::Gpu | Device::Vulkan | Device::WebGpu => {
let _ = ();
}
_ => {}
}
}
pub fn ensure_device(device: Device) -> anyhow::Result<()> {
prepare_device(device);
let name = device_user_name(device);
let feats = recommended_features(device);
if !feature_enabled_in_build(device) {
let mut msg = format!(
"RLX device '{name}' is not enabled in this brainjepa build.\n\n\
Rebuild with:\n\
cargo build --release --no-default-features --features {feats}\n\n\
Example infer:\n\
cargo run --release --no-default-features --features {feats} --bin infer -- \\\n\
--device {name} --input <fmri.safetensors>",
);
if device == Device::Mlx {
if let Some(note) = mlx_extra_note() {
msg.push_str("\n\n");
msg.push_str(note);
}
}
let avail = available_devices();
msg.push_str("\n\nBackends that work in this binary right now: ");
msg.push_str(&format_available_list(&avail));
return Err(anyhow::anyhow!("{msg}"));
}
if !runtime_available(device) {
let mut msg = format!(
"RLX device '{name}' is compiled in but not available on this machine.\n\n\
{}",
runtime_hint(device),
);
if matches!(device, Device::Gpu | Device::Vulkan) {
msg.push_str(
"\n\nOn macOS, `--device metal` is usually faster than `--device gpu` (native Metal vs wgpu).",
);
}
let avail = available_devices();
msg.push_str("\n\nDevices that work here: ");
msg.push_str(&format_available_list(&avail));
return Err(anyhow::anyhow!("{msg}"));
}
Ok(())
}
pub fn display_name(device: Device) -> String {
rlx::runtime::full_name(device).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_aliases() {
assert_eq!(parse_device("wgpu").unwrap(), Device::Gpu);
assert_eq!(parse_device("MTL").unwrap(), Device::Metal);
assert_eq!(parse_device("nvidia").unwrap(), Device::Cuda);
}
#[test]
fn cpu_always_ok_in_default_build() {
ensure_device(Device::Cpu).expect("cpu should be available");
}
}