Skip to main content

rlx_runtime/
registry.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend registry — a single registration point for all backends.
17//!
18//! Adding a new backend (CUDA, ROCm, wgpu, WASM, TPU) is now a self-contained
19//! change in its own crate:
20//!
21//! ```ignore
22//! // in rlx-cuda/src/lib.rs
23//! #[cfg(feature = "cuda")]
24//! pub fn register() {
25//!     rlx_runtime::register_backend(Device::Cuda,
26//!         || Box::new(CudaBackend) as Box<dyn Backend>);
27//! }
28//! ```
29//!
30//! `Session::compile` consults the registry instead of a hardcoded `match`,
31//! so the runtime crate has no compile-time knowledge of which backends are
32//! available — each enables itself via its Cargo feature.
33
34use crate::backend::Backend;
35use rlx_driver::Device;
36use std::collections::HashMap;
37use std::sync::{OnceLock, RwLock};
38
39/// Factory closure that constructs a fresh backend instance.
40///
41/// Called once per `Session::compile`. Implementations are typically
42/// stateless (e.g. unit struct `CpuBackend`); the per-graph state lives
43/// inside the returned `Box<dyn Backend>`.
44pub type BackendFactory = fn() -> Box<dyn Backend>;
45
46struct Registry {
47    factories: RwLock<HashMap<Device, BackendFactory>>,
48}
49
50fn registry() -> &'static Registry {
51    static REGISTRY: OnceLock<Registry> = OnceLock::new();
52    REGISTRY.get_or_init(|| {
53        let r = Registry {
54            factories: RwLock::new(HashMap::new()),
55        };
56        register_builtin(&r);
57        r
58    })
59}
60
61/// Register builtin backends compiled into `rlx-runtime`. External
62/// backends (in their own crates) call `register_backend` from their
63/// own init path.
64#[allow(unused_mut, unused_variables)]
65fn register_builtin(r: &Registry) {
66    let mut map = r.factories.write().expect("registry poisoned");
67
68    #[cfg(feature = "cpu")]
69    map.insert(Device::Cpu, || {
70        Box::new(crate::backend::cpu_backend::CpuBackend) as Box<dyn Backend>
71    });
72
73    #[cfg(all(feature = "metal", target_os = "macos"))]
74    map.insert(Device::Metal, || {
75        Box::new(crate::backend::metal_backend::MetalBackend) as Box<dyn Backend>
76    });
77
78    #[cfg(all(feature = "mlx", rlx_mlx_host))]
79    map.insert(Device::Mlx, || {
80        Box::new(crate::backend::mlx_backend::MlxBackend) as Box<dyn Backend>
81    });
82
83    #[cfg(all(feature = "coreml", any(target_os = "macos", target_os = "ios")))]
84    map.insert(Device::Ane, || {
85        Box::new(crate::backend::coreml_backend::CoremlBackend) as Box<dyn Backend>
86    });
87
88    #[cfg(feature = "gpu")]
89    map.insert(Device::Gpu, || {
90        Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
91    });
92
93    #[cfg(feature = "vulkan")]
94    map.insert(Device::Vulkan, || {
95        rlx_wgpu::select_vulkan_backend();
96        Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
97    });
98
99    #[cfg(feature = "cuda")]
100    map.insert(Device::Cuda, || {
101        Box::new(crate::backend::cuda_backend::CudaBackend) as Box<dyn Backend>
102    });
103
104    #[cfg(feature = "rocm")]
105    map.insert(Device::Rocm, || {
106        Box::new(crate::backend::rocm_backend::RocmBackend) as Box<dyn Backend>
107    });
108
109    #[cfg(feature = "tpu")]
110    map.insert(Device::Tpu, || {
111        Box::new(crate::backend::tpu_backend::TpuBackend) as Box<dyn Backend>
112    });
113}
114
115/// Register a backend factory for `device`. External backend crates
116/// (rlx-cuda, rlx-rocm, rlx-wgpu, rlx-wasm, …) call this once at startup
117/// (typically from a `pub fn register()` in their lib.rs that the user
118/// invokes — or via a constructor attribute if they use `ctor`/`inventory`).
119///
120/// Re-registering a device replaces the prior factory, so a custom backend
121/// can override a builtin (useful for swap-in alternatives like a tuned
122/// CPU backend).
123pub fn register_backend(device: Device, factory: BackendFactory) {
124    let r = registry();
125    let mut map = r.factories.write().expect("registry poisoned");
126    map.insert(device, factory);
127}
128
129/// Look up a backend factory and instantiate. Returns `None` if no backend
130/// is registered for `device`.
131pub fn backend_for(device: Device) -> Option<Box<dyn Backend>> {
132    let r = registry();
133    let map = r.factories.read().expect("registry poisoned");
134    map.get(&device).map(|f| f())
135}
136
137/// All currently registered devices (deterministic snapshot).
138pub fn registered_devices() -> Vec<Device> {
139    let r = registry();
140    let map = r.factories.read().expect("registry poisoned");
141    let mut out: Vec<Device> = map.keys().copied().collect();
142    out.sort_by_key(|d| format!("{d:?}"));
143    out
144}