1use petgraph::Graph;
2use petgraph::visit::DfsPostOrder;
3use std::error::Error;
4
5pub trait DependencyNode<T, E>
7where
8 T: PartialEq,
9{
10 fn id(&self) -> T;
11
12 fn dependencies(&self) -> Result<Vec<T>, E>;
18}
19
20pub(crate) fn create_dependency_graph<T, I, E>(
27 nodes: Vec<T>,
28) -> Result<Graph<T, ()>, CreateDependencyGraphError<I, E>>
29where
30 T: DependencyNode<I, E>,
31 I: PartialEq,
32 E: Error,
33{
34 let mut graph = Graph::new();
35
36 for node in nodes {
37 graph.add_node(node);
38 }
39
40 for idx in graph.node_indices() {
41 let node = &graph[idx];
42
43 let dependencies = node
44 .dependencies()
45 .map_err(CreateDependencyGraphError::GetNodeDependenciesError)?;
46
47 for dependency in dependencies {
48 let dependency_idx = graph
49 .node_indices()
50 .find(|idx| graph[*idx].id() == dependency)
51 .ok_or(CreateDependencyGraphError::MissingDependency(dependency))?;
52
53 graph.add_edge(idx, dependency_idx, ());
54 }
55 }
56
57 Ok(graph)
58}
59
60#[derive(thiserror::Error, Debug)]
62pub enum CreateDependencyGraphError<I, E: Error> {
63 #[error("Error while determining dependencies of a node: {0}")]
64 GetNodeDependenciesError(#[source] E),
65 #[error("Node references unknown dependency {0}")]
66 MissingDependency(I),
67}
68
69pub fn get_dependencies<'a, T, I, E>(
77 graph: &'a Graph<T, ()>,
78 root_nodes: &[&T],
79) -> Result<Vec<&'a T>, GetDependenciesError<I>>
80where
81 T: DependencyNode<I, E>,
82 I: PartialEq,
83{
84 let mut order: Vec<&T> = Vec::new();
85 let mut dfs = DfsPostOrder::empty(&graph);
86 for root_node in root_nodes {
87 let idx = graph
88 .node_indices()
89 .find(|idx| graph[*idx].id() == root_node.id())
90 .ok_or(GetDependenciesError::UnknownRootNode(root_node.id()))?;
91
92 dfs.move_to(idx);
93
94 while let Some(visited) = dfs.next(&graph) {
95 order.push(&graph[visited]);
96 }
97 }
98 Ok(order)
99}
100
101#[derive(thiserror::Error, Debug)]
103pub enum GetDependenciesError<I> {
104 #[error("Root node {0} is not in the dependency graph")]
105 UnknownRootNode(I),
106}
107
108#[cfg(test)]
109mod tests {
110 use crate::dependency_graph::{DependencyNode, create_dependency_graph, get_dependencies};
111 use std::convert::Infallible;
112
113 impl DependencyNode<String, Infallible> for (&str, Vec<&str>) {
114 fn id(&self) -> String {
115 self.0.to_string()
116 }
117
118 fn dependencies(&self) -> Result<Vec<String>, Infallible> {
119 Ok(self
120 .1
121 .iter()
122 .map(std::string::ToString::to_string)
123 .collect())
124 }
125 }
126
127 #[test]
128 fn test_get_dependencies_one_level_deep() {
129 let a = ("a", Vec::new());
130 let b = ("b", Vec::new());
131 let c = ("c", vec!["a", "b"]);
132
133 let graph = create_dependency_graph(vec![a.clone(), b.clone(), c.clone()]).unwrap();
134
135 assert_eq!(get_dependencies(&graph, &[&a]).unwrap(), &[&a]);
136
137 assert_eq!(get_dependencies(&graph, &[&b]).unwrap(), &[&b]);
138
139 assert_eq!(get_dependencies(&graph, &[&c]).unwrap(), &[&a, &b, &c]);
140
141 assert_eq!(
142 &get_dependencies(&graph, &[&b, &c, &a]).unwrap(),
143 &[&b, &a, &c]
144 );
145 }
146
147 #[test]
148 fn test_get_dependencies_two_levels_deep() {
149 let a = ("a", Vec::new());
150 let b = ("b", vec!["a"]);
151 let c = ("c", vec!["b"]);
152
153 let graph = create_dependency_graph(vec![a.clone(), b.clone(), c.clone()]).unwrap();
154
155 assert_eq!(get_dependencies(&graph, &[&a]).unwrap(), &[&a]);
156
157 assert_eq!(get_dependencies(&graph, &[&b]).unwrap(), &[&a, &b]);
158
159 assert_eq!(get_dependencies(&graph, &[&c]).unwrap(), &[&a, &b, &c]);
160
161 assert_eq!(
162 &get_dependencies(&graph, &[&b, &c, &a]).unwrap(),
163 &[&a, &b, &c]
164 );
165 }
166
167 #[test]
168 #[allow(clippy::many_single_char_names)]
169 fn test_get_dependencies_with_overlap() {
170 let a = ("a", Vec::new());
171 let b = ("b", Vec::new());
172 let c = ("c", Vec::new());
173 let d = ("d", vec!["a", "b"]);
174 let e = ("e", vec!["b", "c"]);
175
176 let graph =
177 create_dependency_graph(vec![a.clone(), b.clone(), c.clone(), d.clone(), e.clone()])
178 .unwrap();
179
180 assert_eq!(
181 get_dependencies(&graph, &[&d, &e, &a]).unwrap(),
182 &[&a, &b, &d, &c, &e]
183 );
184
185 assert_eq!(
186 get_dependencies(&graph, &[&e, &d, &a]).unwrap(),
187 &[&b, &c, &e, &a, &d]
188 );
189 }
190}