1pub use ferrum_interfaces::{
7 backend::KernelExecutor, ComputeBackend, DeviceMemoryManager, TensorFactory, TensorLike,
8 TensorOps, TensorRef,
9};
10
11pub use ferrum_types::{DataType, Device, Result};
13
14pub mod backends;
16pub mod memory;
17
18pub use backends::*;
20pub use memory::*;
21
22use once_cell::sync::Lazy;
23use std::sync::Arc;
24
25static BACKEND_REGISTRY: Lazy<Arc<DefaultBackendRegistry>> =
27 Lazy::new(|| Arc::new(DefaultBackendRegistry::new()));
28
29pub fn global_backend_registry() -> Arc<DefaultBackendRegistry> {
31 BACKEND_REGISTRY.clone()
32}
33
34pub struct DefaultBackendRegistry {
36 compute_backends:
37 parking_lot::RwLock<std::collections::HashMap<String, Arc<dyn ComputeBackend>>>,
38}
39
40impl DefaultBackendRegistry {
41 pub fn new() -> Self {
42 Self {
43 compute_backends: parking_lot::RwLock::new(std::collections::HashMap::new()),
44 }
45 }
46
47 pub fn register_compute_backend(
48 &self,
49 name: &str,
50 backend: Arc<dyn ComputeBackend>,
51 ) -> Result<()> {
52 self.compute_backends
53 .write()
54 .insert(name.to_string(), backend);
55 Ok(())
56 }
57
58 pub fn get_compute_backend(&self, name: &str) -> Option<Arc<dyn ComputeBackend>> {
59 self.compute_backends.read().get(name).cloned()
60 }
61
62 pub fn list_backends(&self) -> Vec<String> {
63 self.compute_backends.read().keys().cloned().collect()
64 }
65}
66
67impl std::fmt::Debug for DefaultBackendRegistry {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("DefaultBackendRegistry")
70 .field("backend_count", &self.compute_backends.read().len())
71 .finish()
72 }
73}
74
75impl Default for DefaultBackendRegistry {
76 fn default() -> Self {
77 Self::new()
78 }
79}