Skip to main content

di/
collection.rs

1use crate::{fmt, validate, Ref, ServiceDescriptor, ServiceDescriptorBuilder, ServiceProvider, Type, ValidationError};
2use std::any::Any;
3use std::collections::HashMap;
4use std::fmt::{Formatter, Result as FormatResult};
5use std::iter::{DoubleEndedIterator, ExactSizeIterator};
6use std::ops::Index;
7use std::slice::{Iter, IterMut};
8use std::vec::IntoIter;
9
10macro_rules! decorate {
11    (($($traits:tt)+), ($($bounds:tt)+)) => {
12        /// Decorates an existing service descriptor with a new one that wraps the original.
13        ///
14        /// # Arguments
15        ///
16        /// * `activate` - The function that will be called to decorate the resolved service instance
17        ///
18        /// # Remarks
19        ///
20        /// This function will only decorate the last registered [ServiceDescriptor] for the specified service type. If
21        /// there are multiple, the others are ignored. If you need decorate all services of a particular service type,
22        /// consider using [Self::decorate_all] instead. If the service to be decorated is not registered, this function
23        /// does nothing. The decorator [ServiceDescriptor] is created with the same lifetime as the original. The
24        /// implementation type of the decorator is determined by the generic parameter `TImpl`. If the original and
25        /// decorator implementation types are the same, the original, decorated [ServiceDescriptor] is not replaced to
26        /// prevent infinite recursion.
27        ///
28        /// # Example
29        ///
30        /// ```
31        /// use di::{injectable, Injectable, ServiceCollection, Ref};
32        ///
33        /// trait Counter {
34        ///     fn count(&self) -> usize;
35        /// }
36        ///
37        /// #[injectable(Counter)]
38        /// struct SingleCount;
39        ///
40        /// impl Counter for SingleCount {
41        ///     fn count(&self) -> usize {
42        ///         1
43        ///     }
44        /// }
45        ///
46        /// struct DoubleCount(Ref<dyn Counter>);
47        ///
48        /// impl Counter for DoubleCount {
49        ///     fn count(&self) -> usize {
50        ///         self.0.count() * 2
51        ///     }
52        /// }
53        ///
54        /// let provider = ServiceCollection::new()
55        ///     .add(SingleCount::transient())
56        ///     .decorate::<dyn Counter, DoubleCount>(|_, decorated| Ref::new(DoubleCount(decorated)))
57        ///     .build_provider()
58        ///     .unwrap();
59        /// let counter = provider.get_required::<dyn Counter>();
60        ///
61        /// assert_eq!(counter.count(), 2);
62        /// ```
63        pub fn decorate<TSvc: ?Sized + $($traits)+, TImpl>(
64            &mut self,
65            activate: impl Fn(&ServiceProvider, Ref<TSvc>) -> Ref<TSvc> + $($bounds)+,
66        ) -> &mut Self {
67            let service_type = Type::of::<TSvc>();
68
69            for item in self.items.iter_mut().rev() {
70                if item.service_type() != service_type {
71                    continue;
72                }
73
74                let impl_type = Type::of::<TImpl>();
75
76                if item.implementation_type() == impl_type {
77                    return self;
78                }
79
80                let original = item.clone();
81                let builder = ServiceDescriptorBuilder::<TSvc, TImpl>::new(original.lifetime(), impl_type);
82
83                *item = builder.from(move |sp| {
84                    let decorated = original.get(sp).downcast_ref::<Ref<TSvc>>().unwrap().clone();
85                    activate(sp, decorated)
86                });
87
88                break;
89            }
90
91            self
92        }
93
94        /// Decorates all existing service descriptors with a new one that wraps the original.
95        ///
96        /// # Arguments
97        ///
98        /// * `activate` - The function that will be called to decorate the resolved service instance
99        ///
100        /// # Remarks
101        ///
102        /// This function decorates all registered [ServiceDescriptor] for the specified service type. If there are none,
103        /// this function does nothing. The decorator [ServiceDescriptor] is created with the same lifetime as the original.
104        /// If the original, decorated [ServiceDescriptor] is the same the decorator type, it is ignored.
105        ///
106        /// # Example
107        ///
108        /// ```
109        /// use di::{injectable, Injectable, ServiceCollection, Ref};
110        /// use std::sync::atomic::{AtomicUsize, Ordering};
111        ///
112        /// trait Feature {
113        ///     fn show(&self);
114        /// }
115        ///
116        /// #[injectable(Feature)]
117        /// struct Feature1;
118        ///
119        /// impl Feature for Feature1 {
120        ///     fn show(&self) {
121        ///     }
122        /// }
123        ///
124        /// #[injectable(Feature)]
125        /// struct Feature2;
126        ///
127        /// impl Feature for Feature2 {
128        ///     fn show(&self) {
129        ///     }
130        /// }
131        ///
132        /// #[injectable]
133        /// struct Tracker(AtomicUsize);
134        ///
135        /// impl Tracker {
136        ///     fn track(&self) {
137        ///         self.0.fetch_add(1, Ordering::Relaxed);
138        ///     }
139        ///
140        ///     fn count(&self) -> usize {
141        ///         self.0.load(Ordering::Relaxed)
142        ///     }
143        /// }
144        ///
145        /// struct FeatureTracker {
146        ///     feature: Ref<dyn Feature>,
147        ///     tracker: Ref<Tracker>,
148        /// };
149        ///
150        /// impl Feature for FeatureTracker {
151        ///     fn show(&self) {
152        ///         self.tracker.track();
153        ///         self.feature.show();
154        ///     }
155        /// }
156        ///
157        /// let provider = ServiceCollection::new()
158        ///     .add(Tracker::singleton())
159        ///     .try_add_to_all(Feature1::transient())
160        ///     .try_add_to_all(Feature2::transient())
161        ///     .decorate_all::<dyn Feature, FeatureTracker>(|sp, decorated| {
162        ///         Ref::new(FeatureTracker { feature: decorated, tracker: sp.get_required::<Tracker>() })
163        ///     })
164        ///     .build_provider()
165        ///     .unwrap();
166        /// let features = provider.get_all::<dyn Feature>();
167        /// let tracker = provider.get_required::<Tracker>();
168        ///
169        /// for feature in features {
170        ///     feature.show();
171        /// }
172        ///
173        /// assert_eq!(tracker.count(), 2);
174        /// ```
175        pub fn decorate_all<TSvc: ?Sized + $($traits)+, TImpl>(
176            &mut self,
177            activate: impl Fn(&ServiceProvider, Ref<TSvc>) -> Ref<TSvc> + $($bounds)+,
178        ) -> &mut Self {
179            let service_type = Type::of::<TSvc>();
180            let func = Ref::new(activate);
181
182            for item in self.items.iter_mut() {
183                let impl_type = Type::of::<TImpl>();
184
185                if item.service_type() != service_type || item.implementation_type() == impl_type {
186                    continue;
187                }
188
189                let original = item.clone();
190                let activate = func.clone();
191                let builder = ServiceDescriptorBuilder::<TSvc, TImpl>::new(original.lifetime(), impl_type);
192
193                *item = builder.from(move |sp| {
194                    let decorated = original.get(sp).downcast_ref::<Ref<TSvc>>().unwrap().clone();
195                    (activate)(sp, decorated)
196                });
197            }
198
199            self
200        }
201    };
202}
203
204/// Represents a service collection.
205#[derive(Default)]
206pub struct ServiceCollection {
207    items: Vec<ServiceDescriptor>,
208}
209
210impl ServiceCollection {
211    /// Creates and returns a new instance of the service collection.
212    #[inline]
213    pub fn new() -> Self {
214        Self::default()
215    }
216
217    /// Returns true if the collection contains no elements.
218    #[inline]
219    pub fn is_empty(&self) -> bool {
220        self.items.is_empty()
221    }
222
223    /// Returns the number of elements in the collection.
224    #[inline]
225    pub fn len(&self) -> usize {
226        self.items.len()
227    }
228
229    /// Removes all elements from the collection.
230    #[inline]
231    pub fn clear(&mut self) {
232        self.items.clear()
233    }
234
235    /// Removes and returns the element at position index within the collection.
236    ///
237    /// # Argument
238    ///
239    /// * `index` - The index of the element to remove
240    ///
241    /// # Panics
242    ///
243    /// Panics if `index` is out of bounds.
244    #[inline]
245    pub fn remove(&mut self, index: usize) -> ServiceDescriptor {
246        self.items.remove(index)
247    }
248
249    /// Adds a service using the specified service descriptor.
250    ///
251    /// # Arguments
252    ///
253    /// * `descriptor` - The [ServiceDescriptor] to register
254    pub fn add<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
255        self.items.push(descriptor.into());
256        self
257    }
258
259    /// Adds a service using the specified service descriptor if the service has not already been registered.
260    ///
261    /// # Arguments
262    ///
263    /// * `descriptor` - The [ServiceDescriptor] to register
264    pub fn try_add<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
265        let new_item = descriptor.into();
266        let service_type = new_item.service_type();
267
268        for item in &self.items {
269            if item.service_type() == service_type {
270                return self;
271            }
272        }
273
274        self.items.push(new_item);
275        self
276    }
277
278    /// Adds a service using the specified service descriptor if the service with same service and
279    /// implementation type has not already been registered.
280    ///
281    /// # Arguments
282    ///
283    /// * `descriptor` - The [ServiceDescriptor] to register
284    pub fn try_add_to_all<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
285        let new_item = descriptor.into();
286        let service_type = new_item.service_type();
287        let implementation_type = new_item.implementation_type();
288
289        if service_type == implementation_type {
290            return self;
291        }
292
293        for item in &self.items {
294            if item.service_type() == service_type && item.implementation_type() == implementation_type {
295                return self;
296            }
297        }
298
299        self.items.push(new_item);
300        self
301    }
302
303    /// Adds the specified service descriptors if each of the services are not already registered
304    /// with the same service and implementation type.
305    ///
306    /// # Arguments
307    ///
308    /// * `descriptors` - The [ServiceDescriptor] sequence to register
309    pub fn try_add_all(&mut self, descriptors: impl IntoIterator<Item = ServiceDescriptor>) -> &mut Self {
310        for descriptor in descriptors {
311            self.try_add_to_all(descriptor);
312        }
313        self
314    }
315
316    /// Removes the first service descriptor with the same service type and adds the replacement.
317    ///
318    /// # Arguments
319    ///
320    /// * `descriptor` - The replacement [ServiceDescriptor]
321    pub fn replace<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
322        let new_item = descriptor.into();
323        let service_type = new_item.service_type();
324
325        for i in 0..self.items.len() {
326            if self.items[i].service_type() == service_type {
327                self.items.remove(i);
328                break;
329            }
330        }
331
332        self.items.push(new_item);
333        self
334    }
335
336    /// Adds or replaces a service with the specified descriptor if the service has not already been registered.
337    ///
338    /// # Arguments
339    ///
340    /// * `descriptor` - The replacement [ServiceDescriptor]
341    #[inline]
342    pub fn try_replace<T: Into<ServiceDescriptor>>(&mut self, descriptor: T) -> &mut Self {
343        self.try_add(descriptor)
344    }
345
346    /// Removes all specified descriptors of the specified type.
347    pub fn remove_all<T: Any + ?Sized>(&mut self) -> &mut Self {
348        let service_type = Type::of::<T>();
349
350        for i in (0..self.items.len()).rev() {
351            if self.items[i].service_type() == service_type {
352                self.items.remove(i);
353            }
354        }
355
356        self
357    }
358
359    /// Builds and returns a new [ServiceProvider].
360    pub fn build_provider(&self) -> Result<ServiceProvider, ValidationError> {
361        validate(self)?;
362
363        let mut services = HashMap::with_capacity(self.items.len());
364
365        for item in &self.items {
366            let key = item.service_type().clone();
367            let descriptors = services.entry(key).or_insert_with(Vec::new);
368
369            // dependencies are only interesting for validation. after a ServiceProvider is created, no further
370            // validation occurs. prevent copying unnecessary memory and allow it to potentially be freed if the
371            // ServiceCollection is dropped.
372            descriptors.push(item.clone_with(false));
373        }
374
375        for values in services.values_mut() {
376            values.shrink_to_fit();
377        }
378
379        services.shrink_to_fit();
380        Ok(ServiceProvider::new(services))
381    }
382
383    /// Gets a read-only iterator for the collection
384    #[inline]
385    pub fn iter(&self) -> impl ExactSizeIterator<Item = &ServiceDescriptor> + DoubleEndedIterator {
386        self.items.iter()
387    }
388
389    cfg_if::cfg_if! {
390        if #[cfg(feature = "async")] {
391            decorate!((Any + Send + Sync), (Send + Sync + 'static));
392        } else {
393            decorate!((Any), ('static));
394        }
395    }
396}
397
398impl<'a> IntoIterator for &'a ServiceCollection {
399    type Item = &'a ServiceDescriptor;
400    type IntoIter = Iter<'a, ServiceDescriptor>;
401
402    fn into_iter(self) -> Self::IntoIter {
403        self.items.iter()
404    }
405}
406
407impl<'a> IntoIterator for &'a mut ServiceCollection {
408    type Item = &'a mut ServiceDescriptor;
409    type IntoIter = IterMut<'a, ServiceDescriptor>;
410
411    fn into_iter(self) -> Self::IntoIter {
412        self.items.iter_mut()
413    }
414}
415
416impl IntoIterator for ServiceCollection {
417    type Item = ServiceDescriptor;
418    type IntoIter = IntoIter<Self::Item>;
419
420    fn into_iter(self) -> Self::IntoIter {
421        self.items.into_iter()
422    }
423}
424
425impl Index<usize> for ServiceCollection {
426    type Output = ServiceDescriptor;
427
428    fn index(&self, index: usize) -> &Self::Output {
429        &self.items[index]
430    }
431}
432
433impl std::fmt::Debug for ServiceCollection {
434    fn fmt(&self, f: &mut Formatter<'_>) -> FormatResult {
435        fmt::write(self, fmt::text::Renderer, f)
436    }
437}
438
439impl std::fmt::Display for ServiceCollection {
440    fn fmt(&self, f: &mut Formatter<'_>) -> FormatResult {
441        cfg_if::cfg_if! {
442            if #[cfg(feature = "fmt")] {
443                if f.alternate() {
444                    return fmt::write(self, fmt::terminal::Renderer, f);
445                }
446            }
447        }
448
449        fmt::write(self, fmt::text::Renderer, f)
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use crate::{existing, existing_as_self, singleton, singleton_as_self, test::*, transient};
457    use std::fs::remove_file;
458    use std::path::{Path, PathBuf};
459
460    #[test]
461    fn is_empty_should_return_true_when_empty() {
462        // arrange
463        let collection = ServiceCollection::default();
464
465        // act
466        let empty = collection.is_empty();
467
468        // assert
469        assert!(empty);
470    }
471
472    #[test]
473    fn length_should_return_zero_when_empty() {
474        // arrange
475        let collection = ServiceCollection::default();
476
477        // act
478        let length = collection.len();
479
480        // assert
481        assert_eq!(length, 0);
482    }
483
484    #[test]
485    fn is_empty_should_return_false_when_not_empty() {
486        // arrange
487        let descriptor = existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default()));
488        let mut collection = ServiceCollection::new();
489
490        collection.add(descriptor);
491
492        // act
493        let not_empty = !collection.is_empty();
494
495        // assert
496        assert!(not_empty);
497    }
498
499    #[test]
500    fn length_should_return_count_when_not_empty() {
501        // arrange
502        let descriptor = existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default()));
503        let mut collection = ServiceCollection::new();
504
505        collection.add(descriptor);
506
507        // act
508        let length = collection.len();
509
510        // assert
511        assert_eq!(length, 1);
512    }
513
514    #[test]
515    fn clear_should_remove_all_elements() {
516        // arrange
517        let descriptor = existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default()));
518        let mut collection = ServiceCollection::new();
519
520        collection.add(descriptor);
521
522        // act
523        collection.clear();
524
525        // assert
526        assert!(collection.is_empty());
527    }
528
529    #[test]
530    fn try_add_should_do_nothing_when_service_is_registered() {
531        // arrange
532        let mut collection = ServiceCollection::new();
533
534        collection.add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
535
536        // act
537        collection
538            .try_add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
539
540        // assert
541        assert_eq!(collection.len(), 1);
542    }
543
544    #[test]
545    fn try_add_to_all_should_add_descriptor_when_implementation_is_unregistered() {
546        // arrange
547        let mut collection = ServiceCollection::new();
548
549        collection.add(existing::<dyn TestService, TestServiceImpl>(Box::new(
550            TestServiceImpl::default(),
551        )));
552
553        collection.try_add_to_all(
554            singleton::<dyn OtherTestService, OtherTestServiceImpl>()
555                .from(|sp| Ref::new(OtherTestServiceImpl::new(sp.get_required::<dyn TestService>()))),
556        );
557
558        // act
559        let count = collection.len();
560
561        // assert
562        assert_eq!(count, 2);
563    }
564
565    #[test]
566    fn try_add_to_all_should_not_add_descriptor_when_implementation_is_registered() {
567        // arrange
568        let mut collection = ServiceCollection::new();
569
570        collection.add(existing::<dyn TestService, TestServiceImpl>(Box::new(
571            TestServiceImpl::default(),
572        )));
573
574        collection.try_add_to_all(
575            transient::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())),
576        );
577
578        // act
579        let count = collection.len();
580
581        // assert
582        assert_eq!(count, 1);
583    }
584
585    #[test]
586    fn try_add_all_should_only_add_descriptors_for_unregistered_implementations() {
587        // arrange
588        let descriptors = vec![
589            existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default())),
590            transient::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())),
591        ];
592        let mut collection = ServiceCollection::new();
593
594        collection.try_add_all(descriptors.into_iter());
595
596        // act
597        let count = collection.len();
598
599        // assert
600        assert_eq!(count, 1);
601    }
602
603    #[test]
604    fn replace_should_replace_first_registered_service() {
605        // arrange
606        let mut collection = ServiceCollection::new();
607
608        collection
609            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
610            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
611
612        // act
613        collection
614            .replace(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
615
616        // assert
617        assert_eq!(collection.len(), 2);
618    }
619
620    #[test]
621    fn remove_all_should_remove_registered_services() {
622        // arrange
623        let mut collection = ServiceCollection::new();
624
625        collection
626            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
627            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())));
628
629        // act
630        collection.remove_all::<dyn TestService>();
631
632        // assert
633        assert!(collection.is_empty());
634    }
635
636    #[test]
637    fn try_replace_should_do_nothing_when_service_is_registered() {
638        // arrange
639        let mut collection = ServiceCollection::new();
640
641        collection
642            .add(singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl { value: 1 })));
643
644        // act
645        collection.try_replace(
646            singleton::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl { value: 2 })),
647        );
648
649        // assert
650        let value = collection
651            .build_provider()
652            .unwrap()
653            .get_required::<dyn TestService>()
654            .value();
655        assert_eq!(value, 1);
656    }
657
658    #[test]
659    fn remove_should_remove_element_at_index() {
660        // arrange
661        let descriptor = existing::<dyn TestService, TestServiceImpl>(Box::new(TestServiceImpl::default()));
662        let mut collection = ServiceCollection::new();
663
664        collection.add(descriptor);
665
666        // act
667        let _ = collection.remove(0);
668
669        // assert
670        assert!(collection.is_empty());
671    }
672
673    #[test]
674    fn service_collection_should_drop_existing_as_service() {
675        // arrange
676        let file = new_temp_file("drop1");
677
678        // act
679        {
680            let mut services = ServiceCollection::new();
681            services.add(existing_as_self(Droppable::new(file.clone())));
682        }
683
684        // assert
685        let dropped = !file.exists();
686        remove_file(&file).ok();
687        assert!(dropped);
688    }
689
690    #[test]
691    fn service_collection_should_not_drop_service_if_never_instantiated() {
692        // arrange
693        let file = new_temp_file("drop4");
694        let mut services = ServiceCollection::new();
695
696        // act
697        {
698            services
699                .add(existing::<Path, PathBuf>(file.clone().into_boxed_path()))
700                .add(singleton_as_self().from(|sp| Ref::new(Droppable::new(sp.get_required::<Path>().to_path_buf()))));
701        }
702
703        // assert
704        let not_dropped = file.exists();
705        remove_file(&file).ok();
706        assert!(not_dropped);
707    }
708}