Skip to main content

rlx_runtime/
device_ext.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Engine-layer extensions for [`rlx_driver::Device`] (plan #58).
17//!
18//! `is_available` and `available_devices` consult the runtime's
19//! backend registry + Cargo features, both of which are
20//! engine-layer concerns. Keeping them here preserves the
21//! one-way dep direction (driver doesn't know about engine).
22
23use rlx_driver::Device;
24use rlx_ir::{Graph, Op};
25
26use crate::CompileOptions;
27
28/// Check whether `device` has a compiled-in backend or has been
29/// registered by an external crate.
30///
31/// GPU-family builtins (CUDA / ROCm / wgpu / TPU) additionally probe
32/// for a live driver or adapter at runtime so CI hosts that compile
33/// with `--features cuda` but have no NVIDIA stack don't report
34/// false positives. Other devices are Cargo-feature-gated; externally
35/// registered backends are discovered via the registry.
36pub fn is_available(device: Device) -> bool {
37    #[cfg(feature = "cuda")]
38    if device == Device::Cuda {
39        return rlx_cuda::is_available();
40    }
41    #[cfg(feature = "rocm")]
42    if device == Device::Rocm {
43        return rlx_rocm::is_available();
44    }
45    #[cfg(feature = "gpu")]
46    if device == Device::Gpu {
47        return rlx_wgpu::is_available();
48    }
49    #[cfg(feature = "vulkan")]
50    if device == Device::Vulkan {
51        return rlx_wgpu::is_vulkan_available();
52    }
53    #[cfg(feature = "tpu")]
54    if device == Device::Tpu {
55        return rlx_tpu::is_available();
56    }
57
58    let feature_gated = match device {
59        Device::Cpu => cfg!(feature = "cpu"),
60        Device::Metal => cfg!(feature = "metal"),
61        Device::Mlx => cfg!(feature = "mlx"),
62        Device::Ane => cfg!(feature = "ane"),
63        Device::Cuda => cfg!(feature = "cuda"),
64        Device::Rocm => cfg!(feature = "rocm"),
65        Device::Tpu => cfg!(feature = "tpu"),
66        Device::Gpu => cfg!(feature = "gpu"),
67        Device::Vulkan => cfg!(feature = "vulkan"),
68        Device::OpenGl => cfg!(feature = "opengl"),
69        Device::DirectX => cfg!(feature = "directx"),
70        Device::WebGpu => cfg!(feature = "webgpu"),
71    };
72    if feature_gated {
73        return true;
74    }
75    crate::registry::registered_devices().contains(&device)
76}
77
78/// Apple backends enabled in this build (`metal`, `mlx`, `gpu` on macOS).
79#[cfg(all(feature = "apple", target_os = "macos"))]
80pub fn available_apple_devices() -> Vec<Device> {
81    [Device::Metal, Device::Mlx, Device::Gpu]
82        .into_iter()
83        .filter(|d| is_available(*d))
84        .collect()
85}
86
87/// Every variant currently available — Cargo-feature-gated or
88/// runtime-registered.
89pub fn available_devices() -> Vec<Device> {
90    Device::all()
91        .iter()
92        .copied()
93        .filter(|d| is_available(*d))
94        .collect()
95}
96
97/// Pretty name with engine-known BLAS variant for the CPU device.
98/// Gives `"CPU (Accelerate)"` etc. when the relevant feature is
99/// on; falls back to the bare driver-side `Device::name()` when
100/// no BLAS feature is selected.
101pub fn full_name(device: Device) -> &'static str {
102    if let Device::Cpu = device {
103        if cfg!(feature = "blas-accelerate") {
104            return "CPU (Accelerate)";
105        }
106        if cfg!(feature = "blas-mkl") {
107            return "CPU (MKL)";
108        }
109        if cfg!(feature = "blas-openblas") {
110            return "CPU (OpenBLAS)";
111        }
112    }
113    device.name()
114}
115
116// ── Per-device op-support introspection ──────────────────────────
117//
118// Callers that want to dispatch graphs to a particular device need
119// to know up front whether the device's backend has every op the
120// graph uses wired up. Before this API, the only signal was a
121// runtime panic ("not yet implemented"), which forced downstream
122// crates (e.g. `eda-magnetics::graph::pick_device_for`) to bake
123// hand-maintained "what's missing on X" tables into their own
124// source — those drift the moment a backend lands the missing op.
125//
126// [`supports`] consults the backend-side knowledge (CPU is the
127// reference and assumed complete; MLX / Metal each name the ops
128// they don't yet lower) so consumers can ask once and stop
129// re-implementing the table.
130
131/// Is `op` lowerable by the backend for `device` *in this build*?
132///
133/// - CPU is the reference; always returns `true`.
134/// - GPU backends return `false` only for the specific ops/variants
135///   their lowering currently rejects. As backends close gaps, the
136///   matches here shrink and consumers automatically pick them up.
137/// - For devices not feature-gated in, returns `false` (you can't
138///   dispatch to a backend that isn't compiled in regardless).
139pub fn supports(device: Device, op: &Op) -> bool {
140    if !is_available(device) {
141        return false;
142    }
143    match device {
144        Device::Cpu => true, // reference backend; ground truth
145        Device::Mlx => mlx_supports(op),
146        Device::Metal => metal_supports(op),
147        Device::Gpu | Device::Cuda | Device::Rocm => gpu_family_supports(op),
148        // Other backends not yet characterised here. Conservative:
149        // assume `false` so callers won't dispatch blind; tighten as
150        // each backend grows a `<x>_supports` arm below.
151        _ => false,
152    }
153}
154
155/// Is every op in `graph` lowerable by `device`?
156///
157/// When a backend is registered, uses the same rewrite + legalization probe as
158/// [`legalize_graph_for_device`] (see [`KernelDispatchReport::compile_ready`]).
159/// Otherwise falls back to per-op [`supports`] heuristics.
160pub fn supports_graph(device: Device, graph: &Graph) -> bool {
161    supports_graph_with_options(device, graph, &CompileOptions::default())
162}
163
164/// Like [`supports_graph`] with explicit [`CompileOptions::kernel_dispatch`].
165pub fn supports_graph_with_options(
166    device: Device,
167    graph: &Graph,
168    options: &CompileOptions,
169) -> bool {
170    if !is_available(device) {
171        return false;
172    }
173    if let Some(backend) = crate::registry::backend_for(device) {
174        let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
175            graph.clone(),
176            device.name(),
177            backend.supported_ops(),
178            options.kernel_dispatch,
179        );
180        return report.compile_ready;
181    }
182    graph.nodes().iter().all(|n| supports(device, &n.op))
183}
184
185/// Legalize `graph` for `device` using that backend's claimed [`OpKind`] set.
186///
187/// Applies the same rewrite + legalization path as [`Backend::compile`] (e.g.
188/// CUDA/ROCm rewrites before the legality check). Returns an error when the
189/// backend feature is not enabled or the graph contains unsupported ops.
190///
191/// Does not require a live GPU/TPU driver — only that the backend crate is
192/// compiled in.
193pub fn legalize_graph_for_device(graph: Graph, device: Device) -> Result<Graph, String> {
194    let (graph, _report) = legalize_graph_for_device_with_report(graph, device)?;
195    Ok(graph)
196}
197
198/// Like [`legalize_graph_for_device`] but returns a [`KernelDispatchReport`] for tooling.
199pub fn legalize_graph_for_device_with_report(
200    graph: Graph,
201    device: Device,
202) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
203    legalize_graph_for_device_with_options(graph, device, &CompileOptions::default())
204}
205
206/// Like [`legalize_graph_for_device_with_report`] using [`CompileOptions::kernel_dispatch`]
207/// (and the same rewrite path as [`Backend::compile`]).
208pub fn legalize_graph_for_device_with_options(
209    graph: Graph,
210    device: Device,
211    options: &CompileOptions,
212) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
213    let backend = crate::registry::backend_for(device).ok_or_else(|| {
214        format!(
215            "no backend registered for {device:?} — enable the matching \
216             `rlx-runtime` Cargo feature (e.g. `metal`, `gpu`, `cuda`)"
217        )
218    })?;
219    let ops = backend.supported_ops();
220    let (graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
221        graph,
222        device.name(),
223        ops,
224        options.kernel_dispatch,
225    );
226    if !report.compile_ready {
227        return Err(format!(
228            "{}\n{}",
229            rlx_opt::format_legalize_error(device.name(), &report.still_unsupported),
230            rlx_opt::format_dispatch_report(&report)
231        ));
232    }
233    Ok((graph, report))
234}
235
236/// Dispatch report for `graph` on `device` without mutating the graph (static common-ir probe).
237pub fn dispatch_report_for_device(
238    graph: &Graph,
239    device: Device,
240) -> Result<rlx_opt::KernelDispatchReport, String> {
241    dispatch_report_for_device_with_options(graph, device, &CompileOptions::default())
242}
243
244/// Like [`dispatch_report_for_device`] with explicit [`CompileOptions::kernel_dispatch`].
245pub fn dispatch_report_for_device_with_options(
246    graph: &Graph,
247    device: Device,
248    options: &CompileOptions,
249) -> Result<rlx_opt::KernelDispatchReport, String> {
250    let backend = crate::registry::backend_for(device)
251        .ok_or_else(|| format!("no backend registered for {device:?}"))?;
252    Ok(rlx_opt::analyze_dispatch(
253        graph,
254        device.name(),
255        backend.supported_ops(),
256        options.kernel_dispatch,
257    ))
258}
259
260/// First op in `graph` that `device` cannot lower after rewrite, or `None`.
261///
262/// Prefer the backend claim-set probe when registered; otherwise [`supports`].
263pub fn first_unsupported_op(device: Device, graph: &Graph) -> Option<(usize, &Op)> {
264    first_unsupported_op_with_options(device, graph, &CompileOptions::default())
265}
266
267/// Like [`first_unsupported_op`] with explicit [`CompileOptions::kernel_dispatch`].
268pub fn first_unsupported_op_with_options<'a>(
269    device: Device,
270    graph: &'a Graph,
271    options: &CompileOptions,
272) -> Option<(usize, &'a Op)> {
273    if !is_available(device) {
274        return graph.nodes().first().map(|n| (0, &n.op));
275    }
276    if let Some(backend) = crate::registry::backend_for(device) {
277        let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
278            graph.clone(),
279            device.name(),
280            backend.supported_ops(),
281            options.kernel_dispatch,
282        );
283        if let Some((id, kind)) = report.still_unsupported.first() {
284            let idx = graph.nodes().iter().position(|n| n.id == *id).unwrap_or(0);
285            let op = graph
286                .nodes()
287                .iter()
288                .find(|n| n.id == *id)
289                .map(|n| &n.op)
290                .unwrap_or(&graph.nodes()[0].op);
291            let _ = kind;
292            return Some((idx, op));
293        }
294        return None;
295    }
296    graph
297        .nodes()
298        .iter()
299        .enumerate()
300        .find_map(|(i, n)| (!supports(device, &n.op)).then_some((i, &n.op)))
301}
302
303#[allow(unused_variables)]
304fn mlx_supports(op: &Op) -> bool {
305    // After Sin/Cos wiring (forward + backward), MLX's `Activation`
306    // dispatch is complete for every variant in `rlx_ir::Activation`.
307    // Add narrow guards here only when a future Op or Activation
308    // variant lands without an MLX lowering.
309    true
310}
311
312#[allow(unused_variables)]
313fn metal_supports(op: &Op) -> bool {
314    // No characterized gaps for the activations rlx-eda exercises.
315    // The Sin/Cos/Tan/Atan MSL kernels landed in `rlx-metal/src/kernels.rs`
316    // (`{sin,cos,tan,atan}_inplace`) alongside the dispatch slots in
317    // `backend.rs:1764`. Narrow this back down if a future Op or
318    // Activation variant lands without a Metal kernel.
319    let _ = op;
320    true
321}
322
323#[allow(unused_variables)]
324fn gpu_family_supports(op: &Op) -> bool {
325    // CUDA / ROCm / wgpu share the same IR surface area as CPU for the
326    // ops V-JEPA2 and other vision models exercise. Narrow when a backend
327    // reports a concrete lowering gap.
328    let _ = op;
329    true
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use rlx_ir::op::{Activation, BinaryOp};
336    use rlx_ir::{DType, Graph, Shape};
337
338    fn scalar_shape() -> Shape {
339        Shape::new(&[1], DType::F32)
340    }
341
342    #[test]
343    fn cpu_supports_everything_built_in() {
344        assert!(supports(Device::Cpu, &Op::Activation(Activation::Sin)));
345        assert!(supports(Device::Cpu, &Op::Activation(Activation::Cos)));
346        assert!(supports(Device::Cpu, &Op::Activation(Activation::Exp)));
347        assert!(supports(Device::Cpu, &Op::Binary(BinaryOp::Add)));
348    }
349
350    #[test]
351    fn unbuilt_device_supports_nothing() {
352        // OpenGl isn't a workspace feature; should report false.
353        assert!(!supports(Device::OpenGl, &Op::Activation(Activation::Relu)));
354    }
355
356    #[test]
357    #[cfg(feature = "metal")]
358    fn metal_supports_full_activation_set() {
359        // After the {sin,cos,tan,atan}_inplace MSL kernels landed in
360        // rlx-metal/src/kernels.rs, Metal has every Activation variant
361        // rlx-eda exercises.
362        for act in [
363            Activation::Sin,
364            Activation::Cos,
365            Activation::Tan,
366            Activation::Atan,
367            Activation::Exp,
368        ] {
369            assert!(
370                supports(Device::Metal, &Op::Activation(act)),
371                "Metal should support Activation::{act:?}"
372            );
373        }
374    }
375
376    #[test]
377    fn graph_walk_reports_first_blocker() {
378        let mut g = Graph::new("walk");
379        let s = scalar_shape();
380        let x = g.input("x", s.clone());
381        let _e = g.activation(Activation::Exp, x, s.clone());
382        let _sin = g.activation(Activation::Sin, x, s);
383        // CPU always supports.
384        assert!(supports_graph(Device::Cpu, &g));
385        assert!(first_unsupported_op(Device::Cpu, &g).is_none());
386    }
387}