Skip to main content

graphify_core/
graph.rs

1use std::collections::HashMap;
2use std::io::Write;
3
4use petgraph::Undirected;
5use petgraph::stable_graph::{NodeIndex, StableGraph};
6use serde_json::{Value, json};
7use tracing::warn;
8
9use crate::error::{GraphifyError, Result};
10use crate::model::{CommunityInfo, GraphEdge, GraphNode, Hyperedge};
11
12/// A knowledge graph backed by `petgraph::StableGraph`.
13///
14/// Provides ID-based node lookup and serialization to/from the
15/// NetworkX `node_link_data` JSON format for Python interoperability.
16#[derive(Debug)]
17pub struct KnowledgeGraph {
18    graph: StableGraph<GraphNode, GraphEdge, Undirected>,
19    index_map: HashMap<String, NodeIndex>,
20    pub communities: Vec<CommunityInfo>,
21    pub hyperedges: Vec<Hyperedge>,
22}
23
24impl Default for KnowledgeGraph {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl KnowledgeGraph {
31    pub fn new() -> Self {
32        Self {
33            graph: StableGraph::default(),
34            index_map: HashMap::new(),
35            communities: Vec::new(),
36            hyperedges: Vec::new(),
37        }
38    }
39
40    /// Add a node. Returns an error if a node with the same `id` already exists.
41    pub fn add_node(&mut self, node: GraphNode) -> Result<NodeIndex> {
42        if self.index_map.contains_key(&node.id) {
43            return Err(GraphifyError::DuplicateNode(node.id.clone()));
44        }
45        let id = node.id.clone();
46        let idx = self.graph.add_node(node);
47        self.index_map.insert(id, idx);
48        Ok(idx)
49    }
50
51    /// Add an edge between two nodes identified by their string IDs.
52    pub fn add_edge(&mut self, edge: GraphEdge) -> Result<()> {
53        let &src = self
54            .index_map
55            .get(&edge.source)
56            .ok_or_else(|| GraphifyError::NodeNotFound(edge.source.clone()))?;
57        let &tgt = self
58            .index_map
59            .get(&edge.target)
60            .ok_or_else(|| GraphifyError::NodeNotFound(edge.target.clone()))?;
61        self.graph.add_edge(src, tgt, edge);
62        Ok(())
63    }
64
65    pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
66        self.index_map
67            .get(id)
68            .and_then(|&idx| self.graph.node_weight(idx))
69    }
70
71    /// Get a mutable reference to a node by its string ID.
72    pub fn get_node_mut(&mut self, id: &str) -> Option<&mut GraphNode> {
73        self.index_map
74            .get(id)
75            .copied()
76            .and_then(|idx| self.graph.node_weight_mut(idx))
77    }
78
79    pub fn get_neighbors(&self, id: &str) -> Vec<&GraphNode> {
80        let Some(&idx) = self.index_map.get(id) else {
81            return Vec::new();
82        };
83        self.graph
84            .neighbors(idx)
85            .filter_map(|ni| self.graph.node_weight(ni))
86            .collect()
87    }
88
89    pub fn node_count(&self) -> usize {
90        self.graph.node_count()
91    }
92
93    pub fn edge_count(&self) -> usize {
94        self.graph.edge_count()
95    }
96
97    /// Replace the hyperedges list.
98    pub fn set_hyperedges(&mut self, h: Vec<Hyperedge>) {
99        self.hyperedges = h;
100    }
101
102    /// Iterate over all node IDs.
103    pub fn node_ids(&self) -> Vec<String> {
104        self.index_map.keys().cloned().collect()
105    }
106
107    /// Get the degree (number of edges) for a node by id.
108    pub fn degree(&self, id: &str) -> usize {
109        self.index_map
110            .get(id)
111            .map_or(0, |&idx| self.graph.edges(idx).count())
112    }
113
114    /// Get neighbor IDs as strings.
115    pub fn neighbor_ids(&self, id: &str) -> Vec<String> {
116        self.get_neighbors(id)
117            .iter()
118            .map(|n| n.id.clone())
119            .collect()
120    }
121
122    /// Collect all nodes as a Vec.
123    pub fn nodes(&self) -> Vec<&GraphNode> {
124        self.graph
125            .node_indices()
126            .filter_map(|idx| self.graph.node_weight(idx))
127            .collect()
128    }
129
130    /// Iterate over all edges as `(source_id, target_id, &GraphEdge)`.
131    pub fn edges_with_endpoints(&self) -> Vec<(&str, &str, &GraphEdge)> {
132        self.graph
133            .edge_indices()
134            .filter_map(|idx| {
135                let (a, b) = self.graph.edge_endpoints(idx)?;
136                let na = self.graph.node_weight(a)?;
137                let nb = self.graph.node_weight(b)?;
138                let e = self.graph.edge_weight(idx)?;
139                Some((na.id.as_str(), nb.id.as_str(), e))
140            })
141            .collect()
142    }
143
144    /// Iterate over all edge weights.
145    pub fn edges(&self) -> Vec<&GraphEdge> {
146        self.graph
147            .edge_indices()
148            .filter_map(|idx| self.graph.edge_weight(idx))
149            .collect()
150    }
151
152    /// Serialize to the NetworkX `node_link_data` JSON format.
153    pub fn to_node_link_json(&self) -> Value {
154        let nodes: Vec<Value> = self
155            .graph
156            .node_indices()
157            .filter_map(|idx| {
158                let n = self.graph.node_weight(idx)?;
159                Some(serde_json::to_value(n).unwrap_or(Value::Null))
160            })
161            .collect();
162
163        let links: Vec<Value> = self
164            .graph
165            .edge_indices()
166            .filter_map(|idx| {
167                let e = self.graph.edge_weight(idx)?;
168                Some(serde_json::to_value(e).unwrap_or(Value::Null))
169            })
170            .collect();
171
172        json!({
173            "directed": false,
174            "multigraph": false,
175            "graph": {},
176            "nodes": nodes,
177            "links": links,
178        })
179    }
180
181    /// Stream the graph as NetworkX `node_link_data` JSON directly to a writer.
182    ///
183    /// Serialize to the NetworkX `node_link_data` JSON format, writing to
184    /// the provided writer. Uses a streaming serializer to avoid building
185    /// an intermediate JSON Value tree, but still collects node/edge
186    /// references into a Vec for serialization.
187    pub fn write_node_link_json<W: Write>(&self, writer: W) -> serde_json::Result<()> {
188        use serde::ser::SerializeMap;
189        use serde_json::ser::{PrettyFormatter, Serializer};
190
191        let formatter = PrettyFormatter::with_indent(b"  ");
192        let mut ser = Serializer::with_formatter(writer, formatter);
193        let mut map = serde::Serializer::serialize_map(&mut ser, Some(5))?;
194
195        map.serialize_entry("directed", &false)?;
196        map.serialize_entry("multigraph", &false)?;
197        map.serialize_entry("graph", &serde_json::Map::new())?;
198
199        let nodes: Vec<&GraphNode> = self
200            .graph
201            .node_indices()
202            .filter_map(|idx| self.graph.node_weight(idx))
203            .collect();
204        map.serialize_entry("nodes", &nodes)?;
205
206        let links: Vec<&GraphEdge> = self
207            .graph
208            .edge_indices()
209            .filter_map(|idx| self.graph.edge_weight(idx))
210            .collect();
211        map.serialize_entry("links", &links)?;
212
213        map.end()
214    }
215
216    /// Deserialize from the NetworkX `node_link_data` JSON format.
217    pub fn from_node_link_json(value: &Value) -> Result<Self> {
218        let mut kg = Self::new();
219
220        if let Some(nodes) = value.get("nodes").and_then(|v| v.as_array()) {
221            for nv in nodes {
222                let node: GraphNode = serde_json::from_value(nv.clone())
223                    .map_err(GraphifyError::SerializationError)?;
224                if let Err(e) = kg.add_node(node) {
225                    warn!("skipping node during import: {e}");
226                }
227            }
228        }
229
230        if let Some(links) = value.get("links").and_then(|v| v.as_array()) {
231            for lv in links {
232                let edge: GraphEdge = serde_json::from_value(lv.clone())
233                    .map_err(GraphifyError::SerializationError)?;
234                if let Err(e) = kg.add_edge(edge) {
235                    warn!("skipping edge during import: {e}");
236                }
237            }
238        }
239
240        Ok(kg)
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::confidence::Confidence;
248    use crate::model::NodeType;
249
250    fn make_node(id: &str) -> GraphNode {
251        GraphNode {
252            id: id.into(),
253            label: id.into(),
254            source_file: "test.rs".into(),
255            source_location: None,
256            node_type: NodeType::Class,
257            community: None,
258            extra: HashMap::new(),
259        }
260    }
261
262    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
263        GraphEdge {
264            source: src.into(),
265            target: tgt.into(),
266            relation: "calls".into(),
267            confidence: Confidence::Extracted,
268            confidence_score: 1.0,
269            source_file: "test.rs".into(),
270            source_location: None,
271            weight: 1.0,
272            extra: HashMap::new(),
273        }
274    }
275
276    #[test]
277    fn add_and_get_node() {
278        let mut kg = KnowledgeGraph::new();
279        kg.add_node(make_node("a")).unwrap();
280        assert_eq!(kg.node_count(), 1);
281        assert!(kg.get_node("a").is_some());
282        assert!(kg.get_node("missing").is_none());
283    }
284
285    #[test]
286    fn duplicate_node_error() {
287        let mut kg = KnowledgeGraph::new();
288        kg.add_node(make_node("a")).unwrap();
289        let err = kg.add_node(make_node("a")).unwrap_err();
290        assert!(matches!(err, GraphifyError::DuplicateNode(_)));
291    }
292
293    #[test]
294    fn add_edge_and_neighbors() {
295        let mut kg = KnowledgeGraph::new();
296        kg.add_node(make_node("a")).unwrap();
297        kg.add_node(make_node("b")).unwrap();
298        kg.add_edge(make_edge("a", "b")).unwrap();
299
300        assert_eq!(kg.edge_count(), 1);
301        let neighbors = kg.get_neighbors("a");
302        assert_eq!(neighbors.len(), 1);
303        assert_eq!(neighbors[0].id, "b");
304    }
305
306    #[test]
307    fn edge_missing_node() {
308        let mut kg = KnowledgeGraph::new();
309        kg.add_node(make_node("a")).unwrap();
310        let err = kg.add_edge(make_edge("a", "missing")).unwrap_err();
311        assert!(matches!(err, GraphifyError::NodeNotFound(_)));
312    }
313
314    #[test]
315    fn node_link_roundtrip() {
316        let mut kg = KnowledgeGraph::new();
317        kg.add_node(make_node("x")).unwrap();
318        kg.add_node(make_node("y")).unwrap();
319        kg.add_edge(make_edge("x", "y")).unwrap();
320
321        let json = kg.to_node_link_json();
322        assert_eq!(json["directed"], false);
323        assert_eq!(json["multigraph"], false);
324        assert!(json["nodes"].as_array().unwrap().len() == 2);
325        assert!(json["links"].as_array().unwrap().len() == 1);
326
327        let kg2 = KnowledgeGraph::from_node_link_json(&json).unwrap();
328        assert_eq!(kg2.node_count(), 2);
329        assert_eq!(kg2.edge_count(), 1);
330        assert!(kg2.get_node("x").is_some());
331    }
332
333    #[test]
334    fn empty_graph_json() {
335        let kg = KnowledgeGraph::new();
336        let json = kg.to_node_link_json();
337        assert!(json["nodes"].as_array().unwrap().is_empty());
338        assert!(json["links"].as_array().unwrap().is_empty());
339    }
340
341    #[test]
342    fn get_neighbors_missing_node() {
343        let kg = KnowledgeGraph::new();
344        assert!(kg.get_neighbors("nope").is_empty());
345    }
346
347    #[test]
348    fn default_impl() {
349        let kg = KnowledgeGraph::default();
350        assert_eq!(kg.node_count(), 0);
351    }
352
353    #[test]
354    fn get_node_mut_updates_community() {
355        let mut kg = KnowledgeGraph::new();
356        kg.add_node(make_node("a")).unwrap();
357        assert!(kg.get_node("a").unwrap().community.is_none());
358
359        kg.get_node_mut("a").unwrap().community = Some(42);
360        assert_eq!(kg.get_node("a").unwrap().community, Some(42));
361    }
362
363    #[test]
364    fn get_node_mut_missing_returns_none() {
365        let mut kg = KnowledgeGraph::new();
366        assert!(kg.get_node_mut("nope").is_none());
367    }
368
369    #[test]
370    fn write_node_link_json_matches_to_node_link_json() {
371        let mut kg = KnowledgeGraph::new();
372        kg.add_node(make_node("a")).unwrap();
373        kg.add_node(make_node("b")).unwrap();
374        kg.add_edge(make_edge("a", "b")).unwrap();
375
376        let mut buf = Vec::new();
377        kg.write_node_link_json(&mut buf).unwrap();
378        let streamed: serde_json::Value = serde_json::from_slice(&buf).unwrap();
379
380        let in_mem = kg.to_node_link_json();
381
382        assert_eq!(streamed["directed"], in_mem["directed"]);
383        assert_eq!(streamed["multigraph"], in_mem["multigraph"]);
384        assert_eq!(
385            streamed["nodes"].as_array().unwrap().len(),
386            in_mem["nodes"].as_array().unwrap().len()
387        );
388        assert_eq!(
389            streamed["links"].as_array().unwrap().len(),
390            in_mem["links"].as_array().unwrap().len()
391        );
392    }
393
394    #[test]
395    fn write_node_link_json_empty_graph() {
396        let kg = KnowledgeGraph::new();
397        let mut buf = Vec::new();
398        kg.write_node_link_json(&mut buf).unwrap();
399        let json: serde_json::Value = serde_json::from_slice(&buf).unwrap();
400        assert!(json["nodes"].as_array().unwrap().is_empty());
401        assert!(json["links"].as_array().unwrap().is_empty());
402    }
403}