1use crate::domain::edge::EdgeKind;
2use crate::domain::node::Node;
3use petgraph::graph::{DiGraph, NodeIndex};
4use std::collections::HashMap;
5
6pub type SymbolId = String;
8
9pub struct ContextGraph {
11 pub graph: DiGraph<Node, EdgeKind>,
13
14 pub symbol_to_node: HashMap<SymbolId, NodeIndex>,
16}
17
18impl Default for ContextGraph {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl ContextGraph {
25 pub fn new() -> Self {
26 Self {
27 graph: DiGraph::new(),
28 symbol_to_node: HashMap::new(),
29 }
30 }
31
32 pub fn add_node(&mut self, symbol: SymbolId, node: Node) -> NodeIndex {
33 let idx = self.graph.add_node(node);
34 self.symbol_to_node.insert(symbol, idx);
35 idx
36 }
37
38 pub fn add_edge(&mut self, source: NodeIndex, target: NodeIndex, kind: EdgeKind) {
39 self.graph.add_edge(source, target, kind);
40 }
41
42 pub fn get_node_by_symbol(&self, symbol: &str) -> Option<NodeIndex> {
43 self.symbol_to_node.get(symbol).copied()
44 }
45
46 pub fn node(&self, idx: NodeIndex) -> &Node {
47 &self.graph[idx]
48 }
49
50 pub fn neighbors(&self, idx: NodeIndex) -> impl Iterator<Item = (NodeIndex, &EdgeKind)> {
51 self.graph
52 .neighbors_directed(idx, petgraph::Direction::Outgoing)
53 .map(move |neighbor| {
54 let edge = self.graph.find_edge(idx, neighbor).unwrap();
55 (neighbor, self.graph.edge_weight(edge).unwrap())
56 })
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use crate::domain::edge::EdgeKind;
64 use crate::domain::node::{FunctionNode, Node, NodeCore, SourceSpan, Visibility};
65
66 fn test_node(id: u32, name: &str, context_size: u32) -> Node {
67 let span = SourceSpan {
68 start_line: 0,
69 start_column: 0,
70 end_line: 1,
71 end_column: 10,
72 };
73 let core = NodeCore::new(
74 id,
75 name.to_string(),
76 None,
77 context_size,
78 span,
79 0.5,
80 false,
81 "test.py".to_string(),
82 );
83 Node::Function(FunctionNode {
84 core,
85 param_count: 0,
86 typed_param_count: 0,
87 has_return_type: false,
88 is_async: false,
89 is_generator: false,
90 visibility: Visibility::Public,
91 })
92 }
93
94 #[test]
95 fn test_create_empty_graph() {
96 let graph = ContextGraph::new();
97 assert_eq!(graph.graph.node_count(), 0);
98 assert_eq!(graph.graph.edge_count(), 0);
99 assert!(graph.symbol_to_node.is_empty());
100 }
101
102 #[test]
103 fn test_add_node_returns_index() {
104 let mut graph = ContextGraph::new();
105 let idx = graph.add_node("sym::a".into(), test_node(0, "a", 10));
106 assert_eq!(graph.graph.node_count(), 1);
107 assert_eq!(graph.graph[idx].core().id, 0);
108 }
109
110 #[test]
111 fn test_add_edge_creates_connection() {
112 let mut graph = ContextGraph::new();
113 let idx_a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
114 let idx_b = graph.add_node("sym::b".into(), test_node(1, "b", 20));
115 graph.add_edge(idx_a, idx_b, EdgeKind::Call);
116 assert_eq!(graph.graph.edge_count(), 1);
117 let neighbors: Vec<_> = graph.neighbors(idx_a).collect();
118 assert_eq!(neighbors.len(), 1);
119 assert_eq!(neighbors[0].0, idx_b);
120 assert!(matches!(neighbors[0].1, EdgeKind::Call));
121 }
122
123 #[test]
124 fn test_get_node_by_symbol() {
125 let mut graph = ContextGraph::new();
126 let idx = graph.add_node("sym::foo".into(), test_node(0, "foo", 15));
127 assert_eq!(graph.get_node_by_symbol("sym::foo"), Some(idx));
128 assert_eq!(
129 graph
130 .node(graph.get_node_by_symbol("sym::foo").unwrap())
131 .core()
132 .name,
133 "foo"
134 );
135 }
136
137 #[test]
138 fn test_neighbors_iterator() {
139 let mut graph = ContextGraph::new();
140 let idx_a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
141 let idx_b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
142 let idx_c = graph.add_node("sym::c".into(), test_node(2, "c", 10));
143 graph.add_edge(idx_a, idx_b, EdgeKind::Call);
144 graph.add_edge(idx_a, idx_c, EdgeKind::Call);
145 let mut out: Vec<_> = graph
146 .neighbors(idx_a)
147 .map(|(i, k)| (i, k.clone()))
148 .collect();
149 out.sort_by_key(|(i, _)| i.index());
150 assert_eq!(out.len(), 2);
151 assert_eq!(out[0].0, idx_b);
152 assert_eq!(out[1].0, idx_c);
153 }
154
155 #[test]
156 fn test_nonexistent_symbol_returns_none() {
157 let graph = ContextGraph::new();
158 assert_eq!(graph.get_node_by_symbol("nonexistent"), None);
159 let mut g = ContextGraph::new();
160 g.add_node("sym::x".into(), test_node(0, "x", 1));
161 assert_eq!(g.get_node_by_symbol("sym::y"), None);
162 }
163
164 #[test]
165 fn test_duplicate_symbol_overwrites() {
166 let mut graph = ContextGraph::new();
167 let n1 = test_node(0, "first", 10);
168 let n2 = test_node(1, "second", 20);
169 let _i1 = graph.add_node("sym::dup".into(), n1);
170 let i2 = graph.add_node("sym::dup".into(), n2);
171 assert_eq!(graph.graph.node_count(), 2);
172 assert_eq!(graph.get_node_by_symbol("sym::dup"), Some(i2));
173 assert_eq!(graph.node(i2).core().context_size, 20);
174 }
175
176 #[test]
177 fn test_empty_neighbors() {
178 let mut graph = ContextGraph::new();
179 let idx = graph.add_node("sym::sink".into(), test_node(0, "sink", 5));
180 let count = graph.neighbors(idx).count();
181 assert_eq!(count, 0);
182 }
183
184 #[test]
185 fn test_node_content_preserved() {
186 let mut graph = ContextGraph::new();
187 let n = test_node(42, "preserved", 100);
188 let idx = graph.add_node("sym::p".into(), n);
189 let got = graph.node(idx);
190 assert_eq!(got.core().id, 42);
191 assert_eq!(got.core().name, "preserved");
192 assert_eq!(got.core().context_size, 100);
193 }
194
195 #[test]
196 fn test_multiple_edges_same_direction() {
197 let mut graph = ContextGraph::new();
198 let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
199 let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
200 graph.add_edge(a, b, EdgeKind::Call);
201 graph.add_edge(a, b, EdgeKind::ParamType); assert!(graph.graph.edge_count() >= 2);
203 }
204
205 #[test]
206 fn test_different_edge_kinds() {
207 let mut graph = ContextGraph::new();
208 let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
209 let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
210 graph.add_edge(a, b, EdgeKind::Read);
211 let neighbors: Vec<_> = graph.neighbors(a).collect();
212 assert_eq!(neighbors.len(), 1);
213 assert!(matches!(neighbors[0].1, EdgeKind::Read));
214 }
215
216 #[test]
217 fn test_symbol_to_node_consistency() {
218 let mut graph = ContextGraph::new();
219 let symbols = ["sym::x", "sym::y", "sym::z"];
220 let mut indices = Vec::new();
221 for (i, &s) in symbols.iter().enumerate() {
222 let idx = graph.add_node(s.into(), test_node(i as u32, s, 1));
223 indices.push((s, idx));
224 }
225 for (sym, idx) in indices {
226 assert_eq!(graph.get_node_by_symbol(sym), Some(idx));
227 assert_eq!(graph.symbol_to_node.get(sym).copied(), Some(idx));
228 }
229 }
230
231 #[test]
232 fn test_neighbors_only_outgoing() {
233 let mut graph = ContextGraph::new();
234 let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
235 let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
236 graph.add_edge(a, b, EdgeKind::Call);
237 assert_eq!(graph.neighbors(a).count(), 1);
238 assert_eq!(graph.neighbors(b).count(), 0);
239 }
240
241 #[test]
242 fn test_add_three_nodes_linear_chain() {
243 let mut graph = ContextGraph::new();
244 let i1 = graph.add_node("sym::1".into(), test_node(0, "n1", 10));
245 let i2 = graph.add_node("sym::2".into(), test_node(1, "n2", 20));
246 let i3 = graph.add_node("sym::3".into(), test_node(2, "n3", 30));
247 graph.add_edge(i1, i2, EdgeKind::Call);
248 graph.add_edge(i2, i3, EdgeKind::Call);
249 assert_eq!(graph.graph.node_count(), 3);
250 assert_eq!(graph.graph.edge_count(), 2);
251 assert_eq!(graph.neighbors(i1).count(), 1);
252 assert_eq!(graph.neighbors(i2).count(), 1);
253 assert_eq!(graph.neighbors(i3).count(), 0);
254 }
255}