libcnb_package/
dependency_graph.rs

1use petgraph::Graph;
2use petgraph::visit::DfsPostOrder;
3use std::error::Error;
4
5/// A node of a dependency graph.
6pub trait DependencyNode<T, E>
7where
8    T: PartialEq,
9{
10    fn id(&self) -> T;
11
12    /// The dependencies of a node
13    ///
14    /// # Errors
15    ///
16    /// Will return an `Err` if the dependencies can't be accessed
17    fn dependencies(&self) -> Result<Vec<T>, E>;
18}
19
20/// Create a [`Graph`] from [`DependencyNode`]s.
21///
22/// # Errors
23///
24/// Will return an `Err` if the graph contains references to missing dependencies or the
25/// dependencies of a [`DependencyNode`] couldn't be gathered.
26pub(crate) fn create_dependency_graph<T, I, E>(
27    nodes: Vec<T>,
28) -> Result<Graph<T, ()>, CreateDependencyGraphError<I, E>>
29where
30    T: DependencyNode<I, E>,
31    I: PartialEq,
32    E: Error,
33{
34    let mut graph = Graph::new();
35
36    for node in nodes {
37        graph.add_node(node);
38    }
39
40    for idx in graph.node_indices() {
41        let node = &graph[idx];
42
43        let dependencies = node
44            .dependencies()
45            .map_err(CreateDependencyGraphError::GetNodeDependenciesError)?;
46
47        for dependency in dependencies {
48            let dependency_idx = graph
49                .node_indices()
50                .find(|idx| graph[*idx].id() == dependency)
51                .ok_or(CreateDependencyGraphError::MissingDependency(dependency))?;
52
53            graph.add_edge(idx, dependency_idx, ());
54        }
55    }
56
57    Ok(graph)
58}
59
60/// An error that occurred while creating the dependency graph.
61#[derive(thiserror::Error, Debug)]
62pub enum CreateDependencyGraphError<I, E: Error> {
63    #[error("Error while determining dependencies of a node: {0}")]
64    GetNodeDependenciesError(#[source] E),
65    #[error("Node references unknown dependency {0}")]
66    MissingDependency(I),
67}
68
69/// Collects all the [`DependencyNode`] values found while traversing the given dependency graph
70/// using one or more `root_nodes` values as starting points for the traversal. The returned list
71/// will contain the given `root_nodes` values as well as all their dependencies in topological order.
72///
73/// # Errors
74///
75/// Will return an `Err` if the graph contains references to missing dependencies.
76pub fn get_dependencies<'a, T, I, E>(
77    graph: &'a Graph<T, ()>,
78    root_nodes: &[&T],
79) -> Result<Vec<&'a T>, GetDependenciesError<I>>
80where
81    T: DependencyNode<I, E>,
82    I: PartialEq,
83{
84    let mut order: Vec<&T> = Vec::new();
85    let mut dfs = DfsPostOrder::empty(&graph);
86    for root_node in root_nodes {
87        let idx = graph
88            .node_indices()
89            .find(|idx| graph[*idx].id() == root_node.id())
90            .ok_or(GetDependenciesError::UnknownRootNode(root_node.id()))?;
91
92        dfs.move_to(idx);
93
94        while let Some(visited) = dfs.next(&graph) {
95            order.push(&graph[visited]);
96        }
97    }
98    Ok(order)
99}
100
101/// An error from [`get_dependencies`]
102#[derive(thiserror::Error, Debug)]
103pub enum GetDependenciesError<I> {
104    #[error("Root node {0} is not in the dependency graph")]
105    UnknownRootNode(I),
106}
107
108#[cfg(test)]
109mod tests {
110    use crate::dependency_graph::{DependencyNode, create_dependency_graph, get_dependencies};
111    use std::convert::Infallible;
112
113    impl DependencyNode<String, Infallible> for (&str, Vec<&str>) {
114        fn id(&self) -> String {
115            self.0.to_string()
116        }
117
118        fn dependencies(&self) -> Result<Vec<String>, Infallible> {
119            Ok(self
120                .1
121                .iter()
122                .map(std::string::ToString::to_string)
123                .collect())
124        }
125    }
126
127    #[test]
128    fn test_get_dependencies_one_level_deep() {
129        let a = ("a", Vec::new());
130        let b = ("b", Vec::new());
131        let c = ("c", vec!["a", "b"]);
132
133        let graph = create_dependency_graph(vec![a.clone(), b.clone(), c.clone()]).unwrap();
134
135        assert_eq!(get_dependencies(&graph, &[&a]).unwrap(), &[&a]);
136
137        assert_eq!(get_dependencies(&graph, &[&b]).unwrap(), &[&b]);
138
139        assert_eq!(get_dependencies(&graph, &[&c]).unwrap(), &[&a, &b, &c]);
140
141        assert_eq!(
142            &get_dependencies(&graph, &[&b, &c, &a]).unwrap(),
143            &[&b, &a, &c]
144        );
145    }
146
147    #[test]
148    fn test_get_dependencies_two_levels_deep() {
149        let a = ("a", Vec::new());
150        let b = ("b", vec!["a"]);
151        let c = ("c", vec!["b"]);
152
153        let graph = create_dependency_graph(vec![a.clone(), b.clone(), c.clone()]).unwrap();
154
155        assert_eq!(get_dependencies(&graph, &[&a]).unwrap(), &[&a]);
156
157        assert_eq!(get_dependencies(&graph, &[&b]).unwrap(), &[&a, &b]);
158
159        assert_eq!(get_dependencies(&graph, &[&c]).unwrap(), &[&a, &b, &c]);
160
161        assert_eq!(
162            &get_dependencies(&graph, &[&b, &c, &a]).unwrap(),
163            &[&a, &b, &c]
164        );
165    }
166
167    #[test]
168    #[allow(clippy::many_single_char_names)]
169    fn test_get_dependencies_with_overlap() {
170        let a = ("a", Vec::new());
171        let b = ("b", Vec::new());
172        let c = ("c", Vec::new());
173        let d = ("d", vec!["a", "b"]);
174        let e = ("e", vec!["b", "c"]);
175
176        let graph =
177            create_dependency_graph(vec![a.clone(), b.clone(), c.clone(), d.clone(), e.clone()])
178                .unwrap();
179
180        assert_eq!(
181            get_dependencies(&graph, &[&d, &e, &a]).unwrap(),
182            &[&a, &b, &d, &c, &e]
183        );
184
185        assert_eq!(
186            get_dependencies(&graph, &[&e, &d, &a]).unwrap(),
187            &[&b, &c, &e, &a, &d]
188        );
189    }
190}