use rlx_driver::Device;
use rlx_ir::{Graph, Op};
use crate::CompileOptions;
pub fn is_available(device: Device) -> bool {
#[cfg(feature = "cuda")]
if device == Device::Cuda {
return rlx_cuda::is_available();
}
#[cfg(feature = "rocm")]
if device == Device::Rocm {
return rlx_rocm::is_available();
}
#[cfg(feature = "gpu")]
if device == Device::Gpu {
return rlx_wgpu::is_available();
}
#[cfg(feature = "vulkan")]
if device == Device::Vulkan {
return rlx_wgpu::is_vulkan_available();
}
#[cfg(feature = "tpu")]
if device == Device::Tpu {
return rlx_tpu::is_available();
}
let feature_gated = match device {
Device::Cpu => cfg!(feature = "cpu"),
Device::Metal => cfg!(feature = "metal"),
Device::Mlx => cfg!(feature = "mlx"),
Device::Ane => cfg!(feature = "ane"),
Device::Cuda => cfg!(feature = "cuda"),
Device::Rocm => cfg!(feature = "rocm"),
Device::Tpu => cfg!(feature = "tpu"),
Device::Gpu => cfg!(feature = "gpu"),
Device::Vulkan => cfg!(feature = "vulkan"),
Device::OpenGl => cfg!(feature = "opengl"),
Device::DirectX => cfg!(feature = "directx"),
Device::WebGpu => cfg!(feature = "webgpu"),
};
if feature_gated {
return true;
}
crate::registry::registered_devices().contains(&device)
}
#[cfg(all(feature = "apple", target_os = "macos"))]
pub fn available_apple_devices() -> Vec<Device> {
[Device::Metal, Device::Mlx, Device::Gpu]
.into_iter()
.filter(|d| is_available(*d))
.collect()
}
pub fn available_devices() -> Vec<Device> {
Device::all()
.iter()
.copied()
.filter(|d| is_available(*d))
.collect()
}
pub fn full_name(device: Device) -> &'static str {
if let Device::Cpu = device {
if cfg!(feature = "blas-accelerate") {
return "CPU (Accelerate)";
}
if cfg!(feature = "blas-mkl") {
return "CPU (MKL)";
}
if cfg!(feature = "blas-openblas") {
return "CPU (OpenBLAS)";
}
}
device.name()
}
pub fn supports(device: Device, op: &Op) -> bool {
if !is_available(device) {
return false;
}
match device {
Device::Cpu => true, Device::Mlx => mlx_supports(op),
Device::Metal => metal_supports(op),
Device::Gpu | Device::Cuda | Device::Rocm => gpu_family_supports(op),
_ => false,
}
}
pub fn supports_graph(device: Device, graph: &Graph) -> bool {
supports_graph_with_options(device, graph, &CompileOptions::default())
}
pub fn supports_graph_with_options(
device: Device,
graph: &Graph,
options: &CompileOptions,
) -> bool {
if !is_available(device) {
return false;
}
if let Some(backend) = crate::registry::backend_for(device) {
let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
graph.clone(),
device.name(),
backend.supported_ops(),
options.kernel_dispatch,
);
return report.compile_ready;
}
graph.nodes().iter().all(|n| supports(device, &n.op))
}
pub fn legalize_graph_for_device(graph: Graph, device: Device) -> Result<Graph, String> {
let (graph, _report) = legalize_graph_for_device_with_report(graph, device)?;
Ok(graph)
}
pub fn legalize_graph_for_device_with_report(
graph: Graph,
device: Device,
) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
legalize_graph_for_device_with_options(graph, device, &CompileOptions::default())
}
pub fn legalize_graph_for_device_with_options(
graph: Graph,
device: Device,
options: &CompileOptions,
) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
let backend = crate::registry::backend_for(device).ok_or_else(|| {
format!(
"no backend registered for {device:?} — enable the matching \
`rlx-runtime` Cargo feature (e.g. `metal`, `gpu`, `cuda`)"
)
})?;
let ops = backend.supported_ops();
let (graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
graph,
device.name(),
ops,
options.kernel_dispatch,
);
if !report.compile_ready {
return Err(format!(
"{}\n{}",
rlx_opt::format_legalize_error(device.name(), &report.still_unsupported),
rlx_opt::format_dispatch_report(&report)
));
}
Ok((graph, report))
}
pub fn dispatch_report_for_device(
graph: &Graph,
device: Device,
) -> Result<rlx_opt::KernelDispatchReport, String> {
dispatch_report_for_device_with_options(graph, device, &CompileOptions::default())
}
pub fn dispatch_report_for_device_with_options(
graph: &Graph,
device: Device,
options: &CompileOptions,
) -> Result<rlx_opt::KernelDispatchReport, String> {
let backend = crate::registry::backend_for(device)
.ok_or_else(|| format!("no backend registered for {device:?}"))?;
Ok(rlx_opt::analyze_dispatch(
graph,
device.name(),
backend.supported_ops(),
options.kernel_dispatch,
))
}
pub fn first_unsupported_op(device: Device, graph: &Graph) -> Option<(usize, &Op)> {
first_unsupported_op_with_options(device, graph, &CompileOptions::default())
}
pub fn first_unsupported_op_with_options<'a>(
device: Device,
graph: &'a Graph,
options: &CompileOptions,
) -> Option<(usize, &'a Op)> {
if !is_available(device) {
return graph.nodes().first().map(|n| (0, &n.op));
}
if let Some(backend) = crate::registry::backend_for(device) {
let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
graph.clone(),
device.name(),
backend.supported_ops(),
options.kernel_dispatch,
);
if let Some((id, kind)) = report.still_unsupported.first() {
let idx = graph.nodes().iter().position(|n| n.id == *id).unwrap_or(0);
let op = graph
.nodes()
.iter()
.find(|n| n.id == *id)
.map(|n| &n.op)
.unwrap_or(&graph.nodes()[0].op);
let _ = kind;
return Some((idx, op));
}
return None;
}
graph
.nodes()
.iter()
.enumerate()
.find_map(|(i, n)| (!supports(device, &n.op)).then_some((i, &n.op)))
}
#[allow(unused_variables)]
fn mlx_supports(op: &Op) -> bool {
true
}
#[allow(unused_variables)]
fn metal_supports(op: &Op) -> bool {
let _ = op;
true
}
#[allow(unused_variables)]
fn gpu_family_supports(op: &Op) -> bool {
let _ = op;
true
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::op::{Activation, BinaryOp};
use rlx_ir::{DType, Graph, Shape};
fn scalar_shape() -> Shape {
Shape::new(&[1], DType::F32)
}
#[test]
fn cpu_supports_everything_built_in() {
assert!(supports(Device::Cpu, &Op::Activation(Activation::Sin)));
assert!(supports(Device::Cpu, &Op::Activation(Activation::Cos)));
assert!(supports(Device::Cpu, &Op::Activation(Activation::Exp)));
assert!(supports(Device::Cpu, &Op::Binary(BinaryOp::Add)));
}
#[test]
fn unbuilt_device_supports_nothing() {
assert!(!supports(Device::OpenGl, &Op::Activation(Activation::Relu)));
}
#[test]
#[cfg(feature = "metal")]
fn metal_supports_full_activation_set() {
for act in [
Activation::Sin,
Activation::Cos,
Activation::Tan,
Activation::Atan,
Activation::Exp,
] {
assert!(
supports(Device::Metal, &Op::Activation(act)),
"Metal should support Activation::{act:?}"
);
}
}
#[test]
fn graph_walk_reports_first_blocker() {
let mut g = Graph::new("walk");
let s = scalar_shape();
let x = g.input("x", s.clone());
let _e = g.activation(Activation::Exp, x, s.clone());
let _sin = g.activation(Activation::Sin, x, s);
assert!(supports_graph(Device::Cpu, &g));
assert!(first_unsupported_op(Device::Cpu, &g).is_none());
}
}