use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, LazyLock, Mutex};
use crate::{RegistryError, RegistryEvent};
type TraceCallback = LazyLock<Mutex<Option<Arc<dyn Fn(&RegistryEvent) + Send + Sync>>>>;
pub trait RegistryApi {
fn trace() -> &'static TraceCallback;
fn set_trace_callback(&self, callback: impl Fn(&RegistryEvent) + Send + Sync + 'static) {
let mut guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
*guard = Some(Arc::new(callback));
}
fn clear_trace_callback(&self) {
let mut guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
*guard = None;
}
fn emit_event(&self, event: &RegistryEvent) {
let callback = {
let guard = Self::trace().lock().unwrap_or_else(|p| p.into_inner());
guard.as_ref().cloned()
}; if let Some(cb) = callback {
cb(event);
}
}
fn storage() -> &'static LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>;
fn register<T: Send + Sync + 'static>(&self, value: T) {
self.register_arc(Arc::new(value));
}
fn register_arc<T: Send + Sync + 'static>(&self, value: Arc<T>) {
self.emit_event(&RegistryEvent::Register {
type_name: std::any::type_name::<T>(),
});
Self::storage()
.lock()
.unwrap_or_else(|p| p.into_inner())
.insert(TypeId::of::<T>(), value);
self.emit_event(&RegistryEvent::RegisterCompleted {
type_name: std::any::type_name::<T>(),
});
}
fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, RegistryError> {
let map = Self::storage()
.lock()
.map_err(|_| RegistryError::RegistryLock)?;
let any_arc_opt = map.get(&TypeId::of::<T>()).cloned();
drop(map);
let result: Result<Arc<T>, RegistryError> = match any_arc_opt {
Some(any_arc) => any_arc
.downcast::<T>()
.map_err(|_| RegistryError::TypeMismatch {
type_name: std::any::type_name::<T>(),
}),
None => Err(RegistryError::TypeNotFound {
type_name: std::any::type_name::<T>(),
}),
};
self.emit_event(&RegistryEvent::Get {
type_name: std::any::type_name::<T>(),
found: result.is_ok(),
});
result
}
fn get_cloned<T: Send + Sync + Clone + 'static>(&self) -> Result<T, RegistryError> {
let arc = self.get::<T>()?;
Ok((*arc).clone())
}
fn try_get<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.get::<T>().ok()
}
fn contains<T: Send + Sync + 'static>(&self) -> Result<bool, RegistryError> {
let found = Self::storage()
.lock()
.map(|m| m.contains_key(&TypeId::of::<T>()))
.map_err(|_| RegistryError::RegistryLock)?;
self.emit_event(&RegistryEvent::Contains {
type_name: std::any::type_name::<T>(),
found,
});
Ok(found)
}
#[doc(hidden)]
fn clear(&self) {
self.emit_event(&RegistryEvent::Clear {});
if let Ok(mut registry) = Self::storage().lock() {
registry.clear();
}
}
}
#[cfg(test)]
mod tests {
use crate::RegistryError;
use super::{RegistryApi, TraceCallback};
use serial_test::serial;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, LazyLock, Mutex};
static STORAGE: LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
static TRACE: TraceCallback = LazyLock::new(|| Mutex::new(None));
struct Api;
impl RegistryApi for Api {
fn storage() -> &'static LazyLock<Mutex<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>> {
&STORAGE
}
fn trace() -> &'static TraceCallback {
&TRACE
}
}
const API: Api = Api;
#[test]
#[serial]
fn test_register_and_get_primitive() -> Result<(), RegistryError> {
API.clear();
API.register(42i32);
let num: Arc<i32> = API.get()?;
assert_eq!(*num, 42);
let num_2 = API.get::<i32>()?;
assert_eq!(*num_2, 42);
Ok(())
}
#[test]
#[serial]
fn test_register_and_get_string() {
API.clear();
let s = "test".to_string();
API.register(s.clone());
let retrieved: Arc<String> = API.get().unwrap();
assert_eq!(&*retrieved, &s);
API.clear();
}
#[test]
#[serial]
fn test_get_nonexistent() {
API.clear();
let result: Result<Arc<String>, RegistryError> = API.get();
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
RegistryError::TypeNotFound {
type_name: "alloc::string::String"
}
);
}
#[test]
#[serial]
fn test_thread_safety() {
API.clear();
use std::sync::{mpsc, Arc, Barrier};
use std::thread;
let barrier = Arc::new(Barrier::new(2));
let (main_tx, thread_rx) = mpsc::channel();
let (thread_tx, main_rx) = mpsc::channel();
let barrier_clone = barrier.clone();
let handle = thread::spawn(move || {
API.register(100u32);
thread_tx.send(100u32).unwrap();
let main_value: String = thread_rx.recv().unwrap();
barrier_clone.wait();
let s: Arc<String> = API.get().unwrap();
assert_eq!(&*s, &main_value);
});
let thread_value = main_rx.recv().unwrap();
let num: Arc<u32> = API.get().unwrap();
assert_eq!(*num, thread_value);
let main_string = "main_thread_value".to_string();
API.register(main_string.clone());
main_tx.send(main_string.clone()).unwrap();
barrier.wait();
handle.join().unwrap();
API.clear();
}
#[test]
#[serial]
fn test_multiple_types() {
API.clear();
#[derive(Debug, PartialEq, Eq, Clone)]
struct Num(i32);
#[derive(Debug, PartialEq, Eq, Clone)]
struct Text(String);
#[derive(Debug, PartialEq, Eq, Clone)]
struct Numbers(Vec<i32>);
let num_val = Num(42);
let text_val = Text("hello".to_string());
let nums_val = Numbers(vec![1, 2, 3]);
API.register(num_val.clone());
API.register(text_val.clone());
API.register(nums_val.clone());
let num: Arc<Num> = API.get().unwrap();
assert_eq!(num.0, num_val.0);
let text: Arc<Text> = API.get().unwrap();
assert_eq!(text.0, text_val.0);
let nums: Arc<Numbers> = API.get().unwrap();
assert_eq!(&nums.0, &nums_val.0);
API.clear();
}
#[test]
#[serial]
fn test_custom_type() {
API.clear();
#[derive(Debug, PartialEq, Eq, Clone)]
struct MyStruct {
field: String,
}
let my_value = MyStruct {
field: "test".into(),
};
API.register(my_value.clone());
let retrieved: Arc<MyStruct> = API.get().unwrap();
assert_eq!(&*retrieved, &my_value);
}
#[test]
#[serial]
fn test_tuple_type() -> Result<(), RegistryError> {
API.clear();
let tuple = (1, "test");
API.register(tuple);
let retrieved = API.get::<(i32, &str)>()?;
assert_eq!(&*retrieved, &tuple);
Ok(())
}
#[test]
#[serial]
fn test_overwrite_same_type() {
API.clear();
API.register(10i32);
API.register(20i32);
let num: Arc<i32> = API.get().unwrap();
assert_eq!(*num, 20);
}
#[test]
#[serial]
fn test_get_cloned() {
API.clear();
API.register("hello".to_string());
let value: String = API.get_cloned::<String>().unwrap();
assert_eq!(value, "hello");
}
#[test]
#[serial]
fn test_try_get() {
API.clear();
assert!(API.try_get::<u64>().is_none());
API.register(99u64);
let val = API.try_get::<u64>().expect("should be Some after register");
assert_eq!(*val, 99);
}
#[test]
#[serial]
fn test_contains() {
API.clear();
assert!(!API.contains::<u32>().unwrap());
API.register(1u32);
assert!(API.contains::<u32>().unwrap());
}
#[test]
#[serial]
fn test_function_pointer_registration() {
API.clear();
let multiply_by_two: fn(i32) -> i32 = |x| x * 2;
API.register(multiply_by_two);
let doubler: Arc<fn(i32) -> i32> = API.get().unwrap();
let result = doubler(21);
assert_eq!(result, 42);
}
#[test]
#[serial]
fn test_trace_callback_register_event() {
API.clear();
use std::sync::{Arc as StdArc, Mutex as StdMutex};
let events = StdArc::new(StdMutex::new(Vec::new()));
let events_clone = events.clone();
API.set_trace_callback(move |e| {
events_clone.lock().unwrap().push(format!("{}", e));
});
API.register(5u8);
let captured = events.lock().unwrap();
assert_eq!(captured.len(), 2);
assert_eq!(captured[0], "register { type_name: u8 }");
assert_eq!(captured[1], "register_completed { type_name: u8 }");
API.clear_trace_callback();
}
#[test]
#[serial]
fn test_trace_callback_get_event() {
API.clear();
use std::sync::{Arc as StdArc, Mutex as StdMutex};
let events = StdArc::new(StdMutex::new(Vec::new()));
let events_clone = events.clone();
API.set_trace_callback(move |e| {
events_clone.lock().unwrap().push(format!("{}", e));
});
API.register(42i32);
let _ = API.get::<i32>();
let captured = events.lock().unwrap();
assert_eq!(captured.len(), 3);
assert_eq!(captured[0], "register { type_name: i32 }");
assert_eq!(captured[1], "register_completed { type_name: i32 }");
assert_eq!(captured[2], "get { type_name: i32, found: true }");
API.clear_trace_callback();
}
#[test]
#[serial]
fn test_trace_callback_contains_event() {
API.clear();
use std::sync::{Arc as StdArc, Mutex as StdMutex};
let events = StdArc::new(StdMutex::new(Vec::new()));
let events_clone = events.clone();
API.set_trace_callback(move |e| {
events_clone.lock().unwrap().push(format!("{}", e));
});
let _ = API.contains::<String>();
API.register("test".to_string());
let _ = API.contains::<String>();
let captured = events.lock().unwrap();
assert_eq!(captured.len(), 4);
assert_eq!(
captured[0],
"contains { type_name: alloc::string::String, found: false }"
);
assert_eq!(captured[1], "register { type_name: alloc::string::String }");
assert_eq!(
captured[2],
"register_completed { type_name: alloc::string::String }"
);
assert_eq!(
captured[3],
"contains { type_name: alloc::string::String, found: true }"
);
API.clear_trace_callback();
}
#[test]
#[serial]
fn test_trace_callback_clear_event() {
API.clear();
use std::sync::{Arc as StdArc, Mutex as StdMutex};
let events = StdArc::new(StdMutex::new(Vec::new()));
let events_clone = events.clone();
API.set_trace_callback(move |e| {
events_clone.lock().unwrap().push(format!("{}", e));
});
API.clear();
let captured = events.lock().unwrap();
assert_eq!(captured.len(), 1);
assert_eq!(captured[0], "Clearing the Registry");
API.clear_trace_callback();
}
#[test]
#[serial]
fn test_clear_trace_callback_stops_events() {
API.clear();
use std::sync::{Arc as StdArc, Mutex as StdMutex};
let events = StdArc::new(StdMutex::new(Vec::new()));
let events_clone = events.clone();
API.set_trace_callback(move |e| {
events_clone.lock().unwrap().push(format!("{}", e));
});
API.register(10u16);
let count = events.lock().unwrap().len();
let e0 = events.lock().unwrap()[0].clone();
let e1 = events.lock().unwrap()[1].clone();
assert_eq!(count, 2);
assert_eq!(e0, "register { type_name: u16 }");
assert_eq!(e1, "register_completed { type_name: u16 }");
API.clear_trace_callback();
API.register(20u16);
let _ = API.get::<u16>();
let _ = API.contains::<u16>();
let final_count = events.lock().unwrap().len();
assert_eq!(final_count, 2); }
#[test]
#[serial]
fn test_register_arc_directly() {
API.clear();
let value = Arc::new(42i32);
let clone = value.clone();
API.register_arc(value);
let retrieved: Arc<i32> = API.get().unwrap();
assert_eq!(*retrieved, 42);
assert_eq!(Arc::strong_count(&clone), 3); }
}