rlx_runtime/registry.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Backend registry — a single registration point for all backends.
17//!
18//! Adding a new backend (CUDA, ROCm, wgpu, WASM, TPU) is now a self-contained
19//! change in its own crate:
20//!
21//! ```ignore
22//! // in rlx-cuda/src/lib.rs
23//! #[cfg(feature = "cuda")]
24//! pub fn register() {
25//! rlx_runtime::register_backend(Device::Cuda,
26//! || Box::new(CudaBackend) as Box<dyn Backend>);
27//! }
28//! ```
29//!
30//! `Session::compile` consults the registry instead of a hardcoded `match`,
31//! so the runtime crate has no compile-time knowledge of which backends are
32//! available — each enables itself via its Cargo feature.
33
34use crate::backend::Backend;
35use rlx_driver::Device;
36use std::collections::HashMap;
37use std::sync::{OnceLock, RwLock};
38
39/// Factory closure that constructs a fresh backend instance.
40///
41/// Called once per `Session::compile`. Implementations are typically
42/// stateless (e.g. unit struct `CpuBackend`); the per-graph state lives
43/// inside the returned `Box<dyn Backend>`.
44pub type BackendFactory = fn() -> Box<dyn Backend>;
45
46struct Registry {
47 factories: RwLock<HashMap<Device, BackendFactory>>,
48}
49
50fn registry() -> &'static Registry {
51 static REGISTRY: OnceLock<Registry> = OnceLock::new();
52 REGISTRY.get_or_init(|| {
53 let r = Registry {
54 factories: RwLock::new(HashMap::new()),
55 };
56 register_builtin(&r);
57 r
58 })
59}
60
61/// Register builtin backends compiled into `rlx-runtime`. External
62/// backends (in their own crates) call `register_backend` from their
63/// own init path.
64#[allow(unused_mut, unused_variables)]
65fn register_builtin(r: &Registry) {
66 let mut map = r.factories.write().expect("registry poisoned");
67
68 #[cfg(feature = "cpu")]
69 map.insert(Device::Cpu, || {
70 Box::new(crate::backend::cpu_backend::CpuBackend) as Box<dyn Backend>
71 });
72
73 #[cfg(all(feature = "metal", target_os = "macos"))]
74 map.insert(Device::Metal, || {
75 Box::new(crate::backend::metal_backend::MetalBackend) as Box<dyn Backend>
76 });
77
78 #[cfg(all(feature = "mlx", target_os = "macos"))]
79 map.insert(Device::Mlx, || {
80 Box::new(crate::backend::mlx_backend::MlxBackend) as Box<dyn Backend>
81 });
82
83 #[cfg(feature = "gpu")]
84 map.insert(Device::Gpu, || {
85 Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
86 });
87
88 #[cfg(feature = "vulkan")]
89 map.insert(Device::Vulkan, || {
90 rlx_wgpu::select_vulkan_backend();
91 Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
92 });
93
94 #[cfg(feature = "cuda")]
95 map.insert(Device::Cuda, || {
96 Box::new(crate::backend::cuda_backend::CudaBackend) as Box<dyn Backend>
97 });
98
99 #[cfg(feature = "rocm")]
100 map.insert(Device::Rocm, || {
101 Box::new(crate::backend::rocm_backend::RocmBackend) as Box<dyn Backend>
102 });
103
104 #[cfg(feature = "tpu")]
105 map.insert(Device::Tpu, || {
106 Box::new(crate::backend::tpu_backend::TpuBackend) as Box<dyn Backend>
107 });
108}
109
110/// Register a backend factory for `device`. External backend crates
111/// (rlx-cuda, rlx-rocm, rlx-wgpu, rlx-wasm, …) call this once at startup
112/// (typically from a `pub fn register()` in their lib.rs that the user
113/// invokes — or via a constructor attribute if they use `ctor`/`inventory`).
114///
115/// Re-registering a device replaces the prior factory, so a custom backend
116/// can override a builtin (useful for swap-in alternatives like a tuned
117/// CPU backend).
118pub fn register_backend(device: Device, factory: BackendFactory) {
119 let r = registry();
120 let mut map = r.factories.write().expect("registry poisoned");
121 map.insert(device, factory);
122}
123
124/// Look up a backend factory and instantiate. Returns `None` if no backend
125/// is registered for `device`.
126pub fn backend_for(device: Device) -> Option<Box<dyn Backend>> {
127 let r = registry();
128 let map = r.factories.read().expect("registry poisoned");
129 map.get(&device).map(|f| f())
130}
131
132/// All currently registered devices (deterministic snapshot).
133pub fn registered_devices() -> Vec<Device> {
134 let r = registry();
135 let map = r.factories.read().expect("registry poisoned");
136 let mut out: Vec<Device> = map.keys().copied().collect();
137 out.sort_by_key(|d| format!("{d:?}"));
138 out
139}