Skip to main content

di/
validation.rs

1use crate::{ServiceCardinality, ServiceCollection, ServiceDependency, ServiceDescriptor, ServiceLifetime, Type};
2use std::collections::{HashMap, HashSet};
3use std::fmt::{Display, Formatter};
4
5fn expand_type(t: &Type) -> String {
6    let (name, key) = Type::deconstruct(t);
7
8    match key {
9        Some(val) => format!("'{name}' with the key '{val}'"),
10        _ => format!("'{name}'"),
11    }
12}
13
14#[derive(Clone, Debug)]
15struct ValidationResult {
16    message: String,
17}
18
19impl ValidationResult {
20    fn fail<T: AsRef<str>>(message: T) -> Self {
21        Self {
22            message: String::from(message.as_ref()),
23        }
24    }
25}
26
27/// Represents an validation error.
28#[derive(Clone, Debug)]
29pub struct ValidationError {
30    message: String,
31    results: Vec<ValidationResult>,
32}
33
34impl ValidationError {
35    fn fail(results: Vec<ValidationResult>) -> Self {
36        Self {
37            message: if results.is_empty() {
38                String::from("Validation failed.")
39            } else if results.len() == 1 {
40                results[0].message.clone()
41            } else {
42                String::from("One or more validation errors occurred.")
43            },
44            results,
45        }
46    }
47}
48
49impl Display for ValidationError {
50    fn fmt(&self, formatter: &mut Formatter) -> Result<(), std::fmt::Error> {
51        write!(formatter, "{}", self.message)?;
52
53        if self.results.len() > 1 {
54            for (i, result) in self.results.iter().enumerate() {
55                write!(formatter, "\n  [{}] {}", i + 1, result.message)?;
56            }
57        }
58
59        Ok(())
60    }
61}
62
63impl std::error::Error for ValidationError {
64    fn description(&self) -> &str {
65        "validation error"
66    }
67}
68
69trait ValidationRule<'a> {
70    fn evaluate(&mut self, descriptor: &'a ServiceDescriptor, results: &mut Vec<ValidationResult>);
71}
72
73struct MissingRequiredType<'a> {
74    lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>,
75}
76
77impl<'a> MissingRequiredType<'a> {
78    fn new(lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>) -> Self {
79        Self { lookup }
80    }
81}
82
83impl<'a> ValidationRule<'a> for MissingRequiredType<'a> {
84    fn evaluate(&mut self, descriptor: &'a ServiceDescriptor, results: &mut Vec<ValidationResult>) {
85        for dependency in descriptor.dependencies() {
86            if dependency.cardinality() == ServiceCardinality::ExactlyOne
87                && !self.lookup.contains_key(dependency.injected_type())
88            {
89                results.push(ValidationResult::fail(format!(
90                    "Service '{}' requires dependent service {}, which has not be registered",
91                    descriptor.implementation_type().name(),
92                    expand_type(dependency.injected_type())
93                )));
94            }
95        }
96    }
97}
98
99struct CircularDependency<'a> {
100    lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>,
101    visited: HashSet<&'a Type>,
102    queue: Vec<&'a ServiceDependency>,
103}
104
105impl<'a> CircularDependency<'a> {
106    fn new(lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>) -> Self {
107        Self {
108            lookup,
109            visited: HashSet::new(),
110            queue: Vec::new(),
111        }
112    }
113
114    fn check_dependency_graph(
115        &mut self,
116        root: &'a ServiceDescriptor,
117        dependency: &'a ServiceDependency,
118        results: &mut Vec<ValidationResult>,
119    ) {
120        self.queue.clear();
121        self.queue.push(dependency);
122
123        while let Some(current) = self.queue.pop() {
124            if let Some(descriptors) = self.lookup.get(current.injected_type()) {
125                for descriptor in descriptors {
126                    if self.visited.insert(descriptor.service_type()) {
127                        self.queue.extend(descriptor.dependencies());
128                    }
129
130                    if descriptor.service_type() == root.service_type() {
131                        results.push(ValidationResult::fail(format!(
132                            "A circular dependency was detected for service {} on service '{}'",
133                            expand_type(descriptor.service_type()),
134                            root.implementation_type().name()
135                        )));
136                    }
137                }
138            }
139        }
140    }
141}
142
143impl<'a> ValidationRule<'a> for CircularDependency<'a> {
144    fn evaluate(&mut self, descriptor: &'a ServiceDescriptor, results: &mut Vec<ValidationResult>) {
145        for dependency in descriptor.dependencies() {
146            self.visited.clear();
147            self.visited.insert(descriptor.service_type());
148            self.check_dependency_graph(descriptor, dependency, results);
149        }
150    }
151}
152
153struct SingletonDependsOnScoped<'a> {
154    lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>,
155    visited: HashSet<&'a Type>,
156    queue: Vec<&'a ServiceDescriptor>,
157}
158
159impl<'a> SingletonDependsOnScoped<'a> {
160    fn new(lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>) -> Self {
161        Self {
162            lookup,
163            visited: HashSet::new(),
164            queue: Vec::new(),
165        }
166    }
167}
168
169impl<'a> ValidationRule<'a> for SingletonDependsOnScoped<'a> {
170    fn evaluate(&mut self, descriptor: &'a ServiceDescriptor, results: &mut Vec<ValidationResult>) {
171        if descriptor.lifetime() != ServiceLifetime::Singleton {
172            return;
173        }
174
175        let mut level = "";
176
177        self.visited.clear();
178        self.queue.clear();
179        self.queue.push(descriptor);
180
181        while let Some(current) = self.queue.pop() {
182            if !self.visited.insert(current.service_type()) {
183                continue;
184            }
185
186            for dependency in current.dependencies() {
187                if let Some(descriptors) = self.lookup.get(dependency.injected_type()) {
188                    for next in descriptors {
189                        self.queue.push(next);
190
191                        if next.lifetime() == ServiceLifetime::Scoped {
192                            results.push(ValidationResult::fail(format!(
193                                "The service {} has a singleton lifetime, \
194                                 but its {level}dependency '{}' has a scoped lifetime",
195                                expand_type(descriptor.implementation_type()),
196                                next.service_type().name()
197                            )));
198                        }
199                    }
200                }
201            }
202
203            level = "transitive ";
204        }
205    }
206}
207
208/// Validates the specified [ServiceCollection].
209///
210/// # Arguments
211///
212/// * `services` - The [service collection](ServiceCollection) to validate
213pub fn validate(services: &ServiceCollection) -> Result<(), ValidationError> {
214    let mut lookup = HashMap::with_capacity(services.len());
215
216    for item in services.iter() {
217        let key = item.service_type();
218        let descriptors = lookup.entry(key).or_insert_with(Vec::new);
219        descriptors.push(item);
220    }
221
222    let mut results = Vec::new();
223    let mut missing_type = MissingRequiredType::new(&lookup);
224    let mut circular_dep = CircularDependency::new(&lookup);
225    let mut scoped_in_singleton = SingletonDependsOnScoped::new(&lookup);
226    let mut rules: Vec<&mut dyn ValidationRule> = vec![&mut missing_type, &mut circular_dep, &mut scoped_in_singleton];
227
228    for descriptor in services {
229        for rule in rules.iter_mut() {
230            rule.evaluate(descriptor, &mut results);
231        }
232    }
233
234    if results.is_empty() {
235        Ok(())
236    } else {
237        Err(ValidationError::fail(results))
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::{
245        exactly_one, exactly_one_with_key, scoped, singleton, singleton_as_self, test::*, transient,
246        transient_with_key, zero_or_one, zero_or_one_with_key, Ref,
247    };
248
249    #[test]
250    fn validate_should_report_missing_required_type() {
251        // arrange
252        let mut services = ServiceCollection::new();
253
254        services.add(
255            singleton::<dyn OtherTestService, OtherTestServiceImpl>()
256                .depends_on(exactly_one::<dyn TestService>())
257                .from(|sp| Ref::new(OtherTestServiceImpl::new(sp.get_required::<dyn TestService>()))),
258        );
259
260        // act
261        let result = validate(&services);
262
263        // assert
264        assert_eq!(
265            &result.err().unwrap().to_string(),
266            "Service 'di::test::OtherTestServiceImpl' requires dependent service \
267             'dyn di::test::TestService', which has not be registered"
268        );
269    }
270
271    #[test]
272    fn validate_should_report_missing_required_keyed_type() {
273        // arrange
274        let mut services = ServiceCollection::new();
275
276        services
277            .add(
278                singleton_as_self::<CatInTheHat>()
279                    .depends_on(exactly_one_with_key::<key::Thing1, dyn Thing>())
280                    .depends_on(zero_or_one_with_key::<key::Thing2, dyn Thing>())
281                    .from(|sp| {
282                        Ref::new(CatInTheHat::new(
283                            sp.get_required_by_key::<key::Thing1, dyn Thing>(),
284                            sp.get_by_key::<key::Thing2, dyn Thing>(),
285                        ))
286                    }),
287            )
288            .add(transient_with_key::<key::Thing2, dyn Thing, Thing2>().from(|_| Ref::new(Thing2::default())));
289
290        // act
291        let result = validate(&services);
292
293        // assert
294        assert_eq!(
295            &result.err().unwrap().to_string(),
296            "Service 'di::test::CatInTheHat' requires dependent service \
297             'dyn di::test::Thing' with the key 'di::test::key::Thing1', which has not be registered"
298        );
299    }
300
301    #[test]
302    fn validate_should_ignore_missing_optional_type() {
303        // arrange
304        let mut services = ServiceCollection::new();
305
306        services.add(
307            singleton::<dyn OtherTestService, TestOptionalDepImpl>()
308                .depends_on(zero_or_one::<dyn TestService>())
309                .from(|sp| Ref::new(TestOptionalDepImpl::new(sp.get::<dyn TestService>()))),
310        );
311
312        // act
313        let result = validate(&services);
314
315        // assert
316        assert!(result.is_ok());
317    }
318
319    #[test]
320    fn validate_should_ignore_missing_optional_keyed_type() {
321        // arrange
322        let mut services = ServiceCollection::new();
323
324        services
325            .add(
326                singleton_as_self::<CatInTheHat>()
327                    .depends_on(exactly_one_with_key::<key::Thing1, dyn Thing>())
328                    .depends_on(zero_or_one_with_key::<key::Thing2, dyn Thing>())
329                    .from(|sp| {
330                        Ref::new(CatInTheHat::new(
331                            sp.get_required_by_key::<key::Thing1, dyn Thing>(),
332                            sp.get_by_key::<key::Thing2, dyn Thing>(),
333                        ))
334                    }),
335            )
336            .add(transient_with_key::<key::Thing1, dyn Thing, Thing1>().from(|_| Ref::new(Thing1::default())));
337
338        // act
339        let result = validate(&services);
340
341        // assert
342        assert!(result.is_ok());
343    }
344
345    #[test]
346    fn validate_should_report_circular_dependency() {
347        // arrange
348        let mut services = ServiceCollection::new();
349
350        services.add(
351            singleton::<dyn TestService, TestCircularDepImpl>()
352                .depends_on(exactly_one::<dyn TestService>())
353                .from(|sp| Ref::new(TestCircularDepImpl::new(sp.get_required::<dyn TestService>()))),
354        );
355
356        // act
357        let result = validate(&services);
358
359        // assert
360        assert_eq!(
361            &result.err().unwrap().to_string(),
362            "A circular dependency was detected for service \
363             'dyn di::test::TestService' on service 'di::test::TestCircularDepImpl'"
364        );
365    }
366
367    #[test]
368    fn validate_should_report_multiple_issues() {
369        // arrange
370        let mut services = ServiceCollection::new();
371
372        services
373            .add(
374                singleton::<dyn TestService, TestAllKindOfProblems>()
375                    .depends_on(exactly_one::<dyn OtherTestService>())
376                    .depends_on(exactly_one::<dyn AnotherTestService>())
377                    .from(|sp| {
378                        Ref::new(TestAllKindOfProblems::new(
379                            sp.get_required::<dyn OtherTestService>(),
380                            sp.get_required::<dyn AnotherTestService>(),
381                        ))
382                    }),
383            )
384            .add(
385                singleton::<dyn OtherTestService, OtherTestServiceImpl>()
386                    .depends_on(exactly_one::<dyn TestService>())
387                    .from(|sp| Ref::new(OtherTestServiceImpl::new(sp.get_required::<dyn TestService>()))),
388            );
389
390        // act
391        let result = validate(&services);
392
393        // assert
394        assert_eq!(
395            &result.err().unwrap().to_string(),
396            "One or more validation errors occurred.\n  \
397              [1] Service 'di::test::TestAllKindOfProblems' requires dependent service 'dyn di::test::AnotherTestService', which has not be registered\n  \
398              [2] A circular dependency was detected for service 'dyn di::test::TestService' on service 'di::test::TestAllKindOfProblems'\n  \
399              [3] A circular dependency was detected for service 'dyn di::test::OtherTestService' on service 'di::test::OtherTestServiceImpl'");
400    }
401
402    #[test]
403    fn validate_should_report_scoped_service_in_singleton() {
404        // arrange
405        let mut services = ServiceCollection::new();
406
407        services
408            .add(scoped::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
409            .add(
410                singleton::<dyn OtherTestService, OtherTestServiceImpl>()
411                    .depends_on(exactly_one::<dyn TestService>())
412                    .from(|sp| Ref::new(OtherTestServiceImpl::new(sp.get_required::<dyn TestService>()))),
413            );
414
415        // act
416        let result = validate(&services);
417
418        // assert
419        assert_eq!(
420            &result.err().unwrap().to_string(),
421            "The service 'di::test::OtherTestServiceImpl' has a singleton lifetime, \
422             but its dependency 'dyn di::test::TestService' has a scoped lifetime"
423        );
424    }
425
426    #[test]
427    fn validate_should_report_transitive_scoped_service_in_singleton() {
428        // arrange
429        let mut services = ServiceCollection::new();
430
431        services
432            .add(scoped::<dyn TestService, TestServiceImpl>().from(|_| Ref::new(TestServiceImpl::default())))
433            .add(
434                transient::<dyn OtherTestService, OtherTestServiceImpl>()
435                    .depends_on(exactly_one::<dyn TestService>())
436                    .from(|sp| Ref::new(OtherTestServiceImpl::new(sp.get_required::<dyn TestService>()))),
437            )
438            .add(
439                singleton::<dyn AnotherTestService, AnotherTestServiceImpl>()
440                    .depends_on(exactly_one::<dyn OtherTestService>())
441                    .from(|sp| Ref::new(AnotherTestServiceImpl::new(sp.get_required::<dyn OtherTestService>()))),
442            );
443
444        // act
445        let result = validate(&services);
446
447        // assert
448        assert_eq!(
449            &result.err().unwrap().to_string(),
450            "The service 'di::test::AnotherTestServiceImpl' has a singleton lifetime, \
451             but its transitive dependency 'dyn di::test::TestService' has a scoped lifetime"
452        );
453    }
454
455    #[test]
456    fn validate_should_not_report_circular_dependency_when_visited_multiple_times() {
457        // arrange
458        let mut services = ServiceCollection::new();
459
460        services
461            .add(singleton::<dyn ServiceM, ServiceMImpl>().from(|_sp| Ref::new(ServiceMImpl)))
462            .add(
463                singleton::<dyn ServiceB, ServiceBImpl>()
464                    .depends_on(exactly_one::<dyn ServiceM>())
465                    .from(|sp| Ref::new(ServiceBImpl::new(sp.get_required::<dyn ServiceM>()))),
466            )
467            .add(
468                singleton::<dyn ServiceC, ServiceCImpl>()
469                    .depends_on(exactly_one::<dyn ServiceM>())
470                    .from(|sp| Ref::new(ServiceCImpl::new(sp.get_required::<dyn ServiceM>()))),
471            )
472            .add(
473                singleton::<dyn ServiceA, ServiceAImpl>()
474                    .depends_on(exactly_one::<dyn ServiceM>())
475                    .depends_on(exactly_one::<dyn ServiceB>())
476                    .from(|sp| {
477                        Ref::new(ServiceAImpl::new(
478                            sp.get_required::<dyn ServiceM>(),
479                            sp.get_required::<dyn ServiceB>(),
480                        ))
481                    }),
482            )
483            .add(
484                singleton::<dyn ServiceY, ServiceYImpl>()
485                    .depends_on(exactly_one::<dyn ServiceM>())
486                    .depends_on(exactly_one::<dyn ServiceC>())
487                    .from(|sp| {
488                        Ref::new(ServiceYImpl::new(
489                            sp.get_required::<dyn ServiceM>(),
490                            sp.get_required::<dyn ServiceC>(),
491                        ))
492                    }),
493            )
494            .add(
495                singleton::<dyn ServiceX, ServiceXImpl>()
496                    .depends_on(exactly_one::<dyn ServiceM>())
497                    .depends_on(exactly_one::<dyn ServiceY>())
498                    .from(|sp| {
499                        Ref::new(ServiceXImpl::new(
500                            sp.get_required::<dyn ServiceM>(),
501                            sp.get_required::<dyn ServiceY>(),
502                        ))
503                    }),
504            )
505            .add(
506                singleton::<dyn ServiceZ, ServiceZImpl>()
507                    .depends_on(exactly_one::<dyn ServiceM>())
508                    .depends_on(exactly_one::<dyn ServiceA>())
509                    .depends_on(exactly_one::<dyn ServiceX>())
510                    .from(|sp| {
511                        Ref::new(ServiceZImpl::new(
512                            sp.get_required::<dyn ServiceM>(),
513                            sp.get_required::<dyn ServiceA>(),
514                            sp.get_required::<dyn ServiceX>(),
515                        ))
516                    }),
517            );
518
519        // act
520        let result = validate(&services);
521
522        // assert
523        assert!(result.is_ok());
524    }
525
526    #[test]
527    fn validate_should_report_circular_dependency_in_complex_dependency_tree() {
528        // arrange
529        let mut services = ServiceCollection::new();
530
531        services
532            .add(singleton::<dyn ServiceM, ServiceMImpl>().from(|_sp| Ref::new(ServiceMImpl)))
533            .add(
534                singleton::<dyn ServiceB, ServiceBImpl>()
535                    .depends_on(exactly_one::<dyn ServiceM>())
536                    .from(|sp| Ref::new(ServiceBImpl::new(sp.get_required::<dyn ServiceM>()))),
537            )
538            .add(
539                singleton::<dyn ServiceC, ServiceCWithCircleRefToXImpl>()
540                    .depends_on(exactly_one::<dyn ServiceM>())
541                    .depends_on(exactly_one::<dyn ServiceX>())
542                    .from(|sp| {
543                        Ref::new(ServiceCWithCircleRefToXImpl::new(
544                            sp.get_required::<dyn ServiceM>(),
545                            sp.get_required::<dyn ServiceX>(),
546                        ))
547                    }),
548            )
549            .add(
550                singleton::<dyn ServiceA, ServiceAImpl>()
551                    .depends_on(exactly_one::<dyn ServiceM>())
552                    .depends_on(exactly_one::<dyn ServiceB>())
553                    .from(|sp| {
554                        Ref::new(ServiceAImpl::new(
555                            sp.get_required::<dyn ServiceM>(),
556                            sp.get_required::<dyn ServiceB>(),
557                        ))
558                    }),
559            )
560            .add(
561                singleton::<dyn ServiceY, ServiceYImpl>()
562                    .depends_on(exactly_one::<dyn ServiceM>())
563                    .depends_on(exactly_one::<dyn ServiceC>())
564                    .from(|sp| {
565                        Ref::new(ServiceYImpl::new(
566                            sp.get_required::<dyn ServiceM>(),
567                            sp.get_required::<dyn ServiceC>(),
568                        ))
569                    }),
570            )
571            .add(
572                singleton::<dyn ServiceX, ServiceXImpl>()
573                    .depends_on(exactly_one::<dyn ServiceM>())
574                    .depends_on(exactly_one::<dyn ServiceY>())
575                    .from(|sp| {
576                        Ref::new(ServiceXImpl::new(
577                            sp.get_required::<dyn ServiceM>(),
578                            sp.get_required::<dyn ServiceY>(),
579                        ))
580                    }),
581            )
582            .add(
583                singleton::<dyn ServiceZ, ServiceZImpl>()
584                    .depends_on(exactly_one::<dyn ServiceM>())
585                    .depends_on(exactly_one::<dyn ServiceA>())
586                    .depends_on(exactly_one::<dyn ServiceX>())
587                    .from(|sp| {
588                        Ref::new(ServiceZImpl::new(
589                            sp.get_required::<dyn ServiceM>(),
590                            sp.get_required::<dyn ServiceA>(),
591                            sp.get_required::<dyn ServiceX>(),
592                        ))
593                    }),
594            );
595
596        // act
597        let result = validate(&services);
598
599        // assert
600        assert!(result.is_err());
601    }
602}