ferrunix_core/
cycle_detection.rs

1//! Implementation of a cycle detection algorithm for our dependency resolution algorithm.
2
3use std::any::TypeId;
4
5use crate::dependency_builder::{self, DepBuilder};
6use crate::types::{
7    HashMap, NonAsyncRwLock, Registerable, RegisterableSingleton, Visitor,
8};
9
10/// All possible errors during validation.
11#[derive(Debug, Clone, PartialEq, Hash)]
12#[non_exhaustive]
13pub enum ValidationError {
14    /// A cycle between dependencies has been detected.
15    Cycle,
16    /// Dependencies are missing.
17    Missing,
18}
19
20impl std::fmt::Display for ValidationError {
21    #[allow(clippy::use_debug)]
22    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Cycle => write!(fmt, "cycle detected!"),
25            Self::Missing => write!(fmt, "dependencies missing!"),
26        }
27    }
28}
29
30impl std::error::Error for ValidationError {
31    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32        None
33    }
34}
35
36/// Detailed validation errors.
37#[derive(Debug, Clone, PartialEq, Hash)]
38#[non_exhaustive]
39pub enum FullValidationError {
40    /// A cycle between dependencies has been detected.
41    Cycle(Option<String>),
42    /// Dependencies are missing.
43    Missing(Vec<MissingDependencies>),
44}
45
46impl std::fmt::Display for FullValidationError {
47    #[allow(clippy::use_debug)]
48    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            Self::Cycle(node) => match node {
51                Some(node) => write!(fmt, "cycle detected at {node}"),
52                None => write!(fmt, "cycle detected!"),
53            },
54            Self::Missing(all_missing) => {
55                writeln!(fmt, "dependencies missing:")?;
56
57                for missing in all_missing {
58                    writeln!(
59                        fmt,
60                        "dependencies missing for {} ({:?}):",
61                        missing.ty.1, missing.ty.0
62                    )?;
63                    for (type_id, type_name) in &missing.deps {
64                        writeln!(fmt, " - {type_name} ({type_id:?})")?;
65                    }
66                    writeln!(fmt, "\n")?;
67                }
68
69                Ok(())
70            }
71        }
72    }
73}
74
75impl std::error::Error for FullValidationError {
76    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
77        None
78    }
79}
80
81/// All missing `deps` for type `ty`.
82#[derive(Debug, Clone, PartialEq, Eq, Hash)]
83pub struct MissingDependencies {
84    /// This is the type that has missing dependencies.
85    pub(crate) ty: (TypeId, &'static str),
86    /// These are the missing dependencies of `ty`.
87    pub(crate) deps: Vec<(TypeId, &'static str)>,
88}
89
90impl MissingDependencies {
91    /// Returns a reference to a tuple of the [`std::any::TypeId`] and the type name (as returned
92    /// from [`std::any::type_name`], therefore, it's "best effort", and might not be correct or
93    /// reproducible).
94    ///
95    /// This is the type that has missing dependencies.
96    pub fn ty(&self) -> &(TypeId, &'static str) {
97        &self.ty
98    }
99
100    /// Returns a reference to a slice of a description of all dependencies that are missing.
101    pub fn missing_dependencies(&self) -> &[(TypeId, &'static str)] {
102        &self.deps
103    }
104}
105
106/// Validation whether all dependencies are registered, and the dependency chain has no cycles.
107pub(crate) struct DependencyValidator {
108    /// The visitor callbacks. Those are necessary because we only want to register each type once
109    /// we have collected them all.
110    visitor: NonAsyncRwLock<HashMap<TypeId, Visitor>>,
111    /// Context for visitors.
112    context: NonAsyncRwLock<VisitorContext>,
113}
114
115impl DependencyValidator {
116    /// Create a new dependency validator.
117    pub(crate) fn new() -> Self {
118        Self {
119            visitor: NonAsyncRwLock::new(HashMap::new()),
120            context: NonAsyncRwLock::new(VisitorContext::new()),
121        }
122    }
123
124    /// Register a new transient, without any dependencies.
125    pub(crate) fn add_transient_no_deps<T>(&self)
126    where
127        T: Registerable,
128    {
129        let visitor = Visitor(|_this, _visitors, context| {
130            if let Some(index) = context.visited.get(&TypeId::of::<T>()) {
131                return *index;
132            }
133
134            let index = context.graph.add_node(std::any::type_name::<T>());
135
136            context.visited.insert(TypeId::of::<T>(), index);
137
138            index
139        });
140
141        {
142            let mut visitors = self.visitor.write();
143            visitors.insert(TypeId::of::<T>(), visitor);
144            {
145                let mut context = self.context.write();
146                context.reset();
147            }
148        }
149    }
150
151    /// Register a new singleton, without any dependencies.
152    pub(crate) fn add_singleton_no_deps<T>(&self)
153    where
154        T: RegisterableSingleton,
155    {
156        self.add_transient_no_deps::<T>();
157    }
158
159    /// Register a new transient, with dependencies specified via `Deps`.
160    pub(crate) fn add_transient_deps<
161        T: Registerable,
162        #[cfg(not(feature = "tokio"))] Deps: DepBuilder<T> + 'static,
163        #[cfg(feature = "tokio")] Deps: DepBuilder<T> + Sync + 'static,
164    >(
165        &self,
166    ) {
167        let visitor = Visitor(|this, visitors, context| {
168            // We already visited this type.
169            if let Some(index) = context.visited.get(&TypeId::of::<T>()) {
170                return *index;
171            }
172
173            let current = context.graph.add_node(std::any::type_name::<T>());
174
175            // We visited this type. This must be added before we visit dependencies.
176            {
177                context.visited.insert(TypeId::of::<T>(), current);
178            }
179
180            let type_ids =
181                Deps::as_typeids(dependency_builder::private::SealToken);
182
183            for (type_id, type_name) in &type_ids {
184                // We have been to the dependency type before, we don't need to do it again.
185                if let Some(index) = context.visited.get(type_id) {
186                    context.graph.add_edge(current, *index, ());
187                    continue;
188                }
189
190                // Never seen the type before, visit it.
191                if let Some(visitor) = visitors.get(type_id) {
192                    let index = (visitor.0)(this, visitors, context);
193                    context.graph.add_edge(current, index, ());
194                    continue;
195                }
196
197                {
198                    if let Some(ty) =
199                        context.missing.get_mut(&TypeId::of::<T>())
200                    {
201                        ty.deps.push((*type_id, type_name));
202                    } else {
203                        context.missing.insert(
204                            TypeId::of::<T>(),
205                            MissingDependencies {
206                                ty: (
207                                    TypeId::of::<T>(),
208                                    std::any::type_name::<T>(),
209                                ),
210                                deps: vec![(*type_id, type_name)],
211                            },
212                        );
213                    }
214                }
215
216                #[cfg(feature = "tracing")]
217                tracing::warn!(
218                    "couldn't add dependency of {}: {type_name}",
219                    std::any::type_name::<T>()
220                );
221            }
222
223            current
224        });
225
226        {
227            let mut visitors = self.visitor.write();
228            visitors.insert(TypeId::of::<T>(), visitor);
229            {
230                let mut context = self.context.write();
231                context.reset();
232            }
233        }
234    }
235
236    /// Register a new singleton, with dependencies specified via `Deps`.
237    pub(crate) fn add_singleton_deps<
238        T: RegisterableSingleton,
239        #[cfg(not(feature = "tokio"))] Deps: DepBuilder<T> + 'static,
240        #[cfg(feature = "tokio")] Deps: DepBuilder<T> + Sync + 'static,
241    >(
242        &self,
243    ) {
244        self.add_transient_deps::<T, Deps>();
245    }
246
247    /// Walk the dependency graph and validate that all types can be constructed, all dependencies
248    /// are fulfillable and there are no cycles in the graph.
249    pub(crate) fn validate_all(&self) -> Result<(), ValidationError> {
250        let read_context = self.context.read();
251        if Self::validate_context(&read_context)? {
252            // Validation result is still cached.
253            return Ok(());
254        }
255
256        // No validation result is cached, drop the read lock and acquire an exclusive lock to
257        // update the cached validation result.
258        drop(read_context);
259        let visitors = self.visitor.read();
260        let mut write_context = self.context.write();
261        if Self::validate_context(&write_context)? {
262            // Context was updated by another thread while we waited for the exclusive write lock
263            // to be acquired.
264            return Ok(());
265        }
266
267        // Validation did not run, we need to run it.
268        self.calculate_validation(&visitors, &mut write_context);
269
270        // Throws an error if our dependency graph is invalid.
271        Self::validate_context(&write_context)?;
272
273        Ok(())
274    }
275
276    /// Walk the dependency graph and validate that all types can be constructed, all dependencies
277    /// are fulfillable and there are no cycles in the graph.
278    pub(crate) fn validate_all_full(&self) -> Result<(), FullValidationError> {
279        let mut context = VisitorContext::new();
280        {
281            let visitors = self.visitor.read();
282            self.calculate_validation(&visitors, &mut context);
283        }
284
285        // Evaluate whether we want to make this available via an option? It takes ages to
286        // calculate!
287        // let tarjan = petgraph::algo::tarjan_scc(&context.graph);
288        // dbg!(&tarjan);
289
290        if !context.missing.is_empty() {
291            let mut vec = Vec::with_capacity(context.missing.len());
292            context.missing.iter().for_each(|(_, ty)| {
293                vec.push(ty.clone());
294            });
295            return Err(FullValidationError::Missing(vec));
296        }
297
298        if let Some(cached) = &context.validation_cache {
299            return match cached {
300                Ok(_) => Ok(()),
301                Err(err) => {
302                    let index = err.node_id();
303                    let node_name = context.graph.node_weight(index);
304                    return Err(FullValidationError::Cycle(
305                        node_name.map(|el| (*el).to_owned()),
306                    ));
307                }
308            };
309        }
310
311        unreachable!("this is a bug")
312    }
313
314    /// Inspect `context`, and return a [`ValidationError`] if there are errors in the dependency
315    /// graph.
316    ///
317    /// Returns `Ok(true)` if the validation result is cached.
318    /// Returns `Ok(false)` if the validation result is outdated and needs to be recalculated.
319    fn validate_context(
320        context: &VisitorContext,
321    ) -> Result<bool, ValidationError> {
322        if !context.missing.is_empty() {
323            return Err(ValidationError::Missing);
324        }
325
326        if let Some(cached) = &context.validation_cache {
327            return match cached {
328                Ok(_) => Ok(true),
329                Err(_) => Err(ValidationError::Cycle),
330            };
331        }
332
333        Ok(false)
334    }
335
336    /// Visit all visitors in `self.visitor`, and create the new dependency graph.
337    fn calculate_validation(
338        &self,
339        visitors: &HashMap<TypeId, Visitor>,
340        context: &mut VisitorContext,
341    ) {
342        {
343            for cb in visitors.values() {
344                // To avoid a dead lock due to other visitors needing to be called, we pass in the
345                // visitors hashmap.
346                (cb.0)(self, visitors, context);
347            }
348        }
349
350        // We only calculate whether we have
351        let mut space = petgraph::algo::DfsSpace::new(&context.graph);
352        context.validation_cache =
353            Some(petgraph::algo::toposort(&context.graph, Some(&mut space)));
354    }
355
356    /// Validate whether the type `T` is constructible.
357    pub(crate) fn validate<T>(&self) -> Result<(), ValidationError>
358    where
359        T: Registerable,
360    {
361        let _ = std::marker::PhantomData::<T>;
362        self.validate_all()
363    }
364
365    /// Return a string of the dependency graph visualized using graphviz's `dot` language.
366    pub(crate) fn dotgraph(&self) -> Result<String, ValidationError> {
367        self.validate_all()?;
368
369        let context = self.context.read();
370        let dot = petgraph::dot::Dot::with_config(
371            &context.graph,
372            &[petgraph::dot::Config::EdgeNoLabel],
373        );
374
375        Ok(format!("{dot:?}"))
376    }
377}
378
379/// Context that's passed into every `visitor`.
380pub(crate) struct VisitorContext {
381    /// Dependency graph.
382    graph: petgraph::Graph<&'static str, (), petgraph::Directed>,
383    /// All missing dependencies.
384    missing: HashMap<TypeId, MissingDependencies>,
385    /// Cache of all previously visited types. To avoid infinite recursion and as an optimization.
386    visited: HashMap<TypeId, petgraph::graph::NodeIndex>,
387    /// Cached validation result.
388    validation_cache: Option<
389        Result<
390            Vec<petgraph::graph::NodeIndex>,
391            petgraph::algo::Cycle<petgraph::graph::NodeIndex>,
392        >,
393    >,
394}
395
396impl VisitorContext {
397    /// Create a new default context.
398    pub fn new() -> Self {
399        Self {
400            graph: petgraph::Graph::new(),
401            missing: HashMap::new(),
402            visited: HashMap::new(),
403            validation_cache: None,
404        }
405    }
406
407    /// Reset the context.
408    pub fn reset(&mut self) {
409        self.graph.clear();
410        self.missing.clear();
411        self.visited.clear();
412        self.validation_cache = None;
413    }
414}