use crate::{
KeyedRef, KeyedRefMut, Mut, ServiceDescriptor, Ref, RefMut, Type,
};
use std::any::{type_name, Any};
use std::borrow::Borrow;
use std::collections::HashMap;
use std::iter::empty;
use std::marker::PhantomData;
use std::ops::Deref;
#[derive(Clone)]
pub struct ServiceProvider {
services: Ref<HashMap<Type, Vec<ServiceDescriptor>>>,
}
#[cfg(feature = "async")]
unsafe impl Send for ServiceProvider {}
#[cfg(feature = "async")]
unsafe impl Sync for ServiceProvider {}
impl ServiceProvider {
pub fn new(services: HashMap<Type, Vec<ServiceDescriptor>>) -> Self {
Self {
services: Ref::new(services),
}
}
pub fn get<T: Any + ?Sized>(&self) -> Option<Ref<T>> {
let key = Type::of::<T>();
if let Some(descriptors) = self.services.get(&key) {
if let Some(descriptor) = descriptors.last() {
return Some(
descriptor
.get(self)
.downcast_ref::<Ref<T>>()
.unwrap()
.clone(),
);
}
}
None
}
pub fn get_mut<T: Any + ?Sized>(&self) -> Option<RefMut<T>> {
self.get::<Mut<T>>()
}
pub fn get_by_key<TKey, TSvc: Any + ?Sized>(&self) -> Option<KeyedRef<TKey, TSvc>> {
let key = Type::keyed::<TKey, TSvc>();
if let Some(descriptors) = self.services.get(&key) {
if let Some(descriptor) = descriptors.last() {
return Some(KeyedRef::new(
descriptor
.get(self)
.downcast_ref::<Ref<TSvc>>()
.unwrap()
.clone(),
));
}
}
None
}
pub fn get_by_key_mut<TKey, TSvc: Any + ?Sized>(
&self,
) -> Option<KeyedRefMut<TKey, TSvc>> {
self.get_by_key::<TKey, Mut<TSvc>>()
}
pub fn get_all<T: Any + ?Sized>(&self) -> impl Iterator<Item = Ref<T>> + '_ {
let key = Type::of::<T>();
if let Some(descriptors) = self.services.get(&key) {
ServiceIterator::new(self, descriptors.iter())
} else {
ServiceIterator::new(self, empty())
}
}
pub fn get_all_mut<T: Any + ?Sized>(&self) -> impl Iterator<Item = RefMut<T>> + '_ {
self.get_all::<Mut<T>>()
}
pub fn get_all_by_key<'a, TKey: 'a, TSvc>(
&'a self,
) -> impl Iterator<Item = KeyedRef<TKey, TSvc>> + '_
where
TSvc: Any + ?Sized,
{
let key = Type::keyed::<TKey, TSvc>();
if let Some(descriptors) = self.services.get(&key) {
KeyedServiceIterator::new(self, descriptors.iter())
} else {
KeyedServiceIterator::new(self, empty())
}
}
pub fn get_all_by_key_mut<'a, TKey: 'a, TSvc>(
&'a self,
) -> impl Iterator<Item = KeyedRefMut<TKey, TSvc>> + '_
where
TSvc: Any + ?Sized,
{
self.get_all_by_key::<TKey, Mut<TSvc>>()
}
pub fn get_required<T: Any + ?Sized>(&self) -> Ref<T> {
if let Some(service) = self.get::<T>() {
service
} else {
panic!(
"No service for type '{}' has been registered.",
type_name::<T>()
);
}
}
pub fn get_required_mut<T: Any + ?Sized>(&self) -> RefMut<T> {
self.get_required::<Mut<T>>()
}
pub fn get_required_by_key<TKey, TSvc: Any + ?Sized>(&self) -> KeyedRef<TKey, TSvc> {
if let Some(service) = self.get_by_key::<TKey, TSvc>() {
service
} else {
panic!(
"No service for type '{}' with the key '{}' has been registered.",
type_name::<TSvc>(),
type_name::<TKey>()
);
}
}
pub fn get_required_by_key_mut<TKey, TSvc: Any + ?Sized>(
&self,
) -> KeyedRefMut<TKey, TSvc> {
self.get_required_by_key::<TKey, Mut<TSvc>>()
}
pub fn create_scope(&self) -> Self {
Self::new(self.services.as_ref().clone())
}
}
#[derive(Clone, Default)]
pub struct ScopedServiceProvider {
sp: ServiceProvider
}
impl From<&ServiceProvider> for ScopedServiceProvider {
fn from(value: &ServiceProvider) -> Self {
Self { sp: value.create_scope() }
}
}
impl AsRef<ServiceProvider> for ScopedServiceProvider {
fn as_ref(&self) -> &ServiceProvider {
&self.sp
}
}
impl Borrow<ServiceProvider> for ScopedServiceProvider {
fn borrow(&self) -> &ServiceProvider {
&self.sp
}
}
impl Deref for ScopedServiceProvider {
type Target = ServiceProvider;
fn deref(&self) -> &Self::Target {
&self.sp
}
}
struct ServiceIterator<'a, T>
where
T: Any + ?Sized,
{
provider: &'a ServiceProvider,
descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
_marker: PhantomData<T>,
}
struct KeyedServiceIterator<'a, TKey, TSvc>
where
TSvc: Any + ?Sized,
{
provider: &'a ServiceProvider,
descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
_key: PhantomData<TKey>,
_svc: PhantomData<TSvc>,
}
impl<'a, T: Any + ?Sized> ServiceIterator<'a, T> {
fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
where
I: Iterator<Item = &'a ServiceDescriptor> + 'a,
{
Self {
provider,
descriptors: Box::new(descriptors),
_marker: PhantomData,
}
}
}
impl<'a, T: Any + ?Sized> Iterator for ServiceIterator<'a, T> {
type Item = Ref<T>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(descriptor) = self.descriptors.next() {
Some(
descriptor
.get(self.provider)
.downcast_ref::<Ref<T>>()
.unwrap()
.clone(),
)
} else {
None
}
}
}
impl<'a, TKey, TSvc: Any + ?Sized> KeyedServiceIterator<'a, TKey, TSvc> {
fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
where
I: Iterator<Item = &'a ServiceDescriptor> + 'a,
{
Self {
provider,
descriptors: Box::new(descriptors),
_key: PhantomData,
_svc: PhantomData,
}
}
}
impl<'a, TKey, TSvc: Any + ?Sized> Iterator for KeyedServiceIterator<'a, TKey, TSvc> {
type Item = KeyedRef<TKey, TSvc>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(descriptor) = self.descriptors.next() {
Some(KeyedRef::new(
descriptor
.get(self.provider)
.downcast_ref::<Ref<TSvc>>()
.unwrap()
.clone(),
))
} else {
None
}
}
}
impl Default for ServiceProvider {
fn default() -> Self {
Self {
services: Ref::new(HashMap::with_capacity(0)),
}
}
}
#[cfg(test)]
mod tests {
use crate::{test::*, *};
use std::fs::remove_file;
use std::path::{Path, PathBuf};
#[cfg(feature = "async")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "async")]
use std::thread;
#[test]
fn get_should_return_none_when_service_is_unregistered() {
let services = ServiceCollection::new().build_provider().unwrap();
let result = services.get::<dyn TestService>();
assert!(result.is_none());
}
#[test]
fn get_by_key_should_return_none_when_service_is_unregistered() {
let services = ServiceCollection::new().build_provider().unwrap();
let result = services.get_by_key::<key::Thingy, dyn TestService>();
assert!(result.is_none());
}
#[test]
fn get_should_return_registered_service() {
let services = ServiceCollection::new()
.add(
singleton::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl::default())),
)
.build_provider()
.unwrap();
let result = services.get::<dyn TestService>();
assert!(result.is_some());
}
#[test]
fn get_by_key_should_return_registered_service() {
let services = ServiceCollection::new()
.add(
singleton_with_key::<key::Thingy, dyn Thing, Thing1>()
.from(|_| Ref::new(Thing1::default())),
)
.add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
.build_provider()
.unwrap();
let result = services.get_by_key::<key::Thingy, dyn Thing>();
assert!(result.is_some());
}
#[test]
fn get_required_should_return_registered_service() {
let services = ServiceCollection::new()
.add(
singleton::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl::default())),
)
.build_provider()
.unwrap();
let _ = services.get_required::<dyn TestService>();
}
#[test]
fn get_required_by_key_should_return_registered_service() {
let services = ServiceCollection::new()
.add(
singleton_with_key::<key::Thingy, dyn Thing, Thing3>()
.from(|_| Ref::new(Thing3::default())),
)
.add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
.build_provider()
.unwrap();
let thing = services.get_required_by_key::<key::Thingy, dyn Thing>();
assert_eq!(&thing.to_string(), "di::test::Thing3");
}
#[test]
#[should_panic(
expected = "No service for type 'dyn di::test::TestService' has been registered."
)]
fn get_required_should_panic_when_service_is_unregistered() {
let services = ServiceCollection::new().build_provider().unwrap();
let _ = services.get_required::<dyn TestService>();
}
#[test]
#[should_panic(
expected = "No service for type 'dyn di::test::Thing' with the key 'di::test::key::Thing1' has been registered."
)]
fn get_required_by_key_should_panic_when_service_is_unregistered() {
let services = ServiceCollection::new().build_provider().unwrap();
let _ = services.get_required_by_key::<key::Thing1, dyn Thing>();
}
#[test]
#[allow(clippy::vtable_address_comparisons)]
fn get_should_return_same_instance_for_singleton_service() {
let services = ServiceCollection::new()
.add(existing::<dyn TestService, TestServiceImpl>(Box::new(
TestServiceImpl::default(),
)))
.add(
singleton::<dyn OtherTestService, OtherTestServiceImpl>().from(|sp| {
Ref::new(OtherTestServiceImpl::new(
sp.get_required::<dyn TestService>(),
))
}),
)
.build_provider()
.unwrap();
let svc2 = services.get_required::<dyn OtherTestService>();
let svc1 = services.get_required::<dyn OtherTestService>();
assert!(Ref::ptr_eq(&svc1, &svc2));
}
#[test]
#[allow(clippy::vtable_address_comparisons)]
fn get_should_return_different_instances_for_transient_service() {
let services = ServiceCollection::new()
.add(
transient::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl::default())),
)
.build_provider()
.unwrap();
let svc1 = services.get_required::<dyn TestService>();
let svc2 = services.get_required::<dyn TestService>();
assert!(!Ref::ptr_eq(&svc1, &svc2));
}
#[test]
fn get_all_should_return_all_services() {
let mut collection = ServiceCollection::new();
collection
.add(
singleton::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl { value: 1 })),
)
.add(
singleton::<dyn TestService, TestService2Impl>()
.from(|_| Ref::new(TestService2Impl { value: 2 })),
);
let provider = collection.build_provider().unwrap();
let services = provider.get_all::<dyn TestService>();
let values: Vec<_> = services.map(|s| s.value()).collect();
assert_eq!(&values, &[1, 2]);
}
#[test]
fn get_all_by_key_should_return_all_services() {
let mut collection = ServiceCollection::new();
collection
.add(
singleton_with_key::<key::Thingies, dyn Thing, Thing1>()
.from(|_| Ref::new(Thing1::default())),
)
.add(
singleton_with_key::<key::Thingies, dyn Thing, Thing2>()
.from(|_| Ref::new(Thing2::default())),
)
.add(
singleton_with_key::<key::Thingies, dyn Thing, Thing3>()
.from(|_| Ref::new(Thing3::default())),
);
let provider = collection.build_provider().unwrap();
let services = provider.get_all_by_key::<key::Thingies, dyn Thing>();
let values: Vec<_> = services.map(|s| s.to_string()).collect();
assert_eq!(
&values,
&[
"di::test::Thing1".to_owned(),
"di::test::Thing2".to_owned(),
"di::test::Thing3".to_owned()
]
);
}
#[test]
#[allow(clippy::vtable_address_comparisons)]
fn two_scoped_service_providers_should_create_different_instances() {
let services = ServiceCollection::new()
.add(
scoped::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl::default())),
)
.build_provider()
.unwrap();
let scope1 = services.create_scope();
let scope2 = services.create_scope();
let svc1 = scope1.get_required::<dyn TestService>();
let svc2 = scope2.get_required::<dyn TestService>();
assert!(!Ref::ptr_eq(&svc1, &svc2));
}
#[test]
#[allow(clippy::vtable_address_comparisons)]
fn parent_child_scoped_service_providers_should_create_different_instances() {
let services = ServiceCollection::new()
.add(
scoped::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl::default())),
)
.build_provider()
.unwrap();
let scope1 = services.create_scope();
let scope2 = scope1.create_scope();
let svc1 = scope1.get_required::<dyn TestService>();
let svc2 = scope2.get_required::<dyn TestService>();
assert!(!Ref::ptr_eq(&svc1, &svc2));
}
#[test]
#[allow(clippy::vtable_address_comparisons)]
fn scoped_service_provider_should_have_same_singleton_when_eager_created_in_parent() {
let services = ServiceCollection::new()
.add(
singleton::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl::default())),
)
.build_provider()
.unwrap();
let svc1 = services.get_required::<dyn TestService>();
let scope1 = services.create_scope();
let scope2 = scope1.create_scope();
let svc2 = scope1.get_required::<dyn TestService>();
let svc3 = scope2.get_required::<dyn TestService>();
assert!(Ref::ptr_eq(&svc1, &svc2));
assert!(Ref::ptr_eq(&svc1, &svc3));
}
#[test]
#[allow(clippy::vtable_address_comparisons)]
fn scoped_service_provider_should_have_same_singleton_when_lazy_created_in_parent() {
let services = ServiceCollection::new()
.add(
singleton::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl::default())),
)
.build_provider()
.unwrap();
let scope1 = services.create_scope();
let scope2 = scope1.create_scope();
let svc1 = services.get_required::<dyn TestService>();
let svc2 = scope1.get_required::<dyn TestService>();
let svc3 = scope2.get_required::<dyn TestService>();
assert!(Ref::ptr_eq(&svc1, &svc2));
assert!(Ref::ptr_eq(&svc1, &svc3));
}
#[test]
fn service_provider_should_drop_existing_as_service() {
let file = new_temp_file("drop2");
{
let mut services = ServiceCollection::new();
services.add(existing_as_self(Droppable::new(file.clone())));
let _ = services.build_provider().unwrap();
}
let dropped = !file.exists();
remove_file(&file).ok();
assert!(dropped);
}
#[test]
fn service_provider_should_drop_lazy_initialized_service() {
let file = new_temp_file("drop3");
{
let provider = ServiceCollection::new()
.add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
.add(singleton_as_self().from(|sp| {
Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))
}))
.build_provider()
.unwrap();
let _ = provider.get_required::<Droppable>();
}
let dropped = !file.exists();
remove_file(&file).ok();
assert!(dropped);
}
#[test]
fn service_provider_should_not_drop_service_if_never_instantiated() {
let file = new_temp_file("drop5");
{
let _ = ServiceCollection::new()
.add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
.add(singleton_as_self().from(|sp| {
Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))
}))
.build_provider()
.unwrap();
}
let not_dropped = file.exists();
remove_file(&file).ok();
assert!(not_dropped);
}
#[test]
#[allow(clippy::vtable_address_comparisons)]
fn clone_should_be_shallow() {
let provider1 = ServiceCollection::new()
.add(
transient::<dyn TestService, TestServiceImpl>()
.from(|_| Ref::new(TestServiceImpl::default())),
)
.build_provider()
.unwrap();
let provider2 = provider1.clone();
assert!(Ref::ptr_eq(&provider1.services, &provider2.services));
assert!(std::ptr::eq(
provider1.services.as_ref(),
provider2.services.as_ref()
));
}
#[cfg(feature = "async")]
#[derive(Clone)]
struct Holder<T: Send + Sync + Clone>(T);
#[cfg(feature = "async")]
fn inject<V: Send + Sync + Clone>(value: V) -> Holder<V> {
Holder(value)
}
#[test]
#[cfg(feature = "async")]
fn service_provider_should_be_async_safe() {
let provider = ServiceCollection::new()
.add(
singleton::<dyn TestService, TestAsyncServiceImpl>()
.from(|_| Ref::new(TestAsyncServiceImpl::default())),
)
.build_provider()
.unwrap();
let holder = inject(provider);
let h1 = holder.clone();
let h2 = holder.clone();
let value = Arc::new(Mutex::new(0));
let v1 = value.clone();
let v2 = value.clone();
let t1 = thread::spawn(move || {
let service = h1.0.get_required::<dyn TestService>();
let mut result = v1.lock().unwrap();
*result += service.value();
});
let t2 = thread::spawn(move || {
let service = h2.0.get_required::<dyn TestService>();
let mut result = v2.lock().unwrap();
*result += service.value();
});
t1.join().ok();
t2.join().ok();
assert_eq!(*value.lock().unwrap(), 3);
}
}