Skip to main content

arbor_graph/
graph.rs

1//! Core graph data structure.
2//!
3//! The ArborGraph wraps petgraph and adds indexes for fast lookups.
4//! It's the central data structure that everything else works with.
5
6use crate::edge::{Edge, EdgeKind, GraphEdge};
7use crate::search_index::SearchIndex;
8use arbor_core::CodeNode;
9use petgraph::stable_graph::{NodeIndex, StableDiGraph};
10use petgraph::visit::{EdgeRef, IntoEdgeReferences}; // For edge_references
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Unique identifier for a node in the graph.
15pub type NodeId = NodeIndex;
16
17/// The code relationship graph.
18///
19/// This is the heart of Arbor. It stores all code entities as nodes
20/// and their relationships as edges, with indexes for fast access.
21#[derive(Debug, Serialize, Deserialize)]
22pub struct ArborGraph {
23    /// The underlying petgraph graph.
24    pub(crate) graph: StableDiGraph<CodeNode, Edge>,
25
26    /// Maps string IDs to graph node indexes.
27    id_index: HashMap<String, NodeId>,
28
29    /// Maps node names to node IDs (for search).
30    name_index: HashMap<String, Vec<NodeId>>,
31
32    /// Maps file paths to node IDs (for incremental updates).
33    file_index: HashMap<String, Vec<NodeId>>,
34
35    /// Centrality scores for ranking.
36    centrality: HashMap<NodeId, f64>,
37
38    /// Search index for fast substring queries.
39    #[serde(skip)]
40    search_index: SearchIndex,
41}
42
43impl Default for ArborGraph {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl ArborGraph {
50    /// Creates a new empty graph.
51    pub fn new() -> Self {
52        Self {
53            graph: StableDiGraph::new(),
54            id_index: HashMap::new(),
55            name_index: HashMap::new(),
56            file_index: HashMap::new(),
57            centrality: HashMap::new(),
58            search_index: SearchIndex::new(),
59        }
60    }
61
62    /// Rebuilds the search index from existing graph nodes.
63    /// Call after deserialization since search_index is not serialized.
64    pub fn rebuild_search_index(&mut self) {
65        self.search_index = SearchIndex::new();
66        for index in self.graph.node_indices() {
67            if let Some(node) = self.graph.node_weight(index) {
68                self.search_index.insert(&node.name, index);
69            }
70        }
71    }
72
73    /// Adds a code node to the graph.
74    ///
75    /// Returns the node's index for adding edges later.
76    pub fn add_node(&mut self, node: CodeNode) -> NodeId {
77        let id = node.id.clone();
78        let name = node.name.clone();
79        let file = node.file.clone();
80
81        let index = self.graph.add_node(node);
82
83        // Update indexes
84        self.id_index.insert(id, index);
85        self.name_index.entry(name.clone()).or_default().push(index);
86        self.file_index.entry(file).or_default().push(index);
87        self.search_index.insert(&name, index);
88
89        index
90    }
91
92    /// Adds an edge between two nodes.
93    pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) {
94        self.graph.add_edge(from, to, edge);
95    }
96
97    /// Gets a node by its string ID.
98    pub fn get_by_id(&self, id: &str) -> Option<&CodeNode> {
99        let index = self.id_index.get(id)?;
100        self.graph.node_weight(*index)
101    }
102
103    /// Gets a node by its graph index.
104    pub fn get(&self, index: NodeId) -> Option<&CodeNode> {
105        self.graph.node_weight(index)
106    }
107
108    /// Finds all nodes with a given name.
109    pub fn find_by_name(&self, name: &str) -> Vec<&CodeNode> {
110        self.name_index
111            .get(name)
112            .map(|indexes| {
113                indexes
114                    .iter()
115                    .filter_map(|idx| self.graph.node_weight(*idx))
116                    .collect()
117            })
118            .unwrap_or_default()
119    }
120
121    /// Finds all nodes in a file.
122    pub fn find_by_file(&self, file: &str) -> Vec<&CodeNode> {
123        self.file_index
124            .get(file)
125            .map(|indexes| {
126                indexes
127                    .iter()
128                    .filter_map(|idx| self.graph.node_weight(*idx))
129                    .collect()
130            })
131            .unwrap_or_default()
132    }
133
134    /// Searches for nodes whose name contains the query.
135    ///
136    /// Uses the search index for fast O(k) lookups where k is the number of matches,
137    /// instead of O(n) linear scan over all nodes.
138    pub fn search(&self, query: &str) -> Vec<&CodeNode> {
139        self.search_index
140            .search(query)
141            .iter()
142            .filter_map(|id| self.graph.node_weight(*id))
143            .collect()
144    }
145
146    /// Gets nodes that call the given node.
147    pub fn get_callers(&self, index: NodeId) -> Vec<&CodeNode> {
148        self.graph
149            .neighbors_directed(index, petgraph::Direction::Incoming)
150            .filter_map(|idx| {
151                // Check if the edge is a call
152                let edge_idx = self.graph.find_edge(idx, index)?;
153                let edge = self.graph.edge_weight(edge_idx)?;
154                if edge.kind == EdgeKind::Calls {
155                    self.graph.node_weight(idx)
156                } else {
157                    None
158                }
159            })
160            .collect()
161    }
162
163    /// Gets nodes that this node calls.
164    pub fn get_callees(&self, index: NodeId) -> Vec<&CodeNode> {
165        self.graph
166            .neighbors_directed(index, petgraph::Direction::Outgoing)
167            .filter_map(|idx| {
168                let edge_idx = self.graph.find_edge(index, idx)?;
169                let edge = self.graph.edge_weight(edge_idx)?;
170                if edge.kind == EdgeKind::Calls {
171                    self.graph.node_weight(idx)
172                } else {
173                    None
174                }
175            })
176            .collect()
177    }
178
179    /// Gets all nodes that depend on the given node (directly or transitively).
180    pub fn get_dependents(&self, index: NodeId, max_depth: usize) -> Vec<(NodeId, usize)> {
181        let mut result = Vec::new();
182        let mut visited = std::collections::HashSet::new();
183        let mut queue = vec![(index, 0usize)];
184
185        while let Some((current, depth)) = queue.pop() {
186            if depth > max_depth || visited.contains(&current) {
187                continue;
188            }
189            visited.insert(current);
190
191            if current != index {
192                result.push((current, depth));
193            }
194
195            // Get incoming edges (callers)
196            for neighbor in self
197                .graph
198                .neighbors_directed(current, petgraph::Direction::Incoming)
199            {
200                if !visited.contains(&neighbor) {
201                    queue.push((neighbor, depth + 1));
202                }
203            }
204        }
205
206        result
207    }
208
209    /// Removes all nodes from a file. Used for incremental updates.
210    pub fn remove_file(&mut self, file: &str) {
211        if let Some(indexes) = self.file_index.remove(file) {
212            for index in indexes {
213                if let Some(node) = self.graph.node_weight(index) {
214                    // Remove from name index
215                    let name = node.name.clone();
216                    if let Some(name_list) = self.name_index.get_mut(&name) {
217                        name_list.retain(|&idx| idx != index);
218                    }
219                    // Remove from id index
220                    self.id_index.remove(&node.id);
221                    // Remove from search index
222                    self.search_index.remove(&name, index);
223                }
224                self.graph.remove_node(index);
225            }
226        }
227    }
228
229    /// Gets the centrality score for a node.
230    pub fn centrality(&self, index: NodeId) -> f64 {
231        self.centrality.get(&index).copied().unwrap_or(0.0)
232    }
233
234    /// Sets centrality scores (called after computation).
235    pub fn set_centrality(&mut self, scores: HashMap<NodeId, f64>) {
236        self.centrality = scores;
237    }
238
239    /// Returns the number of nodes.
240    pub fn node_count(&self) -> usize {
241        self.graph.node_count()
242    }
243
244    /// Returns the number of edges.
245    pub fn edge_count(&self) -> usize {
246        self.graph.edge_count()
247    }
248
249    /// Iterates over all nodes.
250    pub fn nodes(&self) -> impl Iterator<Item = &CodeNode> {
251        self.graph.node_weights()
252    }
253
254    /// Iterates over all edges.
255    pub fn edges(&self) -> impl Iterator<Item = &Edge> {
256        self.graph.edge_weights()
257    }
258
259    /// Returns all edges with source and target IDs for export.
260    pub fn export_edges(&self) -> Vec<GraphEdge> {
261        (&self.graph)
262            .edge_references()
263            .filter_map(|edge_ref| {
264                let source = self.graph.node_weight(edge_ref.source())?.id.clone();
265                let target = self.graph.node_weight(edge_ref.target())?.id.clone();
266                let weight = edge_ref.weight(); // &Edge
267                Some(GraphEdge {
268                    source,
269                    target,
270                    kind: weight.kind,
271                })
272            })
273            .collect()
274    }
275
276    /// Iterates over all node indexes.
277    pub fn node_indexes(&self) -> impl Iterator<Item = NodeId> + '_ {
278        self.graph.node_indices()
279    }
280
281    /// Finds the shortest path between two nodes.
282    pub fn find_path(&self, from: NodeId, to: NodeId) -> Option<Vec<&CodeNode>> {
283        let path_indices = petgraph::algo::astar(
284            &self.graph,
285            from,
286            |finish| finish == to,
287            |_| 1, // weight of 1 for all edges (BFS-like)
288            |_| 0, // heuristic
289        )?;
290
291        Some(
292            path_indices
293                .1
294                .into_iter()
295                .filter_map(|idx| self.graph.node_weight(idx))
296                .collect(),
297        )
298    }
299
300    /// Gets the node index for a string ID.
301    pub fn get_index(&self, id: &str) -> Option<NodeId> {
302        self.id_index.get(id).copied()
303    }
304
305    /// Returns all nodes detected as production entry points.
306    pub fn list_entry_points(&self) -> Vec<&CodeNode> {
307        use crate::heuristics::HeuristicsMatcher;
308        self.graph
309            .node_weights()
310            .filter(|n| HeuristicsMatcher::is_likely_entry_point(n))
311            .collect()
312    }
313
314    /// Returns all nodes in a file and the call edges between them.
315    /// Edges returned as (caller_name, callee_name, edge_kind_debug_str) triples.
316    pub fn nodes_in_file_with_edges(
317        &self,
318        file: &str,
319    ) -> (Vec<&CodeNode>, Vec<(String, String, String)>) {
320        let node_ids: std::collections::HashSet<NodeId> = self
321            .file_index
322            .get(file)
323            .map(|ids| ids.iter().copied().collect())
324            .unwrap_or_default();
325
326        let nodes: Vec<&CodeNode> = node_ids
327            .iter()
328            .filter_map(|&id| self.graph.node_weight(id))
329            .collect();
330
331        let mut edges = Vec::new();
332        for &from in &node_ids {
333            for edge_ref in self
334                .graph
335                .edges_directed(from, petgraph::Direction::Outgoing)
336            {
337                let to = edge_ref.target();
338                if node_ids.contains(&to) {
339                    if let (Some(from_node), Some(to_node)) =
340                        (self.graph.node_weight(from), self.graph.node_weight(to))
341                    {
342                        edges.push((
343                            from_node.name.clone(),
344                            to_node.name.clone(),
345                            format!("{:?}", edge_ref.weight().kind),
346                        ));
347                    }
348                }
349            }
350        }
351        (nodes, edges)
352    }
353}
354
355/// Graph statistics for the info endpoint.
356#[derive(Debug, Serialize, Deserialize)]
357pub struct GraphStats {
358    pub node_count: usize,
359    pub edge_count: usize,
360    pub files: usize,
361}
362
363impl ArborGraph {
364    /// Returns graph statistics.
365    pub fn stats(&self) -> GraphStats {
366        GraphStats {
367            node_count: self.node_count(),
368            edge_count: self.edge_count(),
369            files: self.file_index.len(),
370        }
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::edge::{Edge, EdgeKind};
378    use arbor_core::{CodeNode, NodeKind};
379
380    fn make_node(name: &str, file: &str) -> CodeNode {
381        CodeNode::new(name, name, NodeKind::Function, file)
382    }
383
384    #[test]
385    fn test_graph_new_is_empty() {
386        let g = ArborGraph::new();
387        assert_eq!(g.node_count(), 0);
388        assert_eq!(g.edge_count(), 0);
389        assert!(g.nodes().next().is_none());
390    }
391
392    #[test]
393    fn test_graph_add_and_get_node() {
394        let mut g = ArborGraph::new();
395        let node = make_node("foo", "main.rs");
396        let id = g.add_node(node.clone());
397        assert_eq!(g.node_count(), 1);
398
399        let got = g.get(id).unwrap();
400        assert_eq!(got.name, "foo");
401    }
402
403    #[test]
404    fn test_graph_find_by_name() {
405        let mut g = ArborGraph::new();
406        g.add_node(make_node("alpha", "a.rs"));
407        g.add_node(make_node("beta", "b.rs"));
408
409        let found = g.find_by_name("alpha");
410        assert_eq!(found.len(), 1);
411        assert_eq!(found[0].name, "alpha");
412
413        let not_found = g.find_by_name("gamma");
414        assert!(not_found.is_empty());
415    }
416
417    #[test]
418    fn test_graph_find_by_file() {
419        let mut g = ArborGraph::new();
420        g.add_node(make_node("foo", "main.rs"));
421        g.add_node(make_node("bar", "main.rs"));
422        g.add_node(make_node("baz", "other.rs"));
423
424        let main_nodes = g.find_by_file("main.rs");
425        assert_eq!(main_nodes.len(), 2);
426
427        let other_nodes = g.find_by_file("other.rs");
428        assert_eq!(other_nodes.len(), 1);
429
430        let empty = g.find_by_file("nonexistent.rs");
431        assert!(empty.is_empty());
432    }
433
434    #[test]
435    fn test_graph_search_substring() {
436        let mut g = ArborGraph::new();
437        g.add_node(make_node("validate_user", "a.rs"));
438        g.add_node(make_node("validate_email", "b.rs"));
439        g.add_node(make_node("send_email", "c.rs"));
440
441        let results = g.search("validate");
442        assert_eq!(results.len(), 2);
443        assert!(results.iter().any(|n| n.name == "validate_user"));
444        assert!(results.iter().any(|n| n.name == "validate_email"));
445    }
446
447    #[test]
448    fn test_graph_callers_callees() {
449        let mut g = ArborGraph::new();
450        let a = g.add_node(make_node("caller", "a.rs"));
451        let b = g.add_node(make_node("callee", "b.rs"));
452        g.add_edge(a, b, Edge::new(EdgeKind::Calls));
453
454        let callees = g.get_callees(a);
455        assert_eq!(callees.len(), 1);
456        assert_eq!(callees[0].name, "callee");
457
458        let callers = g.get_callers(b);
459        assert_eq!(callers.len(), 1);
460        assert_eq!(callers[0].name, "caller");
461
462        // No callers/callees for disconnected nodes
463        assert!(g.get_callers(a).is_empty());
464        assert!(g.get_callees(b).is_empty());
465    }
466
467    #[test]
468    fn test_graph_get_dependents() {
469        // a -> b -> c
470        let mut g = ArborGraph::new();
471        let a = g.add_node(make_node("a", "a.rs"));
472        let b = g.add_node(make_node("b", "b.rs"));
473        let c = g.add_node(make_node("c", "c.rs"));
474        g.add_edge(a, b, Edge::new(EdgeKind::Calls));
475        g.add_edge(b, c, Edge::new(EdgeKind::Calls));
476
477        // Dependents of c at depth 2 should include a and b
478        let deps = g.get_dependents(c, 2);
479        assert!(deps.iter().any(|(idx, _)| g.get(*idx).unwrap().name == "b"));
480        assert!(deps.iter().any(|(idx, _)| g.get(*idx).unwrap().name == "a"));
481    }
482
483    #[test]
484    fn test_graph_remove_file_cleanup() {
485        let mut g = ArborGraph::new();
486        g.add_node(make_node("foo", "remove_me.rs"));
487        g.add_node(make_node("bar", "remove_me.rs"));
488        g.add_node(make_node("keep", "keep.rs"));
489
490        assert_eq!(g.node_count(), 3);
491
492        g.remove_file("remove_me.rs");
493
494        // Nodes from removed file are gone
495        assert!(g.find_by_name("foo").is_empty());
496        assert!(g.find_by_name("bar").is_empty());
497        // Node from other file remains
498        assert_eq!(g.find_by_name("keep").len(), 1);
499        assert!(g.find_by_file("remove_me.rs").is_empty());
500    }
501
502    #[test]
503    fn test_graph_find_path() {
504        // a -> b -> c
505        let mut g = ArborGraph::new();
506        let a = g.add_node(make_node("start", "a.rs"));
507        let b = g.add_node(make_node("middle", "b.rs"));
508        let c = g.add_node(make_node("end", "c.rs"));
509        g.add_edge(a, b, Edge::new(EdgeKind::Calls));
510        g.add_edge(b, c, Edge::new(EdgeKind::Calls));
511
512        let path = g.find_path(a, c).unwrap();
513        assert_eq!(path.len(), 3);
514        assert_eq!(path[0].name, "start");
515        assert_eq!(path[1].name, "middle");
516        assert_eq!(path[2].name, "end");
517    }
518
519    #[test]
520    fn test_graph_find_path_no_connection() {
521        let mut g = ArborGraph::new();
522        let a = g.add_node(make_node("island_a", "a.rs"));
523        let b = g.add_node(make_node("island_b", "b.rs"));
524
525        // No edges → no path
526        assert!(g.find_path(a, b).is_none());
527    }
528
529    #[test]
530    fn test_graph_export_edges() {
531        let mut g = ArborGraph::new();
532        let a = g.add_node(make_node("a", "a.rs"));
533        let b = g.add_node(make_node("b", "b.rs"));
534        g.add_edge(a, b, Edge::new(EdgeKind::Calls));
535
536        let exported = g.export_edges();
537        assert_eq!(exported.len(), 1);
538        assert_eq!(exported[0].kind, EdgeKind::Calls);
539    }
540
541    #[test]
542    fn test_graph_stats() {
543        let mut g = ArborGraph::new();
544        g.add_node(make_node("a", "x.rs"));
545        g.add_node(make_node("b", "y.rs"));
546
547        let stats = g.stats();
548        assert_eq!(stats.node_count, 2);
549        assert_eq!(stats.edge_count, 0);
550        assert_eq!(stats.files, 2);
551    }
552
553    #[test]
554    fn test_graph_get_index_and_get_by_id() {
555        let mut g = ArborGraph::new();
556        let node = make_node("lookup_me", "test.rs");
557        let node_id_str = node.id.clone();
558        let idx = g.add_node(node);
559
560        assert_eq!(g.get_index(&node_id_str), Some(idx));
561        assert!(g.get_by_id(&node_id_str).is_some());
562        assert!(g.get_index("nonexistent").is_none());
563        assert!(g.get_by_id("nonexistent").is_none());
564    }
565
566    #[test]
567    fn test_graph_centrality_default_zero() {
568        let mut g = ArborGraph::new();
569        let idx = g.add_node(make_node("a", "a.rs"));
570        assert_eq!(g.centrality(idx), 0.0);
571    }
572
573    #[test]
574    fn test_graph_set_centrality() {
575        let mut g = ArborGraph::new();
576        let idx = g.add_node(make_node("a", "a.rs"));
577
578        let mut scores = HashMap::new();
579        scores.insert(idx, 0.75);
580        g.set_centrality(scores);
581
582        assert!((g.centrality(idx) - 0.75).abs() < f64::EPSILON);
583    }
584}
585
586#[cfg(test)]
587mod new_query_tests {
588    use super::*;
589    use crate::edge::{Edge, EdgeKind};
590    use arbor_core::{CodeNode, NodeKind};
591
592    fn make_node(name: &str, kind: NodeKind, file: &str) -> CodeNode {
593        CodeNode::new(name, format!("{}::{}", file, name), kind, file)
594    }
595
596    #[test]
597    fn test_list_entry_points_returns_main() {
598        let mut g = ArborGraph::new();
599        g.add_node(make_node("main", NodeKind::Function, "src/main.rs"));
600        g.add_node(make_node("helper", NodeKind::Function, "src/util.rs"));
601        let eps = g.list_entry_points();
602        assert!(
603            eps.iter().any(|n| n.name == "main"),
604            "main must be an entry point"
605        );
606        assert!(
607            !eps.iter().any(|n| n.name == "helper"),
608            "helper must not be an entry point"
609        );
610    }
611
612    #[test]
613    fn test_nodes_in_file_with_edges_returns_edges() {
614        let mut g = ArborGraph::new();
615        let a = g.add_node(make_node("foo", NodeKind::Function, "src/a.rs"));
616        let b = g.add_node(make_node("bar", NodeKind::Function, "src/a.rs"));
617        let _c = g.add_node(make_node("baz", NodeKind::Function, "src/b.rs"));
618        g.add_edge(
619            a,
620            b,
621            Edge {
622                kind: EdgeKind::Calls,
623                file: None,
624                line: None,
625            },
626        );
627        let (nodes, edges) = g.nodes_in_file_with_edges("src/a.rs");
628        assert_eq!(nodes.len(), 2);
629        assert_eq!(edges.len(), 1);
630        assert_eq!(edges[0].0, "foo");
631        assert_eq!(edges[0].1, "bar");
632    }
633
634    #[test]
635    fn test_nodes_in_file_with_edges_excludes_cross_file_edges() {
636        use crate::edge::{Edge, EdgeKind};
637        let mut g = ArborGraph::new();
638        let a = g.add_node(make_node("foo", NodeKind::Function, "src/a.rs"));
639        let c = g.add_node(make_node("baz", NodeKind::Function, "src/b.rs"));
640        // Edge from a.rs to b.rs — should NOT appear in get_file_graph for a.rs
641        g.add_edge(
642            a,
643            c,
644            Edge {
645                kind: EdgeKind::Calls,
646                file: None,
647                line: None,
648            },
649        );
650        let (nodes, edges) = g.nodes_in_file_with_edges("src/a.rs");
651        assert_eq!(nodes.len(), 1); // only foo
652        assert_eq!(edges.len(), 0); // cross-file edge excluded
653    }
654}