Skip to main content

rlx_runtime/
registry.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend registry — a single registration point for all backends.
17//!
18//! Adding a new backend (CUDA, ROCm, wgpu, WASM, TPU) is now a self-contained
19//! change in its own crate:
20//!
21//! ```ignore
22//! // in rlx-cuda/src/lib.rs
23//! #[cfg(feature = "cuda")]
24//! pub fn register() {
25//!     rlx_runtime::register_backend(Device::Cuda,
26//!         || Box::new(CudaBackend) as Box<dyn Backend>);
27//! }
28//! ```
29//!
30//! `Session::compile` consults the registry instead of a hardcoded `match`,
31//! so the runtime crate has no compile-time knowledge of which backends are
32//! available — each enables itself via its Cargo feature.
33
34use crate::backend::Backend;
35use rlx_driver::Device;
36use std::collections::HashMap;
37use std::sync::{OnceLock, RwLock};
38
39/// Factory closure that constructs a fresh backend instance.
40///
41/// Called once per `Session::compile`. Implementations are typically
42/// stateless (e.g. unit struct `CpuBackend`); the per-graph state lives
43/// inside the returned `Box<dyn Backend>`.
44pub type BackendFactory = fn() -> Box<dyn Backend>;
45
46struct Registry {
47    factories: RwLock<HashMap<Device, BackendFactory>>,
48}
49
50fn registry() -> &'static Registry {
51    static REGISTRY: OnceLock<Registry> = OnceLock::new();
52    REGISTRY.get_or_init(|| {
53        let r = Registry {
54            factories: RwLock::new(HashMap::new()),
55        };
56        register_builtin(&r);
57        r
58    })
59}
60
61/// Register builtin backends compiled into `rlx-runtime`. External
62/// backends (in their own crates) call `register_backend` from their
63/// own init path.
64#[allow(unused_mut, unused_variables)]
65fn register_builtin(r: &Registry) {
66    let mut map = r.factories.write().expect("registry poisoned");
67
68    #[cfg(feature = "cpu")]
69    map.insert(Device::Cpu, || {
70        Box::new(crate::backend::cpu_backend::CpuBackend) as Box<dyn Backend>
71    });
72
73    #[cfg(all(feature = "metal", target_os = "macos"))]
74    map.insert(Device::Metal, || {
75        Box::new(crate::backend::metal_backend::MetalBackend) as Box<dyn Backend>
76    });
77
78    #[cfg(all(feature = "mlx", target_os = "macos"))]
79    map.insert(Device::Mlx, || {
80        Box::new(crate::backend::mlx_backend::MlxBackend) as Box<dyn Backend>
81    });
82
83    #[cfg(feature = "gpu")]
84    map.insert(Device::Gpu, || {
85        Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
86    });
87
88    #[cfg(feature = "vulkan")]
89    map.insert(Device::Vulkan, || {
90        rlx_wgpu::select_vulkan_backend();
91        Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
92    });
93
94    #[cfg(feature = "cuda")]
95    map.insert(Device::Cuda, || {
96        Box::new(crate::backend::cuda_backend::CudaBackend) as Box<dyn Backend>
97    });
98
99    #[cfg(feature = "rocm")]
100    map.insert(Device::Rocm, || {
101        Box::new(crate::backend::rocm_backend::RocmBackend) as Box<dyn Backend>
102    });
103
104    #[cfg(feature = "tpu")]
105    map.insert(Device::Tpu, || {
106        Box::new(crate::backend::tpu_backend::TpuBackend) as Box<dyn Backend>
107    });
108}
109
110/// Register a backend factory for `device`. External backend crates
111/// (rlx-cuda, rlx-rocm, rlx-wgpu, rlx-wasm, …) call this once at startup
112/// (typically from a `pub fn register()` in their lib.rs that the user
113/// invokes — or via a constructor attribute if they use `ctor`/`inventory`).
114///
115/// Re-registering a device replaces the prior factory, so a custom backend
116/// can override a builtin (useful for swap-in alternatives like a tuned
117/// CPU backend).
118pub fn register_backend(device: Device, factory: BackendFactory) {
119    let r = registry();
120    let mut map = r.factories.write().expect("registry poisoned");
121    map.insert(device, factory);
122}
123
124/// Look up a backend factory and instantiate. Returns `None` if no backend
125/// is registered for `device`.
126pub fn backend_for(device: Device) -> Option<Box<dyn Backend>> {
127    let r = registry();
128    let map = r.factories.read().expect("registry poisoned");
129    map.get(&device).map(|f| f())
130}
131
132/// All currently registered devices (deterministic snapshot).
133pub fn registered_devices() -> Vec<Device> {
134    let r = registry();
135    let map = r.factories.read().expect("registry poisoned");
136    let mut out: Vec<Device> = map.keys().copied().collect();
137    out.sort_by_key(|d| format!("{d:?}"));
138    out
139}