Skip to main content

affected_core/
graph.rs

1use petgraph::algo::{is_cyclic_directed, tarjan_scc};
2use petgraph::graph::{DiGraph, NodeIndex};
3use petgraph::visit::Bfs;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6use crate::types::{PackageId, ProjectGraph};
7
8/// Dependency graph wrapper around petgraph.
9pub struct DepGraph {
10    graph: DiGraph<PackageId, ()>,
11    node_map: HashMap<PackageId, NodeIndex>,
12}
13
14impl DepGraph {
15    /// Build from a ProjectGraph. Edges go from dependent → dependency.
16    pub fn from_project_graph(pg: &ProjectGraph) -> Self {
17        let mut graph = DiGraph::new();
18        let mut node_map = HashMap::new();
19
20        for id in pg.packages.keys() {
21            let idx = graph.add_node(id.clone());
22            node_map.insert(id.clone(), idx);
23        }
24
25        for (from, to) in &pg.edges {
26            if let (Some(&from_idx), Some(&to_idx)) = (node_map.get(from), node_map.get(to)) {
27                graph.add_edge(from_idx, to_idx, ());
28            }
29        }
30
31        Self { graph, node_map }
32    }
33
34    /// Given a set of directly changed packages, return all transitively
35    /// affected packages (changed + everything that depends on them).
36    ///
37    /// Uses BFS on the reversed graph: if A→B means "A depends on B",
38    /// then in the reversed graph B→A, and BFS from a changed node B
39    /// finds all packages that transitively depend on B.
40    pub fn affected_by(&self, changed: &HashSet<PackageId>) -> HashSet<PackageId> {
41        let reversed = petgraph::visit::Reversed(&self.graph);
42        let mut result = HashSet::new();
43
44        for pkg in changed {
45            if let Some(&start) = self.node_map.get(pkg) {
46                let mut bfs = Bfs::new(&reversed, start);
47                while let Some(node) = bfs.next(&reversed) {
48                    result.insert(self.graph[node].clone());
49                }
50            }
51        }
52
53        result
54    }
55
56    /// For each affected package, return the shortest dependency chain back to a
57    /// directly changed package. Uses BFS on the reversed graph, tracking parents.
58    pub fn explain_affected(
59        &self,
60        changed: &HashSet<PackageId>,
61        affected: &HashSet<PackageId>,
62    ) -> Vec<(PackageId, Vec<PackageId>)> {
63        let mut parent: HashMap<NodeIndex, Option<NodeIndex>> = HashMap::new();
64        let mut visited: HashSet<NodeIndex> = HashSet::new();
65        let mut queue: VecDeque<NodeIndex> = VecDeque::new();
66
67        // Initialize BFS from all changed packages
68        for pkg in changed {
69            if let Some(&idx) = self.node_map.get(pkg) {
70                if visited.insert(idx) {
71                    parent.insert(idx, None);
72                    queue.push_back(idx);
73                }
74            }
75        }
76
77        // BFS on reversed graph (follow incoming edges)
78        while let Some(current) = queue.pop_front() {
79            for neighbor in self
80                .graph
81                .neighbors_directed(current, petgraph::Direction::Incoming)
82            {
83                if visited.insert(neighbor) {
84                    parent.insert(neighbor, Some(current));
85                    queue.push_back(neighbor);
86                }
87            }
88        }
89
90        let mut results = Vec::new();
91        for pkg in affected {
92            if let Some(&idx) = self.node_map.get(pkg) {
93                if changed.contains(pkg) {
94                    results.push((pkg.clone(), vec![pkg.clone()]));
95                } else if parent.contains_key(&idx) {
96                    let mut chain = vec![pkg.clone()];
97                    let mut cur = idx;
98                    while let Some(Some(prev)) = parent.get(&cur) {
99                        chain.push(self.graph[*prev].clone());
100                        cur = *prev;
101                    }
102                    results.push((pkg.clone(), chain));
103                }
104            }
105        }
106
107        results.sort_by(|a, b| a.0.cmp(&b.0));
108        results
109    }
110
111    /// Check if the dependency graph contains any cycles.
112    pub fn has_cycles(&self) -> bool {
113        is_cyclic_directed(&self.graph)
114    }
115
116    /// Find and return all cycles in the graph (SCCs with size > 1).
117    pub fn find_cycles(&self) -> Vec<Vec<PackageId>> {
118        tarjan_scc(&self.graph)
119            .into_iter()
120            .filter(|scc| scc.len() > 1)
121            .map(|scc| scc.into_iter().map(|idx| self.graph[idx].clone()).collect())
122            .collect()
123    }
124
125    /// Enhanced DOT output where affected nodes are colored red and changed nodes are orange.
126    pub fn to_dot_with_affected(
127        &self,
128        changed: &HashSet<PackageId>,
129        affected: &HashSet<PackageId>,
130    ) -> String {
131        let mut lines = vec!["digraph dependencies {".to_string()];
132        for (pkg_id, &idx) in &self.node_map {
133            let label = &self.graph[idx].0;
134            if changed.contains(pkg_id) {
135                lines.push(format!(
136                    "    \"{}\" [style=filled, fillcolor=orange];",
137                    label
138                ));
139            } else if affected.contains(pkg_id) {
140                lines.push(format!(
141                    "    \"{}\" [style=filled, fillcolor=red, fontcolor=white];",
142                    label
143                ));
144            }
145        }
146        for edge in self.graph.edge_indices() {
147            let (a, b) = self.graph.edge_endpoints(edge).unwrap();
148            lines.push(format!(
149                "    \"{}\" -> \"{}\";",
150                self.graph[a], self.graph[b]
151            ));
152        }
153        lines.push("}".to_string());
154        lines.join("\n")
155    }
156
157    /// Return all package IDs in the graph.
158    pub fn all_packages(&self) -> Vec<&PackageId> {
159        self.graph.node_weights().collect()
160    }
161
162    /// Generate DOT format output for graphviz visualization.
163    pub fn to_dot(&self) -> String {
164        let mut lines = vec!["digraph dependencies {".to_string()];
165        for edge in self.graph.edge_indices() {
166            let (a, b) = self.graph.edge_endpoints(edge).unwrap();
167            lines.push(format!(
168                "    \"{}\" -> \"{}\";",
169                self.graph[a], self.graph[b]
170            ));
171        }
172        lines.push("}".to_string());
173        lines.join("\n")
174    }
175
176    /// Return all edges as (from, to) pairs for display.
177    pub fn edges(&self) -> Vec<(&PackageId, &PackageId)> {
178        self.graph
179            .edge_indices()
180            .map(|e| {
181                let (a, b) = self.graph.edge_endpoints(e).unwrap();
182                (&self.graph[a], &self.graph[b])
183            })
184            .collect()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::types::Package;
192    use std::path::PathBuf;
193
194    fn make_graph(names: &[&str], edges: &[(&str, &str)]) -> ProjectGraph {
195        let mut packages = HashMap::new();
196        for name in names {
197            let id = PackageId(name.to_string());
198            packages.insert(
199                id.clone(),
200                Package {
201                    id: id.clone(),
202                    name: name.to_string(),
203                    version: None,
204                    path: PathBuf::from(format!("/{name}")),
205                    manifest_path: PathBuf::from(format!("/{name}/Cargo.toml")),
206                },
207            );
208        }
209        let edges = edges
210            .iter()
211            .map(|(a, b)| (PackageId(a.to_string()), PackageId(b.to_string())))
212            .collect();
213        ProjectGraph {
214            packages,
215            edges,
216            root: PathBuf::from("/"),
217        }
218    }
219
220    #[test]
221    fn test_linear_chain() {
222        // cli -> api -> core
223        let pg = make_graph(&["core", "api", "cli"], &[("api", "core"), ("cli", "api")]);
224        let dg = DepGraph::from_project_graph(&pg);
225
226        let changed: HashSet<_> = [PackageId("core".into())].into();
227        let affected = dg.affected_by(&changed);
228
229        assert!(affected.contains(&PackageId("core".into())));
230        assert!(affected.contains(&PackageId("api".into())));
231        assert!(affected.contains(&PackageId("cli".into())));
232        assert_eq!(affected.len(), 3);
233    }
234
235    #[test]
236    fn test_leaf_change() {
237        // cli -> api -> core
238        let pg = make_graph(&["core", "api", "cli"], &[("api", "core"), ("cli", "api")]);
239        let dg = DepGraph::from_project_graph(&pg);
240
241        let changed: HashSet<_> = [PackageId("cli".into())].into();
242        let affected = dg.affected_by(&changed);
243
244        assert!(affected.contains(&PackageId("cli".into())));
245        assert_eq!(affected.len(), 1);
246    }
247
248    #[test]
249    fn test_diamond_dependency() {
250        //   app
251        //  /   \
252        // ui   api
253        //  \   /
254        //  core
255        let pg = make_graph(
256            &["core", "ui", "api", "app"],
257            &[
258                ("ui", "core"),
259                ("api", "core"),
260                ("app", "ui"),
261                ("app", "api"),
262            ],
263        );
264        let dg = DepGraph::from_project_graph(&pg);
265
266        let changed: HashSet<_> = [PackageId("core".into())].into();
267        let affected = dg.affected_by(&changed);
268
269        assert_eq!(affected.len(), 4);
270    }
271
272    #[test]
273    fn test_isolated_package() {
274        let pg = make_graph(&["core", "api", "standalone"], &[("api", "core")]);
275        let dg = DepGraph::from_project_graph(&pg);
276
277        let changed: HashSet<_> = [PackageId("core".into())].into();
278        let affected = dg.affected_by(&changed);
279
280        assert!(affected.contains(&PackageId("core".into())));
281        assert!(affected.contains(&PackageId("api".into())));
282        assert!(!affected.contains(&PackageId("standalone".into())));
283    }
284}