Skip to main content

rlx_runtime/
device_ext.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Engine-layer extensions for [`rlx_driver::Device`] (plan #58).
17//!
18//! `is_available` and `available_devices` consult the runtime's
19//! backend registry + Cargo features, both of which are
20//! engine-layer concerns. Keeping them here preserves the
21//! one-way dep direction (driver doesn't know about engine).
22
23use rlx_driver::Device;
24use rlx_ir::{Graph, Op};
25
26use crate::CompileOptions;
27
28/// Preferred probe order for ML workloads (highest throughput first).
29///
30/// Used by [`fastest_device`] and by [`crate::cost::fastest_device_for`] when
31/// calibrated cost models are unavailable for every candidate backend.
32pub(crate) const DEVICE_PRIORITY: &[Device] = &[
33    Device::Tpu,
34    Device::Cuda,
35    Device::Rocm,
36    Device::Mlx,
37    Device::Metal,
38    Device::Ane,
39    Device::Gpu,
40    Device::Vulkan,
41    Device::DirectX,
42    Device::OpenGl,
43    Device::WebGpu,
44    Device::Cpu,
45];
46
47/// Check whether `device` has a compiled-in backend or has been
48/// registered by an external crate.
49///
50/// GPU-family builtins (CUDA / ROCm / wgpu / TPU) additionally probe
51/// for a live driver or adapter at runtime so CI hosts that compile
52/// with `--features cuda` but have no NVIDIA stack don't report
53/// false positives. Other devices are Cargo-feature-gated; externally
54/// registered backends are discovered via the registry.
55/// Whether [`crate::CompiledGraph::run_slots`] + [`crate::CompiledGraph::arena_ptr`]
56/// are implemented (host readback layout; not a GPU-mapped arena on CUDA).
57pub fn supports_run_slots(device: Device) -> bool {
58    matches!(
59        device,
60        Device::Cpu | Device::Metal | Device::Mlx | Device::Cuda | Device::Rocm
61    )
62}
63
64pub fn is_available(device: Device) -> bool {
65    #[cfg(feature = "cuda")]
66    if device == Device::Cuda {
67        return rlx_cuda::is_available();
68    }
69    #[cfg(feature = "rocm")]
70    if device == Device::Rocm {
71        return rlx_rocm::is_available();
72    }
73    #[cfg(feature = "gpu")]
74    if device == Device::Gpu {
75        return rlx_wgpu::is_available();
76    }
77    #[cfg(feature = "vulkan")]
78    if device == Device::Vulkan {
79        return rlx_wgpu::is_vulkan_available();
80    }
81    #[cfg(feature = "tpu")]
82    if device == Device::Tpu {
83        return rlx_tpu::is_available();
84    }
85
86    let feature_gated = match device {
87        Device::Cpu => cfg!(feature = "cpu"),
88        Device::Metal => cfg!(feature = "metal"),
89        Device::Mlx => cfg!(feature = "mlx"),
90        Device::Ane => cfg!(any(feature = "coreml", feature = "ane")),
91        Device::Cuda => cfg!(feature = "cuda"),
92        Device::Rocm => cfg!(feature = "rocm"),
93        Device::Tpu => cfg!(feature = "tpu"),
94        Device::Gpu => cfg!(feature = "gpu"),
95        Device::Vulkan => cfg!(feature = "vulkan"),
96        Device::OpenGl => cfg!(feature = "opengl"),
97        Device::DirectX => cfg!(feature = "directx"),
98        Device::WebGpu => cfg!(feature = "webgpu"),
99    };
100    if feature_gated {
101        return true;
102    }
103    crate::registry::registered_devices().contains(&device)
104}
105
106/// Apple backends enabled in this build (`metal`, `mlx`, `gpu`, `ane` on
107/// macOS).
108#[cfg(all(feature = "apple", target_os = "macos"))]
109pub fn available_apple_devices() -> Vec<Device> {
110    [Device::Metal, Device::Mlx, Device::Gpu, Device::Ane]
111        .into_iter()
112        .filter(|d| is_available(*d))
113        .collect()
114}
115
116/// Every variant currently available — Cargo-feature-gated or
117/// runtime-registered.
118pub fn available_devices() -> Vec<Device> {
119    Device::all()
120        .iter()
121        .copied()
122        .filter(|d| is_available(*d))
123        .collect()
124}
125
126/// Intersection of [`available_devices`] and [`supports_graph`]. Use with
127/// [`crate::GraphDevices`] or [`crate::DevicePolicy`] to restrict the set.
128pub fn devices_for(graph: &Graph) -> Vec<Device> {
129    crate::device_policy::devices_for_with_policy(graph, &crate::DevicePolicy::default())
130}
131
132/// Highest-priority backend that is compiled in and live on this host.
133///
134/// Probes [`DEVICE_PRIORITY`] in order (TPU → CUDA → ROCm → MLX → Metal → …
135/// → CPU). Use this when you want a sensible default `Session` target without
136/// building a graph first. For workload-specific selection, prefer
137/// [`crate::cost::fastest_device_for`].
138pub fn fastest_device() -> Device {
139    fastest_among(&available_devices())
140}
141
142/// Pick the highest-priority entry from `candidates` (see [`DEVICE_PRIORITY`]).
143pub fn fastest_among(candidates: &[Device]) -> Device {
144    for &d in DEVICE_PRIORITY {
145        if candidates.contains(&d) {
146            return d;
147        }
148    }
149    candidates.first().copied().unwrap_or(Device::Cpu)
150}
151
152/// Pretty name with engine-known BLAS variant for the CPU device.
153/// Gives `"CPU (Accelerate)"` etc. when the relevant feature is
154/// on; falls back to the bare driver-side `Device::name()` when
155/// no BLAS feature is selected.
156pub fn full_name(device: Device) -> &'static str {
157    if let Device::Cpu = device {
158        if cfg!(feature = "blas-accelerate") {
159            return "CPU (Accelerate)";
160        }
161        if cfg!(feature = "blas-mkl") {
162            return "CPU (MKL)";
163        }
164        if cfg!(feature = "blas-openblas") {
165            return "CPU (OpenBLAS)";
166        }
167    }
168    device.name()
169}
170
171// ── Per-device op-support introspection ──────────────────────────
172//
173// Callers that want to dispatch graphs to a particular device need
174// to know up front whether the device's backend has every op the
175// graph uses wired up. Before this API, the only signal was a
176// runtime panic ("not yet implemented"), which forced downstream
177// crates (e.g. `eda-magnetics::graph::pick_device_for`) to bake
178// hand-maintained "what's missing on X" tables into their own
179// source — those drift the moment a backend lands the missing op.
180//
181// [`supports`] consults the backend-side knowledge (CPU is the
182// reference and assumed complete; MLX / Metal each name the ops
183// they don't yet lower) so consumers can ask once and stop
184// re-implementing the table.
185
186/// Is `op` lowerable by the backend for `device` *in this build*?
187///
188/// - CPU is the reference; always returns `true`.
189/// - GPU backends return `false` only for the specific ops/variants
190///   their lowering currently rejects. As backends close gaps, the
191///   matches here shrink and consumers automatically pick them up.
192/// - For devices not feature-gated in, returns `false` (you can't
193///   dispatch to a backend that isn't compiled in regardless).
194pub fn supports(device: Device, op: &Op) -> bool {
195    if !is_available(device) {
196        return false;
197    }
198    match device {
199        Device::Cpu => true, // reference backend; ground truth
200        Device::Mlx => mlx_supports(op),
201        Device::Metal => metal_supports(op),
202        Device::Ane => coreml_supports(op),
203        Device::Gpu | Device::Cuda | Device::Rocm => gpu_family_supports(op),
204        // Other backends not yet characterised here. Conservative:
205        // assume `false` so callers won't dispatch blind; tighten as
206        // each backend grows a `<x>_supports` arm below.
207        _ => false,
208    }
209}
210
211/// Is every op in `graph` lowerable by `device`?
212///
213/// When a backend is registered, uses the same rewrite + legalization probe as
214/// [`legalize_graph_for_device`] (see [`KernelDispatchReport::compile_ready`]).
215/// Otherwise falls back to per-op [`supports`] heuristics.
216pub fn supports_graph(device: Device, graph: &Graph) -> bool {
217    supports_graph_with_options(device, graph, &CompileOptions::default())
218}
219
220/// Like [`supports_graph`] with explicit [`CompileOptions::kernel_dispatch`].
221pub fn supports_graph_with_options(
222    device: Device,
223    graph: &Graph,
224    options: &CompileOptions,
225) -> bool {
226    if !is_available(device) {
227        return false;
228    }
229    if let Some(backend) = crate::registry::backend_for(device) {
230        let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
231            graph.clone(),
232            device.name(),
233            backend.supported_ops(),
234            options.kernel_dispatch,
235        );
236        return report.compile_ready;
237    }
238    graph.nodes().iter().all(|n| supports(device, &n.op))
239}
240
241/// Legalize `graph` for `device` using that backend's claimed [`OpKind`] set.
242///
243/// Applies the same rewrite + legalization path as [`Backend::compile`] (e.g.
244/// CUDA/ROCm rewrites before the legality check). Returns an error when the
245/// backend feature is not enabled or the graph contains unsupported ops.
246///
247/// Does not require a live GPU/TPU driver — only that the backend crate is
248/// compiled in.
249pub fn legalize_graph_for_device(graph: Graph, device: Device) -> Result<Graph, String> {
250    let (graph, _report) = legalize_graph_for_device_with_report(graph, device)?;
251    Ok(graph)
252}
253
254/// Like [`legalize_graph_for_device`] but returns a [`KernelDispatchReport`] for tooling.
255pub fn legalize_graph_for_device_with_report(
256    graph: Graph,
257    device: Device,
258) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
259    legalize_graph_for_device_with_options(graph, device, &CompileOptions::default())
260}
261
262/// Like [`legalize_graph_for_device_with_report`] using [`CompileOptions::kernel_dispatch`]
263/// (and the same rewrite path as [`Backend::compile`]).
264pub fn legalize_graph_for_device_with_options(
265    graph: Graph,
266    device: Device,
267    options: &CompileOptions,
268) -> Result<(Graph, rlx_opt::KernelDispatchReport), String> {
269    let backend = crate::registry::backend_for(device).ok_or_else(|| {
270        format!(
271            "no backend registered for {device:?} — enable the matching \
272             `rlx-runtime` Cargo feature (e.g. `metal`, `gpu`, `cuda`)"
273        )
274    })?;
275    let ops = backend.supported_ops();
276    let (graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
277        graph,
278        device.name(),
279        ops,
280        options.kernel_dispatch,
281    );
282    if !report.compile_ready {
283        return Err(format!(
284            "{}\n{}",
285            rlx_opt::format_legalize_error(device.name(), &report.still_unsupported),
286            rlx_opt::format_dispatch_report(&report)
287        ));
288    }
289    Ok((graph, report))
290}
291
292/// Dispatch report for `graph` on `device` without mutating the graph (static common-ir probe).
293pub fn dispatch_report_for_device(
294    graph: &Graph,
295    device: Device,
296) -> Result<rlx_opt::KernelDispatchReport, String> {
297    dispatch_report_for_device_with_options(graph, device, &CompileOptions::default())
298}
299
300/// Like [`dispatch_report_for_device`] with explicit [`CompileOptions::kernel_dispatch`].
301pub fn dispatch_report_for_device_with_options(
302    graph: &Graph,
303    device: Device,
304    options: &CompileOptions,
305) -> Result<rlx_opt::KernelDispatchReport, String> {
306    let backend = crate::registry::backend_for(device)
307        .ok_or_else(|| format!("no backend registered for {device:?}"))?;
308    Ok(rlx_opt::analyze_dispatch(
309        graph,
310        device.name(),
311        backend.supported_ops(),
312        options.kernel_dispatch,
313    ))
314}
315
316/// First op in `graph` that `device` cannot lower after rewrite, or `None`.
317///
318/// Prefer the backend claim-set probe when registered; otherwise [`supports`].
319pub fn first_unsupported_op(device: Device, graph: &Graph) -> Option<(usize, &Op)> {
320    first_unsupported_op_with_options(device, graph, &CompileOptions::default())
321}
322
323/// Like [`first_unsupported_op`] with explicit [`CompileOptions::kernel_dispatch`].
324pub fn first_unsupported_op_with_options<'a>(
325    device: Device,
326    graph: &'a Graph,
327    options: &CompileOptions,
328) -> Option<(usize, &'a Op)> {
329    if !is_available(device) {
330        return graph.nodes().first().map(|n| (0, &n.op));
331    }
332    if let Some(backend) = crate::registry::backend_for(device) {
333        let (_, report) = rlx_opt::prepare_graph_for_backend_with_report(
334            graph.clone(),
335            device.name(),
336            backend.supported_ops(),
337            options.kernel_dispatch,
338        );
339        if let Some((id, kind)) = report.still_unsupported.first() {
340            let idx = graph.nodes().iter().position(|n| n.id == *id).unwrap_or(0);
341            let op = graph
342                .nodes()
343                .iter()
344                .find(|n| n.id == *id)
345                .map(|n| &n.op)
346                .unwrap_or(&graph.nodes()[0].op);
347            let _ = kind;
348            return Some((idx, op));
349        }
350        return None;
351    }
352    graph
353        .nodes()
354        .iter()
355        .enumerate()
356        .find_map(|(i, n)| (!supports(device, &n.op)).then_some((i, &n.op)))
357}
358
359#[allow(unused_variables)]
360fn mlx_supports(op: &Op) -> bool {
361    // After Sin/Cos wiring (forward + backward), MLX's `Activation`
362    // dispatch is complete for every variant in `rlx_ir::Activation`.
363    // Add narrow guards here only when a future Op or Activation
364    // variant lands without an MLX lowering.
365    true
366}
367
368#[allow(unused_variables)]
369fn metal_supports(op: &Op) -> bool {
370    // No characterized gaps for the activations rlx-eda exercises.
371    // The Sin/Cos/Tan/Atan MSL kernels landed in `rlx-metal/src/kernels.rs`
372    // (`{sin,cos,tan,atan}_inplace`) alongside the dispatch slots in
373    // `backend.rs:1764`. Narrow this back down if a future Op or
374    // Activation variant lands without a Metal kernel.
375    let _ = op;
376    true
377}
378
379/// CoreML / ANE lowers a fixed, declared op set (see `rlx_coreml::mil`).
380/// Unlike the GPU backends — whose lowering covers the whole IR surface —
381/// CoreML is an inference compiler with a finite op claim, so we check
382/// membership directly against the backend's published list.
383fn coreml_supports(op: &Op) -> bool {
384    crate::backend::COREML_SUPPORTED_OPS.contains(&op.kind())
385}
386
387#[allow(unused_variables)]
388fn gpu_family_supports(op: &Op) -> bool {
389    // CUDA / ROCm / wgpu share the same IR surface area as CPU for the
390    // ops V-JEPA2 and other vision models exercise. Narrow when a backend
391    // reports a concrete lowering gap.
392    let _ = op;
393    true
394}
395
396/// Block until `device`'s queue is idle. Metal drains the global queue;
397/// other backends are no-ops.
398pub fn drain_device(device: Device) {
399    #[cfg(all(target_os = "macos", feature = "metal"))]
400    {
401        if device == Device::Metal {
402            rlx_metal::device::drain_command_queue();
403        }
404    }
405    #[cfg(not(all(target_os = "macos", feature = "metal")))]
406    let _ = device;
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use rlx_ir::op::{Activation, BinaryOp};
413    use rlx_ir::{DType, Graph, Shape};
414
415    fn scalar_shape() -> Shape {
416        Shape::new(&[1], DType::F32)
417    }
418
419    #[test]
420    fn cpu_supports_everything_built_in() {
421        assert!(supports(Device::Cpu, &Op::Activation(Activation::Sin)));
422        assert!(supports(Device::Cpu, &Op::Activation(Activation::Cos)));
423        assert!(supports(Device::Cpu, &Op::Activation(Activation::Exp)));
424        assert!(supports(Device::Cpu, &Op::Binary(BinaryOp::Add)));
425    }
426
427    #[test]
428    fn unbuilt_device_supports_nothing() {
429        // OpenGl isn't a workspace feature; should report false.
430        assert!(!supports(Device::OpenGl, &Op::Activation(Activation::Relu)));
431    }
432
433    #[test]
434    #[cfg(feature = "metal")]
435    fn metal_supports_full_activation_set() {
436        // After the {sin,cos,tan,atan}_inplace MSL kernels landed in
437        // rlx-metal/src/kernels.rs, Metal has every Activation variant
438        // rlx-eda exercises.
439        for act in [
440            Activation::Sin,
441            Activation::Cos,
442            Activation::Tan,
443            Activation::Atan,
444            Activation::Exp,
445        ] {
446            assert!(
447                supports(Device::Metal, &Op::Activation(act)),
448                "Metal should support Activation::{act:?}"
449            );
450        }
451    }
452
453    #[test]
454    fn graph_walk_reports_first_blocker() {
455        let mut g = Graph::new("walk");
456        let s = scalar_shape();
457        let x = g.input("x", s.clone());
458        let _e = g.activation(Activation::Exp, x, s.clone());
459        let _sin = g.activation(Activation::Sin, x, s);
460        // CPU always supports.
461        assert!(supports_graph(Device::Cpu, &g));
462        assert!(first_unsupported_op(Device::Cpu, &g).is_none());
463    }
464
465    #[test]
466    fn fastest_device_returns_cpu_when_only_cpu_is_available() {
467        let pick = fastest_device();
468        assert!(is_available(pick));
469        assert_eq!(pick, fastest_among(&available_devices()));
470    }
471
472    #[test]
473    fn fastest_among_respects_priority_order() {
474        let pick = fastest_among(&[Device::Cpu, Device::Metal, Device::Mlx]);
475        assert_eq!(pick, Device::Mlx);
476    }
477
478    #[test]
479    fn devices_for_is_subset_of_available() {
480        let mut g = Graph::new("id");
481        let x = g.input("x", scalar_shape());
482        g.set_outputs(vec![x]);
483        for d in devices_for(&g) {
484            assert!(is_available(d));
485            assert!(supports_graph(d, &g));
486        }
487    }
488}