1use petgraph::algo::{is_cyclic_directed, tarjan_scc};
2use petgraph::graph::{DiGraph, NodeIndex};
3use petgraph::visit::Bfs;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6use crate::types::{PackageId, ProjectGraph};
7
8#[non_exhaustive]
10pub struct DepGraph {
11 graph: DiGraph<PackageId, ()>,
12 node_map: HashMap<PackageId, NodeIndex>,
13}
14
15impl DepGraph {
16 pub fn from_project_graph(pg: &ProjectGraph) -> Self {
18 let mut graph = DiGraph::new();
19 let mut node_map = HashMap::new();
20
21 for id in pg.packages.keys() {
22 let idx = graph.add_node(id.clone());
23 node_map.insert(id.clone(), idx);
24 }
25
26 for (from, to) in &pg.edges {
27 if let (Some(&from_idx), Some(&to_idx)) = (node_map.get(from), node_map.get(to)) {
28 graph.add_edge(from_idx, to_idx, ());
29 }
30 }
31
32 Self { graph, node_map }
33 }
34
35 pub fn affected_by(&self, changed: &HashSet<PackageId>) -> HashSet<PackageId> {
42 let reversed = petgraph::visit::Reversed(&self.graph);
43 let mut result = HashSet::new();
44
45 for pkg in changed {
46 if let Some(&start) = self.node_map.get(pkg) {
47 let mut bfs = Bfs::new(&reversed, start);
48 while let Some(node) = bfs.next(&reversed) {
49 result.insert(self.graph[node].clone());
50 }
51 }
52 }
53
54 result
55 }
56
57 pub fn explain_affected(
60 &self,
61 changed: &HashSet<PackageId>,
62 affected: &HashSet<PackageId>,
63 ) -> Vec<(PackageId, Vec<PackageId>)> {
64 let mut parent: HashMap<NodeIndex, Option<NodeIndex>> = HashMap::new();
65 let mut visited: HashSet<NodeIndex> = HashSet::new();
66 let mut queue: VecDeque<NodeIndex> = VecDeque::new();
67
68 for pkg in changed {
70 if let Some(&idx) = self.node_map.get(pkg) {
71 if visited.insert(idx) {
72 parent.insert(idx, None);
73 queue.push_back(idx);
74 }
75 }
76 }
77
78 while let Some(current) = queue.pop_front() {
80 for neighbor in self
81 .graph
82 .neighbors_directed(current, petgraph::Direction::Incoming)
83 {
84 if visited.insert(neighbor) {
85 parent.insert(neighbor, Some(current));
86 queue.push_back(neighbor);
87 }
88 }
89 }
90
91 let mut results = Vec::new();
92 for pkg in affected {
93 if let Some(&idx) = self.node_map.get(pkg) {
94 if changed.contains(pkg) {
95 results.push((pkg.clone(), vec![pkg.clone()]));
96 } else if parent.contains_key(&idx) {
97 let mut chain = vec![pkg.clone()];
98 let mut cur = idx;
99 while let Some(Some(prev)) = parent.get(&cur) {
100 chain.push(self.graph[*prev].clone());
101 cur = *prev;
102 }
103 results.push((pkg.clone(), chain));
104 }
105 }
106 }
107
108 results.sort_by(|a, b| a.0.cmp(&b.0));
109 results
110 }
111
112 pub fn has_cycles(&self) -> bool {
114 is_cyclic_directed(&self.graph)
115 }
116
117 pub fn find_cycles(&self) -> Vec<Vec<PackageId>> {
119 tarjan_scc(&self.graph)
120 .into_iter()
121 .filter(|scc| scc.len() > 1)
122 .map(|scc| scc.into_iter().map(|idx| self.graph[idx].clone()).collect())
123 .collect()
124 }
125
126 pub fn to_dot_with_affected(
128 &self,
129 changed: &HashSet<PackageId>,
130 affected: &HashSet<PackageId>,
131 ) -> String {
132 let mut lines = vec!["digraph dependencies {".to_string()];
133 for (pkg_id, &idx) in &self.node_map {
134 let label = &self.graph[idx].0;
135 if changed.contains(pkg_id) {
136 lines.push(format!(
137 " \"{}\" [style=filled, fillcolor=orange];",
138 label
139 ));
140 } else if affected.contains(pkg_id) {
141 lines.push(format!(
142 " \"{}\" [style=filled, fillcolor=red, fontcolor=white];",
143 label
144 ));
145 }
146 }
147 for edge in self.graph.edge_indices() {
148 let (a, b) = self
149 .graph
150 .edge_endpoints(edge)
151 .expect("edge from graph iteration must have endpoints");
152 lines.push(format!(
153 " \"{}\" -> \"{}\";",
154 self.graph[a], self.graph[b]
155 ));
156 }
157 lines.push("}".to_string());
158 lines.join("\n")
159 }
160
161 pub fn all_packages(&self) -> Vec<&PackageId> {
163 self.graph.node_weights().collect()
164 }
165
166 pub fn to_dot(&self) -> String {
168 let mut lines = vec!["digraph dependencies {".to_string()];
169 for edge in self.graph.edge_indices() {
170 let (a, b) = self
171 .graph
172 .edge_endpoints(edge)
173 .expect("edge from graph iteration must have endpoints");
174 lines.push(format!(
175 " \"{}\" -> \"{}\";",
176 self.graph[a], self.graph[b]
177 ));
178 }
179 lines.push("}".to_string());
180 lines.join("\n")
181 }
182
183 pub fn edges(&self) -> Vec<(&PackageId, &PackageId)> {
185 self.graph
186 .edge_indices()
187 .map(|e| {
188 let (a, b) = self
189 .graph
190 .edge_endpoints(e)
191 .expect("edge from graph iteration must have endpoints");
192 (&self.graph[a], &self.graph[b])
193 })
194 .collect()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::types::Package;
202 use std::path::PathBuf;
203
204 fn make_graph(names: &[&str], edges: &[(&str, &str)]) -> ProjectGraph {
205 let mut packages = HashMap::new();
206 for name in names {
207 let id = PackageId(name.to_string());
208 packages.insert(
209 id.clone(),
210 Package {
211 id: id.clone(),
212 name: name.to_string(),
213 version: None,
214 path: PathBuf::from(format!("/{name}")),
215 manifest_path: PathBuf::from(format!("/{name}/Cargo.toml")),
216 },
217 );
218 }
219 let edges = edges
220 .iter()
221 .map(|(a, b)| (PackageId(a.to_string()), PackageId(b.to_string())))
222 .collect();
223 ProjectGraph {
224 packages,
225 edges,
226 root: PathBuf::from("/"),
227 }
228 }
229
230 #[test]
231 fn test_linear_chain() {
232 let pg = make_graph(&["core", "api", "cli"], &[("api", "core"), ("cli", "api")]);
234 let dg = DepGraph::from_project_graph(&pg);
235
236 let changed: HashSet<_> = [PackageId("core".into())].into();
237 let affected = dg.affected_by(&changed);
238
239 assert!(affected.contains(&PackageId("core".into())));
240 assert!(affected.contains(&PackageId("api".into())));
241 assert!(affected.contains(&PackageId("cli".into())));
242 assert_eq!(affected.len(), 3);
243 }
244
245 #[test]
246 fn test_leaf_change() {
247 let pg = make_graph(&["core", "api", "cli"], &[("api", "core"), ("cli", "api")]);
249 let dg = DepGraph::from_project_graph(&pg);
250
251 let changed: HashSet<_> = [PackageId("cli".into())].into();
252 let affected = dg.affected_by(&changed);
253
254 assert!(affected.contains(&PackageId("cli".into())));
255 assert_eq!(affected.len(), 1);
256 }
257
258 #[test]
259 fn test_diamond_dependency() {
260 let pg = make_graph(
266 &["core", "ui", "api", "app"],
267 &[
268 ("ui", "core"),
269 ("api", "core"),
270 ("app", "ui"),
271 ("app", "api"),
272 ],
273 );
274 let dg = DepGraph::from_project_graph(&pg);
275
276 let changed: HashSet<_> = [PackageId("core".into())].into();
277 let affected = dg.affected_by(&changed);
278
279 assert_eq!(affected.len(), 4);
280 }
281
282 #[test]
283 fn test_isolated_package() {
284 let pg = make_graph(&["core", "api", "standalone"], &[("api", "core")]);
285 let dg = DepGraph::from_project_graph(&pg);
286
287 let changed: HashSet<_> = [PackageId("core".into())].into();
288 let affected = dg.affected_by(&changed);
289
290 assert!(affected.contains(&PackageId("core".into())));
291 assert!(affected.contains(&PackageId("api".into())));
292 assert!(!affected.contains(&PackageId("standalone".into())));
293 }
294}