use rlx_core::flow_bridge::compile_options_from_profile;
use rlx_flow::CompileProfile;
use rlx_ir::logical_kernel::KernelDispatchConfig;
use rlx_runtime::{CompileOptions, Device};
pub fn moonvit_use_decomposed_rope(device: Device) -> bool {
matches!(
device,
Device::Metal | Device::Mlx | Device::Gpu | Device::Vulkan
)
}
pub fn locateanything_uses_cpu_host(requested: Device) -> bool {
locateanything_host_device(requested) != requested
}
pub fn locateanything_host_device(requested: Device) -> Device {
match requested {
Device::Cpu => Device::Cpu,
_ => Device::Cpu,
}
}
pub fn vision_encode_device(lm_device: Device) -> Device {
locateanything_host_device(lm_device)
}
pub fn lm_host_device(lm_device: Device) -> Device {
locateanything_host_device(lm_device)
}
pub fn lm_gpu_kv_enabled(requested: Device) -> bool {
if locateanything_uses_cpu_host(requested) {
return false;
}
matches!(
requested,
Device::Mlx | Device::Cuda | Device::Rocm | Device::Gpu | Device::Vulkan
)
}
pub fn lm_active_extent_enabled(requested: Device) -> bool {
!locateanything_uses_cpu_host(requested) && requested != Device::Metal
}
pub fn lm_decode_compile_options(device: Device) -> CompileOptions {
let mut profile = CompileProfile::llama32_decode();
if matches!(device, Device::Gpu | Device::Vulkan | Device::Mlx) {
profile.fusion.skip = true;
}
compile_options_from_profile(&profile, device, KernelDispatchConfig::default())
}
pub fn lm_prefill_compile_options(device: Device) -> CompileOptions {
let mut profile = CompileProfile::llama32_prefill();
if matches!(device, Device::Gpu | Device::Vulkan) {
profile.fusion.skip = true;
}
compile_options_from_profile(&profile, device, KernelDispatchConfig::default())
}
pub fn metal_lm_compile_guard<R, F>(device: Device, f: F) -> R
where
F: FnOnce() -> R,
{
if device == Device::Metal {
rlx_ir::env::set("RLX_DISABLE_MPSGRAPH", "1");
let out = f();
rlx_ir::env::unset("RLX_DISABLE_MPSGRAPH");
out
} else {
f()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn host_device_cpu_only_for_now() {
assert_eq!(locateanything_host_device(Device::Cpu), Device::Cpu);
assert_eq!(locateanything_host_device(Device::Metal), Device::Cpu);
assert_eq!(locateanything_host_device(Device::Mlx), Device::Cpu);
assert_eq!(locateanything_host_device(Device::Cuda), Device::Cpu);
assert!(locateanything_uses_cpu_host(Device::Metal));
assert!(!locateanything_uses_cpu_host(Device::Cpu));
}
#[test]
fn gpu_kv_disabled_on_cpu_host() {
assert!(!lm_gpu_kv_enabled(Device::Metal));
assert!(!lm_gpu_kv_enabled(Device::Mlx));
assert!(!lm_active_extent_enabled(Device::Metal));
}
}