1use crate::{KeyedRef, KeyedRefMut, Mut, Ref, RefMut, ServiceDescriptor, Type};
2use std::any::{type_name, Any};
3use std::borrow::Borrow;
4use std::collections::HashMap;
5use std::iter::empty;
6use std::marker::PhantomData;
7use std::ops::Deref;
8
9#[derive(Clone)]
11pub struct ServiceProvider {
12 services: Ref<HashMap<Type, Vec<ServiceDescriptor>>>,
13}
14
15impl ServiceProvider {
16 pub fn new(services: HashMap<Type, Vec<ServiceDescriptor>>) -> Self {
22 Self {
23 services: Ref::new(services),
24 }
25 }
26
27 pub fn get<T: Any + ?Sized>(&self) -> Option<Ref<T>> {
29 let key = Type::of::<T>();
30
31 if let Some(descriptors) = self.services.get(&key) {
32 if let Some(descriptor) = descriptors.last() {
33 return Some(descriptor.get(self).downcast_ref::<Ref<T>>().unwrap().clone());
34 }
35 }
36
37 None
38 }
39
40 #[inline]
42 pub fn get_mut<T: Any + ?Sized>(&self) -> Option<RefMut<T>> {
43 self.get::<Mut<T>>()
44 }
45
46 pub fn get_by_key<TKey, TSvc: Any + ?Sized>(&self) -> Option<KeyedRef<TKey, TSvc>> {
48 let key = Type::keyed::<TKey, TSvc>();
49
50 if let Some(descriptors) = self.services.get(&key) {
51 if let Some(descriptor) = descriptors.last() {
52 return Some(KeyedRef::new(
53 descriptor.get(self).downcast_ref::<Ref<TSvc>>().unwrap().clone(),
54 ));
55 }
56 }
57
58 None
59 }
60
61 #[inline]
63 pub fn get_by_key_mut<TKey, TSvc: Any + ?Sized>(&self) -> Option<KeyedRefMut<TKey, TSvc>> {
64 self.get_by_key::<TKey, Mut<TSvc>>()
65 }
66
67 pub fn get_all<T: Any + ?Sized>(&self) -> impl Iterator<Item = Ref<T>> + '_ {
69 let key = Type::of::<T>();
70
71 if let Some(descriptors) = self.services.get(&key) {
72 ServiceIterator::new(self, descriptors.iter())
73 } else {
74 ServiceIterator::new(self, empty())
75 }
76 }
77
78 #[inline]
80 pub fn get_all_mut<T: Any + ?Sized>(&self) -> impl Iterator<Item = RefMut<T>> + '_ {
81 self.get_all::<Mut<T>>()
82 }
83
84 pub fn get_all_by_key<'a, TKey: 'a, TSvc>(&'a self) -> impl Iterator<Item = KeyedRef<TKey, TSvc>> + 'a
86 where
87 TSvc: Any + ?Sized,
88 {
89 let key = Type::keyed::<TKey, TSvc>();
90
91 if let Some(descriptors) = self.services.get(&key) {
92 KeyedServiceIterator::new(self, descriptors.iter())
93 } else {
94 KeyedServiceIterator::new(self, empty())
95 }
96 }
97
98 #[inline]
100 pub fn get_all_by_key_mut<'a, TKey: 'a, TSvc>(&'a self) -> impl Iterator<Item = KeyedRefMut<TKey, TSvc>> + 'a
101 where
102 TSvc: Any + ?Sized,
103 {
104 self.get_all_by_key::<TKey, Mut<TSvc>>()
105 }
106
107 pub fn get_required<T: Any + ?Sized>(&self) -> Ref<T> {
113 if let Some(service) = self.get::<T>() {
114 service
115 } else {
116 panic!("No service for type '{}' has been registered.", type_name::<T>());
117 }
118 }
119
120 #[inline]
126 pub fn get_required_mut<T: Any + ?Sized>(&self) -> RefMut<T> {
127 self.get_required::<Mut<T>>()
128 }
129
130 pub fn get_required_by_key<TKey, TSvc: Any + ?Sized>(&self) -> KeyedRef<TKey, TSvc> {
136 if let Some(service) = self.get_by_key::<TKey, TSvc>() {
137 service
138 } else {
139 panic!(
140 "No service for type '{}' with the key '{}' has been registered.",
141 type_name::<TSvc>(),
142 type_name::<TKey>()
143 );
144 }
145 }
146
147 #[inline]
153 pub fn get_required_by_key_mut<TKey, TSvc: Any + ?Sized>(&self) -> KeyedRefMut<TKey, TSvc> {
154 self.get_required_by_key::<TKey, Mut<TSvc>>()
155 }
156
157 #[inline]
160 pub fn create_scope(&self) -> Self {
161 Self::new(self.services.as_ref().clone())
162 }
163}
164
165#[derive(Clone, Default)]
173pub struct ScopedServiceProvider {
174 sp: ServiceProvider,
175}
176
177impl From<&ServiceProvider> for ScopedServiceProvider {
178 fn from(value: &ServiceProvider) -> Self {
179 Self {
180 sp: value.create_scope(),
181 }
182 }
183}
184
185impl AsRef<ServiceProvider> for ScopedServiceProvider {
186 #[inline]
187 fn as_ref(&self) -> &ServiceProvider {
188 &self.sp
189 }
190}
191
192impl Borrow<ServiceProvider> for ScopedServiceProvider {
193 #[inline]
194 fn borrow(&self) -> &ServiceProvider {
195 &self.sp
196 }
197}
198
199impl Deref for ScopedServiceProvider {
200 type Target = ServiceProvider;
201
202 #[inline]
203 fn deref(&self) -> &Self::Target {
204 &self.sp
205 }
206}
207
208struct ServiceIterator<'a, T: Any + ?Sized> {
209 provider: &'a ServiceProvider,
210 descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
211 _marker: PhantomData<T>,
212}
213
214struct KeyedServiceIterator<'a, TKey, TSvc: Any + ?Sized> {
215 provider: &'a ServiceProvider,
216 descriptors: Box<dyn Iterator<Item = &'a ServiceDescriptor> + 'a>,
217 _key: PhantomData<TKey>,
218 _svc: PhantomData<TSvc>,
219}
220
221impl<'a, T: Any + ?Sized> ServiceIterator<'a, T> {
222 fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
223 where
224 I: Iterator<Item = &'a ServiceDescriptor> + 'a,
225 {
226 Self {
227 provider,
228 descriptors: Box::new(descriptors),
229 _marker: PhantomData,
230 }
231 }
232}
233
234impl<'a, T: Any + ?Sized> Iterator for ServiceIterator<'a, T> {
235 type Item = Ref<T>;
236 fn next(&mut self) -> Option<Self::Item> {
237 if let Some(descriptor) = self.descriptors.next() {
238 Some(descriptor.get(self.provider).downcast_ref::<Ref<T>>().unwrap().clone())
239 } else {
240 None
241 }
242 }
243}
244
245impl<'a, TKey, TSvc: Any + ?Sized> KeyedServiceIterator<'a, TKey, TSvc> {
246 fn new<I>(provider: &'a ServiceProvider, descriptors: I) -> Self
247 where
248 I: Iterator<Item = &'a ServiceDescriptor> + 'a,
249 {
250 Self {
251 provider,
252 descriptors: Box::new(descriptors),
253 _key: PhantomData,
254 _svc: PhantomData,
255 }
256 }
257}
258
259impl<'a, TKey, TSvc: Any + ?Sized> Iterator for KeyedServiceIterator<'a, TKey, TSvc> {
260 type Item = KeyedRef<TKey, TSvc>;
261
262 fn next(&mut self) -> Option<Self::Item> {
263 if let Some(descriptor) = self.descriptors.next() {
264 Some(KeyedRef::new(
265 descriptor
266 .get(self.provider)
267 .downcast_ref::<Ref<TSvc>>()
268 .unwrap()
269 .clone(),
270 ))
271 } else {
272 None
273 }
274 }
275}
276
277impl Default for ServiceProvider {
278 fn default() -> Self {
279 Self {
280 services: Ref::new(HashMap::new()),
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use crate::{
288 existing, existing_as_self, scoped, singleton, singleton_as_self, singleton_with_key, test::*, transient, Ref,
289 ServiceCollection,
290 };
291 use std::fs::remove_file;
292 use std::path::{Path, PathBuf};
293
294 cfg_if::cfg_if! {
295 if #[cfg(feature = "async")] {
296 use std::sync::{Arc, Mutex};
297 use std::thread;
298 }
299 }
300
301 #[cfg(feature = "async")]
302 fn assert_send_and_sync<T: Send + Sync>(_: T) {}
303
304 #[test]
305 fn get_should_return_none_when_service_is_unregistered() {
306 let services = ServiceCollection::new().build_provider().unwrap();
308
309 let result = services.get::<dyn TestService>();
311
312 assert!(result.is_none());
314 }
315
316 #[test]
317 fn get_by_key_should_return_none_when_service_is_unregistered() {
318 let services = ServiceCollection::new().build_provider().unwrap();
320
321 let result = services.get_by_key::<key::Thingy, dyn TestService>();
323
324 assert!(result.is_none());
326 }
327
328 #[test]
329 fn get_should_return_registered_service() {
330 let services = ServiceCollection::new()
332 .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
333 .build_provider()
334 .unwrap();
335
336 let result = services.get::<dyn TestService>();
338
339 assert!(result.is_some());
341 }
342
343 #[test]
344 fn get_by_key_should_return_registered_service() {
345 let services = ServiceCollection::new()
347 .add(singleton_with_key::<key::Thingy, dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
348 .add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
349 .build_provider()
350 .unwrap();
351
352 let result = services.get_by_key::<key::Thingy, dyn Thing>();
354
355 assert!(result.is_some());
357 }
358
359 #[test]
360 fn get_required_should_return_registered_service() {
361 let services = ServiceCollection::new()
363 .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
364 .build_provider()
365 .unwrap();
366
367 let _ = services.get_required::<dyn TestService>();
369
370 }
373
374 #[test]
375 fn get_required_by_key_should_return_registered_service() {
376 let services = ServiceCollection::new()
378 .add(singleton_with_key::<key::Thingy, dyn Thing, Thing3>().from(|_| Ref::new(Thing3::default())))
379 .add(singleton::<dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
380 .build_provider()
381 .unwrap();
382
383 let thing = services.get_required_by_key::<key::Thingy, dyn Thing>();
385
386 assert_eq!(&thing.to_string(), "di::test::Thing3");
388 }
389
390 #[test]
391 #[should_panic(expected = "No service for type 'dyn di::test::TestService' has been registered.")]
392 fn get_required_should_panic_when_service_is_unregistered() {
393 let services = ServiceCollection::new().build_provider().unwrap();
395
396 let _ = services.get_required::<dyn TestService>();
398
399 }
402
403 #[test]
404 #[should_panic(
405 expected = "No service for type 'dyn di::test::Thing' with the key 'di::test::key::Thing1' has been registered."
406 )]
407 fn get_required_by_key_should_panic_when_service_is_unregistered() {
408 let services = ServiceCollection::new().build_provider().unwrap();
410
411 let _ = services.get_required_by_key::<key::Thing1, dyn Thing>();
413
414 }
417
418 #[test]
419 #[allow(clippy::vtable_address_comparisons)]
420 fn get_should_return_same_instance_for_singleton_service() {
421 let services = ServiceCollection::new()
423 .add(existing::<dyn TestService, TestServiceImpl>(Box::new(
424 TestServiceImpl::default(),
425 )))
426 .add(
427 singleton::<dyn OtherTestService, OtherTestServiceImpl>()
428 .from(|sp| Ref::new(OtherTestServiceImpl::new(sp.get_required::<dyn TestService>()))),
429 )
430 .build_provider()
431 .unwrap();
432
433 let svc2 = services.get_required::<dyn OtherTestService>();
435 let svc1 = services.get_required::<dyn OtherTestService>();
436
437 assert!(Ref::ptr_eq(&svc1, &svc2));
439 }
440
441 #[test]
442 #[allow(clippy::vtable_address_comparisons)]
443 fn get_should_return_different_instances_for_transient_service() {
444 let services = ServiceCollection::new()
446 .add(transient::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
447 .build_provider()
448 .unwrap();
449
450 let svc1 = services.get_required::<dyn TestService>();
452 let svc2 = services.get_required::<dyn TestService>();
453
454 assert!(!Ref::ptr_eq(&svc1, &svc2));
456 }
457
458 #[test]
459 fn get_all_should_return_all_services() {
460 let mut collection = ServiceCollection::new();
462
463 collection
464 .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl { value: 1 })))
465 .add(singleton::<dyn TestService, TestService2Impl>().from(|_| Ref::new(TestService2Impl { value: 2 })));
466
467 let provider = collection.build_provider().unwrap();
468
469 let services = provider.get_all::<dyn TestService>();
471 let values: Vec<_> = services.map(|s| s.value()).collect();
472
473 assert_eq!(&values, &[1, 2]);
475 }
476
477 #[test]
478 fn get_all_by_key_should_return_all_services() {
479 let mut collection = ServiceCollection::new();
481
482 collection
483 .add(singleton_with_key::<key::Thingies, dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())))
484 .add(singleton_with_key::<key::Thingies, dyn Thing, Thing2>().from(|_| Ref::new(Thing2::default())))
485 .add(singleton_with_key::<key::Thingies, dyn Thing, Thing3>().from(|_| Ref::new(Thing3::default())));
486
487 let provider = collection.build_provider().unwrap();
488
489 let services = provider.get_all_by_key::<key::Thingies, dyn Thing>();
491 let values: Vec<_> = services.map(|s| s.to_string()).collect();
492
493 assert_eq!(
495 &values,
496 &[
497 "di::test::Thing1".to_owned(),
498 "di::test::Thing2".to_owned(),
499 "di::test::Thing3".to_owned()
500 ]
501 );
502 }
503
504 #[test]
505 #[allow(clippy::vtable_address_comparisons)]
506 fn two_scoped_service_providers_should_create_different_instances() {
507 let services = ServiceCollection::new()
509 .add(scoped::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
510 .build_provider()
511 .unwrap();
512 let scope1 = services.create_scope();
513 let scope2 = services.create_scope();
514
515 let svc1 = scope1.get_required::<dyn TestService>();
517 let svc2 = scope2.get_required::<dyn TestService>();
518
519 assert!(!Ref::ptr_eq(&svc1, &svc2));
521 }
522
523 #[test]
524 #[allow(clippy::vtable_address_comparisons)]
525 fn parent_child_scoped_service_providers_should_create_different_instances() {
526 let services = ServiceCollection::new()
528 .add(scoped::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
529 .build_provider()
530 .unwrap();
531 let scope1 = services.create_scope();
532 let scope2 = scope1.create_scope();
533
534 let svc1 = scope1.get_required::<dyn TestService>();
536 let svc2 = scope2.get_required::<dyn TestService>();
537
538 assert!(!Ref::ptr_eq(&svc1, &svc2));
540 }
541
542 #[test]
543 #[allow(clippy::vtable_address_comparisons)]
544 fn scoped_service_provider_should_have_same_singleton_when_eager_created_in_parent() {
545 let services = ServiceCollection::new()
547 .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
548 .build_provider()
549 .unwrap();
550 let svc1 = services.get_required::<dyn TestService>();
551 let scope1 = services.create_scope();
552 let scope2 = scope1.create_scope();
553
554 let svc2 = scope1.get_required::<dyn TestService>();
556 let svc3 = scope2.get_required::<dyn TestService>();
557
558 assert!(Ref::ptr_eq(&svc1, &svc2));
560 assert!(Ref::ptr_eq(&svc1, &svc3));
561 }
562
563 #[test]
564 #[allow(clippy::vtable_address_comparisons)]
565 fn scoped_service_provider_should_have_same_singleton_when_lazy_created_in_parent() {
566 let services = ServiceCollection::new()
568 .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
569 .build_provider()
570 .unwrap();
571 let scope1 = services.create_scope();
572 let scope2 = scope1.create_scope();
573 let svc1 = services.get_required::<dyn TestService>();
574
575 let svc2 = scope1.get_required::<dyn TestService>();
577 let svc3 = scope2.get_required::<dyn TestService>();
578
579 assert!(Ref::ptr_eq(&svc1, &svc2));
581 assert!(Ref::ptr_eq(&svc1, &svc3));
582 }
583
584 #[test]
585 fn service_provider_should_drop_existing_as_service() {
586 let file = new_temp_file("drop2");
588
589 {
591 let mut services = ServiceCollection::new();
592 services.add(existing_as_self(Droppable::new(file.clone())));
593 let _ = services.build_provider().unwrap();
594 }
595
596 let dropped = !file.exists();
598 remove_file(&file).ok();
599 assert!(dropped);
600 }
601
602 #[test]
603 fn service_provider_should_drop_lazy_initialized_service() {
604 let file = new_temp_file("drop3");
606
607 {
609 let provider = ServiceCollection::new()
610 .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
611 .add(singleton_as_self().from(|sp| Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))))
612 .build_provider()
613 .unwrap();
614 let _ = provider.get_required::<Droppable>();
615 }
616
617 let dropped = !file.exists();
619 remove_file(&file).ok();
620 assert!(dropped);
621 }
622
623 #[test]
624 fn service_provider_should_not_drop_service_if_never_instantiated() {
625 let file = new_temp_file("drop5");
627
628 {
630 let _ = ServiceCollection::new()
631 .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
632 .add(singleton_as_self().from(|sp| Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))))
633 .build_provider()
634 .unwrap();
635 }
636
637 let not_dropped = file.exists();
639 remove_file(&file).ok();
640 assert!(not_dropped);
641 }
642
643 #[test]
644 #[allow(clippy::vtable_address_comparisons)]
645 fn clone_should_be_shallow() {
646 let provider1 = ServiceCollection::new()
648 .add(transient::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
649 .build_provider()
650 .unwrap();
651
652 let provider2 = provider1.clone();
654
655 assert!(Ref::ptr_eq(&provider1.services, &provider2.services));
657 assert!(std::ptr::eq(provider1.services.as_ref(), provider2.services.as_ref()));
658 }
659
660 #[test]
661 #[cfg(feature = "async")]
662 fn service_provider_should_be_send_and_sync() {
663 let mut services = ServiceCollection::new();
665 let service =
666 singleton::<dyn TestService, TestAsyncServiceImpl>().from(|_| Ref::new(TestAsyncServiceImpl::default()));
667
668 services.add(service);
669
670 let provider = services.build_provider().unwrap();
672
673 assert_send_and_sync(provider);
675 }
676
677 #[cfg(feature = "async")]
678 #[derive(Clone)]
679 struct Holder<T: Send + Sync + Clone>(T);
680
681 #[cfg(feature = "async")]
682 fn inject<V: Send + Sync + Clone>(value: V) -> Holder<V> {
683 Holder(value)
684 }
685
686 #[test]
687 #[cfg(feature = "async")]
688 fn service_provider_should_be_async_safe() {
689 let provider = ServiceCollection::new()
691 .add(
692 singleton::<dyn TestService, TestAsyncServiceImpl>()
693 .from(|_| Ref::new(TestAsyncServiceImpl::default())),
694 )
695 .build_provider()
696 .unwrap();
697 let holder = inject(provider);
698 let h1 = holder.clone();
699 let h2 = holder.clone();
700 let value = Arc::new(Mutex::new(0));
701 let v1 = value.clone();
702 let v2 = value.clone();
703
704 let t1 = thread::spawn(move || {
706 let service = h1.0.get_required::<dyn TestService>();
707 let mut result = v1.lock().unwrap();
708 *result += service.value();
709 });
710
711 let t2 = thread::spawn(move || {
712 let service = h2.0.get_required::<dyn TestService>();
713 let mut result = v2.lock().unwrap();
714 *result += service.value();
715 });
716
717 t1.join().ok();
718 t2.join().ok();
719
720 assert_eq!(*value.lock().unwrap(), 3);
722 }
723}