use crate::backend::Backend;
use rlx_driver::Device;
use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};
pub type BackendFactory = fn() -> Box<dyn Backend>;
struct Registry {
factories: RwLock<HashMap<Device, BackendFactory>>,
}
fn registry() -> &'static Registry {
static REGISTRY: OnceLock<Registry> = OnceLock::new();
REGISTRY.get_or_init(|| {
let r = Registry {
factories: RwLock::new(HashMap::new()),
};
register_builtin(&r);
r
})
}
#[allow(unused_mut, unused_variables)]
fn register_builtin(r: &Registry) {
let mut map = r.factories.write().expect("registry poisoned");
#[cfg(feature = "cpu")]
map.insert(Device::Cpu, || {
Box::new(crate::backend::cpu_backend::CpuBackend) as Box<dyn Backend>
});
#[cfg(all(feature = "metal", target_os = "macos"))]
map.insert(Device::Metal, || {
Box::new(crate::backend::metal_backend::MetalBackend) as Box<dyn Backend>
});
#[cfg(all(feature = "mlx", target_os = "macos"))]
map.insert(Device::Mlx, || {
Box::new(crate::backend::mlx_backend::MlxBackend) as Box<dyn Backend>
});
#[cfg(feature = "gpu")]
map.insert(Device::Gpu, || {
Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
});
#[cfg(feature = "vulkan")]
map.insert(Device::Vulkan, || {
rlx_wgpu::select_vulkan_backend();
Box::new(crate::backend::wgpu_backend::WgpuBackend) as Box<dyn Backend>
});
#[cfg(feature = "cuda")]
map.insert(Device::Cuda, || {
Box::new(crate::backend::cuda_backend::CudaBackend) as Box<dyn Backend>
});
#[cfg(feature = "rocm")]
map.insert(Device::Rocm, || {
Box::new(crate::backend::rocm_backend::RocmBackend) as Box<dyn Backend>
});
#[cfg(feature = "tpu")]
map.insert(Device::Tpu, || {
Box::new(crate::backend::tpu_backend::TpuBackend) as Box<dyn Backend>
});
}
pub fn register_backend(device: Device, factory: BackendFactory) {
let r = registry();
let mut map = r.factories.write().expect("registry poisoned");
map.insert(device, factory);
}
pub fn backend_for(device: Device) -> Option<Box<dyn Backend>> {
let r = registry();
let map = r.factories.read().expect("registry poisoned");
map.get(&device).map(|f| f())
}
pub fn registered_devices() -> Vec<Device> {
let r = registry();
let map = r.factories.read().expect("registry poisoned");
let mut out: Vec<Device> = map.keys().copied().collect();
out.sort_by_key(|d| format!("{d:?}"));
out
}