use std::collections::{HashMap, HashSet, VecDeque};
use crate::core::{CallEdge, CodeNode, EdgeConfidence, Language, NodeId};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::Direction;
#[derive(Debug)]
pub struct CodeGraph {
graph: DiGraph<CodeNode, CallEdge>,
id_to_index: HashMap<NodeId, NodeIndex>,
name_to_ids: HashMap<String, Vec<NodeId>>,
file_to_node_ids: HashMap<String, Vec<NodeId>>,
entry_points: HashSet<NodeIndex>,
test_entry_points: HashSet<NodeIndex>,
language: Option<Language>,
}
impl CodeGraph {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
id_to_index: HashMap::new(),
name_to_ids: HashMap::new(),
file_to_node_ids: HashMap::new(),
entry_points: HashSet::new(),
test_entry_points: HashSet::new(),
language: None,
}
}
pub fn with_language(mut self, language: Language) -> Self {
self.language = Some(language);
self
}
pub fn add_node(&mut self, node: CodeNode) -> NodeIndex {
let id = node.id;
let name = node.name.clone();
let full_name = node.full_name.clone();
let file = node.location.file.clone();
let idx = self.graph.add_node(node);
self.id_to_index.insert(id, idx);
self.name_to_ids.entry(name).or_default().push(id);
if full_name != id.to_string() {
self.name_to_ids.entry(full_name).or_default().push(id);
}
self.file_to_node_ids.entry(file).or_default().push(id);
idx
}
pub fn add_edge(&mut self, edge: CallEdge) -> Result<(), String> {
let from_idx = self
.id_to_index
.get(&edge.from)
.ok_or_else(|| format!("Source node {:?} not found", edge.from))?;
let to_idx = self
.id_to_index
.get(&edge.to)
.ok_or_else(|| format!("Target node {:?} not found", edge.to))?;
self.graph.add_edge(*from_idx, *to_idx, edge);
Ok(())
}
pub fn add_edge_by_index(&mut self, from: NodeIndex, to: NodeIndex) {
let from_id = self.graph[from].id;
let to_id = self.graph[to].id;
let edge = CallEdge::certain(from_id, to_id);
self.graph.add_edge(from, to, edge);
}
pub fn get_node(&self, idx: NodeIndex) -> Option<&CodeNode> {
self.graph.node_weight(idx)
}
pub fn get_node_mut(&mut self, idx: NodeIndex) -> Option<&mut CodeNode> {
self.graph.node_weight_mut(idx)
}
pub fn get_index(&self, id: NodeId) -> Option<NodeIndex> {
self.id_to_index.get(&id).copied()
}
pub fn find_nodes_by_name(&self, name: &str) -> Vec<NodeIndex> {
self.name_to_ids
.get(name)
.map(|ids| {
ids.iter()
.filter_map(|id| self.id_to_index.get(id).copied())
.collect()
})
.unwrap_or_default()
}
pub fn find_node_by_name(&self, name: &str) -> Option<NodeIndex> {
self.find_nodes_by_name(name).into_iter().next()
}
pub fn find_node_by_name_in_file(&self, name: &str, file: &str) -> Option<NodeIndex> {
let file_nodes = self.file_to_node_ids.get(file)?;
let name_nodes = self.name_to_ids.get(name)?;
for id in name_nodes {
if file_nodes.contains(id) {
return self.id_to_index.get(id).copied();
}
}
None
}
pub fn find_node_by_name_in_files(&self, name: &str, files: &[String]) -> Option<NodeIndex> {
for file in files {
if let Some(idx) = self.find_node_by_name_in_file(name, file) {
return Some(idx);
}
}
None
}
pub fn nodes(&self) -> impl Iterator<Item = (NodeIndex, &CodeNode)> {
self.graph
.node_indices()
.map(move |idx| (idx, &self.graph[idx]))
}
pub fn edges(&self) -> impl Iterator<Item = &CallEdge> {
self.graph.edge_weights()
}
pub fn edges_with_endpoints(&self) -> impl Iterator<Item = (NodeIndex, NodeIndex, &CallEdge)> {
self.graph.edge_indices().map(move |eidx| {
let (src, tgt) = self.graph.edge_endpoints(eidx).unwrap();
(src, tgt, &self.graph[eidx])
})
}
pub fn calls_from(&self, idx: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.graph.neighbors_directed(idx, Direction::Outgoing)
}
pub fn callers_of(&self, idx: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.graph.neighbors_directed(idx, Direction::Incoming)
}
pub fn calls_from_compatible(&self, idx: NodeIndex) -> Vec<NodeIndex> {
let lang = self.graph[idx].language;
self.graph
.neighbors_directed(idx, Direction::Outgoing)
.filter(|&n| self.graph[n].language.is_compatible_with(lang))
.collect()
}
pub fn callers_of_compatible(&self, idx: NodeIndex) -> Vec<NodeIndex> {
let lang = self.graph[idx].language;
self.graph
.neighbors_directed(idx, Direction::Incoming)
.filter(|&n| self.graph[n].language.is_compatible_with(lang))
.collect()
}
pub fn find_nodes_by_name_and_language(
&self,
name: &str,
language: Language,
) -> Vec<NodeIndex> {
self.find_nodes_by_name(name)
.into_iter()
.filter(|&idx| self.graph[idx].language.is_compatible_with(language))
.collect()
}
pub fn add_entry_point(&mut self, idx: NodeIndex) {
self.entry_points.insert(idx);
}
pub fn add_test_entry_point(&mut self, idx: NodeIndex) {
self.test_entry_points.insert(idx);
}
pub fn entry_points(&self) -> &HashSet<NodeIndex> {
&self.entry_points
}
pub fn test_entry_points(&self) -> &HashSet<NodeIndex> {
&self.test_entry_points
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
pub fn language(&self) -> Option<Language> {
self.language
}
pub fn inner(&self) -> &DiGraph<CodeNode, CallEdge> {
&self.graph
}
pub fn compute_reachable(&self, entry_points: &HashSet<NodeIndex>) -> HashSet<NodeIndex> {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
for &ep in entry_points {
if visited.insert(ep) {
queue.push_back(ep);
}
}
while let Some(current) = queue.pop_front() {
for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
if visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
visited
}
pub fn compute_production_reachable(&self) -> HashSet<NodeIndex> {
self.compute_reachable(&self.entry_points)
}
pub fn compute_test_reachable(&self) -> HashSet<NodeIndex> {
self.compute_reachable(&self.test_entry_points)
}
pub fn find_unreachable(&self) -> Vec<NodeIndex> {
let all_entries: HashSet<NodeIndex> = self
.entry_points
.union(&self.test_entry_points)
.copied()
.collect();
let reachable = self.compute_reachable(&all_entries);
self.graph
.node_indices()
.filter(|idx| !reachable.contains(idx))
.collect()
}
pub fn is_reachable_from(&self, source: NodeIndex, target: NodeIndex) -> bool {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
visited.insert(source);
queue.push_back(source);
while let Some(current) = queue.pop_front() {
if current == target {
return true;
}
for neighbor in self.graph.neighbors_directed(current, Direction::Outgoing) {
if visited.insert(neighbor) {
queue.push_back(neighbor);
}
}
}
false
}
pub fn merge(&mut self, other: &CodeGraph) {
self.graph.reserve_nodes(other.node_count());
self.graph.reserve_edges(other.edge_count());
let mut index_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
for (other_idx, node) in other.nodes() {
if self.id_to_index.contains_key(&node.id) {
index_map.insert(other_idx, self.id_to_index[&node.id]);
} else {
let new_idx = self.add_node(node.clone());
index_map.insert(other_idx, new_idx);
}
}
for edge in other.edges() {
let from_idx = self.id_to_index.get(&edge.from);
let to_idx = self.id_to_index.get(&edge.to);
if let (Some(&from), Some(&to)) = (from_idx, to_idx) {
self.graph.add_edge(from, to, edge.clone());
}
}
for &ep in &other.entry_points {
if let Some(&new_idx) = index_map.get(&ep) {
self.entry_points.insert(new_idx);
}
}
for &ep in &other.test_entry_points {
if let Some(&new_idx) = index_map.get(&ep) {
self.test_entry_points.insert(new_idx);
}
}
}
pub fn in_degree(&self, idx: NodeIndex) -> usize {
self.graph
.neighbors_directed(idx, Direction::Incoming)
.count()
}
pub fn out_degree(&self, idx: NodeIndex) -> usize {
self.graph
.neighbors_directed(idx, Direction::Outgoing)
.count()
}
pub fn strongly_connected_components(&self) -> Vec<Vec<NodeIndex>> {
petgraph::algo::tarjan_scc(&self.graph)
}
pub fn leaf_nodes(&self) -> Vec<NodeIndex> {
self.graph
.node_indices()
.filter(|&idx| self.out_degree(idx) == 0)
.collect()
}
pub fn root_nodes(&self) -> Vec<NodeIndex> {
self.graph
.node_indices()
.filter(|&idx| self.in_degree(idx) == 0)
.collect()
}
pub fn edge_confidence_distribution(&self) -> HashMap<EdgeConfidence, usize> {
let mut dist = HashMap::new();
for edge in self.edges() {
*dist.entry(edge.confidence).or_insert(0) += 1;
}
dist
}
}
impl Default for CodeGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{NodeKind, SourceLocation, Visibility};
fn make_node(name: &str, kind: NodeKind) -> CodeNode {
CodeNode::new(
name.to_string(),
kind,
SourceLocation::new("test.rs".to_string(), 1, 10, 0, 0),
Language::Rust,
Visibility::Public,
)
}
#[test]
fn test_add_nodes_and_edges() {
let mut graph = CodeGraph::new();
let main_node = make_node("main", NodeKind::Function);
let main_id = main_node.id;
let helper_node = make_node("helper", NodeKind::Function);
let helper_id = helper_node.id;
let main_idx = graph.add_node(main_node);
let helper_idx = graph.add_node(helper_node);
let edge = CallEdge::certain(main_id, helper_id);
graph.add_edge(edge).unwrap();
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 1);
let callees: Vec<_> = graph.calls_from(main_idx).collect();
assert_eq!(callees, vec![helper_idx]);
let callers: Vec<_> = graph.callers_of(helper_idx).collect();
assert_eq!(callers, vec![main_idx]);
}
#[test]
fn test_find_by_name() {
let mut graph = CodeGraph::new();
graph.add_node(make_node("foo", NodeKind::Function));
graph.add_node(make_node("bar", NodeKind::Function));
assert!(graph.find_node_by_name("foo").is_some());
assert!(graph.find_node_by_name("baz").is_none());
}
#[test]
fn test_reachability() {
let mut graph = CodeGraph::new();
let a = make_node("a", NodeKind::Function);
let a_id = a.id;
let b = make_node("b", NodeKind::Function);
let b_id = b.id;
let c = make_node("c", NodeKind::Function);
let d = make_node("d", NodeKind::Function);
let a_idx = graph.add_node(a);
let b_idx = graph.add_node(b);
let c_idx = graph.add_node(c);
let d_idx = graph.add_node(d);
graph.add_edge(CallEdge::certain(a_id, b_id)).unwrap();
let b_node_id = graph.get_node(b_idx).unwrap().id;
let c_node_id = graph.get_node(c_idx).unwrap().id;
graph
.add_edge(CallEdge::certain(b_node_id, c_node_id))
.unwrap();
graph.add_entry_point(a_idx);
let reachable = graph.compute_production_reachable();
assert!(reachable.contains(&a_idx));
assert!(reachable.contains(&b_idx));
assert!(reachable.contains(&c_idx));
assert!(!reachable.contains(&d_idx));
let unreachable = graph.find_unreachable();
assert_eq!(unreachable, vec![d_idx]);
}
#[test]
fn test_merge() {
let mut g1 = CodeGraph::new();
let a = make_node("a", NodeKind::Function);
let a_idx = g1.add_node(a);
g1.add_entry_point(a_idx);
let mut g2 = CodeGraph::new();
let b = make_node("b", NodeKind::Function);
g2.add_node(b);
g1.merge(&g2);
assert_eq!(g1.node_count(), 2);
assert!(g1.find_node_by_name("b").is_some());
}
#[test]
fn test_scc() {
let mut graph = CodeGraph::new();
let a = make_node("a", NodeKind::Function);
let a_id = a.id;
let b = make_node("b", NodeKind::Function);
let b_id = b.id;
graph.add_node(a);
graph.add_node(b);
graph.add_edge(CallEdge::certain(a_id, b_id)).unwrap();
graph.add_edge(CallEdge::certain(b_id, a_id)).unwrap();
let sccs = graph.strongly_connected_components();
let cycle = sccs.iter().find(|scc| scc.len() == 2);
assert!(cycle.is_some());
}
fn make_node_in_file(name: &str, kind: NodeKind, file: &str) -> CodeNode {
CodeNode::new(
name.to_string(),
kind,
SourceLocation::new(file.to_string(), 1, 10, 0, 0),
Language::Rust,
Visibility::Public,
)
}
#[test]
fn test_find_node_by_name_in_file() {
let mut graph = CodeGraph::new();
let helper_a = make_node_in_file("helper", NodeKind::Function, "src/a.rs");
let helper_b = make_node_in_file("helper", NodeKind::Function, "src/b.rs");
graph.add_node(helper_a);
graph.add_node(helper_b);
assert!(graph.find_node_by_name("helper").is_some());
let a_idx = graph.find_node_by_name_in_file("helper", "src/a.rs");
assert!(a_idx.is_some(), "Should find helper in src/a.rs");
let a_node = graph.get_node(a_idx.unwrap()).unwrap();
assert_eq!(a_node.location.file, "src/a.rs");
let b_idx = graph.find_node_by_name_in_file("helper", "src/b.rs");
assert!(b_idx.is_some(), "Should find helper in src/b.rs");
let b_node = graph.get_node(b_idx.unwrap()).unwrap();
assert_eq!(b_node.location.file, "src/b.rs");
assert!(graph
.find_node_by_name_in_file("helper", "src/c.rs")
.is_none());
}
#[test]
fn test_find_node_by_name_in_files() {
let mut graph = CodeGraph::new();
let helper_a = make_node_in_file("helper", NodeKind::Function, "src/a.rs");
let helper_b = make_node_in_file("helper", NodeKind::Function, "src/b.rs");
let other = make_node_in_file("other", NodeKind::Function, "src/c.rs");
graph.add_node(helper_a);
graph.add_node(helper_b);
graph.add_node(other);
let result = graph.find_node_by_name_in_files(
"helper",
&["src/b.rs".to_string(), "src/c.rs".to_string()],
);
assert!(result.is_some(), "Should find helper in src/b.rs");
let node = graph.get_node(result.unwrap()).unwrap();
assert_eq!(node.location.file, "src/b.rs");
let no_result = graph.find_node_by_name_in_files("helper", &["src/c.rs".to_string()]);
assert!(no_result.is_none(), "helper is not in src/c.rs");
}
#[test]
fn test_file_index_populated_on_add() {
let mut graph = CodeGraph::new();
let node = make_node_in_file("foo", NodeKind::Function, "src/main.rs");
graph.add_node(node);
assert!(graph
.find_node_by_name_in_file("foo", "src/main.rs")
.is_some());
assert!(graph
.find_node_by_name_in_file("bar", "src/main.rs")
.is_none());
}
}