Skip to main content

arbor_graph/
graph.rs

1//! Core graph data structure.
2//!
3//! The ArborGraph wraps petgraph and adds indexes for fast lookups.
4//! It's the central data structure that everything else works with.
5
6use crate::edge::{Edge, EdgeKind, GraphEdge};
7use crate::search_index::SearchIndex;
8use arbor_core::CodeNode;
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::EdgeRef; // For edge_references
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Unique identifier for a node in the graph.
15pub type NodeId = NodeIndex;
16
17/// The code relationship graph.
18///
19/// This is the heart of Arbor. It stores all code entities as nodes
20/// and their relationships as edges, with indexes for fast access.
21#[derive(Debug, Serialize, Deserialize)]
22pub struct ArborGraph {
23    /// The underlying petgraph graph.
24    pub(crate) graph: DiGraph<CodeNode, Edge>,
25
26    /// Maps string IDs to graph node indexes.
27    id_index: HashMap<String, NodeId>,
28
29    /// Maps node names to node IDs (for search).
30    name_index: HashMap<String, Vec<NodeId>>,
31
32    /// Maps file paths to node IDs (for incremental updates).
33    file_index: HashMap<String, Vec<NodeId>>,
34
35    /// Centrality scores for ranking.
36    centrality: HashMap<NodeId, f64>,
37
38    /// Search index for fast substring queries.
39    #[serde(skip)]
40    search_index: SearchIndex,
41}
42
43impl Default for ArborGraph {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl ArborGraph {
50    /// Creates a new empty graph.
51    pub fn new() -> Self {
52        Self {
53            graph: DiGraph::new(),
54            id_index: HashMap::new(),
55            name_index: HashMap::new(),
56            file_index: HashMap::new(),
57            centrality: HashMap::new(),
58            search_index: SearchIndex::new(),
59        }
60    }
61
62    /// Adds a code node to the graph.
63    ///
64    /// Returns the node's index for adding edges later.
65    pub fn add_node(&mut self, node: CodeNode) -> NodeId {
66        let id = node.id.clone();
67        let name = node.name.clone();
68        let file = node.file.clone();
69
70        let index = self.graph.add_node(node);
71
72        // Update indexes
73        self.id_index.insert(id, index);
74        self.name_index.entry(name.clone()).or_default().push(index);
75        self.file_index.entry(file).or_default().push(index);
76        self.search_index.insert(&name, index);
77
78        index
79    }
80
81    /// Adds an edge between two nodes.
82    pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) {
83        self.graph.add_edge(from, to, edge);
84    }
85
86    /// Gets a node by its string ID.
87    pub fn get_by_id(&self, id: &str) -> Option<&CodeNode> {
88        let index = self.id_index.get(id)?;
89        self.graph.node_weight(*index)
90    }
91
92    /// Gets a node by its graph index.
93    pub fn get(&self, index: NodeId) -> Option<&CodeNode> {
94        self.graph.node_weight(index)
95    }
96
97    /// Finds all nodes with a given name.
98    pub fn find_by_name(&self, name: &str) -> Vec<&CodeNode> {
99        self.name_index
100            .get(name)
101            .map(|indexes| {
102                indexes
103                    .iter()
104                    .filter_map(|idx| self.graph.node_weight(*idx))
105                    .collect()
106            })
107            .unwrap_or_default()
108    }
109
110    /// Finds all nodes in a file.
111    pub fn find_by_file(&self, file: &str) -> Vec<&CodeNode> {
112        self.file_index
113            .get(file)
114            .map(|indexes| {
115                indexes
116                    .iter()
117                    .filter_map(|idx| self.graph.node_weight(*idx))
118                    .collect()
119            })
120            .unwrap_or_default()
121    }
122
123    /// Searches for nodes whose name contains the query.
124    ///
125    /// Uses the search index for fast O(k) lookups where k is the number of matches,
126    /// instead of O(n) linear scan over all nodes.
127    pub fn search(&self, query: &str) -> Vec<&CodeNode> {
128        self.search_index
129            .search(query)
130            .iter()
131            .filter_map(|id| self.graph.node_weight(*id))
132            .collect()
133    }
134
135    /// Gets nodes that call the given node.
136    pub fn get_callers(&self, index: NodeId) -> Vec<&CodeNode> {
137        self.graph
138            .neighbors_directed(index, petgraph::Direction::Incoming)
139            .filter_map(|idx| {
140                // Check if the edge is a call
141                let edge_idx = self.graph.find_edge(idx, index)?;
142                let edge = self.graph.edge_weight(edge_idx)?;
143                if edge.kind == EdgeKind::Calls {
144                    self.graph.node_weight(idx)
145                } else {
146                    None
147                }
148            })
149            .collect()
150    }
151
152    /// Gets nodes that this node calls.
153    pub fn get_callees(&self, index: NodeId) -> Vec<&CodeNode> {
154        self.graph
155            .neighbors_directed(index, petgraph::Direction::Outgoing)
156            .filter_map(|idx| {
157                let edge_idx = self.graph.find_edge(index, idx)?;
158                let edge = self.graph.edge_weight(edge_idx)?;
159                if edge.kind == EdgeKind::Calls {
160                    self.graph.node_weight(idx)
161                } else {
162                    None
163                }
164            })
165            .collect()
166    }
167
168    /// Gets all nodes that depend on the given node (directly or transitively).
169    pub fn get_dependents(&self, index: NodeId, max_depth: usize) -> Vec<(NodeId, usize)> {
170        let mut result = Vec::new();
171        let mut visited = std::collections::HashSet::new();
172        let mut queue = vec![(index, 0usize)];
173
174        while let Some((current, depth)) = queue.pop() {
175            if depth > max_depth || visited.contains(&current) {
176                continue;
177            }
178            visited.insert(current);
179
180            if current != index {
181                result.push((current, depth));
182            }
183
184            // Get incoming edges (callers)
185            for neighbor in self
186                .graph
187                .neighbors_directed(current, petgraph::Direction::Incoming)
188            {
189                if !visited.contains(&neighbor) {
190                    queue.push((neighbor, depth + 1));
191                }
192            }
193        }
194
195        result
196    }
197
198    /// Removes all nodes from a file. Used for incremental updates.
199    pub fn remove_file(&mut self, file: &str) {
200        if let Some(indexes) = self.file_index.remove(file) {
201            for index in indexes {
202                if let Some(node) = self.graph.node_weight(index) {
203                    // Remove from name index
204                    let name = node.name.clone();
205                    if let Some(name_list) = self.name_index.get_mut(&name) {
206                        name_list.retain(|&idx| idx != index);
207                    }
208                    // Remove from id index
209                    self.id_index.remove(&node.id);
210                    // Remove from search index
211                    self.search_index.remove(&name, index);
212                }
213                self.graph.remove_node(index);
214            }
215        }
216    }
217
218    /// Gets the centrality score for a node.
219    pub fn centrality(&self, index: NodeId) -> f64 {
220        self.centrality.get(&index).copied().unwrap_or(0.0)
221    }
222
223    /// Sets centrality scores (called after computation).
224    pub fn set_centrality(&mut self, scores: HashMap<NodeId, f64>) {
225        self.centrality = scores;
226    }
227
228    /// Returns the number of nodes.
229    pub fn node_count(&self) -> usize {
230        self.graph.node_count()
231    }
232
233    /// Returns the number of edges.
234    pub fn edge_count(&self) -> usize {
235        self.graph.edge_count()
236    }
237
238    /// Iterates over all nodes.
239    pub fn nodes(&self) -> impl Iterator<Item = &CodeNode> {
240        self.graph.node_weights()
241    }
242
243    /// Iterates over all edges.
244    pub fn edges(&self) -> impl Iterator<Item = &Edge> {
245        self.graph.edge_weights()
246    }
247
248    /// Returns all edges with source and target IDs for export.
249    pub fn export_edges(&self) -> Vec<GraphEdge> {
250        self.graph
251            .edge_references()
252            .map(|edge_ref| {
253                let source = self
254                    .graph
255                    .node_weight(edge_ref.source())
256                    .unwrap()
257                    .id
258                    .clone();
259                let target = self
260                    .graph
261                    .node_weight(edge_ref.target())
262                    .unwrap()
263                    .id
264                    .clone();
265                let weight = edge_ref.weight(); // &Edge
266                GraphEdge {
267                    source,
268                    target,
269                    kind: weight.kind,
270                }
271            })
272            .collect()
273    }
274
275    /// Iterates over all node indexes.
276    pub fn node_indexes(&self) -> impl Iterator<Item = NodeId> + '_ {
277        self.graph.node_indices()
278    }
279
280    /// Finds the shortest path between two nodes.
281    pub fn find_path(&self, from: NodeId, to: NodeId) -> Option<Vec<&CodeNode>> {
282        let path_indices = petgraph::algo::astar(
283            &self.graph,
284            from,
285            |finish| finish == to,
286            |_| 1, // weight of 1 for all edges (BFS-like)
287            |_| 0, // heuristic
288        )?;
289
290        Some(
291            path_indices
292                .1
293                .into_iter()
294                .filter_map(|idx| self.graph.node_weight(idx))
295                .collect(),
296        )
297    }
298
299    /// Gets the node index for a string ID.
300    pub fn get_index(&self, id: &str) -> Option<NodeId> {
301        self.id_index.get(id).copied()
302    }
303}
304
305/// Graph statistics for the info endpoint.
306#[derive(Debug, Serialize, Deserialize)]
307pub struct GraphStats {
308    pub node_count: usize,
309    pub edge_count: usize,
310    pub files: usize,
311}
312
313impl ArborGraph {
314    /// Returns graph statistics.
315    pub fn stats(&self) -> GraphStats {
316        GraphStats {
317            node_count: self.node_count(),
318            edge_count: self.edge_count(),
319            files: self.file_index.len(),
320        }
321    }
322}