Skip to main content

options/prelude/
di.rs

1use super::Builder;
2use crate::{validation::Validate, *};
3use cfg_if::cfg_if;
4use di::{
5    exactly_one, scoped, singleton, singleton_as_self, transient, transient_factory, zero_or_more, ServiceCollection,
6    ServiceDescriptor, ServiceProvider,
7};
8
9macro_rules! opts_ext {
10    (($($bounds:tt)+)) => {
11        /// Defines extension methods for the [ServiceCollection](::di::ServiceCollection) struct.
12        pub trait OptionsExt {
13            /// Registers an options type that will have all of its associated services registered.
14            fn add_options<T: Value + Default + 'static>(&mut self) -> Builder<'_, T>;
15
16            /// Registers an options type that will have all of its associated services registered.
17            ///
18            /// # Arguments
19            ///
20            /// * `name` - The name associated with the options
21            ///
22            /// # Remarks
23            ///
24            /// Names are matched using case-insensitive ASCII characters.
25            fn add_named_options<T: Value + Default + 'static>(
26                &mut self,
27                name: impl AsRef<str>,
28            ) -> Builder<'_, T>;
29
30            /// Registers an options type that will have all of its associated services registered.
31            ///
32            /// # Arguments
33            ///
34            /// * `factory` - The function used to create the associated options factory
35            fn add_options_with<T, F>(&mut self, factory: F) -> Builder<'_, T>
36            where
37                T: Value,
38                F: Fn(&ServiceProvider) -> Ref<dyn Factory<T>> + $($bounds)+;
39
40            /// Registers an options type that will have all of its associated services registered.
41            ///
42            /// # Arguments
43            ///
44            /// * `name` - The name associated with the options
45            /// * `factory` - The function used to create the associated options factory
46            ///
47            /// # Remarks
48            ///
49            /// Names are matched using case-insensitive ASCII characters.
50            fn add_named_options_with<T, F>(
51                &mut self,
52                name: impl AsRef<str>,
53                factory: F,
54            ) -> Builder<'_, T>
55            where
56                T: Value,
57                F: Fn(&ServiceProvider) -> Ref<dyn Factory<T>> + $($bounds)+;
58
59            /// Registers an action used to initialize a particular type of configuration options.
60            ///
61            /// # Arguments
62            ///
63            /// * `setup` - The setup action used to configure options.
64            fn configure_options<T, F>(&mut self, setup: F) -> &mut Self
65            where
66                T: Value + Default + 'static,
67                F: Fn(&mut T) + $($bounds)+;
68
69            /// Registers an action used to initialize a particular type of configuration options.
70            ///
71            /// # Arguments
72            ///
73            /// * `name` - The name associated with the options
74            /// * `setup` - The setup action used to configure options
75            ///
76            /// # Remarks
77            ///
78            /// Names are matched using case-insensitive ASCII characters.
79            fn configure_named_options<T, F>(
80                &mut self,
81                name: impl AsRef<str>,
82                setup: F,
83            ) -> &mut Self
84            where
85                T: Value + Default + 'static,
86                F: Fn(&mut T) + $($bounds)+;
87
88            /// Registers an action used to initialize a particular type of configuration options.
89            ///
90            /// # Arguments
91            ///
92            /// * `setup` - The setup action used to configure options
93            fn post_configure_options<T, F>(&mut self, setup: F) -> &mut Self
94            where
95                T: Value + Default + 'static,
96                F: Fn(&mut T) + $($bounds)+;
97
98            /// Registers an action used to initialize a particular type of configuration options.
99            ///
100            /// # Arguments
101            ///
102            /// * `name` - The name associated with the options
103            /// * `setup` - The setup action used to configure options
104            ///
105            /// # Remarks
106            ///
107            /// Names are matched using case-insensitive ASCII characters.
108            fn post_configure_named_options<T, F>(
109                &mut self,
110                name: impl AsRef<str>,
111                setup: F,
112            ) -> &mut Self
113            where
114                T: Value + Default + 'static,
115                F: Fn(&mut T) + $($bounds)+;
116        }
117    };
118}
119
120fn _add_options<'a, T: Value>(
121    services: &'a mut ServiceCollection,
122    name: &str,
123    descriptor: ServiceDescriptor,
124) -> Builder<'a, T> {
125    services
126        .try_add(
127            singleton_as_self::<Manager<T>>()
128                .depends_on(exactly_one::<dyn Factory<T>>())
129                .from(|sp| Ref::new(Manager::new(sp.get_required::<dyn Factory<T>>()))),
130        )
131        .try_add(
132            singleton::<T, Manager<T>>()
133                .depends_on(exactly_one::<Manager<T>>())
134                .from(|sp| sp.get_required::<Manager<T>>().get_unchecked()),
135        )
136        .try_add(
137            scoped::<dyn Snapshot<T>, Manager<T>>()
138                .depends_on(exactly_one::<Manager<T>>())
139                .from(|sp| sp.get_required::<Manager<T>>()),
140        )
141        .try_add(
142            singleton::<dyn Monitor<T>, DefaultMonitor<T>>()
143                .depends_on(exactly_one::<Cache<T>>())
144                .depends_on(zero_or_more::<dyn ChangeTokenSource<T>>())
145                .depends_on(exactly_one::<dyn Factory<T>>())
146                .from(|sp| {
147                    Ref::new(DefaultMonitor::new(
148                        sp.get_required::<Cache<T>>(),
149                        sp.get_all::<dyn ChangeTokenSource<T>>().collect(),
150                        sp.get_required::<dyn Factory<T>>(),
151                    ))
152                }),
153        )
154        .try_add(descriptor)
155        .try_add(singleton_as_self::<Cache<T>>().from(|_| Ref::new(Cache::new())));
156
157    Builder::new(services, name)
158}
159
160macro_rules! opts_ext_impl {
161    (($($bounds:tt)+)) => {
162        impl OptionsExt for ServiceCollection {
163            fn add_options<T: Value + Default + 'static>(&mut self) -> Builder<'_, T> {
164                let descriptor = transient::<dyn Factory<T>, DefaultFactory<T>>()
165                    .depends_on(zero_or_more::<dyn Configure<T>>())
166                    .depends_on(zero_or_more::<dyn PostConfigure<T>>())
167                    .depends_on(zero_or_more::<dyn Validate<T>>())
168                    .from(|sp| {
169                        Ref::new(DefaultFactory::new(
170                            sp.get_all::<dyn Configure<T>>().collect(),
171                            sp.get_all::<dyn PostConfigure<T>>().collect(),
172                            sp.get_all::<dyn Validate<T>>().collect(),
173                        ))
174                    });
175
176                _add_options(self, "", descriptor)
177            }
178
179            fn add_named_options<T: Value + Default + 'static>(
180                &mut self,
181                name: impl AsRef<str>,
182            ) -> Builder<'_, T> {
183                let descriptor = transient::<dyn Factory<T>, DefaultFactory<T>>()
184                    .depends_on(zero_or_more::<dyn Configure<T>>())
185                    .depends_on(zero_or_more::<dyn PostConfigure<T>>())
186                    .depends_on(zero_or_more::<dyn Validate<T>>())
187                    .from(|sp| {
188                        Ref::new(DefaultFactory::new(
189                            sp.get_all::<dyn Configure<T>>().collect(),
190                            sp.get_all::<dyn PostConfigure<T>>().collect(),
191                            sp.get_all::<dyn Validate<T>>().collect(),
192                        ))
193                    });
194
195                _add_options(self, name.as_ref(), descriptor)
196            }
197
198            #[inline]
199            fn add_options_with<T, F>(&mut self, factory: F) -> Builder<'_, T>
200            where
201                T: Value,
202                F: Fn(&ServiceProvider) -> Ref<dyn Factory<T>> + $($bounds)+,
203            {
204                _add_options(self, "", transient_factory(factory))
205            }
206
207            #[inline]
208            fn add_named_options_with<T, F>(
209                &mut self,
210                name: impl AsRef<str>,
211                factory: F,
212            ) -> Builder<'_, T>
213            where
214                T: Value,
215                F: Fn(&ServiceProvider) -> Ref<dyn Factory<T>> + $($bounds)+,
216            {
217                _add_options(self, name.as_ref(), transient_factory(factory))
218            }
219
220            #[inline]
221            fn configure_options<T, F>(&mut self, setup: F) -> &mut Self
222            where
223                T: Value + Default + 'static,
224                F: Fn(&mut T) + $($bounds)+,
225            {
226                self.add_options().configure(setup).into()
227            }
228
229            #[inline]
230            fn configure_named_options<T, F>(
231                &mut self,
232                name: impl AsRef<str>,
233                setup: F,
234            ) -> &mut Self
235            where
236                T: Value + Default + 'static,
237                F: Fn(&mut T) + $($bounds)+,
238            {
239                self.add_named_options(name).configure(setup).into()
240            }
241
242            #[inline]
243            fn post_configure_options<T, F>(&mut self, setup: F) -> &mut Self
244            where
245                T: Value + Default + 'static,
246                F: Fn(&mut T) + $($bounds)+,
247            {
248                self.add_options().post_configure(setup).into()
249            }
250
251            #[inline]
252            fn post_configure_named_options<T, F>(
253                &mut self,
254                name: impl AsRef<str>,
255                setup: F,
256            ) -> &mut Self
257            where
258                T: Value + Default + 'static,
259                F: Fn(&mut T) + $($bounds)+,
260            {
261                self.add_named_options(name).post_configure(setup).into()
262            }
263        }
264    };
265}
266
267cfg_if! {
268    if #[cfg(feature = "async")] {
269        opts_ext!((Send + Sync + 'static));
270        opts_ext_impl!((Send + Sync + 'static));
271    } else {
272        opts_ext!(('static));
273        opts_ext_impl!(('static));
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use di::{existing_as_self, transient};
281    use std::sync::RwLock;
282
283    #[derive(Default, Debug, PartialEq, Eq)]
284    struct TestOptions {
285        enabled: bool,
286        setting: usize,
287    }
288
289    #[derive(Default)]
290    struct TestValidation;
291
292    impl Validate<TestOptions> for TestValidation {
293        fn run(&self, _name: &str, options: &TestOptions) -> validation::Result {
294            if !options.enabled && options.setting > 0 {
295                validation::fail("Setting must be zero when disabled")
296            } else {
297                validation::success()
298            }
299        }
300    }
301
302    struct TestService {
303        value: RwLock<usize>,
304    }
305
306    impl TestService {
307        fn next(&self) -> usize {
308            let mut value = self.value.write().unwrap();
309            let current = *value;
310            *value += 1;
311            current
312        }
313
314        fn calls(&self) -> usize {
315            *self.value.read().unwrap() - 1
316        }
317    }
318
319    impl Default for TestService {
320        fn default() -> Self {
321            Self { value: RwLock::new(1) }
322        }
323    }
324
325    #[test]
326    fn get_should_resolve_service() {
327        // arrange
328        let provider = ServiceCollection::new()
329            .add_options::<TestOptions>()
330            .build_provider()
331            .unwrap();
332
333        // act
334        let result = provider.get::<TestOptions>();
335
336        // assert
337        assert!(result.is_some());
338    }
339
340    #[test]
341    fn get_required_should_configure_options() {
342        // arrange
343        let provider = ServiceCollection::new()
344            .configure_options(|o: &mut TestOptions| o.setting = 1)
345            .build_provider()
346            .unwrap();
347
348        // act
349        let options = provider.get_required::<TestOptions>();
350
351        // assert
352        assert_eq!(options.setting, 1);
353    }
354
355    #[test]
356    fn get_required_should_post_configure_options() {
357        // arrange
358        let provider = ServiceCollection::new()
359            .post_configure_options(|o: &mut TestOptions| o.setting = 1)
360            .build_provider()
361            .unwrap();
362
363        // act
364        let options = provider.get_required::<TestOptions>();
365
366        // assert
367        assert_eq!(options.setting, 1);
368    }
369
370    #[test]
371    fn get_required_should_apply_all_configurations() {
372        // arrange
373        let provider = ServiceCollection::new()
374            .configure_options(|o: &mut TestOptions| o.setting = 1)
375            .configure_options(|o: &mut TestOptions| o.enabled = true)
376            .post_configure_options(|o: &mut TestOptions| o.setting = 2)
377            .build_provider()
378            .unwrap();
379
380        // act
381        let result = provider.get_required::<TestOptions>();
382        let options = result;
383
384        // assert
385        assert!(options.enabled);
386        assert_eq!(options.setting, 2);
387    }
388
389    #[test]
390    fn get_required_should_not_panic_when_configured_options_are_valid() {
391        // arrange
392        let provider = ServiceCollection::new()
393            .configure_options(|o: &mut TestOptions| {
394                o.enabled = true;
395                o.setting = 1;
396            })
397            .add(transient::<dyn Validate<TestOptions>, TestValidation>().from(|_| Ref::new(TestValidation::default())))
398            .build_provider()
399            .unwrap();
400
401        // act
402        let options = provider.get_required::<TestOptions>();
403
404        // assert
405        let _ = options;
406    }
407
408    #[test]
409    #[should_panic(expected = "Setting must be zero when disabled")]
410    fn get_required_should_panic_when_configured_options_are_invalid() {
411        // arrange
412        let provider = ServiceCollection::new()
413            .configure_options(|o: &mut TestOptions| {
414                o.enabled = false;
415                o.setting = 1;
416            })
417            .add(transient::<dyn Validate<TestOptions>, TestValidation>().from(|_| Ref::new(TestValidation::default())))
418            .build_provider()
419            .unwrap();
420
421        // act
422        let options = provider.get_required::<TestOptions>();
423
424        // assert
425        let _ = options;
426    }
427
428    #[test]
429    fn get_required_should_configure_options_with_1_dependency() {
430        // arrange
431        let provider = ServiceCollection::new()
432            .add_options::<TestOptions>()
433            .configure1(|o, d1: Ref<TestService>| o.setting = d1.next())
434            .add(existing_as_self(TestService::default()))
435            .build_provider()
436            .unwrap();
437
438        // act
439        let options = provider.get_required::<TestOptions>();
440
441        // assert
442        assert_eq!(options.setting, 1);
443    }
444
445    #[test]
446    fn get_required_should_configure_options_with_2_dependencies() {
447        // arrange
448        let provider = ServiceCollection::new()
449            .add_options::<TestOptions>()
450            .configure2(|o, d1: Ref<TestService>, d2: Ref<TestService>| o.setting = d1.next() + d2.next())
451            .add(existing_as_self(TestService::default()))
452            .build_provider()
453            .unwrap();
454
455        // act
456        let options = provider.get_required::<TestOptions>();
457
458        // assert
459        assert_eq!(options.setting, 3);
460    }
461
462    #[test]
463    fn get_required_should_configure_options_with_3_dependencies() {
464        // arrange
465        let provider = ServiceCollection::new()
466            .add_options::<TestOptions>()
467            .configure3(|o, d1: Ref<TestService>, d2: Ref<TestService>, d3: Ref<TestService>| {
468                o.setting = d1.next() + d2.next() + d3.next()
469            })
470            .add(existing_as_self(TestService::default()))
471            .build_provider()
472            .unwrap();
473
474        // act
475        let options = provider.get_required::<TestOptions>();
476
477        // assert
478        assert_eq!(options.setting, 6);
479    }
480
481    #[test]
482    fn get_required_should_configure_options_with_4_dependencies() {
483        // arrange
484        let provider = ServiceCollection::new()
485            .add_options::<TestOptions>()
486            .configure4(
487                |o, d1: Ref<TestService>, d2: Ref<TestService>, d3: Ref<TestService>, d4: Ref<TestService>| {
488                    o.setting = d1.next() + d2.next() + d3.next() + d4.next()
489                },
490            )
491            .add(existing_as_self(TestService::default()))
492            .build_provider()
493            .unwrap();
494
495        // act
496        let options = provider.get_required::<TestOptions>();
497
498        // assert
499        assert_eq!(options.setting, 10);
500    }
501
502    #[test]
503    fn get_required_should_configure_options_with_5_dependencies() {
504        // arrange
505        let provider = ServiceCollection::new()
506            .add_options::<TestOptions>()
507            .configure5(
508                |o,
509                 d1: Ref<TestService>,
510                 d2: Ref<TestService>,
511                 d3: Ref<TestService>,
512                 d4: Ref<TestService>,
513                 d5: Ref<TestService>| {
514                    o.setting = d1.next() + d2.next() + d3.next() + d4.next() + d5.next()
515                },
516            )
517            .add(existing_as_self(TestService::default()))
518            .build_provider()
519            .unwrap();
520
521        // act
522        let options = provider.get_required::<TestOptions>();
523
524        // assert
525        assert_eq!(options.setting, 15);
526    }
527
528    #[test]
529    fn get_required_should_post_configure_options_with_1_dependency() {
530        // arrange
531        let provider = ServiceCollection::new()
532            .add_options::<TestOptions>()
533            .post_configure1(|o, d1: Ref<TestService>| o.setting = d1.next())
534            .add(existing_as_self(TestService::default()))
535            .build_provider()
536            .unwrap();
537
538        // act
539        let options = provider.get_required::<TestOptions>();
540
541        // assert
542        assert_eq!(options.setting, 1);
543    }
544
545    #[test]
546    fn get_required_should_post_configure_options_with_2_dependencies() {
547        // arrange
548        let provider = ServiceCollection::new()
549            .add_options::<TestOptions>()
550            .post_configure2(|o, d1: Ref<TestService>, d2: Ref<TestService>| o.setting = d1.next() + d2.next())
551            .add(existing_as_self(TestService::default()))
552            .build_provider()
553            .unwrap();
554
555        // act
556        let options = provider.get_required::<TestOptions>();
557
558        // assert
559        assert_eq!(options.setting, 3);
560    }
561
562    #[test]
563    fn get_required_should_post_configure_options_with_3_dependencies() {
564        // arrange
565        let provider = ServiceCollection::new()
566            .add_options::<TestOptions>()
567            .post_configure3(|o, d1: Ref<TestService>, d2: Ref<TestService>, d3: Ref<TestService>| {
568                o.setting = d1.next() + d2.next() + d3.next()
569            })
570            .add(existing_as_self(TestService::default()))
571            .build_provider()
572            .unwrap();
573
574        // act
575        let options = provider.get_required::<TestOptions>();
576
577        // assert
578        assert_eq!(options.setting, 6);
579    }
580
581    #[test]
582    fn get_required_should_post_configure_options_with_4_dependencies() {
583        // arrange
584        let provider = ServiceCollection::new()
585            .add_options::<TestOptions>()
586            .post_configure4(
587                |o, d1: Ref<TestService>, d2: Ref<TestService>, d3: Ref<TestService>, d4: Ref<TestService>| {
588                    o.setting = d1.next() + d2.next() + d3.next() + d4.next()
589                },
590            )
591            .add(existing_as_self(TestService::default()))
592            .build_provider()
593            .unwrap();
594
595        // act
596        let options = provider.get_required::<TestOptions>();
597
598        // assert
599        assert_eq!(options.setting, 10);
600    }
601
602    #[test]
603    fn get_required_should_post_configure_options_with_5_dependencies() {
604        // arrange
605        let provider = ServiceCollection::new()
606            .add_options::<TestOptions>()
607            .post_configure5(
608                |o,
609                 d1: Ref<TestService>,
610                 d2: Ref<TestService>,
611                 d3: Ref<TestService>,
612                 d4: Ref<TestService>,
613                 d5: Ref<TestService>| {
614                    o.setting = d1.next() + d2.next() + d3.next() + d4.next() + d5.next()
615                },
616            )
617            .add(existing_as_self(TestService::default()))
618            .build_provider()
619            .unwrap();
620
621        // act
622        let options = provider.get_required::<TestOptions>();
623
624        // assert
625        assert_eq!(options.setting, 15);
626    }
627
628    #[test]
629    fn get_required_should_validate_options_with_1_dependency() {
630        // arrange
631        let provider = ServiceCollection::new()
632            .add_options::<TestOptions>()
633            .configure(|o| o.enabled = true)
634            .validate1(
635                |o, d1: Ref<TestService>| {
636                    let _ = d1.next();
637                    o.enabled
638                },
639                "Not enabled!",
640            )
641            .add(existing_as_self(TestService::default()))
642            .build_provider()
643            .unwrap();
644
645        // act
646        let options = provider.get_required::<TestOptions>();
647        let service = provider.get_required::<TestService>();
648
649        // assert
650        assert_eq!(options.enabled, true);
651        assert_eq!(service.calls(), 1);
652    }
653
654    #[test]
655    fn get_required_should_validate_options_with_2_dependencies() {
656        // arrange
657        let provider = ServiceCollection::new()
658            .add_options::<TestOptions>()
659            .configure(|o| o.enabled = true)
660            .validate2(
661                |o, d1: Ref<TestService>, d2: Ref<TestService>| {
662                    let _ = d1.next() + d2.next();
663                    o.enabled
664                },
665                "Not enabled!",
666            )
667            .add(existing_as_self(TestService::default()))
668            .build_provider()
669            .unwrap();
670
671        // act
672        let options = provider.get_required::<TestOptions>();
673        let service = provider.get_required::<TestService>();
674
675        // assert
676        assert_eq!(options.enabled, true);
677        assert_eq!(service.calls(), 2);
678    }
679
680    #[test]
681    fn get_required_should_validate_options_with_3_dependencies() {
682        // arrange
683        let provider = ServiceCollection::new()
684            .add_options::<TestOptions>()
685            .configure(|o| o.enabled = true)
686            .validate3(
687                |o, d1: Ref<TestService>, d2: Ref<TestService>, d3: Ref<TestService>| {
688                    let _ = d1.next() + d2.next() + d3.next();
689                    o.enabled
690                },
691                "Not enabled!",
692            )
693            .add(existing_as_self(TestService::default()))
694            .build_provider()
695            .unwrap();
696
697        // act
698        let options = provider.get_required::<TestOptions>();
699        let service = provider.get_required::<TestService>();
700
701        // assert
702        assert_eq!(options.enabled, true);
703        assert_eq!(service.calls(), 3);
704    }
705
706    #[test]
707    fn get_required_should_validate_options_with_4_dependencies() {
708        // arrange
709        let provider = ServiceCollection::new()
710            .add_options::<TestOptions>()
711            .configure(|o| o.enabled = true)
712            .validate4(
713                |o, d1: Ref<TestService>, d2: Ref<TestService>, d3: Ref<TestService>, d4: Ref<TestService>| {
714                    let _ = d1.next() + d2.next() + d3.next() + d4.next();
715                    o.enabled
716                },
717                "Not enabled!",
718            )
719            .add(existing_as_self(TestService::default()))
720            .build_provider()
721            .unwrap();
722
723        // act
724        let options = provider.get_required::<TestOptions>();
725        let service = provider.get_required::<TestService>();
726
727        // assert
728        assert_eq!(options.enabled, true);
729        assert_eq!(service.calls(), 4);
730    }
731
732    #[test]
733    fn get_required_should_validate_options_with_5_dependencies() {
734        // arrange
735        let provider = ServiceCollection::new()
736            .add_options::<TestOptions>()
737            .configure(|o| o.enabled = true)
738            .validate5(
739                |o,
740                 d1: Ref<TestService>,
741                 d2: Ref<TestService>,
742                 d3: Ref<TestService>,
743                 d4: Ref<TestService>,
744                 d5: Ref<TestService>| {
745                    let _ = d1.next() + d2.next() + d3.next() + d4.next() + d5.next();
746                    o.enabled
747                },
748                "Not enabled!",
749            )
750            .add(existing_as_self(TestService::default()))
751            .build_provider()
752            .unwrap();
753
754        // act
755        let options = provider.get_required::<TestOptions>();
756        let service = provider.get_required::<TestService>();
757
758        // assert
759        assert_eq!(options.enabled, true);
760        assert_eq!(service.calls(), 5);
761    }
762}