use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use super::error::ContainerError;
type AnyArc = Arc<dyn Any + Send + Sync>;
enum Binding {
Singleton(AnyArc),
Factory(Box<dyn Fn() -> AnyArc + Send + Sync>),
}
pub struct Container {
bindings: RwLock<HashMap<TypeId, Binding>>,
}
impl Default for Container {
fn default() -> Self {
Self::new()
}
}
impl Container {
pub fn new() -> Self {
Self {
bindings: RwLock::new(HashMap::new()),
}
}
pub fn bind<T, F>(&self, factory: F)
where
T: Any + Send + Sync + 'static,
F: Fn() -> T + Send + Sync + 'static,
{
let mut map = self.bindings.write().expect("container lock poisoned");
map.insert(
TypeId::of::<T>(),
Binding::Factory(Box::new(move || Arc::new(factory()))),
);
}
pub fn singleton<T>(&self, instance: T)
where
T: Any + Send + Sync + 'static,
{
let mut map = self.bindings.write().expect("container lock poisoned");
map.insert(TypeId::of::<T>(), Binding::Singleton(Arc::new(instance)));
}
pub fn make<T>(&self) -> Result<Arc<T>, ContainerError>
where
T: Any + Send + Sync + 'static,
{
let map = self.bindings.read().expect("container lock poisoned");
match map.get(&TypeId::of::<T>()) {
Some(Binding::Singleton(arc)) => arc
.clone()
.downcast::<T>()
.map_err(|_| ContainerError::TypeMismatch(std::any::type_name::<T>())),
Some(Binding::Factory(f)) => f()
.downcast::<T>()
.map_err(|_| ContainerError::TypeMismatch(std::any::type_name::<T>())),
None => Err(ContainerError::NotRegistered(std::any::type_name::<T>())),
}
}
pub fn extend<T, F>(&self, extender: F) -> Result<(), ContainerError>
where
T: Any + Send + Sync + 'static,
F: FnOnce(Arc<T>) -> T,
{
let existing = self.make::<T>()?;
let new_instance = extender(existing);
self.singleton(new_instance);
Ok(())
}
pub fn swap<T>(&self, instance: T)
where
T: Any + Send + Sync + 'static,
{
self.singleton(instance);
}
}