Skip to main content

di/
collection.rs

1use crate::{fmt, validate, Ref, ServiceDescriptor, ServiceDescriptorBuilder, ServiceProvider, Type, ValidationError};
2use std::any::Any;
3use std::collections::HashMap;
4use std::fmt::{Formatter, Result as FormatResult};
5use std::iter::{DoubleEndedIterator, ExactSizeIterator};
6use std::ops::Index;
7use std::slice::{Iter, IterMut};
8use std::vec::IntoIter;
9
10macro_rules! decorate {
11    (($($traits:tt)+), ($($bounds:tt)+)) => {
12        /// Decorates an existing service descriptor with a new one that wraps the original.
13        ///
14        /// # Arguments
15        ///
16        /// * `activate` - The function that will be called to decorate the resolved service instance
17        ///
18        /// # Remarks
19        ///
20        /// This function will only decorate the last registered [ServiceDescriptor] for the specified service type. If
21        /// there are multiple, the others are ignored. If you need to decorate all services of a particular service
22        /// type, consider using [decorate_all](Self::decorate_all) instead. If the service to be decorated is not
23        /// registered, this function does nothing. The decorator [ServiceDescriptor] is created with the same
24        /// [lifetime](crate::ServiceLifetime) as the original service registration. The implementation type of the
25        /// decorator is determined by the generic parameter `TImpl`. If the original and decorator implementation types
26        /// are the same, the original, decorated [ServiceDescriptor] is not replaced to prevent infinite recursion.
27        ///
28        /// # Example
29        ///
30        /// ```
31        /// use di::{injectable, Injectable, ServiceCollection, Ref};
32        ///
33        /// trait Counter {
34        ///     fn count(&self) -> usize;
35        /// }
36        ///
37        /// #[injectable(Counter)]
38        /// struct SingleCount;
39        ///
40        /// impl Counter for SingleCount {
41        ///     fn count(&self) -> usize {
42        ///         1
43        ///     }
44        /// }
45        ///
46        /// struct DoubleCount(Ref<dyn Counter>);
47        ///
48        /// impl Counter for DoubleCount {
49        ///     fn count(&self) -> usize {
50        ///         self.0.count() * 2
51        ///     }
52        /// }
53        ///
54        /// let provider = ServiceCollection::new()
55        ///     .add(SingleCount::transient())
56        ///     .decorate::<dyn Counter, DoubleCount>(|_, decorated| Ref::new(DoubleCount(decorated)))
57        ///     .build_provider()
58        ///     .unwrap();
59        /// let counter = provider.get_required::<dyn Counter>();
60        ///
61        /// assert_eq!(counter.count(), 2);
62        /// ```
63        pub fn decorate<TSvc: ?Sized + $($traits)+, TImpl>(
64            &mut self,
65            activate: impl Fn(&ServiceProvider, Ref<TSvc>) -> Ref<TSvc> + $($bounds)+,
66        ) -> &mut Self {
67            let service_type = Type::of::<TSvc>();
68
69            for item in self.items.iter_mut().rev() {
70                if item.service_type() != service_type {
71                    continue;
72                }
73
74                let impl_type = Type::of::<TImpl>();
75
76                if item.implementation_type() == impl_type {
77                    return self;
78                }
79
80                let original = item.clone();
81                let builder = ServiceDescriptorBuilder::<TSvc, TImpl>::new(original.lifetime(), impl_type);
82
83                *item = builder.from(move |sp| {
84                    let decorated = original.get(sp).downcast_ref::<Ref<TSvc>>().unwrap().clone();
85                    activate(sp, decorated)
86                });
87
88                break;
89            }
90
91            self
92        }
93
94        /// Decorates all existing service descriptors with a new one that wraps the original.
95        ///
96        /// # Arguments
97        ///
98        /// * `activate` - The function that will be called to decorate the resolved service instance
99        ///
100        /// # Remarks
101        ///
102        /// This function decorates all registered [ServiceDescriptor] instances for the specified service type. If
103        /// there are none, this function does nothing. The decorator [ServiceDescriptor] is created with the same
104        /// [lifetime](crate::ServiceLifetime) as the original. If the original, decorated [ServiceDescriptor] is the
105        /// same the decorator type, it is ignored.
106        ///
107        /// # Example
108        ///
109        /// ```
110        /// use di::{injectable, Injectable, ServiceCollection, Ref};
111        /// use std::sync::atomic::{AtomicUsize, Ordering};
112        ///
113        /// trait Feature {
114        ///     fn show(&self);
115        /// }
116        ///
117        /// #[injectable(Feature)]
118        /// struct Feature1;
119        ///
120        /// impl Feature for Feature1 {
121        ///     fn show(&self) {
122        ///     }
123        /// }
124        ///
125        /// #[injectable(Feature)]
126        /// struct Feature2;
127        ///
128        /// impl Feature for Feature2 {
129        ///     fn show(&self) {
130        ///     }
131        /// }
132        ///
133        /// #[injectable]
134        /// struct Tracker(AtomicUsize);
135        ///
136        /// impl Tracker {
137        ///     fn track(&self) {
138        ///         self.0.fetch_add(1, Ordering::Relaxed);
139        ///     }
140        ///
141        ///     fn count(&self) -> usize {
142        ///         self.0.load(Ordering::Relaxed)
143        ///     }
144        /// }
145        ///
146        /// struct FeatureTracker {
147        ///     feature: Ref<dyn Feature>,
148        ///     tracker: Ref<Tracker>,
149        /// };
150        ///
151        /// impl Feature for FeatureTracker {
152        ///     fn show(&self) {
153        ///         self.tracker.track();
154        ///         self.feature.show();
155        ///     }
156        /// }
157        ///
158        /// let provider = ServiceCollection::new()
159        ///     .add(Tracker::singleton())
160        ///     .try_add_to_all(Feature1::transient())
161        ///     .try_add_to_all(Feature2::transient())
162        ///     .decorate_all::<dyn Feature, FeatureTracker>(|sp, decorated| {
163        ///         Ref::new(FeatureTracker { feature: decorated, tracker: sp.get_required::<Tracker>() })
164        ///     })
165        ///     .build_provider()
166        ///     .unwrap();
167        /// let features = provider.get_all::<dyn Feature>();
168        /// let tracker = provider.get_required::<Tracker>();
169        ///
170        /// for feature in features {
171        ///     feature.show();
172        /// }
173        ///
174        /// assert_eq!(tracker.count(), 2);
175        /// ```
176        pub fn decorate_all<TSvc: ?Sized + $($traits)+, TImpl>(
177            &mut self,
178            activate: impl Fn(&ServiceProvider, Ref<TSvc>) -> Ref<TSvc> + $($bounds)+,
179        ) -> &mut Self {
180            let service_type = Type::of::<TSvc>();
181            let func = Ref::new(activate);
182
183            for item in self.items.iter_mut() {
184                let impl_type = Type::of::<TImpl>();
185
186                if item.service_type() != service_type || item.implementation_type() == impl_type {
187                    continue;
188                }
189
190                let original = item.clone();
191                let activate = func.clone();
192                let builder = ServiceDescriptorBuilder::<TSvc, TImpl>::new(original.lifetime(), impl_type);
193
194                *item = builder.from(move |sp| {
195                    let decorated = original.get(sp).downcast_ref::<Ref<TSvc>>().unwrap().clone();
196                    (activate)(sp, decorated)
197                });
198            }
199
200            self
201        }
202    };
203}
204
205/// Represents a service collection.
206#[derive(Default)]
207pub struct ServiceCollection {
208    items: Vec<ServiceDescriptor>,
209}
210
211impl ServiceCollection {
212    /// Creates and returns a new instance of the service collection.
213    #[inline]
214    pub fn new() -> Self {
215        Self::default()
216    }
217
218    /// Returns true if the collection contains no elements.
219    #[inline]
220    pub fn is_empty(&self) -> bool {
221        self.items.is_empty()
222    }
223
224    /// Returns the number of elements in the collection.
225    #[inline]
226    pub fn len(&self) -> usize {
227        self.items.len()
228    }
229
230    /// Removes all elements from the collection.
231    #[inline]
232    pub fn clear(&mut self) {
233        self.items.clear()
234    }
235
236    /// Removes and returns the element at position index within the collection.
237    ///
238    /// # Argument
239    ///
240    /// * `index` - The index of the element to remove
241    ///
242    /// # Panics
243    ///
244    /// Panics if `index` is out of bounds.
245    #[inline]
246    pub fn remove(&mut self, index: usize) -> ServiceDescriptor {
247        self.items.remove(index)
248    }
249
250    /// Adds a service using the specified service descriptor.
251    ///
252    /// # Arguments
253    ///
254    /// * `descriptor` - The [ServiceDescriptor] to register
255    pub fn add<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
256        self.items.push(descriptor.into());
257        self
258    }
259
260    /// Adds a service using the specified service descriptor if the service has not already been registered.
261    ///
262    /// # Arguments
263    ///
264    /// * `descriptor` - The [ServiceDescriptor] to register
265    pub fn try_add<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
266        let new_item = descriptor.into();
267        let service_type = new_item.service_type();
268
269        for item in &self.items {
270            if item.service_type() == service_type {
271                return self;
272            }
273        }
274
275        self.items.push(new_item);
276        self
277    }
278
279    /// Adds a service using the specified service descriptor if the service with same service and
280    /// implementation type has not already been registered.
281    ///
282    /// # Arguments
283    ///
284    /// * `descriptor` - The [ServiceDescriptor] to register
285    pub fn try_add_to_all<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
286        let new_item = descriptor.into();
287        let service_type = new_item.service_type();
288        let implementation_type = new_item.implementation_type();
289
290        if service_type == implementation_type {
291            return self;
292        }
293
294        for item in &self.items {
295            if item.service_type() == service_type && item.implementation_type() == implementation_type {
296                return self;
297            }
298        }
299
300        self.items.push(new_item);
301        self
302    }
303
304    /// Adds the specified service descriptors if each of the services are not already registered
305    /// with the same service and implementation type.
306    ///
307    /// # Arguments
308    ///
309    /// * `descriptors` - The [ServiceDescriptor] sequence to register
310    pub fn try_add_all(&mut self, descriptors: impl IntoIterator<Item = ServiceDescriptor>) -> &mut Self {
311        for descriptor in descriptors {
312            self.try_add_to_all(descriptor);
313        }
314        self
315    }
316
317    /// Removes the first service descriptor with the same service type and adds the replacement.
318    ///
319    /// # Arguments
320    ///
321    /// * `descriptor` - The replacement [ServiceDescriptor]
322    pub fn replace<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
323        let new_item = descriptor.into();
324        let service_type = new_item.service_type();
325
326        for i in 0..self.items.len() {
327            if self.items[i].service_type() == service_type {
328                self.items.remove(i);
329                break;
330            }
331        }
332
333        self.items.push(new_item);
334        self
335    }
336
337    /// Adds or replaces a service with the specified descriptor if the service has not already been registered.
338    ///
339    /// # Arguments
340    ///
341    /// * `descriptor` - The replacement [ServiceDescriptor]
342    #[inline]
343    pub fn try_replace<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
344        self.try_add(descriptor)
345    }
346
347    /// Removes all specified descriptors of the specified type.
348    pub fn remove_all<T: Any + ?Sized>(&mut self) -> &mut Self {
349        let service_type = Type::of::<T>();
350
351        for i in (0..self.items.len()).rev() {
352            if self.items[i].service_type() == service_type {
353                self.items.remove(i);
354            }
355        }
356
357        self
358    }
359
360    /// Builds and returns a new [ServiceProvider].
361    pub fn build_provider(&self) -> Result<ServiceProvider, ValidationError> {
362        validate(self)?;
363
364        let mut services = HashMap::with_capacity(self.items.len());
365
366        for item in &self.items {
367            let key = item.service_type().clone();
368            let descriptors = services.entry(key).or_insert_with(Vec::new);
369
370            // dependencies are only interesting for validation. after a ServiceProvider is created, no further
371            // validation occurs. prevent copying unnecessary memory and allow it to potentially be freed if the
372            // ServiceCollection is dropped.
373            descriptors.push(item.clone_with(false));
374        }
375
376        for values in services.values_mut() {
377            values.shrink_to_fit();
378        }
379
380        services.shrink_to_fit();
381        Ok(ServiceProvider::new(services))
382    }
383
384    /// Gets a read-only iterator for the collection
385    #[inline]
386    pub fn iter(&self) -> impl ExactSizeIterator<Item = &ServiceDescriptor> + DoubleEndedIterator {
387        self.items.iter()
388    }
389
390    cfg_if::cfg_if! {
391        if #[cfg(feature = "async")] {
392            decorate!((Any + Send + Sync), (Send + Sync + 'static));
393        } else {
394            decorate!((Any), ('static));
395        }
396    }
397}
398
399impl<'a> IntoIterator for &'a ServiceCollection {
400    type Item = &'a ServiceDescriptor;
401    type IntoIter = Iter<'a, ServiceDescriptor>;
402
403    fn into_iter(self) -> Self::IntoIter {
404        self.items.iter()
405    }
406}
407
408impl<'a> IntoIterator for &'a mut ServiceCollection {
409    type Item = &'a mut ServiceDescriptor;
410    type IntoIter = IterMut<'a, ServiceDescriptor>;
411
412    fn into_iter(self) -> Self::IntoIter {
413        self.items.iter_mut()
414    }
415}
416
417impl IntoIterator for ServiceCollection {
418    type Item = ServiceDescriptor;
419    type IntoIter = IntoIter<Self::Item>;
420
421    fn into_iter(self) -> Self::IntoIter {
422        self.items.into_iter()
423    }
424}
425
426impl Index<usize> for ServiceCollection {
427    type Output = ServiceDescriptor;
428
429    fn index(&self, index: usize) -> &Self::Output {
430        &self.items[index]
431    }
432}
433
434impl std::fmt::Debug for ServiceCollection {
435    fn fmt(&self, f: &mut Formatter<'_>) -> FormatResult {
436        fmt::write(self, fmt::text::Renderer, f)
437    }
438}
439
440impl std::fmt::Display for ServiceCollection {
441    fn fmt(&self, f: &mut Formatter<'_>) -> FormatResult {
442        cfg_if::cfg_if! {
443            if #[cfg(feature = "fmt")] {
444                if f.alternate() {
445                    return fmt::write(self, fmt::terminal::Renderer, f);
446                }
447            }
448        }
449
450        fmt::write(self, fmt::text::Renderer, f)
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::{existing, existing_as_self, singleton, singleton_as_self, test::*, transient};
458    use std::fs::remove_file;
459    use std::path::{Path, PathBuf};
460
461    #[test]
462    fn is_empty_should_return_true_when_empty() {
463        // arrange
464        let collection = ServiceCollection::default();
465
466        // act
467        let empty = collection.is_empty();
468
469        // assert
470        assert!(empty);
471    }
472
473    #[test]
474    fn length_should_return_zero_when_empty() {
475        // arrange
476        let collection = ServiceCollection::default();
477
478        // act
479        let length = collection.len();
480
481        // assert
482        assert_eq!(length, 0);
483    }
484
485    #[test]
486    fn is_empty_should_return_false_when_not_empty() {
487        // arrange
488        let descriptor = existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default()));
489        let mut collection = ServiceCollection::new();
490
491        collection.add(descriptor);
492
493        // act
494        let not_empty = !collection.is_empty();
495
496        // assert
497        assert!(not_empty);
498    }
499
500    #[test]
501    fn length_should_return_count_when_not_empty() {
502        // arrange
503        let descriptor = existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default()));
504        let mut collection = ServiceCollection::new();
505
506        collection.add(descriptor);
507
508        // act
509        let length = collection.len();
510
511        // assert
512        assert_eq!(length, 1);
513    }
514
515    #[test]
516    fn clear_should_remove_all_elements() {
517        // arrange
518        let descriptor = existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default()));
519        let mut collection = ServiceCollection::new();
520
521        collection.add(descriptor);
522
523        // act
524        collection.clear();
525
526        // assert
527        assert!(collection.is_empty());
528    }
529
530    #[test]
531    fn try_add_should_do_nothing_when_service_is_registered() {
532        // arrange
533        let mut collection = ServiceCollection::new();
534
535        collection.add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
536
537        // act
538        collection
539            .try_add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
540
541        // assert
542        assert_eq!(collection.len(), 1);
543    }
544
545    #[test]
546    fn try_add_to_all_should_add_descriptor_when_implementation_is_unregistered() {
547        // arrange
548        let mut collection = ServiceCollection::new();
549
550        collection.add(existing::<dyn TestService, TestServiceImpl>(Box::new(
551            TestServiceImpl::default(),
552        )));
553
554        collection.try_add_to_all(
555            singleton::<dyn OtherTestService, OtherTestServiceImpl>()
556                .from(|sp| Ref::new(OtherTestServiceImpl::new(sp.get_required::<dyn TestService>()))),
557        );
558
559        // act
560        let count = collection.len();
561
562        // assert
563        assert_eq!(count, 2);
564    }
565
566    #[test]
567    fn try_add_to_all_should_not_add_descriptor_when_implementation_is_registered() {
568        // arrange
569        let mut collection = ServiceCollection::new();
570
571        collection.add(existing::<dyn TestService, TestServiceImpl>(Box::new(
572            TestServiceImpl::default(),
573        )));
574
575        collection.try_add_to_all(
576            transient::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())),
577        );
578
579        // act
580        let count = collection.len();
581
582        // assert
583        assert_eq!(count, 1);
584    }
585
586    #[test]
587    fn try_add_all_should_only_add_descriptors_for_unregistered_implementations() {
588        // arrange
589        let descriptors = vec![
590            existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default())),
591            transient::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())),
592        ];
593        let mut collection = ServiceCollection::new();
594
595        collection.try_add_all(descriptors.into_iter());
596
597        // act
598        let count = collection.len();
599
600        // assert
601        assert_eq!(count, 1);
602    }
603
604    #[test]
605    fn replace_should_replace_first_registered_service() {
606        // arrange
607        let mut collection = ServiceCollection::new();
608
609        collection
610            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
611            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
612
613        // act
614        collection
615            .replace(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
616
617        // assert
618        assert_eq!(collection.len(), 2);
619    }
620
621    #[test]
622    fn remove_all_should_remove_registered_services() {
623        // arrange
624        let mut collection = ServiceCollection::new();
625
626        collection
627            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
628            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
629
630        // act
631        collection.remove_all::<dyn TestService>();
632
633        // assert
634        assert!(collection.is_empty());
635    }
636
637    #[test]
638    fn try_replace_should_do_nothing_when_service_is_registered() {
639        // arrange
640        let mut collection = ServiceCollection::new();
641
642        collection
643            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl { value: 1 })));
644
645        // act
646        collection.try_replace(
647            singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl { value: 2 })),
648        );
649
650        // assert
651        let value = collection
652            .build_provider()
653            .unwrap()
654            .get_required::<dyn TestService>()
655            .value();
656        assert_eq!(value, 1);
657    }
658
659    #[test]
660    fn remove_should_remove_element_at_index() {
661        // arrange
662        let descriptor = existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default()));
663        let mut collection = ServiceCollection::new();
664
665        collection.add(descriptor);
666
667        // act
668        let _ = collection.remove(0);
669
670        // assert
671        assert!(collection.is_empty());
672    }
673
674    #[test]
675    fn service_collection_should_drop_existing_as_service() {
676        // arrange
677        let file = new_temp_file("drop1");
678
679        // act
680        {
681            let mut services = ServiceCollection::new();
682            services.add(existing_as_self(Droppable::new(file.clone())));
683        }
684
685        // assert
686        let dropped = !file.exists();
687        remove_file(&file).ok();
688        assert!(dropped);
689    }
690
691    #[test]
692    fn service_collection_should_not_drop_service_if_never_instantiated() {
693        // arrange
694        let file = new_temp_file("drop4");
695        let mut services = ServiceCollection::new();
696
697        // act
698        {
699            services
700                .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
701                .add(singleton_as_self().from(|sp| Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))));
702        }
703
704        // assert
705        let not_dropped = file.exists();
706        remove_file(&file).ok();
707        assert!(not_dropped);
708    }
709}