1use once_cell::sync::Lazy;
35use std::any::{Any, TypeId};
36use std::collections::HashMap;
37use std::sync::{Arc, RwLock};
38
39#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum RegistryError {
42 TypeNotFound,
44 TypeAlreadyExists,
46 LockError,
48 InvalidCast,
50}
51
52impl std::fmt::Display for RegistryError {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 RegistryError::TypeNotFound => write!(f, "Type not found in registry"),
56 RegistryError::TypeAlreadyExists => write!(f, "Type already registered"),
57 RegistryError::LockError => write!(f, "Failed to acquire lock on registry"),
58 RegistryError::InvalidCast => write!(f, "Invalid type cast"),
59 }
60 }
61}
62
63impl std::error::Error for RegistryError {}
64
65pub trait Registrable: Any + Send + Sync {
67 fn type_id(&self) -> TypeId {
69 TypeId::of::<Self>()
70 }
71}
72
73impl<T: Any + Send + Sync> Registrable for T {}
74
75pub struct TypeRegistry {
77 registry: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
79}
80
81impl TypeRegistry {
82 pub fn new() -> Self {
84 Self {
85 registry: RwLock::new(HashMap::new()),
86 }
87 }
88
89 pub fn global() -> &'static TypeRegistry {
91 static GLOBAL_REGISTRY: Lazy<TypeRegistry> = Lazy::new(TypeRegistry::new);
92 &GLOBAL_REGISTRY
93 }
94
95 pub fn register<T: Registrable>(&self, value: T) -> Result<(), RegistryError> {
97 let type_id = TypeId::of::<T>();
98 let mut registry = self.registry.write().map_err(|_| RegistryError::LockError)?;
99
100 if registry.contains_key(&type_id) {
101 return Err(RegistryError::TypeAlreadyExists);
102 }
103
104 registry.insert(type_id, Box::new(value));
105 Ok(())
106 }
107
108 pub fn register_or_replace<T: Registrable>(&self, value: T) -> Result<(), RegistryError> {
110 let type_id = TypeId::of::<T>();
111 let mut registry = self.registry.write().map_err(|_| RegistryError::LockError)?;
112 registry.insert(type_id, Box::new(value));
113 Ok(())
114 }
115
116
117 pub fn get_cloned<T: Registrable + Clone>(&self) -> Result<T, RegistryError> {
119 let type_id = TypeId::of::<T>();
120 let registry = self.registry.read().map_err(|_| RegistryError::LockError)?;
121
122 let value = registry
123 .get(&type_id)
124 .ok_or(RegistryError::TypeNotFound)?
125 .downcast_ref::<T>()
126 .ok_or(RegistryError::InvalidCast)?;
127
128 Ok(value.clone())
129 }
130
131 pub fn get<T: Registrable + 'static>(&self) -> Result<Arc<T>, RegistryError> {
133 let type_id = TypeId::of::<T>();
134 let registry = self.registry.read().map_err(|_| RegistryError::LockError)?;
135
136 let value = registry
137 .get(&type_id)
138 .ok_or(RegistryError::TypeNotFound)?;
139
140 if let Some(arc_value) = value.downcast_ref::<Arc<T>>() {
142 return Ok(Arc::clone(arc_value));
143 }
144
145 if let Some(_direct_value) = value.downcast_ref::<T>() {
147 return Err(RegistryError::InvalidCast);
151 }
152
153 Err(RegistryError::InvalidCast)
154 }
155
156 pub fn register_arc<T: Registrable + 'static>(&self, value: Arc<T>) -> Result<(), RegistryError> {
158 let type_id = TypeId::of::<T>();
159 let mut registry = self.registry.write().map_err(|_| RegistryError::LockError)?;
160
161 if registry.contains_key(&type_id) {
162 return Err(RegistryError::TypeAlreadyExists);
163 }
164
165 registry.insert(type_id, Box::new(value));
166 Ok(())
167 }
168
169 pub fn register_arc_or_replace<T: Registrable + 'static>(&self, value: Arc<T>) -> Result<(), RegistryError> {
171 let type_id = TypeId::of::<T>();
172 let mut registry = self.registry.write().map_err(|_| RegistryError::LockError)?;
173 registry.insert(type_id, Box::new(value));
174 Ok(())
175 }
176
177 pub fn is_registered<T: Registrable>(&self) -> bool {
179 let type_id = TypeId::of::<T>();
180 let registry = self.registry.read().unwrap_or_else(|_| panic!("Registry lock poisoned"));
181 registry.contains_key(&type_id)
182 }
183
184 pub fn unregister<T: Registrable>(&self) -> Result<(), RegistryError> {
186 let type_id = TypeId::of::<T>();
187 let mut registry = self.registry.write().map_err(|_| RegistryError::LockError)?;
188
189 if registry.remove(&type_id).is_some() {
190 Ok(())
191 } else {
192 Err(RegistryError::TypeNotFound)
193 }
194 }
195
196 pub fn len(&self) -> usize {
198 let registry = self.registry.read().unwrap_or_else(|_| panic!("Registry lock poisoned"));
199 registry.len()
200 }
201
202 pub fn is_empty(&self) -> bool {
204 self.len() == 0
205 }
206
207 pub fn clear(&self) -> Result<(), RegistryError> {
209 let mut registry = self.registry.write().map_err(|_| RegistryError::LockError)?;
210 registry.clear();
211 Ok(())
212 }
213
214 pub fn registered_types(&self) -> Result<Vec<TypeId>, RegistryError> {
216 let registry = self.registry.read().map_err(|_| RegistryError::LockError)?;
217 Ok(registry.keys().copied().collect())
218 }
219}
220
221impl Default for TypeRegistry {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227#[macro_export]
229macro_rules! register {
230 ($type:ty, $value:expr) => {
231 {
232 let _result = $crate::TypeRegistry::global()
233 .register::<$type>($value);
234 _result
235 }
236 };
237}
238
239#[macro_export]
241macro_rules! register_factory {
242 ($type:ty, $factory:expr) => {
243 {
244 let value = $factory();
245 $crate::register!($type, value)
246 }
247 };
248}
249
250
251#[macro_export]
253macro_rules! get_cloned {
254 ($type:ty) => {
255 $crate::TypeRegistry::global().get_cloned::<$type>()
256 };
257}
258
259#[macro_export]
261macro_rules! get {
262 ($type:ty) => {
263 $crate::TypeRegistry::global().get::<$type>()
264 };
265}
266
267#[macro_export]
270macro_rules! register_arc {
271 ($type:ty, $value:expr) => {
272 {
273 let _result = $crate::TypeRegistry::global()
274 .register_arc::<$type>(std::sync::Arc::new($value));
275 _result
276 }
277 };
278}
279
280
281pub type Factory<T> = Box<dyn Fn() -> T + Send + Sync>;
283
284pub struct FactoryRegistry {
286 registry: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
287}
288
289impl FactoryRegistry {
290 pub fn new() -> Self {
291 Self {
292 registry: RwLock::new(HashMap::new()),
293 }
294 }
295
296 pub fn global() -> &'static FactoryRegistry {
297 static GLOBAL_FACTORY_REGISTRY: Lazy<FactoryRegistry> = Lazy::new(FactoryRegistry::new);
298 &GLOBAL_FACTORY_REGISTRY
299 }
300
301 pub fn register_factory<T: 'static>(&self, factory: Factory<T>) -> Result<(), RegistryError> {
302 let type_id = TypeId::of::<T>();
303 let mut registry = self.registry.write().map_err(|_| RegistryError::LockError)?;
304
305 if registry.contains_key(&type_id) {
306 return Err(RegistryError::TypeAlreadyExists);
307 }
308
309 registry.insert(type_id, Box::new(factory));
310 Ok(())
311 }
312
313 pub fn create<T: 'static>(&self) -> Result<T, RegistryError> {
314 let type_id = TypeId::of::<T>();
315 let registry = self.registry.read().map_err(|_| RegistryError::LockError)?;
316
317 let factory = registry
318 .get(&type_id)
319 .ok_or(RegistryError::TypeNotFound)?
320 .downcast_ref::<Factory<T>>()
321 .ok_or(RegistryError::InvalidCast)?;
322
323 Ok(factory())
324 }
325}
326
327impl Default for FactoryRegistry {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use std::sync::Arc;
337 use std::thread;
338
339 #[derive(Debug, PartialEq, Clone)]
340 struct TestStruct {
341 value: i32,
342 }
343
344 trait TestTrait: Send + Sync {
345 fn get_value(&self) -> i32;
346 }
347
348 impl TestTrait for TestStruct {
349 fn get_value(&self) -> i32 {
350 self.value
351 }
352 }
353
354 #[test]
355 fn test_basic_registration() {
356 let registry = TypeRegistry::new();
357 let test_struct = TestStruct { value: 42 };
358
359 assert!(registry.register(test_struct).is_ok());
360
361 let value = registry.get_cloned::<TestStruct>().unwrap().value;
362 assert_eq!(value, 42);
363 }
364
365 #[test]
366 fn test_duplicate_registration_fails() {
367 let registry = TypeRegistry::new();
368 let test_struct1 = TestStruct { value: 42 };
369 let test_struct2 = TestStruct { value: 100 };
370
371 assert!(registry.register(test_struct1).is_ok());
372 assert_eq!(registry.register(test_struct2), Err(RegistryError::TypeAlreadyExists));
373 }
374
375 #[test]
376 fn test_registration_replacement() {
377 let registry = TypeRegistry::new();
378 let test_struct1 = TestStruct { value: 42 };
379 let test_struct2 = TestStruct { value: 100 };
380
381 assert!(registry.register(test_struct1).is_ok());
382 assert!(registry.register_or_replace(test_struct2).is_ok());
383
384 let value = registry.get_cloned::<TestStruct>().unwrap().value;
385 assert_eq!(value, 100);
386 }
387
388 #[test]
389 fn test_unregister() {
390 let registry = TypeRegistry::new();
391 let test_struct = TestStruct { value: 42 };
392
393 assert!(registry.register(test_struct).is_ok());
394 assert!(registry.is_registered::<TestStruct>());
395 assert!(registry.unregister::<TestStruct>().is_ok());
396 assert!(!registry.is_registered::<TestStruct>());
397 }
398
399 #[test]
400 fn test_global_registry() {
401 let registry = TypeRegistry::global();
402 let test_struct = TestStruct { value: 123 };
403
404 assert!(registry.register(test_struct).is_ok());
405 let value = registry.get_cloned::<TestStruct>().unwrap().value;
406 assert_eq!(value, 123);
407
408 assert!(registry.clear().is_ok());
409 }
410
411 #[test]
412 fn test_macros() {
413 #[derive(Clone)]
414 struct MacroTestStruct { value: String }
415
416 let result = register!(MacroTestStruct, MacroTestStruct {
417 value: "test".to_string()
418 });
419 assert!(result.is_ok());
420
421 let value = TypeRegistry::global().get_cloned::<MacroTestStruct>().unwrap();
422 assert_eq!(value.value, "test");
423
424 TypeRegistry::global().clear().unwrap();
425 }
426
427 #[test]
428 fn test_factory_registry() {
429 let factory_registry = FactoryRegistry::new();
430
431 let factory: Factory<TestStruct> = Box::new(|| TestStruct { value: 999 });
432 assert!(factory_registry.register_factory(factory).is_ok());
433
434 let created = factory_registry.create::<TestStruct>().unwrap();
435 assert_eq!(created.value, 999);
436 }
437
438 #[test]
439 fn test_thread_safety() {
440 let registry = Arc::new(TypeRegistry::new());
441 let mut handles = vec![];
442
443 for i in 0..10 {
444 let registry_clone = Arc::clone(®istry);
445 let handle = thread::spawn(move || {
446 struct ThreadTestStruct(i32);
447 let _ = registry_clone.register(ThreadTestStruct(i));
448 });
449 handles.push(handle);
450 }
451
452 for handle in handles {
453 handle.join().unwrap();
454 }
455
456 assert_eq!(registry.len(), 1); }
458
459 #[test]
460 fn test_trait_objects() {
461 let registry = TypeRegistry::new();
462 let handler: Box<dyn TestTrait> = Box::new(TestStruct { value: 42 });
463
464 assert!(registry.register(handler).is_ok());
465
466 assert!(registry.is_registered::<Box<dyn TestTrait>>());
469 }
470
471
472 #[test]
473 fn test_arc_registration_and_get() {
474 let registry = TypeRegistry::new();
475 let test_struct = Arc::new(TestStruct { value: 123 });
476
477 assert!(registry.register_arc(test_struct.clone()).is_ok());
479
480 let retrieved = registry.get::<TestStruct>().unwrap();
482 assert_eq!(retrieved.value, 123);
483
484 assert!(Arc::ptr_eq(&test_struct, &retrieved));
486 }
487
488 #[test]
489 fn test_arc_macros() {
490 let test_struct = TestStruct { value: 456 };
491
492 assert!(register_arc!(TestStruct, test_struct).is_ok());
494
495 let retrieved = get!(TestStruct).unwrap();
497 assert_eq!(retrieved.value, 456);
498
499 TypeRegistry::global().clear().unwrap();
501 }
502
503
504
505}