Skip to main content

context_footprint/domain/
graph.rs

1use crate::domain::edge::EdgeKind;
2use crate::domain::node::Node;
3use petgraph::graph::{DiGraph, NodeIndex};
4use std::collections::HashMap;
5
6/// Symbol identifier (SCIP symbol string)
7pub type SymbolId = String;
8
9/// Context Graph - the core data structure
10pub struct ContextGraph {
11    /// The directed graph of nodes and edges
12    pub graph: DiGraph<Node, EdgeKind>,
13
14    /// Mapping from symbol to node index
15    pub symbol_to_node: HashMap<SymbolId, NodeIndex>,
16}
17
18impl Default for ContextGraph {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl ContextGraph {
25    pub fn new() -> Self {
26        Self {
27            graph: DiGraph::new(),
28            symbol_to_node: HashMap::new(),
29        }
30    }
31
32    pub fn add_node(&mut self, symbol: SymbolId, node: Node) -> NodeIndex {
33        let idx = self.graph.add_node(node);
34        self.symbol_to_node.insert(symbol, idx);
35        idx
36    }
37
38    pub fn add_edge(&mut self, source: NodeIndex, target: NodeIndex, kind: EdgeKind) {
39        self.graph.add_edge(source, target, kind);
40    }
41
42    pub fn get_node_by_symbol(&self, symbol: &str) -> Option<NodeIndex> {
43        self.symbol_to_node.get(symbol).copied()
44    }
45
46    pub fn node(&self, idx: NodeIndex) -> &Node {
47        &self.graph[idx]
48    }
49
50    pub fn neighbors(&self, idx: NodeIndex) -> impl Iterator<Item = (NodeIndex, &EdgeKind)> {
51        self.graph
52            .neighbors_directed(idx, petgraph::Direction::Outgoing)
53            .map(move |neighbor| {
54                let edge = self.graph.find_edge(idx, neighbor).unwrap();
55                (neighbor, self.graph.edge_weight(edge).unwrap())
56            })
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use crate::domain::edge::EdgeKind;
64    use crate::domain::node::{FunctionNode, Node, NodeCore, SourceSpan, Visibility};
65
66    fn test_node(id: u32, name: &str, context_size: u32) -> Node {
67        let span = SourceSpan {
68            start_line: 0,
69            start_column: 0,
70            end_line: 1,
71            end_column: 10,
72        };
73        let core = NodeCore::new(
74            id,
75            name.to_string(),
76            None,
77            context_size,
78            span,
79            0.5,
80            false,
81            "test.py".to_string(),
82        );
83        Node::Function(FunctionNode {
84            core,
85            param_count: 0,
86            typed_param_count: 0,
87            has_return_type: false,
88            is_async: false,
89            is_generator: false,
90            visibility: Visibility::Public,
91        })
92    }
93
94    #[test]
95    fn test_create_empty_graph() {
96        let graph = ContextGraph::new();
97        assert_eq!(graph.graph.node_count(), 0);
98        assert_eq!(graph.graph.edge_count(), 0);
99        assert!(graph.symbol_to_node.is_empty());
100    }
101
102    #[test]
103    fn test_add_node_returns_index() {
104        let mut graph = ContextGraph::new();
105        let idx = graph.add_node("sym::a".into(), test_node(0, "a", 10));
106        assert_eq!(graph.graph.node_count(), 1);
107        assert_eq!(graph.graph[idx].core().id, 0);
108    }
109
110    #[test]
111    fn test_add_edge_creates_connection() {
112        let mut graph = ContextGraph::new();
113        let idx_a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
114        let idx_b = graph.add_node("sym::b".into(), test_node(1, "b", 20));
115        graph.add_edge(idx_a, idx_b, EdgeKind::Call);
116        assert_eq!(graph.graph.edge_count(), 1);
117        let neighbors: Vec<_> = graph.neighbors(idx_a).collect();
118        assert_eq!(neighbors.len(), 1);
119        assert_eq!(neighbors[0].0, idx_b);
120        assert!(matches!(neighbors[0].1, EdgeKind::Call));
121    }
122
123    #[test]
124    fn test_get_node_by_symbol() {
125        let mut graph = ContextGraph::new();
126        let idx = graph.add_node("sym::foo".into(), test_node(0, "foo", 15));
127        assert_eq!(graph.get_node_by_symbol("sym::foo"), Some(idx));
128        assert_eq!(
129            graph
130                .node(graph.get_node_by_symbol("sym::foo").unwrap())
131                .core()
132                .name,
133            "foo"
134        );
135    }
136
137    #[test]
138    fn test_neighbors_iterator() {
139        let mut graph = ContextGraph::new();
140        let idx_a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
141        let idx_b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
142        let idx_c = graph.add_node("sym::c".into(), test_node(2, "c", 10));
143        graph.add_edge(idx_a, idx_b, EdgeKind::Call);
144        graph.add_edge(idx_a, idx_c, EdgeKind::Call);
145        let mut out: Vec<_> = graph
146            .neighbors(idx_a)
147            .map(|(i, k)| (i, k.clone()))
148            .collect();
149        out.sort_by_key(|(i, _)| i.index());
150        assert_eq!(out.len(), 2);
151        assert_eq!(out[0].0, idx_b);
152        assert_eq!(out[1].0, idx_c);
153    }
154
155    #[test]
156    fn test_nonexistent_symbol_returns_none() {
157        let graph = ContextGraph::new();
158        assert_eq!(graph.get_node_by_symbol("nonexistent"), None);
159        let mut g = ContextGraph::new();
160        g.add_node("sym::x".into(), test_node(0, "x", 1));
161        assert_eq!(g.get_node_by_symbol("sym::y"), None);
162    }
163
164    #[test]
165    fn test_duplicate_symbol_overwrites() {
166        let mut graph = ContextGraph::new();
167        let n1 = test_node(0, "first", 10);
168        let n2 = test_node(1, "second", 20);
169        let _i1 = graph.add_node("sym::dup".into(), n1);
170        let i2 = graph.add_node("sym::dup".into(), n2);
171        assert_eq!(graph.graph.node_count(), 2);
172        assert_eq!(graph.get_node_by_symbol("sym::dup"), Some(i2));
173        assert_eq!(graph.node(i2).core().context_size, 20);
174    }
175
176    #[test]
177    fn test_empty_neighbors() {
178        let mut graph = ContextGraph::new();
179        let idx = graph.add_node("sym::sink".into(), test_node(0, "sink", 5));
180        let count = graph.neighbors(idx).count();
181        assert_eq!(count, 0);
182    }
183
184    #[test]
185    fn test_node_content_preserved() {
186        let mut graph = ContextGraph::new();
187        let n = test_node(42, "preserved", 100);
188        let idx = graph.add_node("sym::p".into(), n);
189        let got = graph.node(idx);
190        assert_eq!(got.core().id, 42);
191        assert_eq!(got.core().name, "preserved");
192        assert_eq!(got.core().context_size, 100);
193    }
194
195    #[test]
196    fn test_multiple_edges_same_direction() {
197        let mut graph = ContextGraph::new();
198        let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
199        let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
200        graph.add_edge(a, b, EdgeKind::Call);
201        graph.add_edge(a, b, EdgeKind::ParamType); // petgraph allows multi-edges
202        assert!(graph.graph.edge_count() >= 2);
203    }
204
205    #[test]
206    fn test_different_edge_kinds() {
207        let mut graph = ContextGraph::new();
208        let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
209        let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
210        graph.add_edge(a, b, EdgeKind::Read);
211        let neighbors: Vec<_> = graph.neighbors(a).collect();
212        assert_eq!(neighbors.len(), 1);
213        assert!(matches!(neighbors[0].1, EdgeKind::Read));
214    }
215
216    #[test]
217    fn test_symbol_to_node_consistency() {
218        let mut graph = ContextGraph::new();
219        let symbols = ["sym::x", "sym::y", "sym::z"];
220        let mut indices = Vec::new();
221        for (i, &s) in symbols.iter().enumerate() {
222            let idx = graph.add_node(s.into(), test_node(i as u32, s, 1));
223            indices.push((s, idx));
224        }
225        for (sym, idx) in indices {
226            assert_eq!(graph.get_node_by_symbol(sym), Some(idx));
227            assert_eq!(graph.symbol_to_node.get(sym).copied(), Some(idx));
228        }
229    }
230
231    #[test]
232    fn test_neighbors_only_outgoing() {
233        let mut graph = ContextGraph::new();
234        let a = graph.add_node("sym::a".into(), test_node(0, "a", 10));
235        let b = graph.add_node("sym::b".into(), test_node(1, "b", 10));
236        graph.add_edge(a, b, EdgeKind::Call);
237        assert_eq!(graph.neighbors(a).count(), 1);
238        assert_eq!(graph.neighbors(b).count(), 0);
239    }
240
241    #[test]
242    fn test_add_three_nodes_linear_chain() {
243        let mut graph = ContextGraph::new();
244        let i1 = graph.add_node("sym::1".into(), test_node(0, "n1", 10));
245        let i2 = graph.add_node("sym::2".into(), test_node(1, "n2", 20));
246        let i3 = graph.add_node("sym::3".into(), test_node(2, "n3", 30));
247        graph.add_edge(i1, i2, EdgeKind::Call);
248        graph.add_edge(i2, i3, EdgeKind::Call);
249        assert_eq!(graph.graph.node_count(), 3);
250        assert_eq!(graph.graph.edge_count(), 2);
251        assert_eq!(graph.neighbors(i1).count(), 1);
252        assert_eq!(graph.neighbors(i2).count(), 1);
253        assert_eq!(graph.neighbors(i3).count(), 0);
254    }
255}