Skip to main content

affected_core/
graph.rs

1use petgraph::algo::{is_cyclic_directed, tarjan_scc};
2use petgraph::graph::{DiGraph, NodeIndex};
3use petgraph::visit::Bfs;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6use crate::types::{PackageId, ProjectGraph};
7
8/// Dependency graph wrapper around petgraph.
9#[non_exhaustive]
10pub struct DepGraph {
11    graph: DiGraph<PackageId, ()>,
12    node_map: HashMap<PackageId, NodeIndex>,
13}
14
15impl DepGraph {
16    /// Build from a ProjectGraph. Edges go from dependent → dependency.
17    pub fn from_project_graph(pg: &ProjectGraph) -> Self {
18        let mut graph = DiGraph::new();
19        let mut node_map = HashMap::new();
20
21        for id in pg.packages.keys() {
22            let idx = graph.add_node(id.clone());
23            node_map.insert(id.clone(), idx);
24        }
25
26        for (from, to) in &pg.edges {
27            if let (Some(&from_idx), Some(&to_idx)) = (node_map.get(from), node_map.get(to)) {
28                graph.add_edge(from_idx, to_idx, ());
29            }
30        }
31
32        Self { graph, node_map }
33    }
34
35    /// Given a set of directly changed packages, return all transitively
36    /// affected packages (changed + everything that depends on them).
37    ///
38    /// Uses BFS on the reversed graph: if A→B means "A depends on B",
39    /// then in the reversed graph B→A, and BFS from a changed node B
40    /// finds all packages that transitively depend on B.
41    pub fn affected_by(&self, changed: &HashSet<PackageId>) -> HashSet<PackageId> {
42        let reversed = petgraph::visit::Reversed(&self.graph);
43        let mut result = HashSet::new();
44
45        for pkg in changed {
46            if let Some(&start) = self.node_map.get(pkg) {
47                let mut bfs = Bfs::new(&reversed, start);
48                while let Some(node) = bfs.next(&reversed) {
49                    result.insert(self.graph[node].clone());
50                }
51            }
52        }
53
54        result
55    }
56
57    /// For each affected package, return the shortest dependency chain back to a
58    /// directly changed package. Uses BFS on the reversed graph, tracking parents.
59    pub fn explain_affected(
60        &self,
61        changed: &HashSet<PackageId>,
62        affected: &HashSet<PackageId>,
63    ) -> Vec<(PackageId, Vec<PackageId>)> {
64        let mut parent: HashMap<NodeIndex, Option<NodeIndex>> = HashMap::new();
65        let mut visited: HashSet<NodeIndex> = HashSet::new();
66        let mut queue: VecDeque<NodeIndex> = VecDeque::new();
67
68        // Initialize BFS from all changed packages
69        for pkg in changed {
70            if let Some(&idx) = self.node_map.get(pkg) {
71                if visited.insert(idx) {
72                    parent.insert(idx, None);
73                    queue.push_back(idx);
74                }
75            }
76        }
77
78        // BFS on reversed graph (follow incoming edges)
79        while let Some(current) = queue.pop_front() {
80            for neighbor in self
81                .graph
82                .neighbors_directed(current, petgraph::Direction::Incoming)
83            {
84                if visited.insert(neighbor) {
85                    parent.insert(neighbor, Some(current));
86                    queue.push_back(neighbor);
87                }
88            }
89        }
90
91        let mut results = Vec::new();
92        for pkg in affected {
93            if let Some(&idx) = self.node_map.get(pkg) {
94                if changed.contains(pkg) {
95                    results.push((pkg.clone(), vec![pkg.clone()]));
96                } else if parent.contains_key(&idx) {
97                    let mut chain = vec![pkg.clone()];
98                    let mut cur = idx;
99                    while let Some(Some(prev)) = parent.get(&cur) {
100                        chain.push(self.graph[*prev].clone());
101                        cur = *prev;
102                    }
103                    results.push((pkg.clone(), chain));
104                }
105            }
106        }
107
108        results.sort_by(|a, b| a.0.cmp(&b.0));
109        results
110    }
111
112    /// Check if the dependency graph contains any cycles.
113    pub fn has_cycles(&self) -> bool {
114        is_cyclic_directed(&self.graph)
115    }
116
117    /// Find and return all cycles in the graph (SCCs with size > 1).
118    pub fn find_cycles(&self) -> Vec<Vec<PackageId>> {
119        tarjan_scc(&self.graph)
120            .into_iter()
121            .filter(|scc| scc.len() > 1)
122            .map(|scc| scc.into_iter().map(|idx| self.graph[idx].clone()).collect())
123            .collect()
124    }
125
126    /// Enhanced DOT output where affected nodes are colored red and changed nodes are orange.
127    pub fn to_dot_with_affected(
128        &self,
129        changed: &HashSet<PackageId>,
130        affected: &HashSet<PackageId>,
131    ) -> String {
132        let mut lines = vec!["digraph dependencies {".to_string()];
133        for (pkg_id, &idx) in &self.node_map {
134            let label = &self.graph[idx].0;
135            if changed.contains(pkg_id) {
136                lines.push(format!(
137                    "    \"{}\" [style=filled, fillcolor=orange];",
138                    label
139                ));
140            } else if affected.contains(pkg_id) {
141                lines.push(format!(
142                    "    \"{}\" [style=filled, fillcolor=red, fontcolor=white];",
143                    label
144                ));
145            }
146        }
147        for edge in self.graph.edge_indices() {
148            let (a, b) = self
149                .graph
150                .edge_endpoints(edge)
151                .expect("edge from graph iteration must have endpoints");
152            lines.push(format!(
153                "    \"{}\" -> \"{}\";",
154                self.graph[a], self.graph[b]
155            ));
156        }
157        lines.push("}".to_string());
158        lines.join("\n")
159    }
160
161    /// Return all package IDs in the graph.
162    pub fn all_packages(&self) -> Vec<&PackageId> {
163        self.graph.node_weights().collect()
164    }
165
166    /// Generate DOT format output for graphviz visualization.
167    pub fn to_dot(&self) -> String {
168        let mut lines = vec!["digraph dependencies {".to_string()];
169        for edge in self.graph.edge_indices() {
170            let (a, b) = self
171                .graph
172                .edge_endpoints(edge)
173                .expect("edge from graph iteration must have endpoints");
174            lines.push(format!(
175                "    \"{}\" -> \"{}\";",
176                self.graph[a], self.graph[b]
177            ));
178        }
179        lines.push("}".to_string());
180        lines.join("\n")
181    }
182
183    /// Return all edges as (from, to) pairs for display.
184    pub fn edges(&self) -> Vec<(&PackageId, &PackageId)> {
185        self.graph
186            .edge_indices()
187            .map(|e| {
188                let (a, b) = self
189                    .graph
190                    .edge_endpoints(e)
191                    .expect("edge from graph iteration must have endpoints");
192                (&self.graph[a], &self.graph[b])
193            })
194            .collect()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::types::Package;
202    use std::path::PathBuf;
203
204    fn make_graph(names: &[&str], edges: &[(&str, &str)]) -> ProjectGraph {
205        let mut packages = HashMap::new();
206        for name in names {
207            let id = PackageId(name.to_string());
208            packages.insert(
209                id.clone(),
210                Package {
211                    id: id.clone(),
212                    name: name.to_string(),
213                    version: None,
214                    path: PathBuf::from(format!("/{name}")),
215                    manifest_path: PathBuf::from(format!("/{name}/Cargo.toml")),
216                },
217            );
218        }
219        let edges = edges
220            .iter()
221            .map(|(a, b)| (PackageId(a.to_string()), PackageId(b.to_string())))
222            .collect();
223        ProjectGraph {
224            packages,
225            edges,
226            root: PathBuf::from("/"),
227        }
228    }
229
230    #[test]
231    fn test_linear_chain() {
232        // cli -> api -> core
233        let pg = make_graph(&["core", "api", "cli"], &[("api", "core"), ("cli", "api")]);
234        let dg = DepGraph::from_project_graph(&pg);
235
236        let changed: HashSet<_> = [PackageId("core".into())].into();
237        let affected = dg.affected_by(&changed);
238
239        assert!(affected.contains(&PackageId("core".into())));
240        assert!(affected.contains(&PackageId("api".into())));
241        assert!(affected.contains(&PackageId("cli".into())));
242        assert_eq!(affected.len(), 3);
243    }
244
245    #[test]
246    fn test_leaf_change() {
247        // cli -> api -> core
248        let pg = make_graph(&["core", "api", "cli"], &[("api", "core"), ("cli", "api")]);
249        let dg = DepGraph::from_project_graph(&pg);
250
251        let changed: HashSet<_> = [PackageId("cli".into())].into();
252        let affected = dg.affected_by(&changed);
253
254        assert!(affected.contains(&PackageId("cli".into())));
255        assert_eq!(affected.len(), 1);
256    }
257
258    #[test]
259    fn test_diamond_dependency() {
260        //   app
261        //  /   \
262        // ui   api
263        //  \   /
264        //  core
265        let pg = make_graph(
266            &["core", "ui", "api", "app"],
267            &[
268                ("ui", "core"),
269                ("api", "core"),
270                ("app", "ui"),
271                ("app", "api"),
272            ],
273        );
274        let dg = DepGraph::from_project_graph(&pg);
275
276        let changed: HashSet<_> = [PackageId("core".into())].into();
277        let affected = dg.affected_by(&changed);
278
279        assert_eq!(affected.len(), 4);
280    }
281
282    #[test]
283    fn test_isolated_package() {
284        let pg = make_graph(&["core", "api", "standalone"], &[("api", "core")]);
285        let dg = DepGraph::from_project_graph(&pg);
286
287        let changed: HashSet<_> = [PackageId("core".into())].into();
288        let affected = dg.affected_by(&changed);
289
290        assert!(affected.contains(&PackageId("core".into())));
291        assert!(affected.contains(&PackageId("api".into())));
292        assert!(!affected.contains(&PackageId("standalone".into())));
293    }
294}