#![deny(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use rustc_hash::FxHashMap as HashMap;
use std::any::Any;
use std::sync::{Arc, Mutex};
#[cfg(any(feature = "derive", feature = "debug"))]
pub use coi_derive::*;
#[cfg(feature = "debug")]
use petgraph::{
algo::toposort,
graph::{DiGraph, NodeIndex},
};
#[cfg(feature = "debug")]
use std::fmt::{self, Debug};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Key not found: {0}")]
KeyNotFound(String),
#[error("Type mismatch for key: {0}")]
TypeMismatch(String),
#[error("Inner error: {0}")]
Inner(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
}
pub type Result<T> = std::result::Result<T, Error>;
pub trait Inject: Send + Sync + 'static {}
impl<T: Inject + ?Sized> Inject for Arc<T> {}
#[derive(Copy, Clone, Debug)]
pub enum RegistrationKind {
Transient,
Scoped,
Singleton,
}
#[derive(Clone, Debug)]
pub struct Registration<T> {
kind: RegistrationKind,
provider: T,
}
impl<T> Registration<T> {
pub fn new(kind: RegistrationKind, provider: T) -> Self {
Self { kind, provider }
}
}
#[derive(Clone, Debug)]
struct InnerContainer {
provider_map: HashMap<&'static str, Registration<Arc<dyn Any + Send + Sync>>>,
resolved_map: HashMap<&'static str, Arc<dyn Any + Send + Sync>>,
parent: Option<Container>,
#[cfg(feature = "debug")]
dependency_map: HashMap<&'static str, &'static [&'static str]>,
}
impl InnerContainer {
fn check_resolved<T>(&self, key: &'static str) -> Option<Result<Arc<T>>>
where
T: Inject + ?Sized,
{
self.resolved_map.get(key).map(|v| {
v.downcast_ref::<Arc<T>>()
.map(Arc::clone)
.ok_or_else(|| Error::TypeMismatch(key.to_owned()))
})
}
}
#[derive(Clone, Debug)]
pub struct Container(Arc<Mutex<InnerContainer>>);
#[cfg(feature = "debug")]
#[cfg_attr(docsrs, doc(cfg(feature = "debug")))]
#[derive(Debug, thiserror::Error)]
pub enum AnalysisError {
#[error("Cycle detected at node `{0}`")]
Cycle(&'static str),
#[error("Node `{0}` depends on `{1}`, the latter of which is not registered")]
Missing(&'static str, &'static str),
}
#[cfg(feature = "debug")]
#[derive(Clone, Default)]
struct AnalysisNode {
registration: Option<RegistrationKind>,
id: &'static str,
}
#[cfg(feature = "debug")]
impl fmt::Display for AnalysisNode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.registration {
Some(reg) => match reg {
RegistrationKind::Transient => write!(f, "Transient - {}", self.id),
RegistrationKind::Singleton => write!(f, "Singleton - {}", self.id),
RegistrationKind::Scoped => write!(f, "Scoped - {}", self.id),
},
None => write!(f, "MISSING - {}", self.id),
}
}
}
impl Container {
fn new(container: InnerContainer) -> Self {
Self(Arc::new(Mutex::new(container)))
}
pub fn resolve<T>(&self, key: &'static str) -> Result<Arc<T>>
where
T: Inject + ?Sized,
{
let (kind, provider) = {
let container = self.0.lock().unwrap();
if let Some(resolved) = container.check_resolved::<T>(key) {
return resolved;
}
let registration = match container.provider_map.get(key) {
Some(provider) => provider,
None => {
return match &container.parent {
Some(parent) => {
let parent = parent.clone();
parent.resolve::<T>(key)
}
None => Err(Error::KeyNotFound(key.to_owned())),
};
}
};
(
registration.kind,
registration
.provider
.downcast_ref::<Arc<dyn Provide<Output = T> + Send + Sync + 'static>>()
.map(Arc::clone)
.ok_or_else(|| Error::TypeMismatch(key.to_owned()))?,
)
};
let provided = provider.provide(self);
match kind {
RegistrationKind::Transient => provided,
RegistrationKind::Scoped | RegistrationKind::Singleton => {
let mut container = self.0.lock().unwrap();
Ok(container
.resolved_map
.entry(key)
.or_insert(Arc::new(provided?))
.downcast_ref::<Arc<T>>()
.map(Arc::clone)
.unwrap())
}
}
}
pub fn scoped(&self) -> Container {
let container: &InnerContainer = &self.0.lock().unwrap();
Container::new(InnerContainer {
provider_map: container
.provider_map
.iter()
.filter_map(|(k, v)| match v.kind {
kind @ RegistrationKind::Scoped | kind @ RegistrationKind::Transient => Some((
*k,
Registration {
kind,
provider: Arc::clone(&v.provider),
},
)),
_ => None,
})
.collect(),
resolved_map: HashMap::default(),
#[cfg(feature = "debug")]
dependency_map: container.dependency_map.clone(),
parent: Some(self.clone()),
})
}
#[cfg(feature = "debug")]
fn dependency_graph(&self) -> DiGraph<AnalysisNode, AnalysisNode> {
let container = self.0.lock().unwrap();
let mut graph = DiGraph::<AnalysisNode, AnalysisNode>::new();
let mut key_to_node = container
.dependency_map
.keys()
.map(|k| -> (&'static str, NodeIndex) {
let kind = container.provider_map[k].kind;
let n = graph.add_node(AnalysisNode {
registration: Some(kind),
id: k,
});
(k, n)
})
.collect::<HashMap<&str, _>>();
for (k, deps) in &container.dependency_map {
let kn = key_to_node[k as &str];
let edges = deps
.iter()
.map(|dep| {
let vn = match key_to_node.get(dep) {
Some(vn) => *vn,
None => {
let vn = graph.add_node(AnalysisNode {
registration: None,
id: dep,
});
key_to_node.insert(dep, vn);
key_to_node[dep]
}
};
(kn, vn)
})
.collect::<Vec<_>>();
graph.extend_with_edges(&edges[..]);
}
graph
}
#[cfg(feature = "debug")]
#[cfg_attr(docsrs, doc(cfg(feature = "debug")))]
pub fn analyze(&self) -> std::result::Result<(), Vec<AnalysisError>> {
use petgraph::Direction;
let graph = self.dependency_graph();
let mut errors = graph
.node_indices()
.filter(|i| graph[*i].registration.is_none())
.flat_map(|i| {
let to = &graph[i].id;
graph
.neighbors_directed(i, Direction::Incoming)
.map(|from| AnalysisError::Missing(graph[from].id, to))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
if let Err(cycle) = toposort(&graph, None) {
errors.push(AnalysisError::Cycle(graph[cycle.node_id()].id));
}
if !errors.is_empty() {
Err(errors)
} else {
Ok(())
}
}
#[cfg(feature = "debug")]
#[cfg_attr(docsrs, doc(cfg(feature = "debug")))]
pub fn dot_graph(&self) -> String {
use petgraph::dot::{Config, Dot};
let graph = self.dependency_graph();
format!("{}", Dot::with_config(&graph, &[Config::EdgeNoLabel]))
}
}
#[derive(Clone, Default)]
pub struct ContainerBuilder {
provider_map: HashMap<&'static str, Registration<Arc<dyn Any + Send + Sync>>>,
#[cfg(feature = "debug")]
dependency_map: HashMap<&'static str, &'static [&'static str]>,
}
impl ContainerBuilder {
pub fn new() -> Self {
Self {
provider_map: HashMap::default(),
#[cfg(feature = "debug")]
dependency_map: HashMap::default(),
}
}
#[inline]
pub fn register<P, T>(self, key: &'static str, provider: P) -> Self
where
T: Inject + ?Sized,
P: Provide<Output = T> + Send + Sync + 'static,
{
self.register_as(
key,
Registration::new(RegistrationKind::Transient, provider),
)
}
fn get_arc<P, T>(provider: P) -> Arc<dyn Provide<Output = T> + Send + Sync>
where
T: Inject + ?Sized,
P: Provide<Output = T> + Send + Sync + 'static,
{
Arc::new(provider)
}
pub fn register_as<P, T>(mut self, key: &'static str, registration: Registration<P>) -> Self
where
T: Inject + ?Sized,
P: Provide<Output = T> + Send + Sync + 'static,
{
#[cfg(feature = "debug")]
let deps = registration.provider.dependencies();
self.provider_map.insert(
key,
Registration {
kind: registration.kind,
provider: Arc::new(Self::get_arc(registration.provider))
as Arc<dyn Any + Send + Sync>,
},
);
#[cfg(feature = "debug")]
self.dependency_map.insert(key, deps);
self
}
pub fn build(self) -> Container {
Container::new(InnerContainer {
provider_map: self.provider_map,
resolved_map: HashMap::default(),
parent: None,
#[cfg(feature = "debug")]
dependency_map: self.dependency_map,
})
}
}
pub trait Provide {
type Output: Inject + ?Sized;
fn provide(&self, container: &Container) -> Result<Arc<Self::Output>>;
#[cfg(feature = "debug")]
#[cfg_attr(docsrs, doc(cfg(feature = "debug")))]
fn dependencies(&self) -> &'static [&'static str];
}
#[cfg(not(feature = "debug"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "debug"))))]
impl<T, F> Provide for F
where
F: Fn(&Container) -> Result<Arc<T>>,
T: Inject + ?Sized,
{
type Output = T;
fn provide(&self, container: &Container) -> Result<Arc<Self::Output>> {
self(container)
}
}
#[cfg(not(feature = "debug"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "debug"))))]
impl<T> Provide for dyn Fn(&Container) -> Result<Arc<T>>
where
T: Inject + ?Sized,
{
type Output = T;
fn provide(&self, container: &Container) -> Result<Arc<Self::Output>> {
self(container)
}
}
#[cfg(feature = "debug")]
#[cfg_attr(docsrs, doc(cfg(feature = "debug")))]
impl<T, F> Provide for (&'static [&'static str], F)
where
F: Fn(&Container) -> Result<Arc<T>>,
T: Inject + ?Sized,
{
type Output = T;
fn provide(&self, container: &Container) -> Result<Arc<Self::Output>> {
(self.1)(container)
}
fn dependencies(&self) -> &'static [&'static str] {
self.0
}
}
#[cfg(feature = "debug")]
#[cfg_attr(docsrs, doc(cfg(feature = "debug")))]
impl<T> Provide
for (
&'static [&'static str],
dyn Fn(&Container) -> Result<Arc<T>>,
)
where
T: Inject + ?Sized,
{
type Output = T;
fn provide(&self, container: &Container) -> Result<Arc<Self::Output>> {
(self.1)(container)
}
fn dependencies(&self) -> &'static [&'static str] {
self.0
}
}
#[macro_export]
macro_rules! container {
(@registration $provider:expr; scoped) => {
$crate::Registration::new(
$crate::RegistrationKind::Scoped,
$provider
)
};
(@registration $provider:expr; singleton) => {
$crate::Registration::new(
$crate::RegistrationKind::Singleton,
$provider
)
};
(@registration $provider:expr; transient) => {
$crate::Registration::new(
$crate::RegistrationKind::Transient,
$provider
)
};
(@registration $provider:expr) => {
$crate::Registration::new(
$crate::RegistrationKind::Transient,
$provider
)
};
(@line $builder:ident $key:ident $provider:expr $(; $call:ident)?) => {
$builder = $builder.register_as(stringify!($key), container!(@registration $provider $(; $call)?));
};
($($key:ident => $provider:expr $(; $call:ident)?),+) => {
container!{ $( $key => $provider $(; $call)?, )+ }
};
($($key:ident => $provider:expr $(; $call:ident)?,)+) => {
{
let mut builder = ::coi::ContainerBuilder::new();
$(container!(@line builder $key $provider $(; $call)?);)+
builder.build()
}
}
}
#[macro_export]
macro_rules! provide_closure {
($($move:ident)? |$($arg:ident: Arc<$ty:ty>),*| $(-> $res:ty)? $block:block) => {
provide_closure!($($move)? |$($arg: Arc<$ty>,)*| $(-> $res)? $block)
};
($($move:ident)? |$($arg:ident: Arc<$ty:ty>,)*| $(-> $res:ty)? $block:block) => {
{
use $crate::__provide_closure_impl;
__provide_closure_impl!($($move)? |$($arg: $ty,)*| $(-> $res)? $block)
}
};
($($move:ident)? |$($arg:ident),*| $(-> $res:ty)? $block:block) => {
compile_error!("this macro requires closure arguments to have explicitly defined parameter types")
};
}
#[doc(hidden)]
#[macro_export]
#[cfg(not(feature = "debug"))]
macro_rules! __provide_closure_impl {
($($move:ident)? |$($arg:ident: $ty:ty,)*| $(-> $res:ty)? $block:block) => {
$($move)? |_container: &$crate::Container| $(-> $res)? {
$(let $arg = _container.resolve::<$ty>(stringify!($arg))?;)*
$block
}
};
}
#[doc(hidden)]
#[macro_export]
#[cfg(feature = "debug")]
macro_rules! __provide_closure_impl {
($($move:ident)? |$($arg:ident: $ty:ty,)*| $(-> $res:ty)? $block:block) => {
(
&[$(stringify!($arg),)*],
$($move)? |_container: &$crate::Container| $(-> $res)? {
$(let $arg = _container.resolve::<$ty>(stringify!($arg))?;)*
$block
}
)
};
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn ensure_display() {
use std::io;
let error = Error::KeyNotFound("S".to_owned());
let displayed = format!("{}", error);
assert_eq!(displayed, "Key not found: S");
let error = Error::TypeMismatch("S2".to_owned());
let displayed = format!("{}", error);
assert_eq!(displayed, "Type mismatch for key: S2");
let error = Error::Inner(Box::new(io::Error::new(io::ErrorKind::NotFound, "oh no!")));
let displayed = format!("{}", error);
assert_eq!(displayed, "Inner error: oh no!");
}
#[test]
fn ensure_debug() {
let error = Error::KeyNotFound("S".to_owned());
let debugged = format!("{:?}", error);
assert_eq!(debugged, "KeyNotFound(\"S\")");
let error = Error::TypeMismatch("S2".to_owned());
let debugged = format!("{:?}", error);
assert_eq!(debugged, "TypeMismatch(\"S2\")");
}
#[test]
fn conainer_builder_is_clonable() {
let builder = ContainerBuilder::new();
for _ in 0..2 {
let builder = builder.clone();
let _container = builder.build();
}
}
#[test]
fn container_is_clonable() {
let container = ContainerBuilder::new().build();
#[allow(clippy::redundant_clone)]
let _container = container.clone();
}
}