Skip to main content

rlx_runtime/
device_ext.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Engine-layer extensions for [`rlx_driver::Device`] (plan #58).
17//!
18//! `is_available` and `available_devices` consult the runtime's
19//! backend registry + Cargo features, both of which are
20//! engine-layer concerns. Keeping them here preserves the
21//! one-way dep direction (driver doesn't know about engine).
22
23use rlx_driver::Device;
24use rlx_ir::{Graph, Op};
25
26use crate::CompileOptions;
27
28/// Preferred probe order for ML workloads (highest throughput first).
29///
30/// Used by [`fastest_device`] and by [`crate::cost::fastest_device_for`] when
31/// calibrated cost models are unavailable for every candidate backend.
32pub(crate) const DEVICE_PRIORITY: &[Device] = &[
33    Device::Tpu,
34    Device::Cuda,
35    Device::Rocm,
36    Device::Mlx,
37    Device::Metal,
38    Device::Ane,
39    Device::Gpu,
40    Device::Vulkan,
41    Device::DirectX,
42    Device::OpenGl,
43    Device::WebGpu,
44    Device::Cpu,
45];
46
47/// Check whether `device` has a compiled-in backend or has been
48/// registered by an external crate.
49///
50/// GPU-family builtins (CUDA / ROCm / wgpu / TPU) additionally probe
51/// for a live driver or adapter at runtime so CI hosts that compile
52/// with `--features cuda` but have no NVIDIA stack don't report
53/// false positives. Other devices are Cargo-feature-gated; externally
54/// registered backends are discovered via the registry.
55/// Whether [`crate::CompiledGraph::run_slots`] + [`crate::CompiledGraph::arena_ptr`]
56/// are implemented (host readback layout; not a GPU-mapped arena on CUDA).
57pub fn supports_run_slots(device: Device) -> bool {
58    matches!(
59        device,
60        Device::Cpu | Device::Metal | Device::Mlx | Device::Cuda | Device::Rocm
61    )
62}
63
64pub fn is_available(device: Device) -> bool {
65    #[cfg(feature = "cuda")]
66    if device == Device::Cuda {
67        return rlx_cuda::is_available();
68    }
69    #[cfg(feature = "rocm")]
70    if device == Device::Rocm {
71        return rlx_rocm::is_available();
72    }
73    #[cfg(feature = "gpu")]
74    if device == Device::Gpu {
75        return rlx_wgpu::is_available();
76    }
77    #[cfg(feature = "vulkan")]
78    if device == Device::Vulkan {
79        return rlx_wgpu::is_vulkan_available();
80    }
81    #[cfg(feature = "tpu")]
82    if device == Device::Tpu {
83        return rlx_tpu::is_available();
84    }
85
86    let feature_gated = match device {
87        Device::Cpu => cfg!(feature = "cpu"),
88        Device::Metal => cfg!(feature = "metal"),
89        Device::Mlx => cfg!(feature = "mlx"),
90        Device::Ane => cfg!(feature = "ane"),
91        Device::Cuda => cfg!(feature = "cuda"),
92        Device::Rocm => cfg!(feature = "rocm"),
93        Device::Tpu => cfg!(feature = "tpu"),
94        Device::Gpu => cfg!(feature = "gpu"),
95        Device::Vulkan => cfg!(feature = "vulkan"),
96        Device::OpenGl => cfg!(feature = "opengl"),
97        Device::DirectX => cfg!(feature = "directx"),
98        Device::WebGpu => cfg!(feature = "webgpu"),
99    };
100    if feature_gated {
101        return true;
102    }
103    crate::registry::registered_devices().contains(&device)
104}
105
106/// Apple backends enabled in this build (`metal`, `mlx`, `gpu` on macOS).
107#[cfg(all(feature = "apple", target_os = "macos"))]
108pub fn available_apple_devices() -> Vec<Device> {
109    [Device::Metal, Device::Mlx, Device::Gpu]
110        .into_iter()
111        .filter(|d| is_available(*d))
112        .collect()
113}
114
115/// Every variant currently available — Cargo-feature-gated or
116/// runtime-registered.
117pub fn available_devices() -> Vec<Device> {
118    Device::all()
119        .iter()
120        .copied()
121        .filter(|d| is_available(*d))
122        .collect()
123}
124
125/// Intersection of [`available_devices`] and [`supports_graph`]. Use with
126/// [`crate::GraphDevices`] or [`crate::DevicePolicy`] to restrict the set.
127pub fn devices_for(graph: &Graph) -> Vec<Device> {
128    crate::device_policy::devices_for_with_policy(graph, &crate::DevicePolicy::default())
129}
130
131/// Highest-priority backend that is compiled in and live on this host.
132///
133/// Probes [`DEVICE_PRIORITY`] in order (TPU → CUDA → ROCm → MLX → Metal → …
134/// → CPU). Use this when you want a sensible default `Session` target without
135/// building a graph first. For workload-specific selection, prefer
136/// [`crate::cost::fastest_device_for`].
137pub fn fastest_device() -> Device {
138    fastest_among(&available_devices())
139}
140
141/// Pick the highest-priority entry from `candidates` (see [`DEVICE_PRIORITY`]).
142pub fn fastest_among(candidates: &[Device]) -> Device {
143    for &d in DEVICE_PRIORITY {
144        if candidates.contains(&d) {
145            return d;
146        }
147    }
148    candidates.first().copied().unwrap_or(Device::Cpu)
149}
150
151/// Pretty name with engine-known BLAS variant for the CPU device.
152/// Gives `"CPU (Accelerate)"` etc. when the relevant feature is
153/// on; falls back to the bare driver-side `Device::name()` when
154/// no BLAS feature is selected.
155pub fn full_name(device: Device) -> &'static str {
156    if let Device::Cpu = device {
157        if cfg!(feature = "blas-accelerate") {
158            return "CPU (Accelerate)";
159        }
160        if cfg!(feature = "blas-mkl") {
161            return "CPU (MKL)";
162        }
163        if cfg!(feature = "blas-openblas") {
164            return "CPU (OpenBLAS)";
165        }
166    }
167    device.name()
168}
169
170// ── Per-device op-support introspection ──────────────────────────
171//
172// Callers that want to dispatch graphs to a particular device need
173// to know up front whether the device's backend has every op the
174// graph uses wired up. Before this API, the only signal was a
175// runtime panic ("not yet implemented"), which forced downstream
176// crates (e.g. `eda-magnetics::graph::pick_device_for`) to bake
177// hand-maintained "what's missing on X" tables into their own
178// source — those drift the moment a backend lands the missing op.
179//
180// [`supports`] consults the backend-side knowledge (CPU is the
181// reference and assumed complete; MLX / Metal each name the ops
182// they don't yet lower) so consumers can ask once and stop
183// re-implementing the table.
184
185/// Is `op` lowerable by the backend for `device` *in this build*?
186///
187/// - CPU is the reference; always returns `true`.
188/// - GPU backends return `false` only for the specific ops/variants
189///   their lowering currently rejects. As backends close gaps, the
190///   matches here shrink and consumers automatically pick them up.
191/// - For devices not feature-gated in, returns `false` (you can't
192///   dispatch to a backend that isn't compiled in regardless).
193pub fn supports(device: Device, op: &Op) -> bool {
194    if !is_available(device) {
195        return false;
196    }
197    match device {
198        Device::Cpu => true, // reference backend; ground truth
199        Device::Mlx => mlx_supports(op),
200        Device::Metal => metal_supports(op),
201        Device::Gpu | Device::Cuda | Device::Rocm => gpu_family_supports(op),
202        // Other backends not yet characterised here. Conservative:
203        // assume `false` so callers won't dispatch blind; tighten as
204        // each backend grows a `<x>_supports` arm below.
205        _ => false,
206    }
207}
208
209/// Is every op in `graph` lowerable by `device`?
210///
211/// When a backend is registered, uses the same rewrite + legalization probe as
212/// [`legalize_graph_for_device`] (see [`KernelDispatchReport::compile_ready`]).
213/// Otherwise falls back to per-op [`supports`] heuristics.
214pub fn supports_graph(device: Device, graph: &Graph) -> bool {
215    supports_graph_with_options(device, graph, &CompileOptions::default())
216}
217
218/// Like [`supports_graph`] with explicit [`CompileOptions::kernel_dispatch`].
219pub fn supports_graph_with_options(
220    device: Device,
221    graph: &Graph,
222    options: &CompileOptions,
223) -> bool {
224    if !is_available(device) {
225        return false;
226    }
227    if let Some(backend) = crate::registry::backend_for(device) {
228        let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
229            graph.clone(),
230            device.name(),
231            backend.supported_ops(),
232            options.kernel_dispatch,
233        );
234        return report.compile_ready;
235    }
236    graph.nodes().iter().all(|n| supports(device, &n.op))
237}
238
239/// Legalize `graph` for `device` using that backend's claimed [`OpKind`] set.
240///
241/// Applies the same rewrite + legalization path as [`Backend::compile`] (e.g.
242/// CUDA/ROCm rewrites before the legality check). Returns an error when the
243/// backend feature is not enabled or the graph contains unsupported ops.
244///
245/// Does not require a live GPU/TPU driver — only that the backend crate is
246/// compiled in.
247pub fn legalize_graph_for_device(graph: Graph, device: Device) -> Result<Graph, String> {
248    let (graph, _report) = legalize_graph_for_device_with_report(graph, device)?;
249    Ok(graph)
250}
251
252/// Like [`legalize_graph_for_device`] but returns a [`KernelDispatchReport`] for tooling.
253pub fn legalize_graph_for_device_with_report(
254    graph: Graph,
255    device: Device,
256) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
257    legalize_graph_for_device_with_options(graph, device, &CompileOptions::default())
258}
259
260/// Like [`legalize_graph_for_device_with_report`] using [`CompileOptions::kernel_dispatch`]
261/// (and the same rewrite path as [`Backend::compile`]).
262pub fn legalize_graph_for_device_with_options(
263    graph: Graph,
264    device: Device,
265    options: &CompileOptions,
266) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
267    let backend = crate::registry::backend_for(device).ok_or_else(|| {
268        format!(
269            "no backend registered for {device:?} — enable the matching \
270             `rlx-runtime` Cargo feature (e.g. `metal`, `gpu`, `cuda`)"
271        )
272    })?;
273    let ops = backend.supported_ops();
274    let (graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
275        graph,
276        device.name(),
277        ops,
278        options.kernel_dispatch,
279    );
280    if !report.compile_ready {
281        return Err(format!(
282            "{}\n{}",
283            rlx_opt::format_legalize_error(device.name(), &report.still_unsupported),
284            rlx_opt::format_dispatch_report(&report)
285        ));
286    }
287    Ok((graph, report))
288}
289
290/// Dispatch report for `graph` on `device` without mutating the graph (static common-ir probe).
291pub fn dispatch_report_for_device(
292    graph: &Graph,
293    device: Device,
294) -> Result<rlx_opt::KernelDispatchReport, String> {
295    dispatch_report_for_device_with_options(graph, device, &CompileOptions::default())
296}
297
298/// Like [`dispatch_report_for_device`] with explicit [`CompileOptions::kernel_dispatch`].
299pub fn dispatch_report_for_device_with_options(
300    graph: &Graph,
301    device: Device,
302    options: &CompileOptions,
303) -> Result<rlx_opt::KernelDispatchReport, String> {
304    let backend = crate::registry::backend_for(device)
305        .ok_or_else(|| format!("no backend registered for {device:?}"))?;
306    Ok(rlx_opt::analyze_dispatch(
307        graph,
308        device.name(),
309        backend.supported_ops(),
310        options.kernel_dispatch,
311    ))
312}
313
314/// First op in `graph` that `device` cannot lower after rewrite, or `None`.
315///
316/// Prefer the backend claim-set probe when registered; otherwise [`supports`].
317pub fn first_unsupported_op(device: Device, graph: &Graph) -> Option<(usize, &Op)> {
318    first_unsupported_op_with_options(device, graph, &CompileOptions::default())
319}
320
321/// Like [`first_unsupported_op`] with explicit [`CompileOptions::kernel_dispatch`].
322pub fn first_unsupported_op_with_options<'a>(
323    device: Device,
324    graph: &'a Graph,
325    options: &CompileOptions,
326) -> Option<(usize, &'a Op)> {
327    if !is_available(device) {
328        return graph.nodes().first().map(|n| (0, &n.op));
329    }
330    if let Some(backend) = crate::registry::backend_for(device) {
331        let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
332            graph.clone(),
333            device.name(),
334            backend.supported_ops(),
335            options.kernel_dispatch,
336        );
337        if let Some((id, kind)) = report.still_unsupported.first() {
338            let idx = graph.nodes().iter().position(|n| n.id == *id).unwrap_or(0);
339            let op = graph
340                .nodes()
341                .iter()
342                .find(|n| n.id == *id)
343                .map(|n| &n.op)
344                .unwrap_or(&graph.nodes()[0].op);
345            let _ = kind;
346            return Some((idx, op));
347        }
348        return None;
349    }
350    graph
351        .nodes()
352        .iter()
353        .enumerate()
354        .find_map(|(i, n)| (!supports(device, &n.op)).then_some((i, &n.op)))
355}
356
357#[allow(unused_variables)]
358fn mlx_supports(op: &Op) -> bool {
359    // After Sin/Cos wiring (forward + backward), MLX's `Activation`
360    // dispatch is complete for every variant in `rlx_ir::Activation`.
361    // Add narrow guards here only when a future Op or Activation
362    // variant lands without an MLX lowering.
363    true
364}
365
366#[allow(unused_variables)]
367fn metal_supports(op: &Op) -> bool {
368    // No characterized gaps for the activations rlx-eda exercises.
369    // The Sin/Cos/Tan/Atan MSL kernels landed in `rlx-metal/src/kernels.rs`
370    // (`{sin,cos,tan,atan}_inplace`) alongside the dispatch slots in
371    // `backend.rs:1764`. Narrow this back down if a future Op or
372    // Activation variant lands without a Metal kernel.
373    let _ = op;
374    true
375}
376
377#[allow(unused_variables)]
378fn gpu_family_supports(op: &Op) -> bool {
379    // CUDA / ROCm / wgpu share the same IR surface area as CPU for the
380    // ops V-JEPA2 and other vision models exercise. Narrow when a backend
381    // reports a concrete lowering gap.
382    let _ = op;
383    true
384}
385
386/// Block until `device`'s queue is idle. Metal drains the global queue;
387/// other backends are no-ops.
388pub fn drain_device(device: Device) {
389    #[cfg(all(target_os = "macos", feature = "metal"))]
390    {
391        if device == Device::Metal {
392            rlx_metal::device::drain_command_queue();
393        }
394    }
395    #[cfg(not(all(target_os = "macos", feature = "metal")))]
396    let _ = device;
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use rlx_ir::op::{Activation, BinaryOp};
403    use rlx_ir::{DType, Graph, Shape};
404
405    fn scalar_shape() -> Shape {
406        Shape::new(&[1], DType::F32)
407    }
408
409    #[test]
410    fn cpu_supports_everything_built_in() {
411        assert!(supports(Device::Cpu, &Op::Activation(Activation::Sin)));
412        assert!(supports(Device::Cpu, &Op::Activation(Activation::Cos)));
413        assert!(supports(Device::Cpu, &Op::Activation(Activation::Exp)));
414        assert!(supports(Device::Cpu, &Op::Binary(BinaryOp::Add)));
415    }
416
417    #[test]
418    fn unbuilt_device_supports_nothing() {
419        // OpenGl isn't a workspace feature; should report false.
420        assert!(!supports(Device::OpenGl, &Op::Activation(Activation::Relu)));
421    }
422
423    #[test]
424    #[cfg(feature = "metal")]
425    fn metal_supports_full_activation_set() {
426        // After the {sin,cos,tan,atan}_inplace MSL kernels landed in
427        // rlx-metal/src/kernels.rs, Metal has every Activation variant
428        // rlx-eda exercises.
429        for act in [
430            Activation::Sin,
431            Activation::Cos,
432            Activation::Tan,
433            Activation::Atan,
434            Activation::Exp,
435        ] {
436            assert!(
437                supports(Device::Metal, &Op::Activation(act)),
438                "Metal should support Activation::{act:?}"
439            );
440        }
441    }
442
443    #[test]
444    fn graph_walk_reports_first_blocker() {
445        let mut g = Graph::new("walk");
446        let s = scalar_shape();
447        let x = g.input("x", s.clone());
448        let _e = g.activation(Activation::Exp, x, s.clone());
449        let _sin = g.activation(Activation::Sin, x, s);
450        // CPU always supports.
451        assert!(supports_graph(Device::Cpu, &g));
452        assert!(first_unsupported_op(Device::Cpu, &g).is_none());
453    }
454
455    #[test]
456    fn fastest_device_returns_cpu_when_only_cpu_is_available() {
457        let pick = fastest_device();
458        assert!(is_available(pick));
459        assert_eq!(pick, fastest_among(&available_devices()));
460    }
461
462    #[test]
463    fn fastest_among_respects_priority_order() {
464        let pick = fastest_among(&[Device::Cpu, Device::Metal, Device::Mlx]);
465        assert_eq!(pick, Device::Mlx);
466    }
467
468    #[test]
469    fn devices_for_is_subset_of_available() {
470        let mut g = Graph::new("id");
471        let x = g.input("x", scalar_shape());
472        g.set_outputs(vec![x]);
473        for d in devices_for(&g) {
474            assert!(is_available(d));
475            assert!(supports_graph(d, &g));
476        }
477    }
478}