di/
provider.rs

1use crate::{
2    KeyedRef, KeyedRefMut, Mut, ServiceDescriptor, Ref, RefMut, Type,
3};
4use std::any::{type_name, Any};
5use std::borrow::Borrow;
6use std::collections::HashMap;
7use std::iter::empty;
8use std::marker::PhantomData;
9use std::ops::Deref;
10
11/// Represents a service provider.
12#[derive(Clone)]
13pub struct ServiceProvider {
14    services: Ref<HashMap<Type, Vec<ServiceDescriptor>>>,
15}
16
17#[cfg(feature = "async")]
18unsafe impl Send for ServiceProvider {}
19
20#[cfg(feature = "async")]
21unsafe impl Sync for ServiceProvider {}
22
23impl ServiceProvider {
24    /// Initializes a new service provider.
25    ///
26    /// # Arguments
27    ///
28    /// * `services` - The [`ServiceDescriptor`](crate::ServiceDescriptor) map encapsulated by the provider
29    pub fn new(services: HashMap<Type, Vec<ServiceDescriptor>>) -> Self {
30        Self {
31            services: Ref::new(services),
32        }
33    }
34
35    /// Gets a service of the specified type.
36    pub fn get<T: Any + ?Sized>(&self) -> Option<Ref<T>> {
37        let key = Type::of::<T>();
38
39        if let Some(descriptors) = self.services.get(&key) {
40            if let Some(descriptor) = descriptors.last() {
41                return Some(
42                    descriptor
43                        .get(self)
44                        .downcast_ref::<Ref<T>>()
45                        .unwrap()
46                        .clone(),
47                );
48            }
49        }
50
51        None
52    }
53
54    /// Gets a mutable service of the specified type.
55    pub fn get_mut<T: Any + ?Sized>(&self) -> Option<RefMut<T>> {
56        self.get::<Mut<T>>()
57    }
58
59    /// Gets a keyed service of the specified type.
60    pub fn get_by_key<TKey, TSvc: Any + ?Sized>(&self) -> Option<KeyedRef<TKey, TSvc>> {
61        let key = Type::keyed::<TKey, TSvc>();
62
63        if let Some(descriptors) = self.services.get(&key) {
64            if let Some(descriptor) = descriptors.last() {
65                return Some(KeyedRef::new(
66                    descriptor
67                        .get(self)
68                        .downcast_ref::<Ref<TSvc>>()
69                        .unwrap()
70                        .clone(),
71                ));
72            }
73        }
74
75        None
76    }
77
78    /// Gets a keyed, mutable service of the specified type.
79    pub fn get_by_key_mut<TKey, TSvc: Any + ?Sized>(
80        &self,
81    ) -> Option<KeyedRefMut<TKey, TSvc>> {
82        self.get_by_key::<TKey, Mut<TSvc>>()
83    }
84
85    /// Gets all of the services of the specified type.
86    pub fn get_all<T: Any + ?Sized>(&self) -> impl Iterator<Item = Ref<T>> + '_ {
87        let key = Type::of::<T>();
88
89        if let Some(descriptors) = self.services.get(&key) {
90            ServiceIterator::new(self, descriptors.iter())
91        } else {
92            ServiceIterator::new(self, empty())
93        }
94    }
95
96    /// Gets all of the mutable services of the specified type.
97    pub fn get_all_mut<T: Any + ?Sized>(&self) -> impl Iterator<Item = RefMut<T>> + '_ {
98        self.get_all::<Mut<T>>()
99    }
100
101    /// Gets all of the services of the specified key and type.
102    pub fn get_all_by_key<'a, TKey: 'a, TSvc>(
103        &'a self,
104    ) -> impl Iterator<Item = KeyedRef<TKey, TSvc>> + '_
105    where
106        TSvc: Any + ?Sized,
107    {
108        let key = Type::keyed::<TKey, TSvc>();
109
110        if let Some(descriptors) = self.services.get(&key) {
111            KeyedServiceIterator::new(self, descriptors.iter())
112        } else {
113            KeyedServiceIterator::new(self, empty())
114        }
115    }
116
117    /// Gets all of the mutable services of the specified key and type.
118    pub fn get_all_by_key_mut<'a, TKey: 'a, TSvc>(
119        &'a self,
120    ) -> impl Iterator<Item = KeyedRefMut<TKey, TSvc>> + '_
121    where
122        TSvc: Any + ?Sized,
123    {
124        self.get_all_by_key::<TKey, Mut<TSvc>>()
125    }
126
127    /// Gets a required service of the specified type.
128    ///
129    /// # Panics
130    ///
131    /// The requested service of type `T` does not exist.
132    pub fn get_required<T: Any + ?Sized>(&self) -> Ref<T> {
133        if let Some(service) = self.get::<T>() {
134            service
135        } else {
136            panic!(
137                "No service for type '{}' has been registered.",
138                type_name::<T>()
139            );
140        }
141    }
142
143    /// Gets a required, mutable service of the specified type.
144    ///
145    /// # Panics
146    ///
147    /// The requested service of type `T` does not exist.
148    pub fn get_required_mut<T: Any + ?Sized>(&self) -> RefMut<T> {
149        self.get_required::<Mut<T>>()
150    }
151
152    /// Gets a required keyed service of the specified type.
153    ///
154    /// # Panics
155    ///
156    /// The requested service of type `TSvc` with key `TKey` does not exist.
157    pub fn get_required_by_key<TKey, TSvc: Any + ?Sized>(&self) -> KeyedRef<TKey, TSvc> {
158        if let Some(service) = self.get_by_key::<TKey, TSvc>() {
159            service
160        } else {
161            panic!(
162                "No service for type '{}' with the key '{}' has been registered.",
163                type_name::<TSvc>(),
164                type_name::<TKey>()
165            );
166        }
167    }
168
169    /// Gets a required keyed service of the specified type.
170    ///
171    /// # Panics
172    ///
173    /// The requested service of type `TSvc` with key `TKey` does not exist.
174    pub fn get_required_by_key_mut<TKey, TSvc: Any + ?Sized>(
175        &self,
176    ) -> KeyedRefMut<TKey, TSvc> {
177        self.get_required_by_key::<TKey, Mut<TSvc>>()
178    }
179
180    /// Creates and returns a new service provider that is used to resolve
181    /// services from a newly create scope.
182    pub fn create_scope(&self) -> Self {
183        Self::new(self.services.as_ref().clone())
184    }
185}
186
187/// Represents a scoped [`ServiceProvider`].
188/// 
189/// # Remarks
190/// 
191/// This struct has the exact same functionality as [`ServiceProvider`](crate::ServiceProvider).
192/// When a new instance is created, it also creates a new scope from the source
193/// [`ServiceProvider`](crate::ServiceProvider). The primary use case for this struct is to
194/// explicitly declare that a new scope should be created at the injection call site.
195#[derive(Clone, Default)]
196pub struct ScopedServiceProvider {
197    sp: ServiceProvider
198}
199
200impl From<&ServiceProvider> for ScopedServiceProvider {
201    fn from(value: &ServiceProvider) -> Self {
202        Self { sp: value.create_scope() }
203    }
204}
205
206impl AsRef<ServiceProvider> for ScopedServiceProvider {
207    fn as_ref(&self) -> &ServiceProvider {
208        &self.sp
209    }
210}
211
212impl Borrow<ServiceProvider> for ScopedServiceProvider {
213    fn borrow(&self) -> &ServiceProvider {
214        &self.sp
215    }
216}
217
218impl Deref for ScopedServiceProvider {
219    type Target = ServiceProvider;
220
221    fn deref(&self) -> &Self::Target {
222        &self.sp
223    }
224}
225
226struct ServiceIterator<'a, T>
227where
228    T: Any + ?Sized,
229{
230    provider: &'a ServiceProvider,
231    descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
232    _marker: PhantomData<T>,
233}
234
235struct KeyedServiceIterator<'a, TKey, TSvc>
236where
237    TSvc: Any + ?Sized,
238{
239    provider: &'a ServiceProvider,
240    descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
241    _key: PhantomData<TKey>,
242    _svc: PhantomData<TSvc>,
243}
244
245impl<'a, T: Any + ?Sized> ServiceIterator<'a, T> {
246    fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
247    where
248        I: Iterator<Item = &'a ServiceDescriptor> + 'a,
249    {
250        Self {
251            provider,
252            descriptors: Box::new(descriptors),
253            _marker: PhantomData,
254        }
255    }
256}
257
258impl<'a, T: Any + ?Sized> Iterator for ServiceIterator<'a, T> {
259    type Item = Ref<T>;
260    fn next(&mut self) -> Option<Self::Item> {
261        if let Some(descriptor) = self.descriptors.next() {
262            Some(
263                descriptor
264                    .get(self.provider)
265                    .downcast_ref::<Ref<T>>()
266                    .unwrap()
267                    .clone(),
268            )
269        } else {
270            None
271        }
272    }
273}
274
275impl<'a, TKey, TSvc: Any + ?Sized> KeyedServiceIterator<'a, TKey, TSvc> {
276    fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
277    where
278        I: Iterator<Item = &'a ServiceDescriptor> + 'a,
279    {
280        Self {
281            provider,
282            descriptors: Box::new(descriptors),
283            _key: PhantomData,
284            _svc: PhantomData,
285        }
286    }
287}
288
289impl<'a, TKey, TSvc: Any + ?Sized> Iterator for KeyedServiceIterator<'a, TKey, TSvc> {
290    type Item = KeyedRef<TKey, TSvc>;
291    fn next(&mut self) -> Option<Self::Item> {
292        if let Some(descriptor) = self.descriptors.next() {
293            Some(KeyedRef::new(
294                descriptor
295                    .get(self.provider)
296                    .downcast_ref::<Ref<TSvc>>()
297                    .unwrap()
298                    .clone(),
299            ))
300        } else {
301            None
302        }
303    }
304}
305
306impl Default for ServiceProvider {
307    fn default() -> Self {
308        Self {
309            services: Ref::new(HashMap::with_capacity(0)),
310        }
311    }
312}
313
314#[cfg(test)]
315mod tests {
316
317    use crate::{test::*, *};
318    use std::fs::remove_file;
319    use std::path::{Path, PathBuf};
320
321    #[cfg(feature = "async")]
322    use std::sync::{Arc, Mutex};
323
324    #[cfg(feature = "async")]
325    use std::thread;
326
327    #[test]
328    fn get_should_return_none_when_service_is_unregistered() {
329        // arrange
330        let services = ServiceCollection::new().build_provider().unwrap();
331
332        // act
333        let result = services.get::<dyn TestService>();
334
335        // assert
336        assert!(result.is_none());
337    }
338
339    #[test]
340    fn get_by_key_should_return_none_when_service_is_unregistered() {
341        // arrange
342        let services = ServiceCollection::new().build_provider().unwrap();
343
344        // act
345        let result = services.get_by_key::<key::Thingy, dyn TestService>();
346
347        // assert
348        assert!(result.is_none());
349    }
350
351    #[test]
352    fn get_should_return_registered_service() {
353        // arrange
354        let services = ServiceCollection::new()
355            .add(
356                singleton::<dyn TestService, TestServiceImpl>()
357                    .from(|_| Ref::new(TestServiceImpl::default())),
358            )
359            .build_provider()
360            .unwrap();
361
362        // act
363        let result = services.get::<dyn TestService>();
364
365        // assert
366        assert!(result.is_some());
367    }
368
369    #[test]
370    fn get_by_key_should_return_registered_service() {
371        // arrange
372        let services = ServiceCollection::new()
373            .add(
374                singleton_with_key::<key::Thingy, dyn Thing, Thing1>()
375                    .from(|_| Ref::new(Thing1::default())),
376            )
377            .add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
378            .build_provider()
379            .unwrap();
380
381        // act
382        let result = services.get_by_key::<key::Thingy, dyn Thing>();
383
384        // assert
385        assert!(result.is_some());
386    }
387
388    #[test]
389    fn get_required_should_return_registered_service() {
390        // arrange
391        let services = ServiceCollection::new()
392            .add(
393                singleton::<dyn TestService, TestServiceImpl>()
394                    .from(|_| Ref::new(TestServiceImpl::default())),
395            )
396            .build_provider()
397            .unwrap();
398
399        // act
400        let _ = services.get_required::<dyn TestService>();
401
402        // assert
403        // didn't panic
404    }
405
406    #[test]
407    fn get_required_by_key_should_return_registered_service() {
408        // arrange
409        let services = ServiceCollection::new()
410            .add(
411                singleton_with_key::<key::Thingy, dyn Thing, Thing3>()
412                    .from(|_| Ref::new(Thing3::default())),
413            )
414            .add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
415            .build_provider()
416            .unwrap();
417
418        // act
419        let thing = services.get_required_by_key::<key::Thingy, dyn Thing>();
420
421        // assert
422        assert_eq!(&thing.to_string(), "di::test::Thing3");
423    }
424
425    #[test]
426    #[should_panic(
427        expected = "No service for type 'dyn di::test::TestService' has been registered."
428    )]
429    fn get_required_should_panic_when_service_is_unregistered() {
430        // arrange
431        let services = ServiceCollection::new().build_provider().unwrap();
432
433        // act
434        let _ = services.get_required::<dyn TestService>();
435
436        // assert
437        // panics
438    }
439
440    #[test]
441    #[should_panic(
442        expected = "No service for type 'dyn di::test::Thing' with the key 'di::test::key::Thing1' has been registered."
443    )]
444    fn get_required_by_key_should_panic_when_service_is_unregistered() {
445        // arrange
446        let services = ServiceCollection::new().build_provider().unwrap();
447
448        // act
449        let _ = services.get_required_by_key::<key::Thing1, dyn Thing>();
450
451        // assert
452        // panics
453    }
454
455    #[test]
456    #[allow(clippy::vtable_address_comparisons)]
457    fn get_should_return_same_instance_for_singleton_service() {
458        // arrange
459        let services = ServiceCollection::new()
460            .add(existing::<dyn TestService, TestServiceImpl>(Box::new(
461                TestServiceImpl::default(),
462            )))
463            .add(
464                singleton::<dyn OtherTestService, OtherTestServiceImpl>().from(|sp| {
465                    Ref::new(OtherTestServiceImpl::new(
466                        sp.get_required::<dyn TestService>(),
467                    ))
468                }),
469            )
470            .build_provider()
471            .unwrap();
472
473        // act
474        let svc2 = services.get_required::<dyn OtherTestService>();
475        let svc1 = services.get_required::<dyn OtherTestService>();
476
477        // assert
478        assert!(Ref::ptr_eq(&svc1, &svc2));
479    }
480
481    #[test]
482    #[allow(clippy::vtable_address_comparisons)]
483    fn get_should_return_different_instances_for_transient_service() {
484        // arrange
485        let services = ServiceCollection::new()
486            .add(
487                transient::<dyn TestService, TestServiceImpl>()
488                    .from(|_| Ref::new(TestServiceImpl::default())),
489            )
490            .build_provider()
491            .unwrap();
492
493        // act
494        let svc1 = services.get_required::<dyn TestService>();
495        let svc2 = services.get_required::<dyn TestService>();
496
497        // assert
498        assert!(!Ref::ptr_eq(&svc1, &svc2));
499    }
500
501    #[test]
502    fn get_all_should_return_all_services() {
503        // arrange
504        let mut collection = ServiceCollection::new();
505
506        collection
507            .add(
508                singleton::<dyn TestService, TestServiceImpl>()
509                    .from(|_| Ref::new(TestServiceImpl { value: 1 })),
510            )
511            .add(
512                singleton::<dyn TestService, TestService2Impl>()
513                    .from(|_| Ref::new(TestService2Impl { value: 2 })),
514            );
515
516        let provider = collection.build_provider().unwrap();
517
518        // act
519        let services = provider.get_all::<dyn TestService>();
520        let values: Vec<_> = services.map(|s| s.value()).collect();
521
522        // assert
523        assert_eq!(&values, &[1, 2]);
524    }
525
526    #[test]
527    fn get_all_by_key_should_return_all_services() {
528        // arrange
529        let mut collection = ServiceCollection::new();
530
531        collection
532            .add(
533                singleton_with_key::<key::Thingies, dyn Thing, Thing1>()
534                    .from(|_| Ref::new(Thing1::default())),
535            )
536            .add(
537                singleton_with_key::<key::Thingies, dyn Thing, Thing2>()
538                    .from(|_| Ref::new(Thing2::default())),
539            )
540            .add(
541                singleton_with_key::<key::Thingies, dyn Thing, Thing3>()
542                    .from(|_| Ref::new(Thing3::default())),
543            );
544
545        let provider = collection.build_provider().unwrap();
546
547        // act
548        let services = provider.get_all_by_key::<key::Thingies, dyn Thing>();
549        let values: Vec<_> = services.map(|s| s.to_string()).collect();
550
551        // assert
552        assert_eq!(
553            &values,
554            &[
555                "di::test::Thing1".to_owned(),
556                "di::test::Thing2".to_owned(),
557                "di::test::Thing3".to_owned()
558            ]
559        );
560    }
561
562    #[test]
563    #[allow(clippy::vtable_address_comparisons)]
564    fn two_scoped_service_providers_should_create_different_instances() {
565        // arrange
566        let services = ServiceCollection::new()
567            .add(
568                scoped::<dyn TestService, TestServiceImpl>()
569                    .from(|_| Ref::new(TestServiceImpl::default())),
570            )
571            .build_provider()
572            .unwrap();
573        let scope1 = services.create_scope();
574        let scope2 = services.create_scope();
575
576        // act
577        let svc1 = scope1.get_required::<dyn TestService>();
578        let svc2 = scope2.get_required::<dyn TestService>();
579
580        // assert
581        assert!(!Ref::ptr_eq(&svc1, &svc2));
582    }
583
584    #[test]
585    #[allow(clippy::vtable_address_comparisons)]
586    fn parent_child_scoped_service_providers_should_create_different_instances() {
587        // arrange
588        let services = ServiceCollection::new()
589            .add(
590                scoped::<dyn TestService, TestServiceImpl>()
591                    .from(|_| Ref::new(TestServiceImpl::default())),
592            )
593            .build_provider()
594            .unwrap();
595        let scope1 = services.create_scope();
596        let scope2 = scope1.create_scope();
597
598        // act
599        let svc1 = scope1.get_required::<dyn TestService>();
600        let svc2 = scope2.get_required::<dyn TestService>();
601
602        // assert
603        assert!(!Ref::ptr_eq(&svc1, &svc2));
604    }
605
606    #[test]
607    #[allow(clippy::vtable_address_comparisons)]
608    fn scoped_service_provider_should_have_same_singleton_when_eager_created_in_parent() {
609        // arrange
610        let services = ServiceCollection::new()
611            .add(
612                singleton::<dyn TestService, TestServiceImpl>()
613                    .from(|_| Ref::new(TestServiceImpl::default())),
614            )
615            .build_provider()
616            .unwrap();
617        let svc1 = services.get_required::<dyn TestService>();
618        let scope1 = services.create_scope();
619        let scope2 = scope1.create_scope();
620
621        // act
622        let svc2 = scope1.get_required::<dyn TestService>();
623        let svc3 = scope2.get_required::<dyn TestService>();
624
625        // assert
626        assert!(Ref::ptr_eq(&svc1, &svc2));
627        assert!(Ref::ptr_eq(&svc1, &svc3));
628    }
629
630    #[test]
631    #[allow(clippy::vtable_address_comparisons)]
632    fn scoped_service_provider_should_have_same_singleton_when_lazy_created_in_parent() {
633        // arrange
634        let services = ServiceCollection::new()
635            .add(
636                singleton::<dyn TestService, TestServiceImpl>()
637                    .from(|_| Ref::new(TestServiceImpl::default())),
638            )
639            .build_provider()
640            .unwrap();
641        let scope1 = services.create_scope();
642        let scope2 = scope1.create_scope();
643        let svc1 = services.get_required::<dyn TestService>();
644
645        // act
646        let svc2 = scope1.get_required::<dyn TestService>();
647        let svc3 = scope2.get_required::<dyn TestService>();
648
649        // assert
650        assert!(Ref::ptr_eq(&svc1, &svc2));
651        assert!(Ref::ptr_eq(&svc1, &svc3));
652    }
653
654    #[test]
655    fn service_provider_should_drop_existing_as_service() {
656        // arrange
657        let file = new_temp_file("drop2");
658
659        // act
660        {
661            let mut services = ServiceCollection::new();
662            services.add(existing_as_self(Droppable::new(file.clone())));
663            let _ = services.build_provider().unwrap();
664        }
665
666        // assert
667        let dropped = !file.exists();
668        remove_file(&file).ok();
669        assert!(dropped);
670    }
671
672    #[test]
673    fn service_provider_should_drop_lazy_initialized_service() {
674        // arrange
675        let file = new_temp_file("drop3");
676
677        // act
678        {
679            let provider = ServiceCollection::new()
680                .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
681                .add(singleton_as_self().from(|sp| {
682                    Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))
683                }))
684                .build_provider()
685                .unwrap();
686            let _ = provider.get_required::<Droppable>();
687        }
688
689        // assert
690        let dropped = !file.exists();
691        remove_file(&file).ok();
692        assert!(dropped);
693    }
694
695    #[test]
696    fn service_provider_should_not_drop_service_if_never_instantiated() {
697        // arrange
698        let file = new_temp_file("drop5");
699
700        // act
701        {
702            let _ = ServiceCollection::new()
703                .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
704                .add(singleton_as_self().from(|sp| {
705                    Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))
706                }))
707                .build_provider()
708                .unwrap();
709        }
710
711        // assert
712        let not_dropped = file.exists();
713        remove_file(&file).ok();
714        assert!(not_dropped);
715    }
716
717    #[test]
718    #[allow(clippy::vtable_address_comparisons)]
719    fn clone_should_be_shallow() {
720        // arrange
721        let provider1 = ServiceCollection::new()
722            .add(
723                transient::<dyn TestService, TestServiceImpl>()
724                    .from(|_| Ref::new(TestServiceImpl::default())),
725            )
726            .build_provider()
727            .unwrap();
728
729        // act
730        let provider2 = provider1.clone();
731
732        // assert
733        assert!(Ref::ptr_eq(&provider1.services, &provider2.services));
734        assert!(std::ptr::eq(
735            provider1.services.as_ref(),
736            provider2.services.as_ref()
737        ));
738    }
739
740    #[cfg(feature = "async")]
741    #[derive(Clone)]
742    struct Holder<T: Send + Sync + Clone>(T);
743
744    #[cfg(feature = "async")]
745    fn inject<V: Send + Sync + Clone>(value: V) -> Holder<V> {
746        Holder(value)
747    }
748
749    #[test]
750    #[cfg(feature = "async")]
751    fn service_provider_should_be_async_safe() {
752        // arrange
753        let provider = ServiceCollection::new()
754            .add(
755                singleton::<dyn TestService, TestAsyncServiceImpl>()
756                    .from(|_| Ref::new(TestAsyncServiceImpl::default())),
757            )
758            .build_provider()
759            .unwrap();
760        let holder = inject(provider);
761        let h1 = holder.clone();
762        let h2 = holder.clone();
763        let value = Arc::new(Mutex::new(0));
764        let v1 = value.clone();
765        let v2 = value.clone();
766
767        // act
768        let t1 = thread::spawn(move || {
769            let service = h1.0.get_required::<dyn TestService>();
770            let mut result = v1.lock().unwrap();
771            *result += service.value();
772        });
773
774        let t2 = thread::spawn(move || {
775            let service = h2.0.get_required::<dyn TestService>();
776            let mut result = v2.lock().unwrap();
777            *result += service.value();
778        });
779
780        t1.join().ok();
781        t2.join().ok();
782
783        // assert
784        assert_eq!(*value.lock().unwrap(), 3);
785    }
786}