use crate::{Container, Injectable};
use std::sync::Arc;
pub trait Service: Injectable + Sized {
type Dependencies: Resolvable;
fn create(deps: Self::Dependencies) -> Self;
}
pub trait Resolvable: Sized {
fn resolve(container: &Container) -> Option<Self>;
}
impl Resolvable for () {
#[inline]
fn resolve(_container: &Container) -> Option<Self> {
Some(())
}
}
impl<T: Injectable> Resolvable for Arc<T> {
#[inline]
fn resolve(container: &Container) -> Option<Self> {
container.try_get::<T>()
}
}
impl<T: Injectable> Resolvable for Option<Arc<T>> {
#[inline]
fn resolve(container: &Container) -> Option<Self> {
Some(container.try_get::<T>())
}
}
macro_rules! impl_resolvable_tuple {
($($T:ident),+) => {
impl<$($T: Injectable),+> Resolvable for ($(Arc<$T>,)+) {
#[inline]
fn resolve(container: &Container) -> Option<Self> {
Some(($(container.try_get::<$T>()?,)+))
}
}
};
}
impl_resolvable_tuple!(A, B);
impl_resolvable_tuple!(A, B, C);
impl_resolvable_tuple!(A, B, C, D);
impl_resolvable_tuple!(A, B, C, D, E);
impl_resolvable_tuple!(A, B, C, D, E, F);
impl_resolvable_tuple!(A, B, C, D, E, F, G);
impl_resolvable_tuple!(A, B, C, D, E, F, G, H);
impl_resolvable_tuple!(A, B, C, D, E, F, G, H, I);
impl_resolvable_tuple!(A, B, C, D, E, F, G, H, I, J);
impl_resolvable_tuple!(A, B, C, D, E, F, G, H, I, J, K);
impl_resolvable_tuple!(A, B, C, D, E, F, G, H, I, J, K, L);
pub trait ServiceProvider {
fn provide<T: Service>(&self);
fn provide_singleton<T: Service>(&self) -> bool;
fn provide_transient<T: Service>(&self);
}
impl ServiceProvider for Container {
#[inline]
fn provide<T: Service>(&self) {
let container = self.clone();
self.lazy(move || {
let deps = T::Dependencies::resolve(&container)
.expect("Failed to resolve dependencies for service");
T::create(deps)
});
}
#[inline]
fn provide_singleton<T: Service>(&self) -> bool {
if let Some(deps) = T::Dependencies::resolve(self) {
self.singleton(T::create(deps));
true
} else {
false
}
}
#[inline]
fn provide_transient<T: Service>(&self) {
let container = self.clone();
self.transient(move || {
let deps = T::Dependencies::resolve(&container)
.expect("Failed to resolve dependencies for transient service");
T::create(deps)
});
}
}
pub trait ServiceModule {
fn register(container: &Container);
}
pub trait DependencyInfo {
fn dependency_names() -> Vec<&'static str>;
}
impl DependencyInfo for () {
fn dependency_names() -> Vec<&'static str> {
vec![]
}
}
impl<T: Injectable> DependencyInfo for Arc<T> {
fn dependency_names() -> Vec<&'static str> {
vec![std::any::type_name::<T>()]
}
}
impl<T: Injectable> DependencyInfo for Option<Arc<T>> {
fn dependency_names() -> Vec<&'static str> {
vec![std::any::type_name::<T>()]
}
}
macro_rules! impl_dependency_info_tuple {
($($T:ident),+) => {
impl<$($T: Injectable),+> DependencyInfo for ($(Arc<$T>,)+) {
fn dependency_names() -> Vec<&'static str> {
vec![$(std::any::type_name::<$T>()),+]
}
}
};
}
impl_dependency_info_tuple!(A, B);
impl_dependency_info_tuple!(A, B, C);
impl_dependency_info_tuple!(A, B, C, D);
impl_dependency_info_tuple!(A, B, C, D, E);
impl_dependency_info_tuple!(A, B, C, D, E, F);
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct Config {
debug: bool,
}
impl Service for Config {
type Dependencies = ();
fn create(_: ()) -> Self {
Config { debug: true }
}
}
#[derive(Clone)]
struct Database {
url: String,
}
impl Service for Database {
type Dependencies = Arc<Config>;
fn create(config: Arc<Config>) -> Self {
Database {
url: if config.debug {
"debug://localhost".into()
} else {
"prod://server".into()
},
}
}
}
#[derive(Clone)]
struct Cache {
size: usize,
}
impl Service for Cache {
type Dependencies = ();
fn create(_: ()) -> Self {
Cache { size: 1024 }
}
}
#[derive(Clone)]
struct UserRepository {
db: Arc<Database>,
cache: Arc<Cache>,
}
impl Service for UserRepository {
type Dependencies = (Arc<Database>, Arc<Cache>);
fn create((db, cache): (Arc<Database>, Arc<Cache>)) -> Self {
UserRepository { db, cache }
}
}
#[test]
fn test_service_no_deps() {
let container = Container::new();
container.provide::<Config>();
let config = container.get::<Config>().unwrap();
assert!(config.debug);
}
#[test]
fn test_service_single_dep() {
let container = Container::new();
container.provide::<Config>();
container.provide::<Database>();
let db = container.get::<Database>().unwrap();
assert_eq!(db.url, "debug://localhost");
}
#[test]
fn test_service_multiple_deps() {
let container = Container::new();
container.provide::<Config>();
container.provide::<Database>();
container.provide::<Cache>();
container.provide::<UserRepository>();
let repo = container.get::<UserRepository>().unwrap();
assert_eq!(repo.db.url, "debug://localhost");
assert_eq!(repo.cache.size, 1024);
}
#[test]
fn test_provide_singleton() {
let container = Container::new();
container.provide::<Config>();
let result = container.provide_singleton::<Database>();
assert!(result);
let db = container.get::<Database>().unwrap();
assert_eq!(db.url, "debug://localhost");
}
#[test]
fn test_provide_singleton_missing_dep() {
let container = Container::new();
let result = container.provide_singleton::<Database>();
assert!(!result);
}
#[test]
fn test_provide_transient() {
use std::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
#[derive(Clone)]
struct Counter(u32);
impl Service for Counter {
type Dependencies = ();
fn create(_: ()) -> Self {
Counter(COUNTER.fetch_add(1, Ordering::SeqCst))
}
}
let container = Container::new();
container.provide_transient::<Counter>();
let c1 = container.get::<Counter>().unwrap();
let c2 = container.get::<Counter>().unwrap();
assert_ne!(c1.0, c2.0);
}
#[test]
fn test_optional_dependency() {
#[derive(Clone)]
struct OptionalCache;
#[derive(Clone)]
struct ServiceWithOptional {
cache: Option<Arc<OptionalCache>>,
}
impl Service for ServiceWithOptional {
type Dependencies = Option<Arc<OptionalCache>>;
fn create(cache: Option<Arc<OptionalCache>>) -> Self {
ServiceWithOptional { cache }
}
}
let container = Container::new();
container.provide::<ServiceWithOptional>();
let svc = container.get::<ServiceWithOptional>().unwrap();
assert!(svc.cache.is_none());
let container2 = Container::new();
container2.singleton(OptionalCache);
container2.provide::<ServiceWithOptional>();
let svc2 = container2.get::<ServiceWithOptional>().unwrap();
assert!(svc2.cache.is_some());
}
#[test]
fn test_dependency_info() {
assert_eq!(
<() as DependencyInfo>::dependency_names(),
Vec::<&str>::new()
);
assert_eq!(
<Arc<Config> as DependencyInfo>::dependency_names(),
vec!["dependency_injector::verified::tests::Config"]
);
assert_eq!(
<(Arc<Database>, Arc<Cache>) as DependencyInfo>::dependency_names().len(),
2
);
}
#[test]
fn test_service_module() {
struct TestModule;
impl ServiceModule for TestModule {
fn register(container: &Container) {
container.provide::<Config>();
container.provide::<Cache>();
}
}
let container = Container::new();
TestModule::register(&container);
assert!(container.contains::<Config>());
assert!(container.contains::<Cache>());
}
}