Skip to main content

agm_core/graph/
query.rs

1//! Transitive dependency queries and conflict detection.
2
3use std::collections::{BTreeSet, HashSet};
4
5use petgraph::Direction;
6use petgraph::graph::NodeIndex;
7use petgraph::visit::EdgeRef;
8
9use super::{AgmGraph, RelationKind};
10
11// ---------------------------------------------------------------------------
12// transitive_deps
13// ---------------------------------------------------------------------------
14
15/// Returns all transitive dependencies of the given node, following only
16/// `Depends` edges.
17///
18/// Performs a DFS from `node_id` along outgoing `Depends` edges.
19/// The result does NOT include `node_id` itself.
20///
21/// Returns an empty set if the node has no dependencies or does not exist.
22#[must_use]
23pub fn transitive_deps(graph: &AgmGraph, node_id: &str) -> HashSet<String> {
24    let Some(&start) = graph.index.get(node_id) else {
25        return HashSet::new();
26    };
27
28    let mut visited: HashSet<NodeIndex> = HashSet::new();
29    let mut stack: Vec<NodeIndex> = vec![start];
30    visited.insert(start);
31
32    while let Some(current) = stack.pop() {
33        for edge in graph.inner.edges(current) {
34            if *edge.weight() == RelationKind::Depends {
35                let target = edge.target();
36                if visited.insert(target) {
37                    stack.push(target);
38                }
39            }
40        }
41    }
42
43    // Exclude the start node itself.
44    visited.remove(&start);
45    visited
46        .into_iter()
47        .map(|idx| graph.inner[idx].clone())
48        .collect()
49}
50
51// ---------------------------------------------------------------------------
52// transitive_dependents
53// ---------------------------------------------------------------------------
54
55/// Returns all transitive dependents (reverse dependencies) of the given
56/// node, following `Depends` edges in reverse.
57///
58/// "What nodes transitively depend on `node_id`?"
59///
60/// Performs a DFS from `node_id` along incoming `Depends` edges.
61/// The result does NOT include `node_id` itself.
62///
63/// Returns an empty set if nothing depends on this node or it does not exist.
64#[must_use]
65pub fn transitive_dependents(graph: &AgmGraph, node_id: &str) -> HashSet<String> {
66    let Some(&start) = graph.index.get(node_id) else {
67        return HashSet::new();
68    };
69
70    let mut visited: HashSet<NodeIndex> = HashSet::new();
71    let mut stack: Vec<NodeIndex> = vec![start];
72    visited.insert(start);
73
74    while let Some(current) = stack.pop() {
75        for edge in graph.inner.edges_directed(current, Direction::Incoming) {
76            if *edge.weight() == RelationKind::Depends {
77                let source = edge.source();
78                if visited.insert(source) {
79                    stack.push(source);
80                }
81            }
82        }
83    }
84
85    // Exclude the start node itself.
86    visited.remove(&start);
87    visited
88        .into_iter()
89        .map(|idx| graph.inner[idx].clone())
90        .collect()
91}
92
93// ---------------------------------------------------------------------------
94// find_conflicts
95// ---------------------------------------------------------------------------
96
97/// Finds all pairs of nodes connected by `Conflicts` edges.
98///
99/// Returns a `Vec<(String, String)>` where each tuple is an ordered pair
100/// `(a, b)` with `a < b` lexicographically (to avoid duplicates, since
101/// conflicts are semantically bidirectional even though the edge is directed).
102///
103/// Returns each pair once regardless of edge direction.
104#[must_use]
105pub fn find_conflicts(graph: &AgmGraph) -> Vec<(String, String)> {
106    let mut seen: BTreeSet<(String, String)> = BTreeSet::new();
107
108    for edge in graph.inner.edge_references() {
109        if *edge.weight() == RelationKind::Conflicts {
110            let src = graph.inner[edge.source()].clone();
111            let tgt = graph.inner[edge.target()].clone();
112            let pair = if src <= tgt { (src, tgt) } else { (tgt, src) };
113            seen.insert(pair);
114        }
115    }
116
117    seen.into_iter().collect()
118}
119
120// ---------------------------------------------------------------------------
121// Tests
122// ---------------------------------------------------------------------------
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::graph::build::build_graph;
128    use crate::graph::test_helpers::*;
129
130    #[test]
131    fn test_transitive_deps_linear_chain() {
132        // a -> b -> c
133        let mut a = make_node("a");
134        let mut b = make_node("b");
135        let c = make_node("c");
136        a.depends = Some(vec!["b".to_owned()]);
137        b.depends = Some(vec!["c".to_owned()]);
138        let graph = build_graph(&make_file(vec![a, b, c]));
139        let deps = transitive_deps(&graph, "a");
140        assert_eq!(deps, HashSet::from(["b".to_owned(), "c".to_owned()]));
141    }
142
143    #[test]
144    fn test_transitive_deps_nonexistent_node_returns_empty() {
145        let graph = build_graph(&make_file(vec![make_node("a")]));
146        let deps = transitive_deps(&graph, "nonexistent");
147        assert!(deps.is_empty());
148    }
149
150    #[test]
151    fn test_transitive_dependents_returns_reverse() {
152        // a -> b -> c; transitive_dependents(c) = {a, b}
153        let mut a = make_node("a");
154        let mut b = make_node("b");
155        let c = make_node("c");
156        a.depends = Some(vec!["b".to_owned()]);
157        b.depends = Some(vec!["c".to_owned()]);
158        let graph = build_graph(&make_file(vec![a, b, c]));
159        let dependents = transitive_dependents(&graph, "c");
160        assert_eq!(dependents, HashSet::from(["a".to_owned(), "b".to_owned()]));
161    }
162
163    #[test]
164    fn test_find_conflicts_returns_ordered_pairs() {
165        let mut a = make_node("a");
166        let mut c = make_node("c");
167        let b = make_node("b");
168        let d = make_node("d");
169        // a conflicts b, c conflicts d
170        a.conflicts = Some(vec!["b".to_owned()]);
171        c.conflicts = Some(vec!["d".to_owned()]);
172        let graph = build_graph(&make_file(vec![a, b, c, d]));
173        let conflicts = find_conflicts(&graph);
174        assert_eq!(conflicts.len(), 2);
175        // Each pair should be lexicographically ordered.
176        for (l, r) in &conflicts {
177            assert!(l <= r, "pair ({l}, {r}) is not ordered");
178        }
179        assert!(conflicts.contains(&("a".to_owned(), "b".to_owned())));
180        assert!(conflicts.contains(&("c".to_owned(), "d".to_owned())));
181    }
182}