Skip to main content

arbor_graph/
impact.rs

1//! Impact analysis for code changes.
2//!
3//! This module provides bidirectional BFS traversal to find all nodes
4//! affected by a change to a target node. It answers the question:
5//! "What breaks if I change this?"
6
7use crate::edge::EdgeKind;
8use crate::graph::{ArborGraph, NodeId};
9use crate::query::NodeInfo;
10use petgraph::visit::EdgeRef;
11use petgraph::Direction;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::time::Instant;
15
16/// Severity of impact based on hop distance from target.
17///
18/// Never construct directly — always use `from_hops()`.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
20pub enum ImpactSeverity {
21    /// 1 hop - immediate callers/callees
22    Direct = 0,
23    /// 2-3 hops - transitively connected
24    Transitive = 1,
25    /// 4+ hops - distantly connected
26    Distant = 2,
27}
28
29impl ImpactSeverity {
30    /// Derives severity from hop distance.
31    ///
32    /// This is the ONLY way to create an ImpactSeverity.
33    /// Thresholds: 1 hop = Direct, 2-3 = Transitive, 4+ = Distant
34    pub fn from_hops(hops: usize) -> Self {
35        match hops {
36            0 | 1 => ImpactSeverity::Direct,
37            2 | 3 => ImpactSeverity::Transitive,
38            _ => ImpactSeverity::Distant,
39        }
40    }
41
42    /// Returns a human-readable description.
43    pub fn as_str(&self) -> &'static str {
44        match self {
45            ImpactSeverity::Direct => "direct",
46            ImpactSeverity::Transitive => "transitive",
47            ImpactSeverity::Distant => "distant",
48        }
49    }
50}
51
52impl std::fmt::Display for ImpactSeverity {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "{}", self.as_str())
55    }
56}
57
58/// Direction of impact from the target node.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum ImpactDirection {
61    /// Nodes that depend on the target (incoming edges).
62    /// These break if the target's interface changes.
63    Upstream,
64    /// Nodes the target depends on (outgoing edges).
65    /// Changes here may require updating the target.
66    Downstream,
67}
68
69impl std::fmt::Display for ImpactDirection {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            ImpactDirection::Upstream => write!(f, "upstream"),
73            ImpactDirection::Downstream => write!(f, "downstream"),
74        }
75    }
76}
77
78/// A node affected by a change to the target.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct AffectedNode {
81    /// The node's graph index.
82    pub node_id: NodeId,
83    /// Full node information.
84    pub node_info: NodeInfo,
85    /// Severity derived from hop distance.
86    pub severity: ImpactSeverity,
87    /// Number of edges from target to this node.
88    pub hop_distance: usize,
89    /// The edge kind of the first hop that led to this node.
90    /// This explains why this node is in the impact set.
91    pub entry_edge: EdgeKind,
92    /// Whether this node is upstream or downstream of target.
93    pub direction: ImpactDirection,
94}
95
96/// Complete impact analysis result.
97#[derive(Debug, Serialize, Deserialize)]
98pub struct ImpactAnalysis {
99    /// The target node being analyzed.
100    pub target: NodeInfo,
101    /// Nodes that depend on the target (callers, importers, etc.)
102    pub upstream: Vec<AffectedNode>,
103    /// Nodes the target depends on (callees, imports, etc.)
104    pub downstream: Vec<AffectedNode>,
105    /// Total count of affected nodes.
106    pub total_affected: usize,
107    /// Maximum depth searched.
108    pub max_depth: usize,
109    /// Time taken in milliseconds.
110    pub query_time_ms: u64,
111}
112
113impl ImpactAnalysis {
114    /// Returns all affected nodes (upstream + downstream) sorted by severity.
115    pub fn all_affected(&self) -> Vec<&AffectedNode> {
116        let mut all: Vec<&AffectedNode> =
117            self.upstream.iter().chain(self.downstream.iter()).collect();
118
119        // Stable sort: severity → hop_distance → node_id
120        all.sort_by(|a, b| {
121            a.severity
122                .cmp(&b.severity)
123                .then_with(|| a.hop_distance.cmp(&b.hop_distance))
124                .then_with(|| a.node_info.id.cmp(&b.node_info.id))
125        });
126
127        all
128    }
129
130    /// Returns only direct (1-hop) affected nodes.
131    pub fn direct_only(&self) -> Vec<&AffectedNode> {
132        self.all_affected()
133            .into_iter()
134            .filter(|n| n.severity == ImpactSeverity::Direct)
135            .collect()
136    }
137
138    /// Returns a summary suitable for CLI output.
139    pub fn summary(&self) -> String {
140        let direct = self
141            .all_affected()
142            .iter()
143            .filter(|n| n.severity == ImpactSeverity::Direct)
144            .count();
145        let transitive = self
146            .all_affected()
147            .iter()
148            .filter(|n| n.severity == ImpactSeverity::Transitive)
149            .count();
150        let distant = self
151            .all_affected()
152            .iter()
153            .filter(|n| n.severity == ImpactSeverity::Distant)
154            .count();
155
156        format!(
157            "Blast Radius: {} nodes (direct: {}, transitive: {}, distant: {})",
158            self.total_affected, direct, transitive, distant
159        )
160    }
161}
162
163impl ArborGraph {
164    /// Analyzes the impact of changing a node.
165    ///
166    /// Performs bidirectional BFS from the target:
167    /// - Upstream: nodes that depend on target (would break if target changes)
168    /// - Downstream: nodes target depends on (may require target updates)
169    ///
170    /// # Arguments
171    /// * `target` - The node to analyze
172    /// * `max_depth` - Maximum hop distance to traverse (0 = unlimited)
173    ///
174    /// # Returns
175    /// Complete impact analysis with affected nodes sorted by severity.
176    pub fn analyze_impact(&self, target: NodeId, max_depth: usize) -> ImpactAnalysis {
177        let start = Instant::now();
178
179        let target_node = match self.get(target) {
180            Some(node) => NodeInfo::from(node),
181            None => {
182                return ImpactAnalysis {
183                    target: NodeInfo {
184                        id: String::new(),
185                        name: String::new(),
186                        qualified_name: String::new(),
187                        kind: String::new(),
188                        file: String::new(),
189                        line_start: 0,
190                        line_end: 0,
191                        signature: None,
192                        centrality: 0.0,
193                    },
194                    upstream: Vec::new(),
195                    downstream: Vec::new(),
196                    total_affected: 0,
197                    max_depth,
198                    query_time_ms: 0,
199                };
200            }
201        };
202
203        let effective_depth = if max_depth == 0 {
204            usize::MAX
205        } else {
206            max_depth
207        };
208
209        let upstream = self.bfs_impact(target, Direction::Incoming, effective_depth);
210        let downstream = self.bfs_impact(target, Direction::Outgoing, effective_depth);
211
212        let total = upstream.len() + downstream.len();
213        let elapsed = start.elapsed().as_millis() as u64;
214
215        ImpactAnalysis {
216            target: target_node,
217            upstream,
218            downstream,
219            total_affected: total,
220            max_depth,
221            query_time_ms: elapsed,
222        }
223    }
224
225    /// BFS traversal in one direction from target.
226    fn bfs_impact(
227        &self,
228        target: NodeId,
229        direction: Direction,
230        max_depth: usize,
231    ) -> Vec<AffectedNode> {
232        let mut result = Vec::new();
233        let mut visited: HashSet<NodeId> = HashSet::new();
234        let mut queue: VecDeque<(NodeId, usize, EdgeKind)> = VecDeque::new();
235
236        // Track entry edges for each node (first edge that reaches it)
237        let mut entry_edges: HashMap<NodeId, EdgeKind> = HashMap::new();
238
239        visited.insert(target);
240
241        // Seed queue with immediate neighbors
242        for edge_ref in self.graph.edges_directed(target, direction) {
243            let neighbor = match direction {
244                Direction::Incoming => edge_ref.source(),
245                Direction::Outgoing => edge_ref.target(),
246            };
247
248            if !visited.contains(&neighbor) {
249                let edge_kind = edge_ref.weight().kind;
250                queue.push_back((neighbor, 1, edge_kind));
251                entry_edges.insert(neighbor, edge_kind);
252            }
253        }
254
255        while let Some((current, depth, entry_edge)) = queue.pop_front() {
256            if depth > max_depth || visited.contains(&current) {
257                continue;
258            }
259
260            visited.insert(current);
261
262            if let Some(node) = self.get(current) {
263                let mut node_info = NodeInfo::from(node);
264                node_info.centrality = self.centrality(current);
265
266                let impact_direction = match direction {
267                    Direction::Incoming => ImpactDirection::Upstream,
268                    Direction::Outgoing => ImpactDirection::Downstream,
269                };
270
271                result.push(AffectedNode {
272                    node_id: current,
273                    node_info,
274                    severity: ImpactSeverity::from_hops(depth),
275                    hop_distance: depth,
276                    entry_edge,
277                    direction: impact_direction,
278                });
279            }
280
281            // Continue BFS if not at max depth
282            if depth < max_depth {
283                for edge_ref in self.graph.edges_directed(current, direction) {
284                    let neighbor = match direction {
285                        Direction::Incoming => edge_ref.source(),
286                        Direction::Outgoing => edge_ref.target(),
287                    };
288
289                    if !visited.contains(&neighbor) {
290                        let next_entry = *entry_edges.get(&neighbor).unwrap_or(&entry_edge);
291                        queue.push_back((neighbor, depth + 1, next_entry));
292
293                        // Store entry edge for first arrival
294                        entry_edges
295                            .entry(neighbor)
296                            .or_insert(edge_ref.weight().kind);
297                    }
298                }
299            }
300        }
301
302        // Sort by severity → hop_distance → id for stable ordering
303        result.sort_by(|a, b| {
304            a.severity
305                .cmp(&b.severity)
306                .then_with(|| a.hop_distance.cmp(&b.hop_distance))
307                .then_with(|| a.node_info.id.cmp(&b.node_info.id))
308        });
309
310        result
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::edge::Edge;
318    use arbor_core::{CodeNode, NodeKind};
319
320    fn make_node(name: &str) -> CodeNode {
321        CodeNode::new(name, name, NodeKind::Function, "test.rs")
322    }
323
324    #[test]
325    fn test_severity_from_hops() {
326        assert_eq!(ImpactSeverity::from_hops(0), ImpactSeverity::Direct);
327        assert_eq!(ImpactSeverity::from_hops(1), ImpactSeverity::Direct);
328        assert_eq!(ImpactSeverity::from_hops(2), ImpactSeverity::Transitive);
329        assert_eq!(ImpactSeverity::from_hops(3), ImpactSeverity::Transitive);
330        assert_eq!(ImpactSeverity::from_hops(4), ImpactSeverity::Distant);
331        assert_eq!(ImpactSeverity::from_hops(100), ImpactSeverity::Distant);
332    }
333
334    #[test]
335    fn test_empty_graph() {
336        let graph = ArborGraph::new();
337        let result = graph.analyze_impact(NodeId::new(0), 5);
338        assert_eq!(result.total_affected, 0);
339        assert!(result.upstream.is_empty());
340        assert!(result.downstream.is_empty());
341    }
342
343    #[test]
344    fn test_single_node() {
345        let mut graph = ArborGraph::new();
346        let id = graph.add_node(make_node("lonely"));
347        let result = graph.analyze_impact(id, 5);
348        assert_eq!(result.total_affected, 0);
349        assert_eq!(result.target.name, "lonely");
350    }
351
352    #[test]
353    fn test_linear_chain() {
354        // A → B → C
355        let mut graph = ArborGraph::new();
356        let a = graph.add_node(make_node("a"));
357        let b = graph.add_node(make_node("b"));
358        let c = graph.add_node(make_node("c"));
359
360        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
361        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
362
363        // Impact of B
364        let result = graph.analyze_impact(b, 5);
365
366        // Upstream: A calls B
367        assert_eq!(result.upstream.len(), 1);
368        assert_eq!(result.upstream[0].node_info.name, "a");
369        assert_eq!(result.upstream[0].hop_distance, 1);
370        assert_eq!(result.upstream[0].severity, ImpactSeverity::Direct);
371
372        // Downstream: B calls C
373        assert_eq!(result.downstream.len(), 1);
374        assert_eq!(result.downstream[0].node_info.name, "c");
375        assert_eq!(result.downstream[0].hop_distance, 1);
376    }
377
378    #[test]
379    fn test_diamond_pattern() {
380        //     A
381        //    / \
382        //   B   C
383        //    \ /
384        //     D
385        let mut graph = ArborGraph::new();
386        let a = graph.add_node(make_node("a"));
387        let b = graph.add_node(make_node("b"));
388        let c = graph.add_node(make_node("c"));
389        let d = graph.add_node(make_node("d"));
390
391        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
392        graph.add_edge(a, c, Edge::new(EdgeKind::Calls));
393        graph.add_edge(b, d, Edge::new(EdgeKind::Calls));
394        graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
395
396        // Impact of A
397        let result = graph.analyze_impact(a, 5);
398
399        // Downstream should have B, C (depth 1) and D (depth 2)
400        assert_eq!(result.downstream.len(), 3);
401
402        let names: Vec<&str> = result
403            .downstream
404            .iter()
405            .map(|n| n.node_info.name.as_str())
406            .collect();
407        assert!(names.contains(&"b"));
408        assert!(names.contains(&"c"));
409        assert!(names.contains(&"d"));
410
411        // D should be transitive
412        let d_node = result
413            .downstream
414            .iter()
415            .find(|n| n.node_info.name == "d")
416            .unwrap();
417        assert_eq!(d_node.hop_distance, 2);
418        assert_eq!(d_node.severity, ImpactSeverity::Transitive);
419    }
420
421    #[test]
422    fn test_cycle_no_infinite_loop() {
423        // A → B → C → A (cycle)
424        let mut graph = ArborGraph::new();
425        let a = graph.add_node(make_node("a"));
426        let b = graph.add_node(make_node("b"));
427        let c = graph.add_node(make_node("c"));
428
429        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
430        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
431        graph.add_edge(c, a, Edge::new(EdgeKind::Calls)); // Cycle back
432
433        // Should not hang
434        let result = graph.analyze_impact(a, 10);
435
436        // Should find B and C downstream
437        assert_eq!(result.downstream.len(), 2);
438
439        // Upstream: C → A (depth 1), and B → C means B is also reachable at depth 2
440        assert_eq!(result.upstream.len(), 2);
441        let upstream_names: Vec<&str> = result
442            .upstream
443            .iter()
444            .map(|n| n.node_info.name.as_str())
445            .collect();
446        assert!(upstream_names.contains(&"c"));
447        assert!(upstream_names.contains(&"b"));
448    }
449
450    #[test]
451    fn test_max_depth_limit() {
452        // A → B → C → D → E
453        let mut graph = ArborGraph::new();
454        let a = graph.add_node(make_node("a"));
455        let b = graph.add_node(make_node("b"));
456        let c = graph.add_node(make_node("c"));
457        let d = graph.add_node(make_node("d"));
458        let e = graph.add_node(make_node("e"));
459
460        graph.add_edge(a, b, Edge::new(EdgeKind::Calls));
461        graph.add_edge(b, c, Edge::new(EdgeKind::Calls));
462        graph.add_edge(c, d, Edge::new(EdgeKind::Calls));
463        graph.add_edge(d, e, Edge::new(EdgeKind::Calls));
464
465        // Impact of A with max_depth = 2
466        let result = graph.analyze_impact(a, 2);
467
468        // Should only find B (depth 1) and C (depth 2), not D or E
469        assert_eq!(result.downstream.len(), 2);
470
471        let names: Vec<&str> = result
472            .downstream
473            .iter()
474            .map(|n| n.node_info.name.as_str())
475            .collect();
476        assert!(names.contains(&"b"));
477        assert!(names.contains(&"c"));
478        assert!(!names.contains(&"d"));
479        assert!(!names.contains(&"e"));
480    }
481
482    #[test]
483    fn test_stable_ordering() {
484        // Verify that results are deterministically ordered
485        let mut graph = ArborGraph::new();
486        let target = graph.add_node(make_node("target"));
487        let z = graph.add_node(make_node("z_caller"));
488        let a = graph.add_node(make_node("a_caller"));
489        let m = graph.add_node(make_node("m_caller"));
490
491        graph.add_edge(z, target, Edge::new(EdgeKind::Calls));
492        graph.add_edge(a, target, Edge::new(EdgeKind::Calls));
493        graph.add_edge(m, target, Edge::new(EdgeKind::Calls));
494
495        let result = graph.analyze_impact(target, 5);
496
497        // All are depth 1 (Direct), so should be sorted by name (which equals ID here)
498        let names: Vec<&str> = result
499            .upstream
500            .iter()
501            .map(|n| n.node_info.name.as_str())
502            .collect();
503        // Sorted alphabetically by node_info.id
504        assert_eq!(names.len(), 3);
505        assert!(names.contains(&"a_caller"));
506        assert!(names.contains(&"m_caller"));
507        assert!(names.contains(&"z_caller"));
508    }
509}