Skip to main content

di/
provider.rs

1use crate::{KeyedRef, KeyedRefMut, Mut, Ref, RefMut, ServiceDescriptor, Type};
2use std::any::{type_name, Any};
3use std::borrow::Borrow;
4use std::collections::HashMap;
5use std::iter::empty;
6use std::marker::PhantomData;
7use std::ops::Deref;
8
9/// Represents a service provider.
10#[derive(Clone)]
11pub struct ServiceProvider {
12    services: Ref<HashMap<Type, Vec<ServiceDescriptor>>>,
13}
14
15impl ServiceProvider {
16    /// Initializes a new service provider.
17    ///
18    /// # Arguments
19    ///
20    /// * `services` - The map of [service descriptors](ServiceDescriptor) encapsulated by the provider
21    pub fn new(services: HashMap<Type, Vec<ServiceDescriptor>>) -> Self {
22        Self {
23            services: Ref::new(services),
24        }
25    }
26
27    /// Gets a service of the specified type.
28    pub fn get<T: Any + ?Sized>(&self) -> Option<Ref<T>> {
29        let key = Type::of::<T>();
30
31        if let Some(descriptors) = self.services.get(&key) {
32            if let Some(descriptor) = descriptors.last() {
33                return Some(descriptor.get(self).downcast_ref::<Ref<T>>().unwrap().clone());
34            }
35        }
36
37        None
38    }
39
40    /// Gets a mutable service of the specified type.
41    #[inline]
42    pub fn get_mut<T: Any + ?Sized>(&self) -> Option<RefMut<T>> {
43        self.get::<Mut<T>>()
44    }
45
46    /// Gets a keyed service of the specified type.
47    pub fn get_by_key<TKey, TSvc: Any + ?Sized>(&self) -> Option<KeyedRef<TKey, TSvc>> {
48        let key = Type::keyed::<TKey, TSvc>();
49
50        if let Some(descriptors) = self.services.get(&key) {
51            if let Some(descriptor) = descriptors.last() {
52                return Some(KeyedRef::new(
53                    descriptor.get(self).downcast_ref::<Ref<TSvc>>().unwrap().clone(),
54                ));
55            }
56        }
57
58        None
59    }
60
61    /// Gets a keyed, mutable service of the specified type.
62    #[inline]
63    pub fn get_by_key_mut<TKey, TSvc: Any + ?Sized>(&self) -> Option<KeyedRefMut<TKey, TSvc>> {
64        self.get_by_key::<TKey, Mut<TSvc>>()
65    }
66
67    /// Gets all of the services of the specified type.
68    pub fn get_all<T: Any + ?Sized>(&self) -> impl Iterator<Item = Ref<T>> + '_ {
69        let key = Type::of::<T>();
70
71        if let Some(descriptors) = self.services.get(&key) {
72            ServiceIterator::new(self, descriptors.iter())
73        } else {
74            ServiceIterator::new(self, empty())
75        }
76    }
77
78    /// Gets all of the mutable services of the specified type.
79    #[inline]
80    pub fn get_all_mut<T: Any + ?Sized>(&self) -> impl Iterator<Item = RefMut<T>> + '_ {
81        self.get_all::<Mut<T>>()
82    }
83
84    /// Gets all of the services of the specified key and type.
85    pub fn get_all_by_key<'a, TKey: 'a, TSvc>(&'a self) -> impl Iterator<Item = KeyedRef<TKey, TSvc>> + 'a
86    where
87        TSvc: Any + ?Sized,
88    {
89        let key = Type::keyed::<TKey, TSvc>();
90
91        if let Some(descriptors) = self.services.get(&key) {
92            KeyedServiceIterator::new(self, descriptors.iter())
93        } else {
94            KeyedServiceIterator::new(self, empty())
95        }
96    }
97
98    /// Gets all of the mutable services of the specified key and type.
99    #[inline]
100    pub fn get_all_by_key_mut<'a, TKey: 'a, TSvc>(&'a self) -> impl Iterator<Item = KeyedRefMut<TKey, TSvc>> + 'a
101    where
102        TSvc: Any + ?Sized,
103    {
104        self.get_all_by_key::<TKey, Mut<TSvc>>()
105    }
106
107    /// Gets a required service of the specified type.
108    ///
109    /// # Panics
110    ///
111    /// The requested service of type `T` does not exist.
112    pub fn get_required<T: Any + ?Sized>(&self) -> Ref<T> {
113        if let Some(service) = self.get::<T>() {
114            service
115        } else {
116            panic!("No service for type '{}' has been registered.", type_name::<T>());
117        }
118    }
119
120    /// Gets a required, mutable service of the specified type.
121    ///
122    /// # Panics
123    ///
124    /// The requested service of type `T` does not exist.
125    #[inline]
126    pub fn get_required_mut<T: Any + ?Sized>(&self) -> RefMut<T> {
127        self.get_required::<Mut<T>>()
128    }
129
130    /// Gets a required keyed service of the specified type.
131    ///
132    /// # Panics
133    ///
134    /// The requested service of type `TSvc` with key `TKey` does not exist.
135    pub fn get_required_by_key<TKey, TSvc: Any + ?Sized>(&self) -> KeyedRef<TKey, TSvc> {
136        if let Some(service) = self.get_by_key::<TKey, TSvc>() {
137            service
138        } else {
139            panic!(
140                "No service for type '{}' with the key '{}' has been registered.",
141                type_name::<TSvc>(),
142                type_name::<TKey>()
143            );
144        }
145    }
146
147    /// Gets a required keyed service of the specified type.
148    ///
149    /// # Panics
150    ///
151    /// The requested service of type `TSvc` with key `TKey` does not exist.
152    #[inline]
153    pub fn get_required_by_key_mut<TKey, TSvc: Any + ?Sized>(&self) -> KeyedRefMut<TKey, TSvc> {
154        self.get_required_by_key::<TKey, Mut<TSvc>>()
155    }
156
157    /// Creates and returns a new service provider that is used to resolve
158    /// services from a newly create scope.
159    #[inline]
160    pub fn create_scope(&self) -> Self {
161        Self::new(self.services.as_ref().clone())
162    }
163}
164
165/// Represents a scoped [ServiceProvider].
166///
167/// # Remarks
168///
169/// This struct has the exact same functionality as [ServiceProvider]. When a new instance is created, it also creates
170/// a new scope from the source [ServiceProvider]. The primary use case for this struct is to explicitly declare that
171/// a new scope should be created at the injection call site.
172#[derive(Clone, Default)]
173pub struct ScopedServiceProvider {
174    sp: ServiceProvider,
175}
176
177impl From<&ServiceProvider> for ScopedServiceProvider {
178    fn from(value: &ServiceProvider) -> Self {
179        Self {
180            sp: value.create_scope(),
181        }
182    }
183}
184
185impl AsRef<ServiceProvider> for ScopedServiceProvider {
186    #[inline]
187    fn as_ref(&self) -> &ServiceProvider {
188        &self.sp
189    }
190}
191
192impl Borrow<ServiceProvider> for ScopedServiceProvider {
193    #[inline]
194    fn borrow(&self) -> &ServiceProvider {
195        &self.sp
196    }
197}
198
199impl Deref for ScopedServiceProvider {
200    type Target = ServiceProvider;
201
202    #[inline]
203    fn deref(&self) -> &Self::Target {
204        &self.sp
205    }
206}
207
208struct ServiceIterator<'a, T: Any + ?Sized> {
209    provider: &'a ServiceProvider,
210    descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
211    _marker: PhantomData<T>,
212}
213
214struct KeyedServiceIterator<'a, TKey, TSvc: Any + ?Sized> {
215    provider: &'a ServiceProvider,
216    descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
217    _key: PhantomData<TKey>,
218    _svc: PhantomData<TSvc>,
219}
220
221impl<'a, T: Any + ?Sized> ServiceIterator<'a, T> {
222    fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
223    where
224        I: Iterator<Item = &'a ServiceDescriptor> + 'a,
225    {
226        Self {
227            provider,
228            descriptors: Box::new(descriptors),
229            _marker: PhantomData,
230        }
231    }
232}
233
234impl<'a, T: Any + ?Sized> Iterator for ServiceIterator<'a, T> {
235    type Item = Ref<T>;
236    fn next(&mut self) -> Option<Self::Item> {
237        if let Some(descriptor) = self.descriptors.next() {
238            Some(descriptor.get(self.provider).downcast_ref::<Ref<T>>().unwrap().clone())
239        } else {
240            None
241        }
242    }
243}
244
245impl<'a, TKey, TSvc: Any + ?Sized> KeyedServiceIterator<'a, TKey, TSvc> {
246    fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
247    where
248        I: Iterator<Item = &'a ServiceDescriptor> + 'a,
249    {
250        Self {
251            provider,
252            descriptors: Box::new(descriptors),
253            _key: PhantomData,
254            _svc: PhantomData,
255        }
256    }
257}
258
259impl<'a, TKey, TSvc: Any + ?Sized> Iterator for KeyedServiceIterator<'a, TKey, TSvc> {
260    type Item = KeyedRef<TKey, TSvc>;
261
262    fn next(&mut self) -> Option<Self::Item> {
263        if let Some(descriptor) = self.descriptors.next() {
264            Some(KeyedRef::new(
265                descriptor
266                    .get(self.provider)
267                    .downcast_ref::<Ref<TSvc>>()
268                    .unwrap()
269                    .clone(),
270            ))
271        } else {
272            None
273        }
274    }
275}
276
277impl Default for ServiceProvider {
278    fn default() -> Self {
279        Self {
280            services: Ref::new(HashMap::new()),
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use crate::{
288        existing, existing_as_self, scoped, singleton, singleton_as_self, singleton_with_key, test::*, transient, Ref,
289        ServiceCollection,
290    };
291    use std::fs::remove_file;
292    use std::path::{Path, PathBuf};
293
294    cfg_if::cfg_if! {
295        if #[cfg(feature = "async")] {
296            use std::sync::{Arc, Mutex};
297            use std::thread;
298        }
299    }
300
301    #[cfg(feature = "async")]
302    fn assert_send_and_sync<T: Send + Sync>(_: T) {}
303
304    #[test]
305    fn get_should_return_none_when_service_is_unregistered() {
306        // arrange
307        let services = ServiceCollection::new().build_provider().unwrap();
308
309        // act
310        let result = services.get::<dyn TestService>();
311
312        // assert
313        assert!(result.is_none());
314    }
315
316    #[test]
317    fn get_by_key_should_return_none_when_service_is_unregistered() {
318        // arrange
319        let services = ServiceCollection::new().build_provider().unwrap();
320
321        // act
322        let result = services.get_by_key::<key::Thingy, dyn TestService>();
323
324        // assert
325        assert!(result.is_none());
326    }
327
328    #[test]
329    fn get_should_return_registered_service() {
330        // arrange
331        let services = ServiceCollection::new()
332            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
333            .build_provider()
334            .unwrap();
335
336        // act
337        let result = services.get::<dyn TestService>();
338
339        // assert
340        assert!(result.is_some());
341    }
342
343    #[test]
344    fn get_by_key_should_return_registered_service() {
345        // arrange
346        let services = ServiceCollection::new()
347            .add(singleton_with_key::<key::Thingy, dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
348            .add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
349            .build_provider()
350            .unwrap();
351
352        // act
353        let result = services.get_by_key::<key::Thingy, dyn Thing>();
354
355        // assert
356        assert!(result.is_some());
357    }
358
359    #[test]
360    fn get_required_should_return_registered_service() {
361        // arrange
362        let services = ServiceCollection::new()
363            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
364            .build_provider()
365            .unwrap();
366
367        // act
368        let _ = services.get_required::<dyn TestService>();
369
370        // assert
371        // didn't panic
372    }
373
374    #[test]
375    fn get_required_by_key_should_return_registered_service() {
376        // arrange
377        let services = ServiceCollection::new()
378            .add(singleton_with_key::<key::Thingy, dyn Thing, Thing3>().from(|_| Ref::new(Thing3::default())))
379            .add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
380            .build_provider()
381            .unwrap();
382
383        // act
384        let thing = services.get_required_by_key::<key::Thingy, dyn Thing>();
385
386        // assert
387        assert_eq!(&thing.to_string(), "di::test::Thing3");
388    }
389
390    #[test]
391    #[should_panic(expected = "No service for type 'dyn di::test::TestService' has been registered.")]
392    fn get_required_should_panic_when_service_is_unregistered() {
393        // arrange
394        let services = ServiceCollection::new().build_provider().unwrap();
395
396        // act
397        let _ = services.get_required::<dyn TestService>();
398
399        // assert
400        // panics
401    }
402
403    #[test]
404    #[should_panic(
405        expected = "No service for type 'dyn di::test::Thing' with the key 'di::test::key::Thing1' has been registered."
406    )]
407    fn get_required_by_key_should_panic_when_service_is_unregistered() {
408        // arrange
409        let services = ServiceCollection::new().build_provider().unwrap();
410
411        // act
412        let _ = services.get_required_by_key::<key::Thing1, dyn Thing>();
413
414        // assert
415        // panics
416    }
417
418    #[test]
419    #[allow(clippy::vtable_address_comparisons)]
420    fn get_should_return_same_instance_for_singleton_service() {
421        // arrange
422        let services = ServiceCollection::new()
423            .add(existing::<dyn TestService, TestServiceImpl>(Box::new(
424                TestServiceImpl::default(),
425            )))
426            .add(
427                singleton::<dyn OtherTestService, OtherTestServiceImpl>()
428                    .from(|sp| Ref::new(OtherTestServiceImpl::new(sp.get_required::<dyn TestService>()))),
429            )
430            .build_provider()
431            .unwrap();
432
433        // act
434        let svc2 = services.get_required::<dyn OtherTestService>();
435        let svc1 = services.get_required::<dyn OtherTestService>();
436
437        // assert
438        assert!(Ref::ptr_eq(&svc1, &svc2));
439    }
440
441    #[test]
442    #[allow(clippy::vtable_address_comparisons)]
443    fn get_should_return_different_instances_for_transient_service() {
444        // arrange
445        let services = ServiceCollection::new()
446            .add(transient::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
447            .build_provider()
448            .unwrap();
449
450        // act
451        let svc1 = services.get_required::<dyn TestService>();
452        let svc2 = services.get_required::<dyn TestService>();
453
454        // assert
455        assert!(!Ref::ptr_eq(&svc1, &svc2));
456    }
457
458    #[test]
459    fn get_all_should_return_all_services() {
460        // arrange
461        let mut collection = ServiceCollection::new();
462
463        collection
464            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl { value: 1 })))
465            .add(singleton::<dyn TestService, TestService2Impl>().from(|_| Ref::new(TestService2Impl { value: 2 })));
466
467        let provider = collection.build_provider().unwrap();
468
469        // act
470        let services = provider.get_all::<dyn TestService>();
471        let values: Vec<_> = services.map(|s| s.value()).collect();
472
473        // assert
474        assert_eq!(&values, &[1, 2]);
475    }
476
477    #[test]
478    fn get_all_by_key_should_return_all_services() {
479        // arrange
480        let mut collection = ServiceCollection::new();
481
482        collection
483            .add(singleton_with_key::<key::Thingies, dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
484            .add(singleton_with_key::<key::Thingies, dyn Thing, Thing2>().from(|_| Ref::new(Thing2::default())))
485            .add(singleton_with_key::<key::Thingies, dyn Thing, Thing3>().from(|_| Ref::new(Thing3::default())));
486
487        let provider = collection.build_provider().unwrap();
488
489        // act
490        let services = provider.get_all_by_key::<key::Thingies, dyn Thing>();
491        let values: Vec<_> = services.map(|s| s.to_string()).collect();
492
493        // assert
494        assert_eq!(
495            &values,
496            &[
497                "di::test::Thing1".to_owned(),
498                "di::test::Thing2".to_owned(),
499                "di::test::Thing3".to_owned()
500            ]
501        );
502    }
503
504    #[test]
505    #[allow(clippy::vtable_address_comparisons)]
506    fn two_scoped_service_providers_should_create_different_instances() {
507        // arrange
508        let services = ServiceCollection::new()
509            .add(scoped::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
510            .build_provider()
511            .unwrap();
512        let scope1 = services.create_scope();
513        let scope2 = services.create_scope();
514
515        // act
516        let svc1 = scope1.get_required::<dyn TestService>();
517        let svc2 = scope2.get_required::<dyn TestService>();
518
519        // assert
520        assert!(!Ref::ptr_eq(&svc1, &svc2));
521    }
522
523    #[test]
524    #[allow(clippy::vtable_address_comparisons)]
525    fn parent_child_scoped_service_providers_should_create_different_instances() {
526        // arrange
527        let services = ServiceCollection::new()
528            .add(scoped::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
529            .build_provider()
530            .unwrap();
531        let scope1 = services.create_scope();
532        let scope2 = scope1.create_scope();
533
534        // act
535        let svc1 = scope1.get_required::<dyn TestService>();
536        let svc2 = scope2.get_required::<dyn TestService>();
537
538        // assert
539        assert!(!Ref::ptr_eq(&svc1, &svc2));
540    }
541
542    #[test]
543    #[allow(clippy::vtable_address_comparisons)]
544    fn scoped_service_provider_should_have_same_singleton_when_eager_created_in_parent() {
545        // arrange
546        let services = ServiceCollection::new()
547            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
548            .build_provider()
549            .unwrap();
550        let svc1 = services.get_required::<dyn TestService>();
551        let scope1 = services.create_scope();
552        let scope2 = scope1.create_scope();
553
554        // act
555        let svc2 = scope1.get_required::<dyn TestService>();
556        let svc3 = scope2.get_required::<dyn TestService>();
557
558        // assert
559        assert!(Ref::ptr_eq(&svc1, &svc2));
560        assert!(Ref::ptr_eq(&svc1, &svc3));
561    }
562
563    #[test]
564    #[allow(clippy::vtable_address_comparisons)]
565    fn scoped_service_provider_should_have_same_singleton_when_lazy_created_in_parent() {
566        // arrange
567        let services = ServiceCollection::new()
568            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
569            .build_provider()
570            .unwrap();
571        let scope1 = services.create_scope();
572        let scope2 = scope1.create_scope();
573        let svc1 = services.get_required::<dyn TestService>();
574
575        // act
576        let svc2 = scope1.get_required::<dyn TestService>();
577        let svc3 = scope2.get_required::<dyn TestService>();
578
579        // assert
580        assert!(Ref::ptr_eq(&svc1, &svc2));
581        assert!(Ref::ptr_eq(&svc1, &svc3));
582    }
583
584    #[test]
585    fn service_provider_should_drop_existing_as_service() {
586        // arrange
587        let file = new_temp_file("drop2");
588
589        // act
590        {
591            let mut services = ServiceCollection::new();
592            services.add(existing_as_self(Droppable::new(file.clone())));
593            let _ = services.build_provider().unwrap();
594        }
595
596        // assert
597        let dropped = !file.exists();
598        remove_file(&file).ok();
599        assert!(dropped);
600    }
601
602    #[test]
603    fn service_provider_should_drop_lazy_initialized_service() {
604        // arrange
605        let file = new_temp_file("drop3");
606
607        // act
608        {
609            let provider = ServiceCollection::new()
610                .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
611                .add(singleton_as_self().from(|sp| Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))))
612                .build_provider()
613                .unwrap();
614            let _ = provider.get_required::<Droppable>();
615        }
616
617        // assert
618        let dropped = !file.exists();
619        remove_file(&file).ok();
620        assert!(dropped);
621    }
622
623    #[test]
624    fn service_provider_should_not_drop_service_if_never_instantiated() {
625        // arrange
626        let file = new_temp_file("drop5");
627
628        // act
629        {
630            let _ = ServiceCollection::new()
631                .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
632                .add(singleton_as_self().from(|sp| Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))))
633                .build_provider()
634                .unwrap();
635        }
636
637        // assert
638        let not_dropped = file.exists();
639        remove_file(&file).ok();
640        assert!(not_dropped);
641    }
642
643    #[test]
644    #[allow(clippy::vtable_address_comparisons)]
645    fn clone_should_be_shallow() {
646        // arrange
647        let provider1 = ServiceCollection::new()
648            .add(transient::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
649            .build_provider()
650            .unwrap();
651
652        // act
653        let provider2 = provider1.clone();
654
655        // assert
656        assert!(Ref::ptr_eq(&provider1.services, &provider2.services));
657        assert!(std::ptr::eq(provider1.services.as_ref(), provider2.services.as_ref()));
658    }
659
660    #[test]
661    #[cfg(feature = "async")]
662    fn service_provider_should_be_send_and_sync() {
663        // arrange
664        let mut services = ServiceCollection::new();
665        let service =
666            singleton::<dyn TestService, TestAsyncServiceImpl>().from(|_| Ref::new(TestAsyncServiceImpl::default()));
667
668        services.add(service);
669
670        // act
671        let provider = services.build_provider().unwrap();
672
673        // assert
674        assert_send_and_sync(provider);
675    }
676
677    #[cfg(feature = "async")]
678    #[derive(Clone)]
679    struct Holder<T: Send + Sync + Clone>(T);
680
681    #[cfg(feature = "async")]
682    fn inject<V: Send + Sync + Clone>(value: V) -> Holder<V> {
683        Holder(value)
684    }
685
686    #[test]
687    #[cfg(feature = "async")]
688    fn service_provider_should_be_async_safe() {
689        // arrange
690        let provider = ServiceCollection::new()
691            .add(
692                singleton::<dyn TestService, TestAsyncServiceImpl>()
693                    .from(|_| Ref::new(TestAsyncServiceImpl::default())),
694            )
695            .build_provider()
696            .unwrap();
697        let holder = inject(provider);
698        let h1 = holder.clone();
699        let h2 = holder.clone();
700        let value = Arc::new(Mutex::new(0));
701        let v1 = value.clone();
702        let v2 = value.clone();
703
704        // act
705        let t1 = thread::spawn(move || {
706            let service = h1.0.get_required::<dyn TestService>();
707            let mut result = v1.lock().unwrap();
708            *result += service.value();
709        });
710
711        let t2 = thread::spawn(move || {
712            let service = h2.0.get_required::<dyn TestService>();
713            let mut result = v2.lock().unwrap();
714            *result += service.value();
715        });
716
717        t1.join().ok();
718        t2.join().ok();
719
720        // assert
721        assert_eq!(*value.lock().unwrap(), 3);
722    }
723}