Skip to main content

arbor_graph/
graph.rs

1//! Core graph data structure.
2//!
3//! The ArborGraph wraps petgraph and adds indexes for fast lookups.
4//! It's the central data structure that everything else works with.
5
6use crate::edge::{Edge, EdgeKind, GraphEdge};
7use crate::search_index::SearchIndex;
8use arbor_core::CodeNode;
9use petgraph::stable_graph::{NodeIndex, StableDiGraph};
10use petgraph::visit::{EdgeRef, IntoEdgeReferences}; // For edge_references
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Unique identifier for a node in the graph.
15pub type NodeId = NodeIndex;
16
17/// The code relationship graph.
18///
19/// This is the heart of Arbor. It stores all code entities as nodes
20/// and their relationships as edges, with indexes for fast access.
21#[derive(Debug, Serialize, Deserialize)]
22pub struct ArborGraph {
23    /// The underlying petgraph graph.
24    pub(crate) graph: StableDiGraph<CodeNode, Edge>,
25
26    /// Maps string IDs to graph node indexes.
27    id_index: HashMap<String, NodeId>,
28
29    /// Maps node names to node IDs (for search).
30    name_index: HashMap<String, Vec<NodeId>>,
31
32    /// Maps file paths to node IDs (for incremental updates).
33    file_index: HashMap<String, Vec<NodeId>>,
34
35    /// Centrality scores for ranking.
36    centrality: HashMap<NodeId, f64>,
37
38    /// Search index for fast substring queries.
39    #[serde(skip)]
40    search_index: SearchIndex,
41}
42
43impl Default for ArborGraph {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl ArborGraph {
50    /// Creates a new empty graph.
51    pub fn new() -> Self {
52        Self {
53            graph: StableDiGraph::new(),
54            id_index: HashMap::new(),
55            name_index: HashMap::new(),
56            file_index: HashMap::new(),
57            centrality: HashMap::new(),
58            search_index: SearchIndex::new(),
59        }
60    }
61
62    /// Adds a code node to the graph.
63    ///
64    /// Returns the node's index for adding edges later.
65    pub fn add_node(&mut self, node: CodeNode) -> NodeId {
66        let id = node.id.clone();
67        let name = node.name.clone();
68        let file = node.file.clone();
69
70        let index = self.graph.add_node(node);
71
72        // Update indexes
73        self.id_index.insert(id, index);
74        self.name_index.entry(name.clone()).or_default().push(index);
75        self.file_index.entry(file).or_default().push(index);
76        self.search_index.insert(&name, index);
77
78        index
79    }
80
81    /// Adds an edge between two nodes.
82    pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) {
83        self.graph.add_edge(from, to, edge);
84    }
85
86    /// Gets a node by its string ID.
87    pub fn get_by_id(&self, id: &str) -> Option<&CodeNode> {
88        let index = self.id_index.get(id)?;
89        self.graph.node_weight(*index)
90    }
91
92    /// Gets a node by its graph index.
93    pub fn get(&self, index: NodeId) -> Option<&CodeNode> {
94        self.graph.node_weight(index)
95    }
96
97    /// Finds all nodes with a given name.
98    pub fn find_by_name(&self, name: &str) -> Vec<&CodeNode> {
99        self.name_index
100            .get(name)
101            .map(|indexes| {
102                indexes
103                    .iter()
104                    .filter_map(|idx| self.graph.node_weight(*idx))
105                    .collect()
106            })
107            .unwrap_or_default()
108    }
109
110    /// Finds all nodes in a file.
111    pub fn find_by_file(&self, file: &str) -> Vec<&CodeNode> {
112        self.file_index
113            .get(file)
114            .map(|indexes| {
115                indexes
116                    .iter()
117                    .filter_map(|idx| self.graph.node_weight(*idx))
118                    .collect()
119            })
120            .unwrap_or_default()
121    }
122
123    /// Searches for nodes whose name contains the query.
124    ///
125    /// Uses the search index for fast O(k) lookups where k is the number of matches,
126    /// instead of O(n) linear scan over all nodes.
127    pub fn search(&self, query: &str) -> Vec<&CodeNode> {
128        self.search_index
129            .search(query)
130            .iter()
131            .filter_map(|id| self.graph.node_weight(*id))
132            .collect()
133    }
134
135    /// Gets nodes that call the given node.
136    pub fn get_callers(&self, index: NodeId) -> Vec<&CodeNode> {
137        self.graph
138            .neighbors_directed(index, petgraph::Direction::Incoming)
139            .filter_map(|idx| {
140                // Check if the edge is a call
141                let edge_idx = self.graph.find_edge(idx, index)?;
142                let edge = self.graph.edge_weight(edge_idx)?;
143                if edge.kind == EdgeKind::Calls {
144                    self.graph.node_weight(idx)
145                } else {
146                    None
147                }
148            })
149            .collect()
150    }
151
152    /// Gets nodes that this node calls.
153    pub fn get_callees(&self, index: NodeId) -> Vec<&CodeNode> {
154        self.graph
155            .neighbors_directed(index, petgraph::Direction::Outgoing)
156            .filter_map(|idx| {
157                let edge_idx = self.graph.find_edge(index, idx)?;
158                let edge = self.graph.edge_weight(edge_idx)?;
159                if edge.kind == EdgeKind::Calls {
160                    self.graph.node_weight(idx)
161                } else {
162                    None
163                }
164            })
165            .collect()
166    }
167
168    /// Gets all nodes that depend on the given node (directly or transitively).
169    pub fn get_dependents(&self, index: NodeId, max_depth: usize) -> Vec<(NodeId, usize)> {
170        let mut result = Vec::new();
171        let mut visited = std::collections::HashSet::new();
172        let mut queue = vec![(index, 0usize)];
173
174        while let Some((current, depth)) = queue.pop() {
175            if depth > max_depth || visited.contains(&current) {
176                continue;
177            }
178            visited.insert(current);
179
180            if current != index {
181                result.push((current, depth));
182            }
183
184            // Get incoming edges (callers)
185            for neighbor in self
186                .graph
187                .neighbors_directed(current, petgraph::Direction::Incoming)
188            {
189                if !visited.contains(&neighbor) {
190                    queue.push((neighbor, depth + 1));
191                }
192            }
193        }
194
195        result
196    }
197
198    /// Removes all nodes from a file. Used for incremental updates.
199    pub fn remove_file(&mut self, file: &str) {
200        if let Some(indexes) = self.file_index.remove(file) {
201            for index in indexes {
202                if let Some(node) = self.graph.node_weight(index) {
203                    // Remove from name index
204                    let name = node.name.clone();
205                    if let Some(name_list) = self.name_index.get_mut(&name) {
206                        name_list.retain(|&idx| idx != index);
207                    }
208                    // Remove from id index
209                    self.id_index.remove(&node.id);
210                    // Remove from search index
211                    self.search_index.remove(&name, index);
212                }
213                self.graph.remove_node(index);
214            }
215        }
216    }
217
218    /// Gets the centrality score for a node.
219    pub fn centrality(&self, index: NodeId) -> f64 {
220        self.centrality.get(&index).copied().unwrap_or(0.0)
221    }
222
223    /// Sets centrality scores (called after computation).
224    pub fn set_centrality(&mut self, scores: HashMap<NodeId, f64>) {
225        self.centrality = scores;
226    }
227
228    /// Returns the number of nodes.
229    pub fn node_count(&self) -> usize {
230        self.graph.node_count()
231    }
232
233    /// Returns the number of edges.
234    pub fn edge_count(&self) -> usize {
235        self.graph.edge_count()
236    }
237
238    /// Iterates over all nodes.
239    pub fn nodes(&self) -> impl Iterator<Item = &CodeNode> {
240        self.graph.node_weights()
241    }
242
243    /// Iterates over all edges.
244    pub fn edges(&self) -> impl Iterator<Item = &Edge> {
245        self.graph.edge_weights()
246    }
247
248    /// Returns all edges with source and target IDs for export.
249    pub fn export_edges(&self) -> Vec<GraphEdge> {
250        (&self.graph)
251            .edge_references()
252            .filter_map(|edge_ref| {
253                let source = self.graph.node_weight(edge_ref.source())?.id.clone();
254                let target = self.graph.node_weight(edge_ref.target())?.id.clone();
255                let weight = edge_ref.weight(); // &Edge
256                Some(GraphEdge {
257                    source,
258                    target,
259                    kind: weight.kind,
260                })
261            })
262            .collect()
263    }
264
265    /// Iterates over all node indexes.
266    pub fn node_indexes(&self) -> impl Iterator<Item = NodeId> + '_ {
267        self.graph.node_indices()
268    }
269
270    /// Finds the shortest path between two nodes.
271    pub fn find_path(&self, from: NodeId, to: NodeId) -> Option<Vec<&CodeNode>> {
272        let path_indices = petgraph::algo::astar(
273            &self.graph,
274            from,
275            |finish| finish == to,
276            |_| 1, // weight of 1 for all edges (BFS-like)
277            |_| 0, // heuristic
278        )?;
279
280        Some(
281            path_indices
282                .1
283                .into_iter()
284                .filter_map(|idx| self.graph.node_weight(idx))
285                .collect(),
286        )
287    }
288
289    /// Gets the node index for a string ID.
290    pub fn get_index(&self, id: &str) -> Option<NodeId> {
291        self.id_index.get(id).copied()
292    }
293}
294
295/// Graph statistics for the info endpoint.
296#[derive(Debug, Serialize, Deserialize)]
297pub struct GraphStats {
298    pub node_count: usize,
299    pub edge_count: usize,
300    pub files: usize,
301}
302
303impl ArborGraph {
304    /// Returns graph statistics.
305    pub fn stats(&self) -> GraphStats {
306        GraphStats {
307            node_count: self.node_count(),
308            edge_count: self.edge_count(),
309            files: self.file_index.len(),
310        }
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::edge::{Edge, EdgeKind};
318    use arbor_core::{CodeNode, NodeKind};
319
320    fn make_node(name: &str, file: &str) -> CodeNode {
321        CodeNode::new(name, name, NodeKind::Function, file)
322    }
323
324    #[test]
325    fn test_graph_new_is_empty() {
326        let g = ArborGraph::new();
327        assert_eq!(g.node_count(), 0);
328        assert_eq!(g.edge_count(), 0);
329        assert!(g.nodes().next().is_none());
330    }
331
332    #[test]
333    fn test_graph_add_and_get_node() {
334        let mut g = ArborGraph::new();
335        let node = make_node("foo", "main.rs");
336        let id = g.add_node(node.clone());
337        assert_eq!(g.node_count(), 1);
338
339        let got = g.get(id).unwrap();
340        assert_eq!(got.name, "foo");
341    }
342
343    #[test]
344    fn test_graph_find_by_name() {
345        let mut g = ArborGraph::new();
346        g.add_node(make_node("alpha", "a.rs"));
347        g.add_node(make_node("beta", "b.rs"));
348
349        let found = g.find_by_name("alpha");
350        assert_eq!(found.len(), 1);
351        assert_eq!(found[0].name, "alpha");
352
353        let not_found = g.find_by_name("gamma");
354        assert!(not_found.is_empty());
355    }
356
357    #[test]
358    fn test_graph_find_by_file() {
359        let mut g = ArborGraph::new();
360        g.add_node(make_node("foo", "main.rs"));
361        g.add_node(make_node("bar", "main.rs"));
362        g.add_node(make_node("baz", "other.rs"));
363
364        let main_nodes = g.find_by_file("main.rs");
365        assert_eq!(main_nodes.len(), 2);
366
367        let other_nodes = g.find_by_file("other.rs");
368        assert_eq!(other_nodes.len(), 1);
369
370        let empty = g.find_by_file("nonexistent.rs");
371        assert!(empty.is_empty());
372    }
373
374    #[test]
375    fn test_graph_search_substring() {
376        let mut g = ArborGraph::new();
377        g.add_node(make_node("validate_user", "a.rs"));
378        g.add_node(make_node("validate_email", "b.rs"));
379        g.add_node(make_node("send_email", "c.rs"));
380
381        let results = g.search("validate");
382        assert_eq!(results.len(), 2);
383        assert!(results.iter().any(|n| n.name == "validate_user"));
384        assert!(results.iter().any(|n| n.name == "validate_email"));
385    }
386
387    #[test]
388    fn test_graph_callers_callees() {
389        let mut g = ArborGraph::new();
390        let a = g.add_node(make_node("caller", "a.rs"));
391        let b = g.add_node(make_node("callee", "b.rs"));
392        g.add_edge(a, b, Edge::new(EdgeKind::Calls));
393
394        let callees = g.get_callees(a);
395        assert_eq!(callees.len(), 1);
396        assert_eq!(callees[0].name, "callee");
397
398        let callers = g.get_callers(b);
399        assert_eq!(callers.len(), 1);
400        assert_eq!(callers[0].name, "caller");
401
402        // No callers/callees for disconnected nodes
403        assert!(g.get_callers(a).is_empty());
404        assert!(g.get_callees(b).is_empty());
405    }
406
407    #[test]
408    fn test_graph_get_dependents() {
409        // a -> b -> c
410        let mut g = ArborGraph::new();
411        let a = g.add_node(make_node("a", "a.rs"));
412        let b = g.add_node(make_node("b", "b.rs"));
413        let c = g.add_node(make_node("c", "c.rs"));
414        g.add_edge(a, b, Edge::new(EdgeKind::Calls));
415        g.add_edge(b, c, Edge::new(EdgeKind::Calls));
416
417        // Dependents of c at depth 2 should include a and b
418        let deps = g.get_dependents(c, 2);
419        assert!(deps.iter().any(|(idx, _)| g.get(*idx).unwrap().name == "b"));
420        assert!(deps.iter().any(|(idx, _)| g.get(*idx).unwrap().name == "a"));
421    }
422
423    #[test]
424    fn test_graph_remove_file_cleanup() {
425        let mut g = ArborGraph::new();
426        g.add_node(make_node("foo", "remove_me.rs"));
427        g.add_node(make_node("bar", "remove_me.rs"));
428        g.add_node(make_node("keep", "keep.rs"));
429
430        assert_eq!(g.node_count(), 3);
431
432        g.remove_file("remove_me.rs");
433
434        // Nodes from removed file are gone
435        assert!(g.find_by_name("foo").is_empty());
436        assert!(g.find_by_name("bar").is_empty());
437        // Node from other file remains
438        assert_eq!(g.find_by_name("keep").len(), 1);
439        assert!(g.find_by_file("remove_me.rs").is_empty());
440    }
441
442    #[test]
443    fn test_graph_find_path() {
444        // a -> b -> c
445        let mut g = ArborGraph::new();
446        let a = g.add_node(make_node("start", "a.rs"));
447        let b = g.add_node(make_node("middle", "b.rs"));
448        let c = g.add_node(make_node("end", "c.rs"));
449        g.add_edge(a, b, Edge::new(EdgeKind::Calls));
450        g.add_edge(b, c, Edge::new(EdgeKind::Calls));
451
452        let path = g.find_path(a, c).unwrap();
453        assert_eq!(path.len(), 3);
454        assert_eq!(path[0].name, "start");
455        assert_eq!(path[1].name, "middle");
456        assert_eq!(path[2].name, "end");
457    }
458
459    #[test]
460    fn test_graph_find_path_no_connection() {
461        let mut g = ArborGraph::new();
462        let a = g.add_node(make_node("island_a", "a.rs"));
463        let b = g.add_node(make_node("island_b", "b.rs"));
464
465        // No edges → no path
466        assert!(g.find_path(a, b).is_none());
467    }
468
469    #[test]
470    fn test_graph_export_edges() {
471        let mut g = ArborGraph::new();
472        let a = g.add_node(make_node("a", "a.rs"));
473        let b = g.add_node(make_node("b", "b.rs"));
474        g.add_edge(a, b, Edge::new(EdgeKind::Calls));
475
476        let exported = g.export_edges();
477        assert_eq!(exported.len(), 1);
478        assert_eq!(exported[0].kind, EdgeKind::Calls);
479    }
480
481    #[test]
482    fn test_graph_stats() {
483        let mut g = ArborGraph::new();
484        g.add_node(make_node("a", "x.rs"));
485        g.add_node(make_node("b", "y.rs"));
486
487        let stats = g.stats();
488        assert_eq!(stats.node_count, 2);
489        assert_eq!(stats.edge_count, 0);
490        assert_eq!(stats.files, 2);
491    }
492
493    #[test]
494    fn test_graph_get_index_and_get_by_id() {
495        let mut g = ArborGraph::new();
496        let node = make_node("lookup_me", "test.rs");
497        let node_id_str = node.id.clone();
498        let idx = g.add_node(node);
499
500        assert_eq!(g.get_index(&node_id_str), Some(idx));
501        assert!(g.get_by_id(&node_id_str).is_some());
502        assert!(g.get_index("nonexistent").is_none());
503        assert!(g.get_by_id("nonexistent").is_none());
504    }
505
506    #[test]
507    fn test_graph_centrality_default_zero() {
508        let mut g = ArborGraph::new();
509        let idx = g.add_node(make_node("a", "a.rs"));
510        assert_eq!(g.centrality(idx), 0.0);
511    }
512
513    #[test]
514    fn test_graph_set_centrality() {
515        let mut g = ArborGraph::new();
516        let idx = g.add_node(make_node("a", "a.rs"));
517
518        let mut scores = HashMap::new();
519        scores.insert(idx, 0.75);
520        g.set_centrality(scores);
521
522        assert!((g.centrality(idx) - 0.75).abs() < f64::EPSILON);
523    }
524}