Skip to main content

graphify_core/
graph.rs

1use std::collections::HashMap;
2use std::io::Write;
3
4use petgraph::Undirected;
5use petgraph::stable_graph::{NodeIndex, StableGraph};
6use serde_json::{Value, json};
7use tracing::warn;
8
9use crate::error::{GraphifyError, Result};
10use crate::model::{CommunityInfo, GraphEdge, GraphNode, Hyperedge};
11
12// ---------------------------------------------------------------------------
13// KnowledgeGraph
14// ---------------------------------------------------------------------------
15
16/// A knowledge graph backed by `petgraph::StableGraph`.
17///
18/// Provides ID-based node lookup and serialization to/from the
19/// NetworkX `node_link_data` JSON format for Python interoperability.
20#[derive(Debug)]
21pub struct KnowledgeGraph {
22    graph: StableGraph<GraphNode, GraphEdge, Undirected>,
23    index_map: HashMap<String, NodeIndex>,
24    pub communities: Vec<CommunityInfo>,
25    pub hyperedges: Vec<Hyperedge>,
26}
27
28impl Default for KnowledgeGraph {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl KnowledgeGraph {
35    pub fn new() -> Self {
36        Self {
37            graph: StableGraph::default(),
38            index_map: HashMap::new(),
39            communities: Vec::new(),
40            hyperedges: Vec::new(),
41        }
42    }
43
44    // -- Mutation --------------------------------------------------------
45
46    /// Add a node. Returns an error if a node with the same `id` already exists.
47    pub fn add_node(&mut self, node: GraphNode) -> Result<NodeIndex> {
48        if self.index_map.contains_key(&node.id) {
49            return Err(GraphifyError::DuplicateNode(node.id.clone()));
50        }
51        let id = node.id.clone();
52        let idx = self.graph.add_node(node);
53        self.index_map.insert(id, idx);
54        Ok(idx)
55    }
56
57    /// Add an edge between two nodes identified by their string IDs.
58    pub fn add_edge(&mut self, edge: GraphEdge) -> Result<()> {
59        let &src = self
60            .index_map
61            .get(&edge.source)
62            .ok_or_else(|| GraphifyError::NodeNotFound(edge.source.clone()))?;
63        let &tgt = self
64            .index_map
65            .get(&edge.target)
66            .ok_or_else(|| GraphifyError::NodeNotFound(edge.target.clone()))?;
67        self.graph.add_edge(src, tgt, edge);
68        Ok(())
69    }
70
71    // -- Query -----------------------------------------------------------
72
73    pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
74        self.index_map
75            .get(id)
76            .and_then(|&idx| self.graph.node_weight(idx))
77    }
78
79    /// Get a mutable reference to a node by its string ID.
80    pub fn get_node_mut(&mut self, id: &str) -> Option<&mut GraphNode> {
81        self.index_map
82            .get(id)
83            .copied()
84            .and_then(|idx| self.graph.node_weight_mut(idx))
85    }
86
87    pub fn get_neighbors(&self, id: &str) -> Vec<&GraphNode> {
88        let Some(&idx) = self.index_map.get(id) else {
89            return Vec::new();
90        };
91        self.graph
92            .neighbors(idx)
93            .filter_map(|ni| self.graph.node_weight(ni))
94            .collect()
95    }
96
97    pub fn node_count(&self) -> usize {
98        self.graph.node_count()
99    }
100
101    pub fn edge_count(&self) -> usize {
102        self.graph.edge_count()
103    }
104
105    /// Replace the hyperedges list.
106    pub fn set_hyperedges(&mut self, h: Vec<Hyperedge>) {
107        self.hyperedges = h;
108    }
109
110    /// Iterate over all node IDs.
111    pub fn node_ids(&self) -> Vec<String> {
112        self.index_map.keys().cloned().collect()
113    }
114
115    /// Get the degree (number of edges) for a node by id.
116    pub fn degree(&self, id: &str) -> usize {
117        self.index_map
118            .get(id)
119            .map(|&idx| self.graph.edges(idx).count())
120            .unwrap_or(0)
121    }
122
123    /// Get neighbor IDs as strings.
124    pub fn neighbor_ids(&self, id: &str) -> Vec<String> {
125        self.get_neighbors(id)
126            .iter()
127            .map(|n| n.id.clone())
128            .collect()
129    }
130
131    /// Collect all nodes as a Vec.
132    pub fn nodes(&self) -> Vec<&GraphNode> {
133        self.graph
134            .node_indices()
135            .filter_map(|idx| self.graph.node_weight(idx))
136            .collect()
137    }
138
139    /// Iterate over all edges as `(source_id, target_id, &GraphEdge)`.
140    pub fn edges_with_endpoints(&self) -> Vec<(&str, &str, &GraphEdge)> {
141        self.graph
142            .edge_indices()
143            .filter_map(|idx| {
144                let (a, b) = self.graph.edge_endpoints(idx)?;
145                let na = self.graph.node_weight(a)?;
146                let nb = self.graph.node_weight(b)?;
147                let e = self.graph.edge_weight(idx)?;
148                Some((na.id.as_str(), nb.id.as_str(), e))
149            })
150            .collect()
151    }
152
153    /// Iterate over all edge weights.
154    pub fn edges(&self) -> Vec<&GraphEdge> {
155        self.graph
156            .edge_indices()
157            .filter_map(|idx| self.graph.edge_weight(idx))
158            .collect()
159    }
160
161    // -- Serialization ---------------------------------------------------
162
163    /// Serialize to the NetworkX `node_link_data` JSON format.
164    pub fn to_node_link_json(&self) -> Value {
165        let nodes: Vec<Value> = self
166            .graph
167            .node_indices()
168            .filter_map(|idx| {
169                let n = self.graph.node_weight(idx)?;
170                Some(serde_json::to_value(n).unwrap_or(Value::Null))
171            })
172            .collect();
173
174        let links: Vec<Value> = self
175            .graph
176            .edge_indices()
177            .filter_map(|idx| {
178                let e = self.graph.edge_weight(idx)?;
179                Some(serde_json::to_value(e).unwrap_or(Value::Null))
180            })
181            .collect();
182
183        json!({
184            "directed": false,
185            "multigraph": false,
186            "graph": {},
187            "nodes": nodes,
188            "links": links,
189        })
190    }
191
192    /// Stream the graph as NetworkX `node_link_data` JSON directly to a writer.
193    ///
194    /// Unlike `to_node_link_json()`, this avoids building the entire JSON tree
195    /// in memory. For a 50K-node graph this saves ~500 MB of intermediate
196    /// allocations.
197    pub fn write_node_link_json<W: Write>(&self, writer: W) -> serde_json::Result<()> {
198        use serde::ser::SerializeMap;
199        use serde_json::ser::{PrettyFormatter, Serializer};
200
201        let formatter = PrettyFormatter::with_indent(b"  ");
202        let mut ser = Serializer::with_formatter(writer, formatter);
203        let mut map = serde::Serializer::serialize_map(&mut ser, Some(5))?;
204
205        map.serialize_entry("directed", &false)?;
206        map.serialize_entry("multigraph", &false)?;
207        map.serialize_entry("graph", &serde_json::Map::new())?;
208
209        // Stream nodes one by one
210        let nodes: Vec<&GraphNode> = self
211            .graph
212            .node_indices()
213            .filter_map(|idx| self.graph.node_weight(idx))
214            .collect();
215        map.serialize_entry("nodes", &nodes)?;
216
217        // Stream edges one by one
218        let links: Vec<&GraphEdge> = self
219            .graph
220            .edge_indices()
221            .filter_map(|idx| self.graph.edge_weight(idx))
222            .collect();
223        map.serialize_entry("links", &links)?;
224
225        map.end()
226    }
227
228    /// Deserialize from the NetworkX `node_link_data` JSON format.
229    pub fn from_node_link_json(value: &Value) -> Result<Self> {
230        let mut kg = Self::new();
231
232        // Nodes
233        if let Some(nodes) = value.get("nodes").and_then(|v| v.as_array()) {
234            for nv in nodes {
235                let node: GraphNode = serde_json::from_value(nv.clone())
236                    .map_err(GraphifyError::SerializationError)?;
237                if let Err(e) = kg.add_node(node) {
238                    warn!("skipping node during import: {e}");
239                }
240            }
241        }
242
243        // Edges (field name is "links" in node_link_data)
244        if let Some(links) = value.get("links").and_then(|v| v.as_array()) {
245            for lv in links {
246                let edge: GraphEdge = serde_json::from_value(lv.clone())
247                    .map_err(GraphifyError::SerializationError)?;
248                if let Err(e) = kg.add_edge(edge) {
249                    warn!("skipping edge during import: {e}");
250                }
251            }
252        }
253
254        Ok(kg)
255    }
256}
257
258// ---------------------------------------------------------------------------
259// Tests
260// ---------------------------------------------------------------------------
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::confidence::Confidence;
266    use crate::model::NodeType;
267
268    fn make_node(id: &str) -> GraphNode {
269        GraphNode {
270            id: id.into(),
271            label: id.into(),
272            source_file: "test.rs".into(),
273            source_location: None,
274            node_type: NodeType::Class,
275            community: None,
276            extra: HashMap::new(),
277        }
278    }
279
280    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
281        GraphEdge {
282            source: src.into(),
283            target: tgt.into(),
284            relation: "calls".into(),
285            confidence: Confidence::Extracted,
286            confidence_score: 1.0,
287            source_file: "test.rs".into(),
288            source_location: None,
289            weight: 1.0,
290            extra: HashMap::new(),
291        }
292    }
293
294    #[test]
295    fn add_and_get_node() {
296        let mut kg = KnowledgeGraph::new();
297        kg.add_node(make_node("a")).unwrap();
298        assert_eq!(kg.node_count(), 1);
299        assert!(kg.get_node("a").is_some());
300        assert!(kg.get_node("missing").is_none());
301    }
302
303    #[test]
304    fn duplicate_node_error() {
305        let mut kg = KnowledgeGraph::new();
306        kg.add_node(make_node("a")).unwrap();
307        let err = kg.add_node(make_node("a")).unwrap_err();
308        assert!(matches!(err, GraphifyError::DuplicateNode(_)));
309    }
310
311    #[test]
312    fn add_edge_and_neighbors() {
313        let mut kg = KnowledgeGraph::new();
314        kg.add_node(make_node("a")).unwrap();
315        kg.add_node(make_node("b")).unwrap();
316        kg.add_edge(make_edge("a", "b")).unwrap();
317
318        assert_eq!(kg.edge_count(), 1);
319        let neighbors = kg.get_neighbors("a");
320        assert_eq!(neighbors.len(), 1);
321        assert_eq!(neighbors[0].id, "b");
322    }
323
324    #[test]
325    fn edge_missing_node() {
326        let mut kg = KnowledgeGraph::new();
327        kg.add_node(make_node("a")).unwrap();
328        let err = kg.add_edge(make_edge("a", "missing")).unwrap_err();
329        assert!(matches!(err, GraphifyError::NodeNotFound(_)));
330    }
331
332    #[test]
333    fn node_link_roundtrip() {
334        let mut kg = KnowledgeGraph::new();
335        kg.add_node(make_node("x")).unwrap();
336        kg.add_node(make_node("y")).unwrap();
337        kg.add_edge(make_edge("x", "y")).unwrap();
338
339        let json = kg.to_node_link_json();
340        assert_eq!(json["directed"], false);
341        assert_eq!(json["multigraph"], false);
342        assert!(json["nodes"].as_array().unwrap().len() == 2);
343        assert!(json["links"].as_array().unwrap().len() == 1);
344
345        // Reconstruct
346        let kg2 = KnowledgeGraph::from_node_link_json(&json).unwrap();
347        assert_eq!(kg2.node_count(), 2);
348        assert_eq!(kg2.edge_count(), 1);
349        assert!(kg2.get_node("x").is_some());
350    }
351
352    #[test]
353    fn empty_graph_json() {
354        let kg = KnowledgeGraph::new();
355        let json = kg.to_node_link_json();
356        assert!(json["nodes"].as_array().unwrap().is_empty());
357        assert!(json["links"].as_array().unwrap().is_empty());
358    }
359
360    #[test]
361    fn get_neighbors_missing_node() {
362        let kg = KnowledgeGraph::new();
363        assert!(kg.get_neighbors("nope").is_empty());
364    }
365
366    #[test]
367    fn default_impl() {
368        let kg = KnowledgeGraph::default();
369        assert_eq!(kg.node_count(), 0);
370    }
371
372    #[test]
373    fn get_node_mut_updates_community() {
374        let mut kg = KnowledgeGraph::new();
375        kg.add_node(make_node("a")).unwrap();
376        assert!(kg.get_node("a").unwrap().community.is_none());
377
378        kg.get_node_mut("a").unwrap().community = Some(42);
379        assert_eq!(kg.get_node("a").unwrap().community, Some(42));
380    }
381
382    #[test]
383    fn get_node_mut_missing_returns_none() {
384        let mut kg = KnowledgeGraph::new();
385        assert!(kg.get_node_mut("nope").is_none());
386    }
387
388    #[test]
389    fn write_node_link_json_matches_to_node_link_json() {
390        let mut kg = KnowledgeGraph::new();
391        kg.add_node(make_node("a")).unwrap();
392        kg.add_node(make_node("b")).unwrap();
393        kg.add_edge(make_edge("a", "b")).unwrap();
394
395        // Streaming write to buffer
396        let mut buf = Vec::new();
397        kg.write_node_link_json(&mut buf).unwrap();
398        let streamed: serde_json::Value = serde_json::from_slice(&buf).unwrap();
399
400        // In-memory build
401        let in_mem = kg.to_node_link_json();
402
403        assert_eq!(streamed["directed"], in_mem["directed"]);
404        assert_eq!(streamed["multigraph"], in_mem["multigraph"]);
405        assert_eq!(
406            streamed["nodes"].as_array().unwrap().len(),
407            in_mem["nodes"].as_array().unwrap().len()
408        );
409        assert_eq!(
410            streamed["links"].as_array().unwrap().len(),
411            in_mem["links"].as_array().unwrap().len()
412        );
413    }
414
415    #[test]
416    fn write_node_link_json_empty_graph() {
417        let kg = KnowledgeGraph::new();
418        let mut buf = Vec::new();
419        kg.write_node_link_json(&mut buf).unwrap();
420        let json: serde_json::Value = serde_json::from_slice(&buf).unwrap();
421        assert!(json["nodes"].as_array().unwrap().is_empty());
422        assert!(json["links"].as_array().unwrap().is_empty());
423    }
424}