1use crate::{
2 KeyedRef, KeyedRefMut, Mut, ServiceDescriptor, Ref, RefMut, Type,
3};
4use std::any::{type_name, Any};
5use std::borrow::Borrow;
6use std::collections::HashMap;
7use std::iter::empty;
8use std::marker::PhantomData;
9use std::ops::Deref;
10
11#[derive(Clone)]
13pub struct ServiceProvider {
14 services: Ref<HashMap<Type, Vec<ServiceDescriptor>>>,
15}
16
17#[cfg(feature = "async")]
18unsafe impl Send for ServiceProvider {}
19
20#[cfg(feature = "async")]
21unsafe impl Sync for ServiceProvider {}
22
23impl ServiceProvider {
24 pub fn new(services: HashMap<Type, Vec<ServiceDescriptor>>) -> Self {
30 Self {
31 services: Ref::new(services),
32 }
33 }
34
35 pub fn get<T: Any + ?Sized>(&self) -> Option<Ref<T>> {
37 let key = Type::of::<T>();
38
39 if let Some(descriptors) = self.services.get(&key) {
40 if let Some(descriptor) = descriptors.last() {
41 return Some(
42 descriptor
43 .get(self)
44 .downcast_ref::<Ref<T>>()
45 .unwrap()
46 .clone(),
47 );
48 }
49 }
50
51 None
52 }
53
54 pub fn get_mut<T: Any + ?Sized>(&self) -> Option<RefMut<T>> {
56 self.get::<Mut<T>>()
57 }
58
59 pub fn get_by_key<TKey, TSvc: Any + ?Sized>(&self) -> Option<KeyedRef<TKey, TSvc>> {
61 let key = Type::keyed::<TKey, TSvc>();
62
63 if let Some(descriptors) = self.services.get(&key) {
64 if let Some(descriptor) = descriptors.last() {
65 return Some(KeyedRef::new(
66 descriptor
67 .get(self)
68 .downcast_ref::<Ref<TSvc>>()
69 .unwrap()
70 .clone(),
71 ));
72 }
73 }
74
75 None
76 }
77
78 pub fn get_by_key_mut<TKey, TSvc: Any + ?Sized>(
80 &self,
81 ) -> Option<KeyedRefMut<TKey, TSvc>> {
82 self.get_by_key::<TKey, Mut<TSvc>>()
83 }
84
85 pub fn get_all<T: Any + ?Sized>(&self) -> impl Iterator<Item = Ref<T>> + '_ {
87 let key = Type::of::<T>();
88
89 if let Some(descriptors) = self.services.get(&key) {
90 ServiceIterator::new(self, descriptors.iter())
91 } else {
92 ServiceIterator::new(self, empty())
93 }
94 }
95
96 pub fn get_all_mut<T: Any + ?Sized>(&self) -> impl Iterator<Item = RefMut<T>> + '_ {
98 self.get_all::<Mut<T>>()
99 }
100
101 pub fn get_all_by_key<'a, TKey: 'a, TSvc>(
103 &'a self,
104 ) -> impl Iterator<Item = KeyedRef<TKey, TSvc>> + '_
105 where
106 TSvc: Any + ?Sized,
107 {
108 let key = Type::keyed::<TKey, TSvc>();
109
110 if let Some(descriptors) = self.services.get(&key) {
111 KeyedServiceIterator::new(self, descriptors.iter())
112 } else {
113 KeyedServiceIterator::new(self, empty())
114 }
115 }
116
117 pub fn get_all_by_key_mut<'a, TKey: 'a, TSvc>(
119 &'a self,
120 ) -> impl Iterator<Item = KeyedRefMut<TKey, TSvc>> + '_
121 where
122 TSvc: Any + ?Sized,
123 {
124 self.get_all_by_key::<TKey, Mut<TSvc>>()
125 }
126
127 pub fn get_required<T: Any + ?Sized>(&self) -> Ref<T> {
133 if let Some(service) = self.get::<T>() {
134 service
135 } else {
136 panic!(
137 "No service for type '{}' has been registered.",
138 type_name::<T>()
139 );
140 }
141 }
142
143 pub fn get_required_mut<T: Any + ?Sized>(&self) -> RefMut<T> {
149 self.get_required::<Mut<T>>()
150 }
151
152 pub fn get_required_by_key<TKey, TSvc: Any + ?Sized>(&self) -> KeyedRef<TKey, TSvc> {
158 if let Some(service) = self.get_by_key::<TKey, TSvc>() {
159 service
160 } else {
161 panic!(
162 "No service for type '{}' with the key '{}' has been registered.",
163 type_name::<TSvc>(),
164 type_name::<TKey>()
165 );
166 }
167 }
168
169 pub fn get_required_by_key_mut<TKey, TSvc: Any + ?Sized>(
175 &self,
176 ) -> KeyedRefMut<TKey, TSvc> {
177 self.get_required_by_key::<TKey, Mut<TSvc>>()
178 }
179
180 pub fn create_scope(&self) -> Self {
183 Self::new(self.services.as_ref().clone())
184 }
185}
186
187#[derive(Clone, Default)]
196pub struct ScopedServiceProvider {
197 sp: ServiceProvider
198}
199
200impl From<&ServiceProvider> for ScopedServiceProvider {
201 fn from(value: &ServiceProvider) -> Self {
202 Self { sp: value.create_scope() }
203 }
204}
205
206impl AsRef<ServiceProvider> for ScopedServiceProvider {
207 fn as_ref(&self) -> &ServiceProvider {
208 &self.sp
209 }
210}
211
212impl Borrow<ServiceProvider> for ScopedServiceProvider {
213 fn borrow(&self) -> &ServiceProvider {
214 &self.sp
215 }
216}
217
218impl Deref for ScopedServiceProvider {
219 type Target = ServiceProvider;
220
221 fn deref(&self) -> &Self::Target {
222 &self.sp
223 }
224}
225
226struct ServiceIterator<'a, T>
227where
228 T: Any + ?Sized,
229{
230 provider: &'a ServiceProvider,
231 descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
232 _marker: PhantomData<T>,
233}
234
235struct KeyedServiceIterator<'a, TKey, TSvc>
236where
237 TSvc: Any + ?Sized,
238{
239 provider: &'a ServiceProvider,
240 descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
241 _key: PhantomData<TKey>,
242 _svc: PhantomData<TSvc>,
243}
244
245impl<'a, T: Any + ?Sized> ServiceIterator<'a, T> {
246 fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
247 where
248 I: Iterator<Item = &'a ServiceDescriptor> + 'a,
249 {
250 Self {
251 provider,
252 descriptors: Box::new(descriptors),
253 _marker: PhantomData,
254 }
255 }
256}
257
258impl<'a, T: Any + ?Sized> Iterator for ServiceIterator<'a, T> {
259 type Item = Ref<T>;
260 fn next(&mut self) -> Option<Self::Item> {
261 if let Some(descriptor) = self.descriptors.next() {
262 Some(
263 descriptor
264 .get(self.provider)
265 .downcast_ref::<Ref<T>>()
266 .unwrap()
267 .clone(),
268 )
269 } else {
270 None
271 }
272 }
273}
274
275impl<'a, TKey, TSvc: Any + ?Sized> KeyedServiceIterator<'a, TKey, TSvc> {
276 fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
277 where
278 I: Iterator<Item = &'a ServiceDescriptor> + 'a,
279 {
280 Self {
281 provider,
282 descriptors: Box::new(descriptors),
283 _key: PhantomData,
284 _svc: PhantomData,
285 }
286 }
287}
288
289impl<'a, TKey, TSvc: Any + ?Sized> Iterator for KeyedServiceIterator<'a, TKey, TSvc> {
290 type Item = KeyedRef<TKey, TSvc>;
291 fn next(&mut self) -> Option<Self::Item> {
292 if let Some(descriptor) = self.descriptors.next() {
293 Some(KeyedRef::new(
294 descriptor
295 .get(self.provider)
296 .downcast_ref::<Ref<TSvc>>()
297 .unwrap()
298 .clone(),
299 ))
300 } else {
301 None
302 }
303 }
304}
305
306impl Default for ServiceProvider {
307 fn default() -> Self {
308 Self {
309 services: Ref::new(HashMap::with_capacity(0)),
310 }
311 }
312}
313
314#[cfg(test)]
315mod tests {
316
317 use crate::{test::*, *};
318 use std::fs::remove_file;
319 use std::path::{Path, PathBuf};
320
321 #[cfg(feature = "async")]
322 use std::sync::{Arc, Mutex};
323
324 #[cfg(feature = "async")]
325 use std::thread;
326
327 #[test]
328 fn get_should_return_none_when_service_is_unregistered() {
329 let services = ServiceCollection::new().build_provider().unwrap();
331
332 let result = services.get::<dyn TestService>();
334
335 assert!(result.is_none());
337 }
338
339 #[test]
340 fn get_by_key_should_return_none_when_service_is_unregistered() {
341 let services = ServiceCollection::new().build_provider().unwrap();
343
344 let result = services.get_by_key::<key::Thingy, dyn TestService>();
346
347 assert!(result.is_none());
349 }
350
351 #[test]
352 fn get_should_return_registered_service() {
353 let services = ServiceCollection::new()
355 .add(
356 singleton::<dyn TestService, TestServiceImpl>()
357 .from(|_| Ref::new(TestServiceImpl::default())),
358 )
359 .build_provider()
360 .unwrap();
361
362 let result = services.get::<dyn TestService>();
364
365 assert!(result.is_some());
367 }
368
369 #[test]
370 fn get_by_key_should_return_registered_service() {
371 let services = ServiceCollection::new()
373 .add(
374 singleton_with_key::<key::Thingy, dyn Thing, Thing1>()
375 .from(|_| Ref::new(Thing1::default())),
376 )
377 .add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
378 .build_provider()
379 .unwrap();
380
381 let result = services.get_by_key::<key::Thingy, dyn Thing>();
383
384 assert!(result.is_some());
386 }
387
388 #[test]
389 fn get_required_should_return_registered_service() {
390 let services = ServiceCollection::new()
392 .add(
393 singleton::<dyn TestService, TestServiceImpl>()
394 .from(|_| Ref::new(TestServiceImpl::default())),
395 )
396 .build_provider()
397 .unwrap();
398
399 let _ = services.get_required::<dyn TestService>();
401
402 }
405
406 #[test]
407 fn get_required_by_key_should_return_registered_service() {
408 let services = ServiceCollection::new()
410 .add(
411 singleton_with_key::<key::Thingy, dyn Thing, Thing3>()
412 .from(|_| Ref::new(Thing3::default())),
413 )
414 .add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
415 .build_provider()
416 .unwrap();
417
418 let thing = services.get_required_by_key::<key::Thingy, dyn Thing>();
420
421 assert_eq!(&thing.to_string(), "di::test::Thing3");
423 }
424
425 #[test]
426 #[should_panic(
427 expected = "No service for type 'dyn di::test::TestService' has been registered."
428 )]
429 fn get_required_should_panic_when_service_is_unregistered() {
430 let services = ServiceCollection::new().build_provider().unwrap();
432
433 let _ = services.get_required::<dyn TestService>();
435
436 }
439
440 #[test]
441 #[should_panic(
442 expected = "No service for type 'dyn di::test::Thing' with the key 'di::test::key::Thing1' has been registered."
443 )]
444 fn get_required_by_key_should_panic_when_service_is_unregistered() {
445 let services = ServiceCollection::new().build_provider().unwrap();
447
448 let _ = services.get_required_by_key::<key::Thing1, dyn Thing>();
450
451 }
454
455 #[test]
456 #[allow(clippy::vtable_address_comparisons)]
457 fn get_should_return_same_instance_for_singleton_service() {
458 let services = ServiceCollection::new()
460 .add(existing::<dyn TestService, TestServiceImpl>(Box::new(
461 TestServiceImpl::default(),
462 )))
463 .add(
464 singleton::<dyn OtherTestService, OtherTestServiceImpl>().from(|sp| {
465 Ref::new(OtherTestServiceImpl::new(
466 sp.get_required::<dyn TestService>(),
467 ))
468 }),
469 )
470 .build_provider()
471 .unwrap();
472
473 let svc2 = services.get_required::<dyn OtherTestService>();
475 let svc1 = services.get_required::<dyn OtherTestService>();
476
477 assert!(Ref::ptr_eq(&svc1, &svc2));
479 }
480
481 #[test]
482 #[allow(clippy::vtable_address_comparisons)]
483 fn get_should_return_different_instances_for_transient_service() {
484 let services = ServiceCollection::new()
486 .add(
487 transient::<dyn TestService, TestServiceImpl>()
488 .from(|_| Ref::new(TestServiceImpl::default())),
489 )
490 .build_provider()
491 .unwrap();
492
493 let svc1 = services.get_required::<dyn TestService>();
495 let svc2 = services.get_required::<dyn TestService>();
496
497 assert!(!Ref::ptr_eq(&svc1, &svc2));
499 }
500
501 #[test]
502 fn get_all_should_return_all_services() {
503 let mut collection = ServiceCollection::new();
505
506 collection
507 .add(
508 singleton::<dyn TestService, TestServiceImpl>()
509 .from(|_| Ref::new(TestServiceImpl { value: 1 })),
510 )
511 .add(
512 singleton::<dyn TestService, TestService2Impl>()
513 .from(|_| Ref::new(TestService2Impl { value: 2 })),
514 );
515
516 let provider = collection.build_provider().unwrap();
517
518 let services = provider.get_all::<dyn TestService>();
520 let values: Vec<_> = services.map(|s| s.value()).collect();
521
522 assert_eq!(&values, &[1, 2]);
524 }
525
526 #[test]
527 fn get_all_by_key_should_return_all_services() {
528 let mut collection = ServiceCollection::new();
530
531 collection
532 .add(
533 singleton_with_key::<key::Thingies, dyn Thing, Thing1>()
534 .from(|_| Ref::new(Thing1::default())),
535 )
536 .add(
537 singleton_with_key::<key::Thingies, dyn Thing, Thing2>()
538 .from(|_| Ref::new(Thing2::default())),
539 )
540 .add(
541 singleton_with_key::<key::Thingies, dyn Thing, Thing3>()
542 .from(|_| Ref::new(Thing3::default())),
543 );
544
545 let provider = collection.build_provider().unwrap();
546
547 let services = provider.get_all_by_key::<key::Thingies, dyn Thing>();
549 let values: Vec<_> = services.map(|s| s.to_string()).collect();
550
551 assert_eq!(
553 &values,
554 &[
555 "di::test::Thing1".to_owned(),
556 "di::test::Thing2".to_owned(),
557 "di::test::Thing3".to_owned()
558 ]
559 );
560 }
561
562 #[test]
563 #[allow(clippy::vtable_address_comparisons)]
564 fn two_scoped_service_providers_should_create_different_instances() {
565 let services = ServiceCollection::new()
567 .add(
568 scoped::<dyn TestService, TestServiceImpl>()
569 .from(|_| Ref::new(TestServiceImpl::default())),
570 )
571 .build_provider()
572 .unwrap();
573 let scope1 = services.create_scope();
574 let scope2 = services.create_scope();
575
576 let svc1 = scope1.get_required::<dyn TestService>();
578 let svc2 = scope2.get_required::<dyn TestService>();
579
580 assert!(!Ref::ptr_eq(&svc1, &svc2));
582 }
583
584 #[test]
585 #[allow(clippy::vtable_address_comparisons)]
586 fn parent_child_scoped_service_providers_should_create_different_instances() {
587 let services = ServiceCollection::new()
589 .add(
590 scoped::<dyn TestService, TestServiceImpl>()
591 .from(|_| Ref::new(TestServiceImpl::default())),
592 )
593 .build_provider()
594 .unwrap();
595 let scope1 = services.create_scope();
596 let scope2 = scope1.create_scope();
597
598 let svc1 = scope1.get_required::<dyn TestService>();
600 let svc2 = scope2.get_required::<dyn TestService>();
601
602 assert!(!Ref::ptr_eq(&svc1, &svc2));
604 }
605
606 #[test]
607 #[allow(clippy::vtable_address_comparisons)]
608 fn scoped_service_provider_should_have_same_singleton_when_eager_created_in_parent() {
609 let services = ServiceCollection::new()
611 .add(
612 singleton::<dyn TestService, TestServiceImpl>()
613 .from(|_| Ref::new(TestServiceImpl::default())),
614 )
615 .build_provider()
616 .unwrap();
617 let svc1 = services.get_required::<dyn TestService>();
618 let scope1 = services.create_scope();
619 let scope2 = scope1.create_scope();
620
621 let svc2 = scope1.get_required::<dyn TestService>();
623 let svc3 = scope2.get_required::<dyn TestService>();
624
625 assert!(Ref::ptr_eq(&svc1, &svc2));
627 assert!(Ref::ptr_eq(&svc1, &svc3));
628 }
629
630 #[test]
631 #[allow(clippy::vtable_address_comparisons)]
632 fn scoped_service_provider_should_have_same_singleton_when_lazy_created_in_parent() {
633 let services = ServiceCollection::new()
635 .add(
636 singleton::<dyn TestService, TestServiceImpl>()
637 .from(|_| Ref::new(TestServiceImpl::default())),
638 )
639 .build_provider()
640 .unwrap();
641 let scope1 = services.create_scope();
642 let scope2 = scope1.create_scope();
643 let svc1 = services.get_required::<dyn TestService>();
644
645 let svc2 = scope1.get_required::<dyn TestService>();
647 let svc3 = scope2.get_required::<dyn TestService>();
648
649 assert!(Ref::ptr_eq(&svc1, &svc2));
651 assert!(Ref::ptr_eq(&svc1, &svc3));
652 }
653
654 #[test]
655 fn service_provider_should_drop_existing_as_service() {
656 let file = new_temp_file("drop2");
658
659 {
661 let mut services = ServiceCollection::new();
662 services.add(existing_as_self(Droppable::new(file.clone())));
663 let _ = services.build_provider().unwrap();
664 }
665
666 let dropped = !file.exists();
668 remove_file(&file).ok();
669 assert!(dropped);
670 }
671
672 #[test]
673 fn service_provider_should_drop_lazy_initialized_service() {
674 let file = new_temp_file("drop3");
676
677 {
679 let provider = ServiceCollection::new()
680 .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
681 .add(singleton_as_self().from(|sp| {
682 Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))
683 }))
684 .build_provider()
685 .unwrap();
686 let _ = provider.get_required::<Droppable>();
687 }
688
689 let dropped = !file.exists();
691 remove_file(&file).ok();
692 assert!(dropped);
693 }
694
695 #[test]
696 fn service_provider_should_not_drop_service_if_never_instantiated() {
697 let file = new_temp_file("drop5");
699
700 {
702 let _ = ServiceCollection::new()
703 .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
704 .add(singleton_as_self().from(|sp| {
705 Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))
706 }))
707 .build_provider()
708 .unwrap();
709 }
710
711 let not_dropped = file.exists();
713 remove_file(&file).ok();
714 assert!(not_dropped);
715 }
716
717 #[test]
718 #[allow(clippy::vtable_address_comparisons)]
719 fn clone_should_be_shallow() {
720 let provider1 = ServiceCollection::new()
722 .add(
723 transient::<dyn TestService, TestServiceImpl>()
724 .from(|_| Ref::new(TestServiceImpl::default())),
725 )
726 .build_provider()
727 .unwrap();
728
729 let provider2 = provider1.clone();
731
732 assert!(Ref::ptr_eq(&provider1.services, &provider2.services));
734 assert!(std::ptr::eq(
735 provider1.services.as_ref(),
736 provider2.services.as_ref()
737 ));
738 }
739
740 #[cfg(feature = "async")]
741 #[derive(Clone)]
742 struct Holder<T: Send + Sync + Clone>(T);
743
744 #[cfg(feature = "async")]
745 fn inject<V: Send + Sync + Clone>(value: V) -> Holder<V> {
746 Holder(value)
747 }
748
749 #[test]
750 #[cfg(feature = "async")]
751 fn service_provider_should_be_async_safe() {
752 let provider = ServiceCollection::new()
754 .add(
755 singleton::<dyn TestService, TestAsyncServiceImpl>()
756 .from(|_| Ref::new(TestAsyncServiceImpl::default())),
757 )
758 .build_provider()
759 .unwrap();
760 let holder = inject(provider);
761 let h1 = holder.clone();
762 let h2 = holder.clone();
763 let value = Arc::new(Mutex::new(0));
764 let v1 = value.clone();
765 let v2 = value.clone();
766
767 let t1 = thread::spawn(move || {
769 let service = h1.0.get_required::<dyn TestService>();
770 let mut result = v1.lock().unwrap();
771 *result += service.value();
772 });
773
774 let t2 = thread::spawn(move || {
775 let service = h2.0.get_required::<dyn TestService>();
776 let mut result = v2.lock().unwrap();
777 *result += service.value();
778 });
779
780 t1.join().ok();
781 t2.join().ok();
782
783 assert_eq!(*value.lock().unwrap(), 3);
785 }
786}