di/
validation.rs

1use crate::{
2    ServiceCardinality, ServiceCollection, ServiceDependency, ServiceDescriptor, ServiceLifetime,
3    Type,
4};
5use std::collections::{HashMap, HashSet};
6use std::fmt::{Display, Formatter};
7
8fn expand_type(t: &Type) -> String {
9    let (name, key) = Type::deconstruct(t);
10
11    match key {
12        Some(val) => format!("'{}' with the key '{}'", name, val),
13        _ => format!("'{}'", name),
14    }
15}
16
17#[derive(Clone, Debug)]
18struct ValidationResult {
19    message: String,
20}
21
22impl ValidationResult {
23    fn fail<T: AsRef<str>>(message: T) -> Self {
24        Self {
25            message: String::from(message.as_ref()),
26        }
27    }
28}
29
30/// Represents an validation error.
31#[derive(Clone, Debug)]
32pub struct ValidationError {
33    message: String,
34    results: Vec<ValidationResult>,
35}
36
37impl ValidationError {
38    fn fail(results: Vec<ValidationResult>) -> Self {
39        Self {
40            message: if results.is_empty() {
41                String::from("Validation failed.")
42            } else if results.len() == 1 {
43                results[0].message.clone()
44            } else {
45                String::from("One or more validation errors occurred.")
46            },
47            results,
48        }
49    }
50}
51
52impl Display for ValidationError {
53    fn fmt(&self, formatter: &mut Formatter) -> Result<(), std::fmt::Error> {
54        write!(formatter, "{}", self.message)?;
55
56        if self.results.len() > 1 {
57            for (i, result) in self.results.iter().enumerate() {
58                write!(formatter, "\n  [{}] {}", i + 1, result.message)?;
59            }
60        }
61
62        Ok(())
63    }
64}
65
66impl std::error::Error for ValidationError {
67    fn description(&self) -> &str {
68        "validation error"
69    }
70}
71
72trait ValidationRule<'a> {
73    fn evaluate(&mut self, descriptor: &'a ServiceDescriptor, results: &mut Vec<ValidationResult>);
74}
75
76struct MissingRequiredType<'a> {
77    lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>,
78}
79
80impl<'a> MissingRequiredType<'a> {
81    fn new(lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>) -> Self {
82        Self { lookup }
83    }
84}
85
86impl<'a> ValidationRule<'a> for MissingRequiredType<'a> {
87    fn evaluate(&mut self, descriptor: &'a ServiceDescriptor, results: &mut Vec<ValidationResult>) {
88        for dependency in descriptor.dependencies() {
89            if dependency.cardinality() == ServiceCardinality::ExactlyOne
90                && !self.lookup.contains_key(dependency.injected_type())
91            {
92                results.push(ValidationResult::fail(format!(
93                    "Service '{}' requires dependent service {}, which has not be registered",
94                    descriptor.implementation_type().name(),
95                    expand_type(dependency.injected_type())
96                )));
97            }
98        }
99    }
100}
101
102struct CircularDependency<'a> {
103    lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>,
104    visited: HashSet<&'a Type>,
105    queue: Vec<&'a ServiceDependency>,
106}
107
108impl<'a> CircularDependency<'a> {
109    fn new(lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>) -> Self {
110        Self {
111            lookup,
112            visited: HashSet::new(),
113            queue: Vec::new(),
114        }
115    }
116
117    fn check_dependency_graph(
118        &mut self,
119        root: &'a ServiceDescriptor,
120        dependency: &'a ServiceDependency,
121        results: &mut Vec<ValidationResult>,
122    ) {
123        self.queue.clear();
124        self.queue.push(dependency);
125
126        while let Some(current) = self.queue.pop() {
127            if let Some(descriptors) = self.lookup.get(current.injected_type()) {
128                for descriptor in descriptors {
129                    if self.visited.insert(descriptor.service_type()) {
130                        self.queue.extend(descriptor.dependencies());
131                    }
132
133                    if descriptor.service_type() == root.service_type() {
134                        results.push(ValidationResult::fail(format!(
135                            "A circular dependency was detected for service {} on service '{}'",
136                            expand_type(descriptor.service_type()),
137                            root.implementation_type().name()
138                        )));
139                    }
140                }
141            }
142        }
143    }
144}
145
146impl<'a> ValidationRule<'a> for CircularDependency<'a> {
147    fn evaluate(&mut self, descriptor: &'a ServiceDescriptor, results: &mut Vec<ValidationResult>) {
148        for dependency in descriptor.dependencies() {
149            self.visited.clear();
150            self.visited.insert(descriptor.service_type());
151            self.check_dependency_graph(descriptor, dependency, results);
152        }
153    }
154}
155
156struct SingletonDependsOnScoped<'a> {
157    lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>,
158    visited: HashSet<&'a Type>,
159    queue: Vec<&'a ServiceDescriptor>,
160}
161
162impl<'a> SingletonDependsOnScoped<'a> {
163    fn new(lookup: &'a HashMap<&'a Type, Vec<&'a ServiceDescriptor>>) -> Self {
164        Self {
165            lookup,
166            visited: HashSet::new(),
167            queue: Vec::new(),
168        }
169    }
170}
171
172impl<'a> ValidationRule<'a> for SingletonDependsOnScoped<'a> {
173    fn evaluate(&mut self, descriptor: &'a ServiceDescriptor, results: &mut Vec<ValidationResult>) {
174        if descriptor.lifetime() != ServiceLifetime::Singleton {
175            return;
176        }
177
178        let mut level = "";
179
180        self.visited.clear();
181        self.queue.clear();
182        self.queue.push(descriptor);
183
184        while let Some(current) = self.queue.pop() {
185            if !self.visited.insert(current.service_type()) {
186                continue;
187            }
188
189            for dependency in current.dependencies() {
190                if let Some(descriptors) = self.lookup.get(dependency.injected_type()) {
191                    for next in descriptors {
192                        self.queue.push(next);
193
194                        if next.lifetime() == ServiceLifetime::Scoped {
195                            results.push(ValidationResult::fail(format!(
196                                "The service {} has a singleton lifetime, \
197                                 but its {}dependency '{}' has a scoped lifetime",
198                                expand_type(descriptor.implementation_type()),
199                                level,
200                                next.service_type().name()
201                            )));
202                        }
203                    }
204                }
205            }
206
207            level = "transitive ";
208        }
209    }
210}
211
212/// Validates the specified [`ServiceCollection`](crate::ServiceCollection).
213///
214/// # Arguments
215///
216/// * `services` - The [`ServiceCollection`](crate::ServiceCollection) to validate
217pub fn validate(services: &ServiceCollection) -> Result<(), ValidationError> {
218    let mut lookup = HashMap::with_capacity(services.len());
219
220    for item in services.iter() {
221        let key = item.service_type();
222        let descriptors = lookup.entry(key).or_insert_with(Vec::new);
223        descriptors.push(item);
224    }
225
226    let mut results = Vec::new();
227    let mut missing_type = MissingRequiredType::new(&lookup);
228    let mut circular_dep = CircularDependency::new(&lookup);
229    let mut scoped_in_singleton = SingletonDependsOnScoped::new(&lookup);
230    let mut rules: Vec<&mut dyn ValidationRule> = vec![
231        &mut missing_type,
232        &mut circular_dep,
233        &mut scoped_in_singleton,
234    ];
235
236    for descriptor in services {
237        for rule in rules.iter_mut() {
238            rule.evaluate(descriptor, &mut results);
239        }
240    }
241
242    if results.is_empty() {
243        Ok(())
244    } else {
245        Err(ValidationError::fail(results))
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::{test::*, *};
253
254    #[test]
255    fn validate_should_report_missing_required_type() {
256        // arrange
257        let mut services = ServiceCollection::new();
258
259        services.add(
260            singleton::<dyn OtherTestService, OtherTestServiceImpl>()
261                .depends_on(exactly_one::<dyn TestService>())
262                .from(|sp| {
263                    Ref::new(OtherTestServiceImpl::new(
264                        sp.get_required::<dyn TestService>(),
265                    ))
266                }),
267        );
268
269        // act
270        let result = validate(&services);
271
272        // assert
273        assert_eq!(
274            &result.err().unwrap().to_string(),
275            "Service 'di::test::OtherTestServiceImpl' requires dependent service \
276             'dyn di::test::TestService', which has not be registered"
277        );
278    }
279
280    #[test]
281    fn validate_should_report_missing_required_keyed_type() {
282        // arrange
283        let mut services = ServiceCollection::new();
284
285        services
286            .add(
287                singleton_as_self::<CatInTheHat>()
288                    .depends_on(exactly_one_with_key::<key::Thing1, dyn Thing>())
289                    .depends_on(zero_or_one_with_key::<key::Thing2, dyn Thing>())
290                    .from(|sp| {
291                        Ref::new(CatInTheHat::new(
292                            sp.get_required_by_key::<key::Thing1, dyn Thing>(),
293                            sp.get_by_key::<key::Thing2, dyn Thing>(),
294                        ))
295                    }),
296            )
297            .add(
298                transient_with_key::<key::Thing2, dyn Thing, Thing2>()
299                    .from(|_| Ref::new(Thing2::default())),
300            );
301
302        // act
303        let result = validate(&services);
304
305        // assert
306        assert_eq!(
307            &result.err().unwrap().to_string(),
308            "Service 'di::test::CatInTheHat' requires dependent service \
309             'dyn di::test::Thing' with the key 'di::test::key::Thing1', which has not be registered"
310        );
311    }
312
313    #[test]
314    fn validate_should_ignore_missing_optional_type() {
315        // arrange
316        let mut services = ServiceCollection::new();
317
318        services.add(
319            singleton::<dyn OtherTestService, TestOptionalDepImpl>()
320                .depends_on(zero_or_one::<dyn TestService>())
321                .from(|sp| Ref::new(TestOptionalDepImpl::new(sp.get::<dyn TestService>()))),
322        );
323
324        // act
325        let result = validate(&services);
326
327        // assert
328        assert!(result.is_ok());
329    }
330
331    #[test]
332    fn validate_should_ignore_missing_optional_keyed_type() {
333        // arrange
334        let mut services = ServiceCollection::new();
335
336        services
337            .add(
338                singleton_as_self::<CatInTheHat>()
339                    .depends_on(exactly_one_with_key::<key::Thing1, dyn Thing>())
340                    .depends_on(zero_or_one_with_key::<key::Thing2, dyn Thing>())
341                    .from(|sp| {
342                        Ref::new(CatInTheHat::new(
343                            sp.get_required_by_key::<key::Thing1, dyn Thing>(),
344                            sp.get_by_key::<key::Thing2, dyn Thing>(),
345                        ))
346                    }),
347            )
348            .add(
349                transient_with_key::<key::Thing1, dyn Thing, Thing1>()
350                    .from(|_| Ref::new(Thing1::default())),
351            );
352
353        // act
354        let result = validate(&services);
355
356        // assert
357        assert!(result.is_ok());
358    }
359
360    #[test]
361    fn validate_should_report_circular_dependency() {
362        // arrange
363        let mut services = ServiceCollection::new();
364
365        services.add(
366            singleton::<dyn TestService, TestCircularDepImpl>()
367                .depends_on(exactly_one::<dyn TestService>())
368                .from(|sp| {
369                    Ref::new(TestCircularDepImpl::new(
370                        sp.get_required::<dyn TestService>(),
371                    ))
372                }),
373        );
374
375        // act
376        let result = validate(&services);
377
378        // assert
379        assert_eq!(
380            &result.err().unwrap().to_string(),
381            "A circular dependency was detected for service \
382             'dyn di::test::TestService' on service 'di::test::TestCircularDepImpl'"
383        );
384    }
385
386    #[test]
387    fn validate_should_report_multiple_issues() {
388        // arrange
389        let mut services = ServiceCollection::new();
390
391        services
392            .add(
393                singleton::<dyn TestService, TestAllKindOfProblems>()
394                    .depends_on(exactly_one::<dyn OtherTestService>())
395                    .depends_on(exactly_one::<dyn AnotherTestService>())
396                    .from(|sp| {
397                        Ref::new(TestAllKindOfProblems::new(
398                            sp.get_required::<dyn OtherTestService>(),
399                            sp.get_required::<dyn AnotherTestService>(),
400                        ))
401                    }),
402            )
403            .add(
404                singleton::<dyn OtherTestService, OtherTestServiceImpl>()
405                    .depends_on(exactly_one::<dyn TestService>())
406                    .from(|sp| {
407                        Ref::new(OtherTestServiceImpl::new(
408                            sp.get_required::<dyn TestService>(),
409                        ))
410                    }),
411            );
412
413        // act
414        let result = validate(&services);
415
416        // assert
417        assert_eq!(
418            &result.err().unwrap().to_string(),
419            "One or more validation errors occurred.\n  \
420              [1] Service 'di::test::TestAllKindOfProblems' requires dependent service 'dyn di::test::AnotherTestService', which has not be registered\n  \
421              [2] A circular dependency was detected for service 'dyn di::test::TestService' on service 'di::test::TestAllKindOfProblems'\n  \
422              [3] A circular dependency was detected for service 'dyn di::test::OtherTestService' on service 'di::test::OtherTestServiceImpl'");
423    }
424
425    #[test]
426    fn validate_should_report_scoped_service_in_singleton() {
427        // arrange
428        let mut services = ServiceCollection::new();
429
430        services
431            .add(
432                scoped::<dyn TestService, TestServiceImpl>()
433                    .from(|_| Ref::new(TestServiceImpl::default())),
434            )
435            .add(
436                singleton::<dyn OtherTestService, OtherTestServiceImpl>()
437                    .depends_on(exactly_one::<dyn TestService>())
438                    .from(|sp| {
439                        Ref::new(OtherTestServiceImpl::new(
440                            sp.get_required::<dyn TestService>(),
441                        ))
442                    }),
443            );
444
445        // act
446        let result = validate(&services);
447
448        // assert
449        assert_eq!(
450            &result.err().unwrap().to_string(),
451            "The service 'di::test::OtherTestServiceImpl' has a singleton lifetime, \
452             but its dependency 'dyn di::test::TestService' has a scoped lifetime"
453        );
454    }
455
456    #[test]
457    fn validate_should_report_transitive_scoped_service_in_singleton() {
458        // arrange
459        let mut services = ServiceCollection::new();
460
461        services
462            .add(
463                scoped::<dyn TestService, TestServiceImpl>()
464                    .from(|_| Ref::new(TestServiceImpl::default())),
465            )
466            .add(
467                transient::<dyn OtherTestService, OtherTestServiceImpl>()
468                    .depends_on(exactly_one::<dyn TestService>())
469                    .from(|sp| {
470                        Ref::new(OtherTestServiceImpl::new(
471                            sp.get_required::<dyn TestService>(),
472                        ))
473                    }),
474            )
475            .add(
476                singleton::<dyn AnotherTestService, AnotherTestServiceImpl>()
477                    .depends_on(exactly_one::<dyn OtherTestService>())
478                    .from(|sp| {
479                        Ref::new(AnotherTestServiceImpl::new(
480                            sp.get_required::<dyn OtherTestService>(),
481                        ))
482                    }),
483            );
484
485        // act
486        let result = validate(&services);
487
488        // assert
489        assert_eq!(
490            &result.err().unwrap().to_string(),
491            "The service 'di::test::AnotherTestServiceImpl' has a singleton lifetime, \
492             but its transitive dependency 'dyn di::test::TestService' has a scoped lifetime"
493        );
494    }
495
496    #[test]
497    fn validate_should_not_report_circular_dependency_when_visited_multiple_times() {
498        // arrange
499        let mut services = ServiceCollection::new();
500
501        services
502            .add(
503                singleton::<dyn ServiceM, ServiceMImpl>().from(|_sp| Ref::new(ServiceMImpl)),
504            )
505            .add(
506                singleton::<dyn ServiceB, ServiceBImpl>()
507                    .depends_on(exactly_one::<dyn ServiceM>())
508                    .from(|sp| {
509                        Ref::new(ServiceBImpl::new(sp.get_required::<dyn ServiceM>()))
510                    }),
511            )
512            .add(
513                singleton::<dyn ServiceC, ServiceCImpl>()
514                    .depends_on(exactly_one::<dyn ServiceM>())
515                    .from(|sp| {
516                        Ref::new(ServiceCImpl::new(sp.get_required::<dyn ServiceM>()))
517                    }),
518            )
519            .add(
520                singleton::<dyn ServiceA, ServiceAImpl>()
521                    .depends_on(exactly_one::<dyn ServiceM>())
522                    .depends_on(exactly_one::<dyn ServiceB>())
523                    .from(|sp| {
524                        Ref::new(ServiceAImpl::new(
525                            sp.get_required::<dyn ServiceM>(),
526                            sp.get_required::<dyn ServiceB>(),
527                        ))
528                    }),
529            )
530            .add(
531                singleton::<dyn ServiceY, ServiceYImpl>()
532                    .depends_on(exactly_one::<dyn ServiceM>())
533                    .depends_on(exactly_one::<dyn ServiceC>())
534                    .from(|sp| {
535                        Ref::new(ServiceYImpl::new(
536                            sp.get_required::<dyn ServiceM>(),
537                            sp.get_required::<dyn ServiceC>(),
538                        ))
539                    }),
540            )
541            .add(
542                singleton::<dyn ServiceX, ServiceXImpl>()
543                    .depends_on(exactly_one::<dyn ServiceM>())
544                    .depends_on(exactly_one::<dyn ServiceY>())
545                    .from(|sp| {
546                        Ref::new(ServiceXImpl::new(
547                            sp.get_required::<dyn ServiceM>(),
548                            sp.get_required::<dyn ServiceY>(),
549                        ))
550                    }),
551            )
552            .add(
553                singleton::<dyn ServiceZ, ServiceZImpl>()
554                    .depends_on(exactly_one::<dyn ServiceM>())
555                    .depends_on(exactly_one::<dyn ServiceA>())
556                    .depends_on(exactly_one::<dyn ServiceX>())
557                    .from(|sp| {
558                        Ref::new(ServiceZImpl::new(
559                            sp.get_required::<dyn ServiceM>(),
560                            sp.get_required::<dyn ServiceA>(),
561                            sp.get_required::<dyn ServiceX>(),
562                        ))
563                    }),
564            );
565
566        // act
567        let result = validate(&services);
568
569        // assert
570        assert!(result.is_ok());
571    }
572
573    #[test]
574    fn validate_should_report_circular_dependency_in_complex_dependency_tree() {
575        // arrange
576        let mut services = ServiceCollection::new();
577
578        services
579            .add(
580                singleton::<dyn ServiceM, ServiceMImpl>().from(|_sp| Ref::new(ServiceMImpl)),
581            )
582            .add(
583                singleton::<dyn ServiceB, ServiceBImpl>()
584                    .depends_on(exactly_one::<dyn ServiceM>())
585                    .from(|sp| {
586                        Ref::new(ServiceBImpl::new(sp.get_required::<dyn ServiceM>()))
587                    }),
588            )
589            .add(
590                singleton::<dyn ServiceC, ServiceCWithCircleRefToXImpl>()
591                    .depends_on(exactly_one::<dyn ServiceM>())
592                    .depends_on(exactly_one::<dyn ServiceX>())
593                    .from(|sp| {
594                        Ref::new(ServiceCWithCircleRefToXImpl::new(
595                            sp.get_required::<dyn ServiceM>(),
596                            sp.get_required::<dyn ServiceX>(),
597                        ))
598                    }),
599            )
600            .add(
601                singleton::<dyn ServiceA, ServiceAImpl>()
602                    .depends_on(exactly_one::<dyn ServiceM>())
603                    .depends_on(exactly_one::<dyn ServiceB>())
604                    .from(|sp| {
605                        Ref::new(ServiceAImpl::new(
606                            sp.get_required::<dyn ServiceM>(),
607                            sp.get_required::<dyn ServiceB>(),
608                        ))
609                    }),
610            )
611            .add(
612                singleton::<dyn ServiceY, ServiceYImpl>()
613                    .depends_on(exactly_one::<dyn ServiceM>())
614                    .depends_on(exactly_one::<dyn ServiceC>())
615                    .from(|sp| {
616                        Ref::new(ServiceYImpl::new(
617                            sp.get_required::<dyn ServiceM>(),
618                            sp.get_required::<dyn ServiceC>(),
619                        ))
620                    }),
621            )
622            .add(
623                singleton::<dyn ServiceX, ServiceXImpl>()
624                    .depends_on(exactly_one::<dyn ServiceM>())
625                    .depends_on(exactly_one::<dyn ServiceY>())
626                    .from(|sp| {
627                        Ref::new(ServiceXImpl::new(
628                            sp.get_required::<dyn ServiceM>(),
629                            sp.get_required::<dyn ServiceY>(),
630                        ))
631                    }),
632            )
633            .add(
634                singleton::<dyn ServiceZ, ServiceZImpl>()
635                    .depends_on(exactly_one::<dyn ServiceM>())
636                    .depends_on(exactly_one::<dyn ServiceA>())
637                    .depends_on(exactly_one::<dyn ServiceX>())
638                    .from(|sp| {
639                        Ref::new(ServiceZImpl::new(
640                            sp.get_required::<dyn ServiceM>(),
641                            sp.get_required::<dyn ServiceA>(),
642                            sp.get_required::<dyn ServiceX>(),
643                        ))
644                    }),
645            );
646
647        // act
648        let result = validate(&services);
649
650        // assert
651        assert!(result.is_err());
652    }
653}