use crate::edge::{Edge, EdgeKind, GraphEdge};
use crate::search_index::SearchIndex;
use arbor_core::CodeNode;
use petgraph::stable_graph::{NodeIndex, StableDiGraph};
use petgraph::visit::{EdgeRef, IntoEdgeReferences}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type NodeId = NodeIndex;
#[derive(Debug, Serialize, Deserialize)]
pub struct ArborGraph {
pub(crate) graph: StableDiGraph<CodeNode, Edge>,
id_index: HashMap<String, NodeId>,
name_index: HashMap<String, Vec<NodeId>>,
file_index: HashMap<String, Vec<NodeId>>,
centrality: HashMap<NodeId, f64>,
#[serde(skip)]
search_index: SearchIndex,
}
impl Default for ArborGraph {
fn default() -> Self {
Self::new()
}
}
impl ArborGraph {
pub fn new() -> Self {
Self {
graph: StableDiGraph::new(),
id_index: HashMap::new(),
name_index: HashMap::new(),
file_index: HashMap::new(),
centrality: HashMap::new(),
search_index: SearchIndex::new(),
}
}
pub fn rebuild_search_index(&mut self) {
self.search_index = SearchIndex::new();
for index in self.graph.node_indices() {
if let Some(node) = self.graph.node_weight(index) {
self.search_index.insert(&node.name, index);
}
}
}
pub fn add_node(&mut self, node: CodeNode) -> NodeId {
let id = node.id.clone();
let name = node.name.clone();
let file = node.file.clone();
let index = self.graph.add_node(node);
self.id_index.insert(id, index);
self.name_index.entry(name.clone()).or_default().push(index);
self.file_index.entry(file).or_default().push(index);
self.search_index.insert(&name, index);
index
}
pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) {
self.graph.add_edge(from, to, edge);
}
pub fn get_by_id(&self, id: &str) -> Option<&CodeNode> {
let index = self.id_index.get(id)?;
self.graph.node_weight(*index)
}
pub fn get(&self, index: NodeId) -> Option<&CodeNode> {
self.graph.node_weight(index)
}
pub fn find_by_name(&self, name: &str) -> Vec<&CodeNode> {
self.name_index
.get(name)
.map(|indexes| {
indexes
.iter()
.filter_map(|idx| self.graph.node_weight(*idx))
.collect()
})
.unwrap_or_default()
}
pub fn find_by_file(&self, file: &str) -> Vec<&CodeNode> {
self.file_index
.get(file)
.map(|indexes| {
indexes
.iter()
.filter_map(|idx| self.graph.node_weight(*idx))
.collect()
})
.unwrap_or_default()
}
pub fn search(&self, query: &str) -> Vec<&CodeNode> {
self.search_index
.search(query)
.iter()
.filter_map(|id| self.graph.node_weight(*id))
.collect()
}
pub fn get_callers(&self, index: NodeId) -> Vec<&CodeNode> {
self.graph
.neighbors_directed(index, petgraph::Direction::Incoming)
.filter_map(|idx| {
let edge_idx = self.graph.find_edge(idx, index)?;
let edge = self.graph.edge_weight(edge_idx)?;
if edge.kind == EdgeKind::Calls {
self.graph.node_weight(idx)
} else {
None
}
})
.collect()
}
pub fn get_callees(&self, index: NodeId) -> Vec<&CodeNode> {
self.graph
.neighbors_directed(index, petgraph::Direction::Outgoing)
.filter_map(|idx| {
let edge_idx = self.graph.find_edge(index, idx)?;
let edge = self.graph.edge_weight(edge_idx)?;
if edge.kind == EdgeKind::Calls {
self.graph.node_weight(idx)
} else {
None
}
})
.collect()
}
pub fn get_dependents(&self, index: NodeId, max_depth: usize) -> Vec<(NodeId, usize)> {
let mut result = Vec::new();
let mut visited = std::collections::HashSet::new();
let mut queue = vec![(index, 0usize)];
while let Some((current, depth)) = queue.pop() {
if depth > max_depth || visited.contains(¤t) {
continue;
}
visited.insert(current);
if current != index {
result.push((current, depth));
}
for neighbor in self
.graph
.neighbors_directed(current, petgraph::Direction::Incoming)
{
if !visited.contains(&neighbor) {
queue.push((neighbor, depth + 1));
}
}
}
result
}
pub fn remove_file(&mut self, file: &str) {
if let Some(indexes) = self.file_index.remove(file) {
for index in indexes {
if let Some(node) = self.graph.node_weight(index) {
let name = node.name.clone();
if let Some(name_list) = self.name_index.get_mut(&name) {
name_list.retain(|&idx| idx != index);
}
self.id_index.remove(&node.id);
self.search_index.remove(&name, index);
}
self.graph.remove_node(index);
}
}
}
pub fn centrality(&self, index: NodeId) -> f64 {
self.centrality.get(&index).copied().unwrap_or(0.0)
}
pub fn set_centrality(&mut self, scores: HashMap<NodeId, f64>) {
self.centrality = scores;
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
pub fn nodes(&self) -> impl Iterator<Item = &CodeNode> {
self.graph.node_weights()
}
pub fn edges(&self) -> impl Iterator<Item = &Edge> {
self.graph.edge_weights()
}
pub fn export_edges(&self) -> Vec<GraphEdge> {
(&self.graph)
.edge_references()
.filter_map(|edge_ref| {
let source = self.graph.node_weight(edge_ref.source())?.id.clone();
let target = self.graph.node_weight(edge_ref.target())?.id.clone();
let weight = edge_ref.weight(); Some(GraphEdge {
source,
target,
kind: weight.kind,
})
})
.collect()
}
pub fn node_indexes(&self) -> impl Iterator<Item = NodeId> + '_ {
self.graph.node_indices()
}
pub fn find_path(&self, from: NodeId, to: NodeId) -> Option<Vec<&CodeNode>> {
let path_indices = petgraph::algo::astar(
&self.graph,
from,
|finish| finish == to,
|_| 1, |_| 0, )?;
Some(
path_indices
.1
.into_iter()
.filter_map(|idx| self.graph.node_weight(idx))
.collect(),
)
}
pub fn get_index(&self, id: &str) -> Option<NodeId> {
self.id_index.get(id).copied()
}
pub fn list_entry_points(&self) -> Vec<&CodeNode> {
use crate::heuristics::HeuristicsMatcher;
self.graph
.node_weights()
.filter(|n| HeuristicsMatcher::is_likely_entry_point(n))
.collect()
}
pub fn nodes_in_file_with_edges(
&self,
file: &str,
) -> (Vec<&CodeNode>, Vec<(String, String, String)>) {
let node_ids: std::collections::HashSet<NodeId> = self
.file_index
.get(file)
.map(|ids| ids.iter().copied().collect())
.unwrap_or_default();
let nodes: Vec<&CodeNode> = node_ids
.iter()
.filter_map(|&id| self.graph.node_weight(id))
.collect();
let mut edges = Vec::new();
for &from in &node_ids {
for edge_ref in self
.graph
.edges_directed(from, petgraph::Direction::Outgoing)
{
let to = edge_ref.target();
if node_ids.contains(&to) {
if let (Some(from_node), Some(to_node)) =
(self.graph.node_weight(from), self.graph.node_weight(to))
{
edges.push((
from_node.name.clone(),
to_node.name.clone(),
format!("{:?}", edge_ref.weight().kind),
));
}
}
}
}
(nodes, edges)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GraphStats {
pub node_count: usize,
pub edge_count: usize,
pub files: usize,
}
impl ArborGraph {
pub fn stats(&self) -> GraphStats {
GraphStats {
node_count: self.node_count(),
edge_count: self.edge_count(),
files: self.file_index.len(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::edge::{Edge, EdgeKind};
use arbor_core::{CodeNode, NodeKind};
fn make_node(name: &str, file: &str) -> CodeNode {
CodeNode::new(name, name, NodeKind::Function, file)
}
#[test]
fn test_graph_new_is_empty() {
let g = ArborGraph::new();
assert_eq!(g.node_count(), 0);
assert_eq!(g.edge_count(), 0);
assert!(g.nodes().next().is_none());
}
#[test]
fn test_graph_add_and_get_node() {
let mut g = ArborGraph::new();
let node = make_node("foo", "main.rs");
let id = g.add_node(node.clone());
assert_eq!(g.node_count(), 1);
let got = g.get(id).unwrap();
assert_eq!(got.name, "foo");
}
#[test]
fn test_graph_find_by_name() {
let mut g = ArborGraph::new();
g.add_node(make_node("alpha", "a.rs"));
g.add_node(make_node("beta", "b.rs"));
let found = g.find_by_name("alpha");
assert_eq!(found.len(), 1);
assert_eq!(found[0].name, "alpha");
let not_found = g.find_by_name("gamma");
assert!(not_found.is_empty());
}
#[test]
fn test_graph_find_by_file() {
let mut g = ArborGraph::new();
g.add_node(make_node("foo", "main.rs"));
g.add_node(make_node("bar", "main.rs"));
g.add_node(make_node("baz", "other.rs"));
let main_nodes = g.find_by_file("main.rs");
assert_eq!(main_nodes.len(), 2);
let other_nodes = g.find_by_file("other.rs");
assert_eq!(other_nodes.len(), 1);
let empty = g.find_by_file("nonexistent.rs");
assert!(empty.is_empty());
}
#[test]
fn test_graph_search_substring() {
let mut g = ArborGraph::new();
g.add_node(make_node("validate_user", "a.rs"));
g.add_node(make_node("validate_email", "b.rs"));
g.add_node(make_node("send_email", "c.rs"));
let results = g.search("validate");
assert_eq!(results.len(), 2);
assert!(results.iter().any(|n| n.name == "validate_user"));
assert!(results.iter().any(|n| n.name == "validate_email"));
}
#[test]
fn test_graph_callers_callees() {
let mut g = ArborGraph::new();
let a = g.add_node(make_node("caller", "a.rs"));
let b = g.add_node(make_node("callee", "b.rs"));
g.add_edge(a, b, Edge::new(EdgeKind::Calls));
let callees = g.get_callees(a);
assert_eq!(callees.len(), 1);
assert_eq!(callees[0].name, "callee");
let callers = g.get_callers(b);
assert_eq!(callers.len(), 1);
assert_eq!(callers[0].name, "caller");
assert!(g.get_callers(a).is_empty());
assert!(g.get_callees(b).is_empty());
}
#[test]
fn test_graph_get_dependents() {
let mut g = ArborGraph::new();
let a = g.add_node(make_node("a", "a.rs"));
let b = g.add_node(make_node("b", "b.rs"));
let c = g.add_node(make_node("c", "c.rs"));
g.add_edge(a, b, Edge::new(EdgeKind::Calls));
g.add_edge(b, c, Edge::new(EdgeKind::Calls));
let deps = g.get_dependents(c, 2);
assert!(deps.iter().any(|(idx, _)| g.get(*idx).unwrap().name == "b"));
assert!(deps.iter().any(|(idx, _)| g.get(*idx).unwrap().name == "a"));
}
#[test]
fn test_graph_remove_file_cleanup() {
let mut g = ArborGraph::new();
g.add_node(make_node("foo", "remove_me.rs"));
g.add_node(make_node("bar", "remove_me.rs"));
g.add_node(make_node("keep", "keep.rs"));
assert_eq!(g.node_count(), 3);
g.remove_file("remove_me.rs");
assert!(g.find_by_name("foo").is_empty());
assert!(g.find_by_name("bar").is_empty());
assert_eq!(g.find_by_name("keep").len(), 1);
assert!(g.find_by_file("remove_me.rs").is_empty());
}
#[test]
fn test_graph_find_path() {
let mut g = ArborGraph::new();
let a = g.add_node(make_node("start", "a.rs"));
let b = g.add_node(make_node("middle", "b.rs"));
let c = g.add_node(make_node("end", "c.rs"));
g.add_edge(a, b, Edge::new(EdgeKind::Calls));
g.add_edge(b, c, Edge::new(EdgeKind::Calls));
let path = g.find_path(a, c).unwrap();
assert_eq!(path.len(), 3);
assert_eq!(path[0].name, "start");
assert_eq!(path[1].name, "middle");
assert_eq!(path[2].name, "end");
}
#[test]
fn test_graph_find_path_no_connection() {
let mut g = ArborGraph::new();
let a = g.add_node(make_node("island_a", "a.rs"));
let b = g.add_node(make_node("island_b", "b.rs"));
assert!(g.find_path(a, b).is_none());
}
#[test]
fn test_graph_export_edges() {
let mut g = ArborGraph::new();
let a = g.add_node(make_node("a", "a.rs"));
let b = g.add_node(make_node("b", "b.rs"));
g.add_edge(a, b, Edge::new(EdgeKind::Calls));
let exported = g.export_edges();
assert_eq!(exported.len(), 1);
assert_eq!(exported[0].kind, EdgeKind::Calls);
}
#[test]
fn test_graph_stats() {
let mut g = ArborGraph::new();
g.add_node(make_node("a", "x.rs"));
g.add_node(make_node("b", "y.rs"));
let stats = g.stats();
assert_eq!(stats.node_count, 2);
assert_eq!(stats.edge_count, 0);
assert_eq!(stats.files, 2);
}
#[test]
fn test_graph_get_index_and_get_by_id() {
let mut g = ArborGraph::new();
let node = make_node("lookup_me", "test.rs");
let node_id_str = node.id.clone();
let idx = g.add_node(node);
assert_eq!(g.get_index(&node_id_str), Some(idx));
assert!(g.get_by_id(&node_id_str).is_some());
assert!(g.get_index("nonexistent").is_none());
assert!(g.get_by_id("nonexistent").is_none());
}
#[test]
fn test_graph_centrality_default_zero() {
let mut g = ArborGraph::new();
let idx = g.add_node(make_node("a", "a.rs"));
assert_eq!(g.centrality(idx), 0.0);
}
#[test]
fn test_graph_set_centrality() {
let mut g = ArborGraph::new();
let idx = g.add_node(make_node("a", "a.rs"));
let mut scores = HashMap::new();
scores.insert(idx, 0.75);
g.set_centrality(scores);
assert!((g.centrality(idx) - 0.75).abs() < f64::EPSILON);
}
}
#[cfg(test)]
mod new_query_tests {
use super::*;
use crate::edge::{Edge, EdgeKind};
use arbor_core::{CodeNode, NodeKind};
fn make_node(name: &str, kind: NodeKind, file: &str) -> CodeNode {
CodeNode::new(name, format!("{}::{}", file, name), kind, file)
}
#[test]
fn test_list_entry_points_returns_main() {
let mut g = ArborGraph::new();
g.add_node(make_node("main", NodeKind::Function, "src/main.rs"));
g.add_node(make_node("helper", NodeKind::Function, "src/util.rs"));
let eps = g.list_entry_points();
assert!(
eps.iter().any(|n| n.name == "main"),
"main must be an entry point"
);
assert!(
!eps.iter().any(|n| n.name == "helper"),
"helper must not be an entry point"
);
}
#[test]
fn test_nodes_in_file_with_edges_returns_edges() {
let mut g = ArborGraph::new();
let a = g.add_node(make_node("foo", NodeKind::Function, "src/a.rs"));
let b = g.add_node(make_node("bar", NodeKind::Function, "src/a.rs"));
let _c = g.add_node(make_node("baz", NodeKind::Function, "src/b.rs"));
g.add_edge(
a,
b,
Edge {
kind: EdgeKind::Calls,
file: None,
line: None,
},
);
let (nodes, edges) = g.nodes_in_file_with_edges("src/a.rs");
assert_eq!(nodes.len(), 2);
assert_eq!(edges.len(), 1);
assert_eq!(edges[0].0, "foo");
assert_eq!(edges[0].1, "bar");
}
#[test]
fn test_nodes_in_file_with_edges_excludes_cross_file_edges() {
use crate::edge::{Edge, EdgeKind};
let mut g = ArborGraph::new();
let a = g.add_node(make_node("foo", NodeKind::Function, "src/a.rs"));
let c = g.add_node(make_node("baz", NodeKind::Function, "src/b.rs"));
g.add_edge(
a,
c,
Edge {
kind: EdgeKind::Calls,
file: None,
line: None,
},
);
let (nodes, edges) = g.nodes_in_file_with_edges("src/a.rs");
assert_eq!(nodes.len(), 1); assert_eq!(edges.len(), 0); }
}