arbor_graph/
graph.rs

1//! Core graph data structure.
2//!
3//! The ArborGraph wraps petgraph and adds indexes for fast lookups.
4//! It's the central data structure that everything else works with.
5
6use crate::edge::{Edge, EdgeKind, GraphEdge};
7use arbor_core::CodeNode;
8use petgraph::graph::{DiGraph, NodeIndex};
9use petgraph::visit::EdgeRef; // For edge_references
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Unique identifier for a node in the graph.
14pub type NodeId = NodeIndex;
15
16/// The code relationship graph.
17///
18/// This is the heart of Arbor. It stores all code entities as nodes
19/// and their relationships as edges, with indexes for fast access.
20#[derive(Debug, Serialize, Deserialize)]
21pub struct ArborGraph {
22    /// The underlying petgraph graph.
23    graph: DiGraph<CodeNode, Edge>,
24
25    /// Maps string IDs to graph node indexes.
26    id_index: HashMap<String, NodeId>,
27
28    /// Maps node names to node IDs (for search).
29    name_index: HashMap<String, Vec<NodeId>>,
30
31    /// Maps file paths to node IDs (for incremental updates).
32    file_index: HashMap<String, Vec<NodeId>>,
33
34    /// Centrality scores for ranking.
35    centrality: HashMap<NodeId, f64>,
36}
37
38impl Default for ArborGraph {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl ArborGraph {
45    /// Creates a new empty graph.
46    pub fn new() -> Self {
47        Self {
48            graph: DiGraph::new(),
49            id_index: HashMap::new(),
50            name_index: HashMap::new(),
51            file_index: HashMap::new(),
52            centrality: HashMap::new(),
53        }
54    }
55
56    /// Adds a code node to the graph.
57    ///
58    /// Returns the node's index for adding edges later.
59    pub fn add_node(&mut self, node: CodeNode) -> NodeId {
60        let id = node.id.clone();
61        let name = node.name.clone();
62        let file = node.file.clone();
63
64        let index = self.graph.add_node(node);
65
66        // Update indexes
67        self.id_index.insert(id, index);
68        self.name_index.entry(name).or_default().push(index);
69        self.file_index.entry(file).or_default().push(index);
70
71        index
72    }
73
74    /// Adds an edge between two nodes.
75    pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) {
76        self.graph.add_edge(from, to, edge);
77    }
78
79    /// Gets a node by its string ID.
80    pub fn get_by_id(&self, id: &str) -> Option<&CodeNode> {
81        let index = self.id_index.get(id)?;
82        self.graph.node_weight(*index)
83    }
84
85    /// Gets a node by its graph index.
86    pub fn get(&self, index: NodeId) -> Option<&CodeNode> {
87        self.graph.node_weight(index)
88    }
89
90    /// Finds all nodes with a given name.
91    pub fn find_by_name(&self, name: &str) -> Vec<&CodeNode> {
92        self.name_index
93            .get(name)
94            .map(|indexes| {
95                indexes
96                    .iter()
97                    .filter_map(|idx| self.graph.node_weight(*idx))
98                    .collect()
99            })
100            .unwrap_or_default()
101    }
102
103    /// Finds all nodes in a file.
104    pub fn find_by_file(&self, file: &str) -> Vec<&CodeNode> {
105        self.file_index
106            .get(file)
107            .map(|indexes| {
108                indexes
109                    .iter()
110                    .filter_map(|idx| self.graph.node_weight(*idx))
111                    .collect()
112            })
113            .unwrap_or_default()
114    }
115
116    /// Searches for nodes whose name contains the query.
117    pub fn search(&self, query: &str) -> Vec<&CodeNode> {
118        let query_lower = query.to_lowercase();
119        self.graph
120            .node_weights()
121            .filter(|node| node.name.to_lowercase().contains(&query_lower))
122            .collect()
123    }
124
125    /// Gets nodes that call the given node.
126    pub fn get_callers(&self, index: NodeId) -> Vec<&CodeNode> {
127        self.graph
128            .neighbors_directed(index, petgraph::Direction::Incoming)
129            .filter_map(|idx| {
130                // Check if the edge is a call
131                let edge_idx = self.graph.find_edge(idx, index)?;
132                let edge = self.graph.edge_weight(edge_idx)?;
133                if edge.kind == EdgeKind::Calls {
134                    self.graph.node_weight(idx)
135                } else {
136                    None
137                }
138            })
139            .collect()
140    }
141
142    /// Gets nodes that this node calls.
143    pub fn get_callees(&self, index: NodeId) -> Vec<&CodeNode> {
144        self.graph
145            .neighbors_directed(index, petgraph::Direction::Outgoing)
146            .filter_map(|idx| {
147                let edge_idx = self.graph.find_edge(index, idx)?;
148                let edge = self.graph.edge_weight(edge_idx)?;
149                if edge.kind == EdgeKind::Calls {
150                    self.graph.node_weight(idx)
151                } else {
152                    None
153                }
154            })
155            .collect()
156    }
157
158    /// Gets all nodes that depend on the given node (directly or transitively).
159    pub fn get_dependents(&self, index: NodeId, max_depth: usize) -> Vec<(NodeId, usize)> {
160        let mut result = Vec::new();
161        let mut visited = std::collections::HashSet::new();
162        let mut queue = vec![(index, 0usize)];
163
164        while let Some((current, depth)) = queue.pop() {
165            if depth > max_depth || visited.contains(&current) {
166                continue;
167            }
168            visited.insert(current);
169
170            if current != index {
171                result.push((current, depth));
172            }
173
174            // Get incoming edges (callers)
175            for neighbor in self
176                .graph
177                .neighbors_directed(current, petgraph::Direction::Incoming)
178            {
179                if !visited.contains(&neighbor) {
180                    queue.push((neighbor, depth + 1));
181                }
182            }
183        }
184
185        result
186    }
187
188    /// Removes all nodes from a file. Used for incremental updates.
189    pub fn remove_file(&mut self, file: &str) {
190        if let Some(indexes) = self.file_index.remove(file) {
191            for index in indexes {
192                if let Some(node) = self.graph.node_weight(index) {
193                    // Remove from name index
194                    if let Some(name_list) = self.name_index.get_mut(&node.name) {
195                        name_list.retain(|&idx| idx != index);
196                    }
197                    // Remove from id index
198                    self.id_index.remove(&node.id);
199                }
200                self.graph.remove_node(index);
201            }
202        }
203    }
204
205    /// Gets the centrality score for a node.
206    pub fn centrality(&self, index: NodeId) -> f64 {
207        self.centrality.get(&index).copied().unwrap_or(0.0)
208    }
209
210    /// Sets centrality scores (called after computation).
211    pub fn set_centrality(&mut self, scores: HashMap<NodeId, f64>) {
212        self.centrality = scores;
213    }
214
215    /// Returns the number of nodes.
216    pub fn node_count(&self) -> usize {
217        self.graph.node_count()
218    }
219
220    /// Returns the number of edges.
221    pub fn edge_count(&self) -> usize {
222        self.graph.edge_count()
223    }
224
225    /// Iterates over all nodes.
226    pub fn nodes(&self) -> impl Iterator<Item = &CodeNode> {
227        self.graph.node_weights()
228    }
229
230    /// Iterates over all edges.
231    pub fn edges(&self) -> impl Iterator<Item = &Edge> {
232        self.graph.edge_weights()
233    }
234
235    /// Returns all edges with source and target IDs for export.
236    pub fn export_edges(&self) -> Vec<GraphEdge> {
237        self.graph
238            .edge_references()
239            .map(|edge_ref| {
240                let source = self
241                    .graph
242                    .node_weight(edge_ref.source())
243                    .unwrap()
244                    .id
245                    .clone();
246                let target = self
247                    .graph
248                    .node_weight(edge_ref.target())
249                    .unwrap()
250                    .id
251                    .clone();
252                let weight = edge_ref.weight(); // &Edge
253                GraphEdge {
254                    source,
255                    target,
256                    kind: weight.kind,
257                }
258            })
259            .collect()
260    }
261
262    /// Iterates over all node indexes.
263    pub fn node_indexes(&self) -> impl Iterator<Item = NodeId> + '_ {
264        self.graph.node_indices()
265    }
266
267    /// Finds the shortest path between two nodes.
268    pub fn find_path(&self, from: NodeId, to: NodeId) -> Option<Vec<&CodeNode>> {
269        let path_indices = petgraph::algo::astar(
270            &self.graph,
271            from,
272            |finish| finish == to,
273            |_| 1, // weight of 1 for all edges (BFS-like)
274            |_| 0, // heuristic
275        )?;
276
277        Some(
278            path_indices
279                .1
280                .into_iter()
281                .filter_map(|idx| self.graph.node_weight(idx))
282                .collect(),
283        )
284    }
285
286    /// Gets the node index for a string ID.
287    pub fn get_index(&self, id: &str) -> Option<NodeId> {
288        self.id_index.get(id).copied()
289    }
290}
291
292/// Graph statistics for the info endpoint.
293#[derive(Debug, Serialize, Deserialize)]
294pub struct GraphStats {
295    pub node_count: usize,
296    pub edge_count: usize,
297    pub files: usize,
298}
299
300impl ArborGraph {
301    /// Returns graph statistics.
302    pub fn stats(&self) -> GraphStats {
303        GraphStats {
304            node_count: self.node_count(),
305            edge_count: self.edge_count(),
306            files: self.file_index.len(),
307        }
308    }
309}