1use crate::edge::{Edge, EdgeKind, GraphEdge};
7use crate::search_index::SearchIndex;
8use arbor_core::CodeNode;
9use petgraph::stable_graph::{NodeIndex, StableDiGraph};
10use petgraph::visit::{EdgeRef, IntoEdgeReferences}; use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14pub type NodeId = NodeIndex;
16
17#[derive(Debug, Serialize, Deserialize)]
22pub struct ArborGraph {
23 pub(crate) graph: StableDiGraph<CodeNode, Edge>,
25
26 id_index: HashMap<String, NodeId>,
28
29 name_index: HashMap<String, Vec<NodeId>>,
31
32 file_index: HashMap<String, Vec<NodeId>>,
34
35 centrality: HashMap<NodeId, f64>,
37
38 #[serde(skip)]
40 search_index: SearchIndex,
41}
42
43impl Default for ArborGraph {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl ArborGraph {
50 pub fn new() -> Self {
52 Self {
53 graph: StableDiGraph::new(),
54 id_index: HashMap::new(),
55 name_index: HashMap::new(),
56 file_index: HashMap::new(),
57 centrality: HashMap::new(),
58 search_index: SearchIndex::new(),
59 }
60 }
61
62 pub fn add_node(&mut self, node: CodeNode) -> NodeId {
66 let id = node.id.clone();
67 let name = node.name.clone();
68 let file = node.file.clone();
69
70 let index = self.graph.add_node(node);
71
72 self.id_index.insert(id, index);
74 self.name_index.entry(name.clone()).or_default().push(index);
75 self.file_index.entry(file).or_default().push(index);
76 self.search_index.insert(&name, index);
77
78 index
79 }
80
81 pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) {
83 self.graph.add_edge(from, to, edge);
84 }
85
86 pub fn get_by_id(&self, id: &str) -> Option<&CodeNode> {
88 let index = self.id_index.get(id)?;
89 self.graph.node_weight(*index)
90 }
91
92 pub fn get(&self, index: NodeId) -> Option<&CodeNode> {
94 self.graph.node_weight(index)
95 }
96
97 pub fn find_by_name(&self, name: &str) -> Vec<&CodeNode> {
99 self.name_index
100 .get(name)
101 .map(|indexes| {
102 indexes
103 .iter()
104 .filter_map(|idx| self.graph.node_weight(*idx))
105 .collect()
106 })
107 .unwrap_or_default()
108 }
109
110 pub fn find_by_file(&self, file: &str) -> Vec<&CodeNode> {
112 self.file_index
113 .get(file)
114 .map(|indexes| {
115 indexes
116 .iter()
117 .filter_map(|idx| self.graph.node_weight(*idx))
118 .collect()
119 })
120 .unwrap_or_default()
121 }
122
123 pub fn search(&self, query: &str) -> Vec<&CodeNode> {
128 self.search_index
129 .search(query)
130 .iter()
131 .filter_map(|id| self.graph.node_weight(*id))
132 .collect()
133 }
134
135 pub fn get_callers(&self, index: NodeId) -> Vec<&CodeNode> {
137 self.graph
138 .neighbors_directed(index, petgraph::Direction::Incoming)
139 .filter_map(|idx| {
140 let edge_idx = self.graph.find_edge(idx, index)?;
142 let edge = self.graph.edge_weight(edge_idx)?;
143 if edge.kind == EdgeKind::Calls {
144 self.graph.node_weight(idx)
145 } else {
146 None
147 }
148 })
149 .collect()
150 }
151
152 pub fn get_callees(&self, index: NodeId) -> Vec<&CodeNode> {
154 self.graph
155 .neighbors_directed(index, petgraph::Direction::Outgoing)
156 .filter_map(|idx| {
157 let edge_idx = self.graph.find_edge(index, idx)?;
158 let edge = self.graph.edge_weight(edge_idx)?;
159 if edge.kind == EdgeKind::Calls {
160 self.graph.node_weight(idx)
161 } else {
162 None
163 }
164 })
165 .collect()
166 }
167
168 pub fn get_dependents(&self, index: NodeId, max_depth: usize) -> Vec<(NodeId, usize)> {
170 let mut result = Vec::new();
171 let mut visited = std::collections::HashSet::new();
172 let mut queue = vec![(index, 0usize)];
173
174 while let Some((current, depth)) = queue.pop() {
175 if depth > max_depth || visited.contains(¤t) {
176 continue;
177 }
178 visited.insert(current);
179
180 if current != index {
181 result.push((current, depth));
182 }
183
184 for neighbor in self
186 .graph
187 .neighbors_directed(current, petgraph::Direction::Incoming)
188 {
189 if !visited.contains(&neighbor) {
190 queue.push((neighbor, depth + 1));
191 }
192 }
193 }
194
195 result
196 }
197
198 pub fn remove_file(&mut self, file: &str) {
200 if let Some(indexes) = self.file_index.remove(file) {
201 for index in indexes {
202 if let Some(node) = self.graph.node_weight(index) {
203 let name = node.name.clone();
205 if let Some(name_list) = self.name_index.get_mut(&name) {
206 name_list.retain(|&idx| idx != index);
207 }
208 self.id_index.remove(&node.id);
210 self.search_index.remove(&name, index);
212 }
213 self.graph.remove_node(index);
214 }
215 }
216 }
217
218 pub fn centrality(&self, index: NodeId) -> f64 {
220 self.centrality.get(&index).copied().unwrap_or(0.0)
221 }
222
223 pub fn set_centrality(&mut self, scores: HashMap<NodeId, f64>) {
225 self.centrality = scores;
226 }
227
228 pub fn node_count(&self) -> usize {
230 self.graph.node_count()
231 }
232
233 pub fn edge_count(&self) -> usize {
235 self.graph.edge_count()
236 }
237
238 pub fn nodes(&self) -> impl Iterator<Item = &CodeNode> {
240 self.graph.node_weights()
241 }
242
243 pub fn edges(&self) -> impl Iterator<Item = &Edge> {
245 self.graph.edge_weights()
246 }
247
248 pub fn export_edges(&self) -> Vec<GraphEdge> {
250 (&self.graph)
251 .edge_references()
252 .filter_map(|edge_ref| {
253 let source = self.graph.node_weight(edge_ref.source())?.id.clone();
254 let target = self.graph.node_weight(edge_ref.target())?.id.clone();
255 let weight = edge_ref.weight(); Some(GraphEdge {
257 source,
258 target,
259 kind: weight.kind,
260 })
261 })
262 .collect()
263 }
264
265 pub fn node_indexes(&self) -> impl Iterator<Item = NodeId> + '_ {
267 self.graph.node_indices()
268 }
269
270 pub fn find_path(&self, from: NodeId, to: NodeId) -> Option<Vec<&CodeNode>> {
272 let path_indices = petgraph::algo::astar(
273 &self.graph,
274 from,
275 |finish| finish == to,
276 |_| 1, |_| 0, )?;
279
280 Some(
281 path_indices
282 .1
283 .into_iter()
284 .filter_map(|idx| self.graph.node_weight(idx))
285 .collect(),
286 )
287 }
288
289 pub fn get_index(&self, id: &str) -> Option<NodeId> {
291 self.id_index.get(id).copied()
292 }
293}
294
295#[derive(Debug, Serialize, Deserialize)]
297pub struct GraphStats {
298 pub node_count: usize,
299 pub edge_count: usize,
300 pub files: usize,
301}
302
303impl ArborGraph {
304 pub fn stats(&self) -> GraphStats {
306 GraphStats {
307 node_count: self.node_count(),
308 edge_count: self.edge_count(),
309 files: self.file_index.len(),
310 }
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::edge::{Edge, EdgeKind};
318 use arbor_core::{CodeNode, NodeKind};
319
320 fn make_node(name: &str, file: &str) -> CodeNode {
321 CodeNode::new(name, name, NodeKind::Function, file)
322 }
323
324 #[test]
325 fn test_graph_new_is_empty() {
326 let g = ArborGraph::new();
327 assert_eq!(g.node_count(), 0);
328 assert_eq!(g.edge_count(), 0);
329 assert!(g.nodes().next().is_none());
330 }
331
332 #[test]
333 fn test_graph_add_and_get_node() {
334 let mut g = ArborGraph::new();
335 let node = make_node("foo", "main.rs");
336 let id = g.add_node(node.clone());
337 assert_eq!(g.node_count(), 1);
338
339 let got = g.get(id).unwrap();
340 assert_eq!(got.name, "foo");
341 }
342
343 #[test]
344 fn test_graph_find_by_name() {
345 let mut g = ArborGraph::new();
346 g.add_node(make_node("alpha", "a.rs"));
347 g.add_node(make_node("beta", "b.rs"));
348
349 let found = g.find_by_name("alpha");
350 assert_eq!(found.len(), 1);
351 assert_eq!(found[0].name, "alpha");
352
353 let not_found = g.find_by_name("gamma");
354 assert!(not_found.is_empty());
355 }
356
357 #[test]
358 fn test_graph_find_by_file() {
359 let mut g = ArborGraph::new();
360 g.add_node(make_node("foo", "main.rs"));
361 g.add_node(make_node("bar", "main.rs"));
362 g.add_node(make_node("baz", "other.rs"));
363
364 let main_nodes = g.find_by_file("main.rs");
365 assert_eq!(main_nodes.len(), 2);
366
367 let other_nodes = g.find_by_file("other.rs");
368 assert_eq!(other_nodes.len(), 1);
369
370 let empty = g.find_by_file("nonexistent.rs");
371 assert!(empty.is_empty());
372 }
373
374 #[test]
375 fn test_graph_search_substring() {
376 let mut g = ArborGraph::new();
377 g.add_node(make_node("validate_user", "a.rs"));
378 g.add_node(make_node("validate_email", "b.rs"));
379 g.add_node(make_node("send_email", "c.rs"));
380
381 let results = g.search("validate");
382 assert_eq!(results.len(), 2);
383 assert!(results.iter().any(|n| n.name == "validate_user"));
384 assert!(results.iter().any(|n| n.name == "validate_email"));
385 }
386
387 #[test]
388 fn test_graph_callers_callees() {
389 let mut g = ArborGraph::new();
390 let a = g.add_node(make_node("caller", "a.rs"));
391 let b = g.add_node(make_node("callee", "b.rs"));
392 g.add_edge(a, b, Edge::new(EdgeKind::Calls));
393
394 let callees = g.get_callees(a);
395 assert_eq!(callees.len(), 1);
396 assert_eq!(callees[0].name, "callee");
397
398 let callers = g.get_callers(b);
399 assert_eq!(callers.len(), 1);
400 assert_eq!(callers[0].name, "caller");
401
402 assert!(g.get_callers(a).is_empty());
404 assert!(g.get_callees(b).is_empty());
405 }
406
407 #[test]
408 fn test_graph_get_dependents() {
409 let mut g = ArborGraph::new();
411 let a = g.add_node(make_node("a", "a.rs"));
412 let b = g.add_node(make_node("b", "b.rs"));
413 let c = g.add_node(make_node("c", "c.rs"));
414 g.add_edge(a, b, Edge::new(EdgeKind::Calls));
415 g.add_edge(b, c, Edge::new(EdgeKind::Calls));
416
417 let deps = g.get_dependents(c, 2);
419 assert!(deps.iter().any(|(idx, _)| g.get(*idx).unwrap().name == "b"));
420 assert!(deps.iter().any(|(idx, _)| g.get(*idx).unwrap().name == "a"));
421 }
422
423 #[test]
424 fn test_graph_remove_file_cleanup() {
425 let mut g = ArborGraph::new();
426 g.add_node(make_node("foo", "remove_me.rs"));
427 g.add_node(make_node("bar", "remove_me.rs"));
428 g.add_node(make_node("keep", "keep.rs"));
429
430 assert_eq!(g.node_count(), 3);
431
432 g.remove_file("remove_me.rs");
433
434 assert!(g.find_by_name("foo").is_empty());
436 assert!(g.find_by_name("bar").is_empty());
437 assert_eq!(g.find_by_name("keep").len(), 1);
439 assert!(g.find_by_file("remove_me.rs").is_empty());
440 }
441
442 #[test]
443 fn test_graph_find_path() {
444 let mut g = ArborGraph::new();
446 let a = g.add_node(make_node("start", "a.rs"));
447 let b = g.add_node(make_node("middle", "b.rs"));
448 let c = g.add_node(make_node("end", "c.rs"));
449 g.add_edge(a, b, Edge::new(EdgeKind::Calls));
450 g.add_edge(b, c, Edge::new(EdgeKind::Calls));
451
452 let path = g.find_path(a, c).unwrap();
453 assert_eq!(path.len(), 3);
454 assert_eq!(path[0].name, "start");
455 assert_eq!(path[1].name, "middle");
456 assert_eq!(path[2].name, "end");
457 }
458
459 #[test]
460 fn test_graph_find_path_no_connection() {
461 let mut g = ArborGraph::new();
462 let a = g.add_node(make_node("island_a", "a.rs"));
463 let b = g.add_node(make_node("island_b", "b.rs"));
464
465 assert!(g.find_path(a, b).is_none());
467 }
468
469 #[test]
470 fn test_graph_export_edges() {
471 let mut g = ArborGraph::new();
472 let a = g.add_node(make_node("a", "a.rs"));
473 let b = g.add_node(make_node("b", "b.rs"));
474 g.add_edge(a, b, Edge::new(EdgeKind::Calls));
475
476 let exported = g.export_edges();
477 assert_eq!(exported.len(), 1);
478 assert_eq!(exported[0].kind, EdgeKind::Calls);
479 }
480
481 #[test]
482 fn test_graph_stats() {
483 let mut g = ArborGraph::new();
484 g.add_node(make_node("a", "x.rs"));
485 g.add_node(make_node("b", "y.rs"));
486
487 let stats = g.stats();
488 assert_eq!(stats.node_count, 2);
489 assert_eq!(stats.edge_count, 0);
490 assert_eq!(stats.files, 2);
491 }
492
493 #[test]
494 fn test_graph_get_index_and_get_by_id() {
495 let mut g = ArborGraph::new();
496 let node = make_node("lookup_me", "test.rs");
497 let node_id_str = node.id.clone();
498 let idx = g.add_node(node);
499
500 assert_eq!(g.get_index(&node_id_str), Some(idx));
501 assert!(g.get_by_id(&node_id_str).is_some());
502 assert!(g.get_index("nonexistent").is_none());
503 assert!(g.get_by_id("nonexistent").is_none());
504 }
505
506 #[test]
507 fn test_graph_centrality_default_zero() {
508 let mut g = ArborGraph::new();
509 let idx = g.add_node(make_node("a", "a.rs"));
510 assert_eq!(g.centrality(idx), 0.0);
511 }
512
513 #[test]
514 fn test_graph_set_centrality() {
515 let mut g = ArborGraph::new();
516 let idx = g.add_node(make_node("a", "a.rs"));
517
518 let mut scores = HashMap::new();
519 scores.insert(idx, 0.75);
520 g.set_centrality(scores);
521
522 assert!((g.centrality(idx) - 0.75).abs() < f64::EPSILON);
523 }
524}