use crate::domain::edge::EdgeKind;
use crate::domain::node::Node;
use petgraph::graph::{DiGraph, NodeIndex};
use std::collections::HashMap;
pub type SymbolId = String;
pub struct ContextGraph {
pub graph: DiGraph<Node, EdgeKind>,
pub symbol_to_node: HashMap<SymbolId, NodeIndex>,
}
impl Default for ContextGraph {
fn default() -> Self {
Self::new()
}
}
impl ContextGraph {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
symbol_to_node: HashMap::new(),
}
}
pub fn add_node(&mut self, symbol: SymbolId, node: Node) -> NodeIndex {
let idx = self.graph.add_node(node);
self.symbol_to_node.insert(symbol, idx);
idx
}
pub fn add_edge(&mut self, source: NodeIndex, target: NodeIndex, kind: EdgeKind) {
self.graph.add_edge(source, target, kind);
}
pub fn get_node_by_symbol(&self, symbol: &str) -> Option<NodeIndex> {
self.symbol_to_node.get(symbol).copied()
}
pub fn node(&self, idx: NodeIndex) -> &Node {
&self.graph[idx]
}
pub fn neighbors(&self, idx: NodeIndex) -> impl Iterator<Item = (NodeIndex, &EdgeKind)> {
self.graph
.neighbors_directed(idx, petgraph::Direction::Outgoing)
.map(move |neighbor| {
let edge = self.graph.find_edge(idx, neighbor).unwrap();
(neighbor, self.graph.edge_weight(edge).unwrap())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::edge::EdgeKind;
use crate::domain::node::{FunctionNode, Node, NodeCore, SourceSpan, Visibility};
fn test_node(id: u32, name: &str, context_size: u32) -> Node {
let span = SourceSpan {
start_line: 0,
start_column: 0,
end_line: 1,
end_column: 10,
};
let core = NodeCore::new(
id,
name.to_string(),
None,
context_size,
span,
0.5,
false,
"test.py".to_string(),
);
Node::Function(FunctionNode {
core,
param_count: 0,
typed_param_count: 0,
has_return_type: false,
is_async: false,
is_generator: false,
visibility: Visibility::Public,
})
}
#[test]
fn test_create_empty_graph() {
let graph = ContextGraph::new();
assert_eq!(graph.graph.node_count(), 0);
assert_eq!(graph.graph.edge_count(), 0);
assert!(graph.symbol_to_node.is_empty());
}
#[test]
fn test_add_node_returns_index() {
let mut graph = ContextGraph::new();
let idx = graph.add_node("sym::a".into(), test_node(0, "a", 10));
assert_eq!(graph.graph.node_count(), 1);
assert_eq!(graph.graph[idx].core().id, 0);
}
#[test]
fn test_add_edge_creates_connection() {
let mut graph = ContextGraph::new();
let idx_a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
let idx_b = graph.add_node("sym::b".into(), test_node(1, "b", 20));
graph.add_edge(idx_a, idx_b, EdgeKind::Call);
assert_eq!(graph.graph.edge_count(), 1);
let neighbors: Vec<_> = graph.neighbors(idx_a).collect();
assert_eq!(neighbors.len(), 1);
assert_eq!(neighbors[0].0, idx_b);
assert!(matches!(neighbors[0].1, EdgeKind::Call));
}
#[test]
fn test_get_node_by_symbol() {
let mut graph = ContextGraph::new();
let idx = graph.add_node("sym::foo".into(), test_node(0, "foo", 15));
assert_eq!(graph.get_node_by_symbol("sym::foo"), Some(idx));
assert_eq!(
graph
.node(graph.get_node_by_symbol("sym::foo").unwrap())
.core()
.name,
"foo"
);
}
#[test]
fn test_neighbors_iterator() {
let mut graph = ContextGraph::new();
let idx_a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
let idx_b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
let idx_c = graph.add_node("sym::c".into(), test_node(2, "c", 10));
graph.add_edge(idx_a, idx_b, EdgeKind::Call);
graph.add_edge(idx_a, idx_c, EdgeKind::Call);
let mut out: Vec<_> = graph
.neighbors(idx_a)
.map(|(i, k)| (i, k.clone()))
.collect();
out.sort_by_key(|(i, _)| i.index());
assert_eq!(out.len(), 2);
assert_eq!(out[0].0, idx_b);
assert_eq!(out[1].0, idx_c);
}
#[test]
fn test_nonexistent_symbol_returns_none() {
let graph = ContextGraph::new();
assert_eq!(graph.get_node_by_symbol("nonexistent"), None);
let mut g = ContextGraph::new();
g.add_node("sym::x".into(), test_node(0, "x", 1));
assert_eq!(g.get_node_by_symbol("sym::y"), None);
}
#[test]
fn test_duplicate_symbol_overwrites() {
let mut graph = ContextGraph::new();
let n1 = test_node(0, "first", 10);
let n2 = test_node(1, "second", 20);
let _i1 = graph.add_node("sym::dup".into(), n1);
let i2 = graph.add_node("sym::dup".into(), n2);
assert_eq!(graph.graph.node_count(), 2);
assert_eq!(graph.get_node_by_symbol("sym::dup"), Some(i2));
assert_eq!(graph.node(i2).core().context_size, 20);
}
#[test]
fn test_empty_neighbors() {
let mut graph = ContextGraph::new();
let idx = graph.add_node("sym::sink".into(), test_node(0, "sink", 5));
let count = graph.neighbors(idx).count();
assert_eq!(count, 0);
}
#[test]
fn test_node_content_preserved() {
let mut graph = ContextGraph::new();
let n = test_node(42, "preserved", 100);
let idx = graph.add_node("sym::p".into(), n);
let got = graph.node(idx);
assert_eq!(got.core().id, 42);
assert_eq!(got.core().name, "preserved");
assert_eq!(got.core().context_size, 100);
}
#[test]
fn test_multiple_edges_same_direction() {
let mut graph = ContextGraph::new();
let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
graph.add_edge(a, b, EdgeKind::Call);
graph.add_edge(a, b, EdgeKind::ParamType); assert!(graph.graph.edge_count() >= 2);
}
#[test]
fn test_different_edge_kinds() {
let mut graph = ContextGraph::new();
let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
graph.add_edge(a, b, EdgeKind::Read);
let neighbors: Vec<_> = graph.neighbors(a).collect();
assert_eq!(neighbors.len(), 1);
assert!(matches!(neighbors[0].1, EdgeKind::Read));
}
#[test]
fn test_symbol_to_node_consistency() {
let mut graph = ContextGraph::new();
let symbols = ["sym::x", "sym::y", "sym::z"];
let mut indices = Vec::new();
for (i, &s) in symbols.iter().enumerate() {
let idx = graph.add_node(s.into(), test_node(i as u32, s, 1));
indices.push((s, idx));
}
for (sym, idx) in indices {
assert_eq!(graph.get_node_by_symbol(sym), Some(idx));
assert_eq!(graph.symbol_to_node.get(sym).copied(), Some(idx));
}
}
#[test]
fn test_neighbors_only_outgoing() {
let mut graph = ContextGraph::new();
let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
graph.add_edge(a, b, EdgeKind::Call);
assert_eq!(graph.neighbors(a).count(), 1);
assert_eq!(graph.neighbors(b).count(), 0);
}
#[test]
fn test_add_three_nodes_linear_chain() {
let mut graph = ContextGraph::new();
let i1 = graph.add_node("sym::1".into(), test_node(0, "n1", 10));
let i2 = graph.add_node("sym::2".into(), test_node(1, "n2", 20));
let i3 = graph.add_node("sym::3".into(), test_node(2, "n3", 30));
graph.add_edge(i1, i2, EdgeKind::Call);
graph.add_edge(i2, i3, EdgeKind::Call);
assert_eq!(graph.graph.node_count(), 3);
assert_eq!(graph.graph.edge_count(), 2);
assert_eq!(graph.neighbors(i1).count(), 1);
assert_eq!(graph.neighbors(i2).count(), 1);
assert_eq!(graph.neighbors(i3).count(), 0);
}
}