1use crate::backend::Backend;
35use rlx_driver::Device;
36use std::collections::HashMap;
37use std::sync::{OnceLock, RwLock};
38
39pub type BackendFactory = fn() -> Box<dyn Backend>;
45
46struct Registry {
47 factories: RwLock<HashMap<Device, BackendFactory>>,
48}
49
50fn registry() -> &'static Registry {
51 static REGISTRY: OnceLock<Registry> = OnceLock::new();
52 REGISTRY.get_or_init(|| {
53 let r = Registry {
54 factories: RwLock::new(HashMap::new()),
55 };
56 register_builtin(&r);
57 r
58 })
59}
60
61#[allow(unused_mut, unused_variables)]
65fn register_builtin(r: &Registry) {
66 let mut map = r.factories.write().expect("registry poisoned");
67
68 #[cfg(feature = "cpu")]
69 map.insert(Device::Cpu, || {
70 Box::new(crate::backend::cpu_backend::CpuBackend) as Box<dyn Backend>
71 });
72
73 #[cfg(all(feature = "metal", target_os = "macos"))]
74 map.insert(Device::Metal, || {
75 Box::new(crate::backend::metal_backend::MetalBackend) as Box<dyn Backend>
76 });
77
78 #[cfg(all(feature = "mlx", rlx_mlx_host))]
79 map.insert(Device::Mlx, || {
80 Box::new(crate::backend::mlx_backend::MlxBackend) as Box<dyn Backend>
81 });
82
83 #[cfg(all(feature = "coreml", any(target_os = "macos", target_os = "ios")))]
84 map.insert(Device::Ane, || {
85 Box::new(crate::backend::coreml_backend::CoremlBackend) as Box<dyn Backend>
86 });
87
88 #[cfg(feature = "gpu")]
89 map.insert(Device::Gpu, || {
90 Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
91 });
92
93 #[cfg(feature = "vulkan")]
94 map.insert(Device::Vulkan, || {
95 rlx_wgpu::select_vulkan_backend();
96 Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
97 });
98
99 #[cfg(feature = "cuda")]
100 map.insert(Device::Cuda, || {
101 Box::new(crate::backend::cuda_backend::CudaBackend) as Box<dyn Backend>
102 });
103
104 #[cfg(feature = "rocm")]
105 map.insert(Device::Rocm, || {
106 Box::new(crate::backend::rocm_backend::RocmBackend) as Box<dyn Backend>
107 });
108
109 #[cfg(feature = "tpu")]
110 map.insert(Device::Tpu, || {
111 Box::new(crate::backend::tpu_backend::TpuBackend) as Box<dyn Backend>
112 });
113}
114
115pub fn register_backend(device: Device, factory: BackendFactory) {
124 let r = registry();
125 let mut map = r.factories.write().expect("registry poisoned");
126 map.insert(device, factory);
127}
128
129pub fn backend_for(device: Device) -> Option<Box<dyn Backend>> {
132 let r = registry();
133 let map = r.factories.read().expect("registry poisoned");
134 map.get(&device).map(|f| f())
135}
136
137pub fn registered_devices() -> Vec<Device> {
139 let r = registry();
140 let map = r.factories.read().expect("registry poisoned");
141 let mut out: Vec<Device> = map.keys().copied().collect();
142 out.sort_by_key(|d| format!("{d:?}"));
143 out
144}