nestforge_core/
container.rs1use std::{
2 any::{Any, TypeId},
3 collections::{HashMap, HashSet},
4 sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9#[derive(Clone, Default)]
27pub struct Container {
28 inner: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
29 names: Arc<RwLock<HashSet<&'static str>>>,
30}
31
32#[derive(Debug, Error)]
33pub enum ContainerError {
34 #[error("Container write lock poisoned")]
35 WriteLockPoisoned,
36 #[error("Container read lock poisoned")]
37 ReadLockPoisoned,
38 #[error("Type already registered: {type_name}")]
39 TypeAlreadyRegistered { type_name: &'static str },
40 #[error("Type not registered: {type_name}")]
41 TypeNotRegistered { type_name: &'static str },
42 #[error("Failed to downcast resolved value: {type_name}")]
43 DowncastFailed { type_name: &'static str },
44 #[error("Type not registered: {type_name} (required by module `{module_name}`)")]
45 TypeNotRegisteredInModule {
46 type_name: &'static str,
47 module_name: &'static str,
48 },
49}
50
51impl Container {
52 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn register<T>(&self, value: T) -> Result<(), ContainerError>
71 where
72 T: Send + Sync + 'static,
73 {
74 let mut map = self
75 .inner
76 .write()
77 .map_err(|_| ContainerError::WriteLockPoisoned)?;
78
79 let type_id = TypeId::of::<T>();
80
81 if map.contains_key(&type_id) {
83 return Err(ContainerError::TypeAlreadyRegistered {
84 type_name: std::any::type_name::<T>(),
85 });
86 }
87
88 map.insert(type_id, Arc::new(value));
90 self.names
91 .write()
92 .map_err(|_| ContainerError::WriteLockPoisoned)?
93 .insert(std::any::type_name::<T>());
94 Ok(())
95 }
96
97 pub fn replace<T>(&self, value: T) -> Result<(), ContainerError>
98 where
99 T: Send + Sync + 'static,
100 {
101 let mut map = self
102 .inner
103 .write()
104 .map_err(|_| ContainerError::WriteLockPoisoned)?;
105
106 map.insert(TypeId::of::<T>(), Arc::new(value));
107 self.names
108 .write()
109 .map_err(|_| ContainerError::WriteLockPoisoned)?
110 .insert(std::any::type_name::<T>());
111 Ok(())
112 }
113
114 pub fn is_type_registered_name(&self, type_name: &'static str) -> Result<bool, ContainerError> {
115 let names = self
116 .names
117 .read()
118 .map_err(|_| ContainerError::ReadLockPoisoned)?;
119 Ok(names.contains(type_name))
120 }
121
122 pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
131 where
132 T: Send + Sync + 'static,
133 {
134 let map = self
135 .inner
136 .read()
137 .map_err(|_| ContainerError::ReadLockPoisoned)?;
138
139 let value = map
140 .get(&TypeId::of::<T>())
141 .ok_or_else(|| ContainerError::TypeNotRegistered {
142 type_name: std::any::type_name::<T>(),
143 })?
144 .clone();
145
146 value
151 .downcast::<T>()
152 .map_err(|_| ContainerError::DowncastFailed {
153 type_name: std::any::type_name::<T>(),
154 })
155 }
156
157 pub fn resolve_in_module<T>(&self, module_name: &'static str) -> Result<Arc<T>, ContainerError>
158 where
159 T: Send + Sync + 'static,
160 {
161 self.resolve::<T>().map_err(|err| match err {
162 ContainerError::TypeNotRegistered { type_name } => {
163 ContainerError::TypeNotRegisteredInModule {
164 type_name,
165 module_name,
166 }
167 }
168 other => other,
169 })
170 }
171}