use std::any::{Any, TypeId};
use std::collections::HashMap;
use cfg_block::cfg_block;
cfg_block! {
if #[cfg(feature = "sync")] {
type AnyDyn = dyn Any + Send + Sync;
type Factory = dyn Fn(&Container) -> Box<AnyDyn> + Send + Sync;
pub trait SyncBounds: Send + Sync + 'static {}
impl<T: Send + Sync + 'static> SyncBounds for T {}
} else {
type AnyDyn = dyn Any;
type Factory = dyn Fn(&Container) -> Box<AnyDyn>;
pub trait SyncBounds: 'static {}
impl<T: 'static> SyncBounds for T {}
}
}
pub struct DeclaredDependency {
pub type_id: fn() -> TypeId,
pub type_name: &'static str,
}
inventory::collect!(DeclaredDependency);
pub struct Container {
singletons: HashMap<TypeId, Box<AnyDyn>>,
factories: HashMap<TypeId, Box<Factory>>,
}
cfg_block! {
#[cfg(feature = "debug")] {
use std::fmt;
impl fmt::Debug for Container {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Container")
.field("singletons", &self.singletons.len())
.field("factories", &self.factories.len())
.finish()
}
}
impl fmt::Display for Container {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Container ({} singletons, {} factories)",
self.singletons.len(),
self.factories.len()
)
}
}
}
}
pub struct ContainerBuilder {
singletons: HashMap<TypeId, Box<AnyDyn>>,
factories: HashMap<TypeId, Box<Factory>>,
}
impl Default for ContainerBuilder {
fn default() -> Self {
Self::new()
}
}
impl ContainerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
singletons: HashMap::new(),
factories: HashMap::new(),
}
}
#[must_use]
pub fn singleton<T: SyncBounds>(mut self, value: T) -> Self {
self.singletons.insert(TypeId::of::<T>(), Box::new(value));
self
}
#[must_use]
pub fn factory<T, F>(mut self, f: F) -> Self
where
T: SyncBounds,
F: Fn(&Container) -> T + SyncBounds,
{
let f = move |c: &Container| -> Box<AnyDyn> { Box::new(f(c)) };
self.factories.insert(TypeId::of::<T>(), Box::new(f));
self
}
#[must_use]
pub fn build(self) -> Container {
Container {
singletons: self.singletons,
factories: self.factories,
}
}
}
impl Container {
#[must_use]
pub fn builder() -> ContainerBuilder {
ContainerBuilder::new()
}
#[must_use]
pub fn get<T: SyncBounds>(&self) -> &T {
self.try_get().expect("dependency not registered")
}
#[must_use]
pub fn try_get<T: SyncBounds>(&self) -> Option<&T> {
self.singletons
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref::<T>())
}
#[must_use]
pub fn resolve<T: SyncBounds>(&self) -> T {
self.try_resolve()
.expect("factory not registered or failed")
}
#[must_use]
pub fn try_resolve<T: SyncBounds>(&self) -> Option<T> {
let factory = self.factories.get(&TypeId::of::<T>())?;
let boxed = factory(self);
boxed.downcast().ok().map(|b| *b)
}
#[must_use]
pub fn contains<T: 'static>(&self) -> bool {
self.singletons.contains_key(&TypeId::of::<T>())
|| self.factories.contains_key(&TypeId::of::<T>())
}
pub fn validate(&self) {
let mut missing: Vec<&'static str> = Vec::new();
for dep in inventory::iter::<DeclaredDependency> {
let type_id = (dep.type_id)();
let registered =
self.singletons.contains_key(&type_id) || self.factories.contains_key(&type_id);
if !registered {
missing.push(dep.type_name);
}
}
if !missing.is_empty() {
panic!(
"Container is missing {} declared dependenc{}: [{}]",
missing.len(),
if missing.len() == 1 { "y" } else { "ies" },
missing.join(", ")
);
}
}
#[must_use]
pub fn singleton_count(&self) -> usize {
self.singletons.len()
}
#[must_use]
pub fn factory_count(&self) -> usize {
self.factories.len()
}
}