Skip to main content

graphify_core/
graph.rs

1use std::collections::HashMap;
2
3use petgraph::Undirected;
4use petgraph::stable_graph::{NodeIndex, StableGraph};
5use serde_json::{Value, json};
6use tracing::warn;
7
8use crate::error::{GraphifyError, Result};
9use crate::model::{CommunityInfo, GraphEdge, GraphNode, Hyperedge};
10
11// ---------------------------------------------------------------------------
12// KnowledgeGraph
13// ---------------------------------------------------------------------------
14
15/// A knowledge graph backed by `petgraph::StableGraph`.
16///
17/// Provides ID-based node lookup and serialization to/from the
18/// NetworkX `node_link_data` JSON format for Python interoperability.
19#[derive(Debug)]
20pub struct KnowledgeGraph {
21    graph: StableGraph<GraphNode, GraphEdge, Undirected>,
22    index_map: HashMap<String, NodeIndex>,
23    pub communities: Vec<CommunityInfo>,
24    pub hyperedges: Vec<Hyperedge>,
25}
26
27impl Default for KnowledgeGraph {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl KnowledgeGraph {
34    pub fn new() -> Self {
35        Self {
36            graph: StableGraph::default(),
37            index_map: HashMap::new(),
38            communities: Vec::new(),
39            hyperedges: Vec::new(),
40        }
41    }
42
43    // -- Mutation --------------------------------------------------------
44
45    /// Add a node. Returns an error if a node with the same `id` already exists.
46    pub fn add_node(&mut self, node: GraphNode) -> Result<NodeIndex> {
47        if self.index_map.contains_key(&node.id) {
48            return Err(GraphifyError::DuplicateNode(node.id.clone()));
49        }
50        let id = node.id.clone();
51        let idx = self.graph.add_node(node);
52        self.index_map.insert(id, idx);
53        Ok(idx)
54    }
55
56    /// Add an edge between two nodes identified by their string IDs.
57    pub fn add_edge(&mut self, edge: GraphEdge) -> Result<()> {
58        let &src = self
59            .index_map
60            .get(&edge.source)
61            .ok_or_else(|| GraphifyError::NodeNotFound(edge.source.clone()))?;
62        let &tgt = self
63            .index_map
64            .get(&edge.target)
65            .ok_or_else(|| GraphifyError::NodeNotFound(edge.target.clone()))?;
66        self.graph.add_edge(src, tgt, edge);
67        Ok(())
68    }
69
70    // -- Query -----------------------------------------------------------
71
72    pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
73        self.index_map
74            .get(id)
75            .and_then(|&idx| self.graph.node_weight(idx))
76    }
77
78    pub fn get_neighbors(&self, id: &str) -> Vec<&GraphNode> {
79        let Some(&idx) = self.index_map.get(id) else {
80            return Vec::new();
81        };
82        self.graph
83            .neighbors(idx)
84            .filter_map(|ni| self.graph.node_weight(ni))
85            .collect()
86    }
87
88    pub fn node_count(&self) -> usize {
89        self.graph.node_count()
90    }
91
92    pub fn edge_count(&self) -> usize {
93        self.graph.edge_count()
94    }
95
96    /// Replace the hyperedges list.
97    pub fn set_hyperedges(&mut self, h: Vec<Hyperedge>) {
98        self.hyperedges = h;
99    }
100
101    /// Iterate over all node IDs.
102    pub fn node_ids(&self) -> Vec<String> {
103        self.index_map.keys().cloned().collect()
104    }
105
106    /// Get the degree (number of edges) for a node by id.
107    pub fn degree(&self, id: &str) -> usize {
108        self.index_map
109            .get(id)
110            .map(|&idx| self.graph.edges(idx).count())
111            .unwrap_or(0)
112    }
113
114    /// Get neighbor IDs as strings.
115    pub fn neighbor_ids(&self, id: &str) -> Vec<String> {
116        self.get_neighbors(id)
117            .iter()
118            .map(|n| n.id.clone())
119            .collect()
120    }
121
122    /// Collect all nodes as a Vec.
123    pub fn nodes(&self) -> Vec<&GraphNode> {
124        self.graph
125            .node_indices()
126            .filter_map(|idx| self.graph.node_weight(idx))
127            .collect()
128    }
129
130    /// Iterate over all edges as `(source_id, target_id, &GraphEdge)`.
131    pub fn edges_with_endpoints(&self) -> Vec<(&str, &str, &GraphEdge)> {
132        self.graph
133            .edge_indices()
134            .filter_map(|idx| {
135                let (a, b) = self.graph.edge_endpoints(idx)?;
136                let na = self.graph.node_weight(a)?;
137                let nb = self.graph.node_weight(b)?;
138                let e = self.graph.edge_weight(idx)?;
139                Some((na.id.as_str(), nb.id.as_str(), e))
140            })
141            .collect()
142    }
143
144    /// Iterate over all edge weights.
145    pub fn edges(&self) -> Vec<&GraphEdge> {
146        self.graph
147            .edge_indices()
148            .filter_map(|idx| self.graph.edge_weight(idx))
149            .collect()
150    }
151
152    // -- Serialization ---------------------------------------------------
153
154    /// Serialize to the NetworkX `node_link_data` JSON format.
155    pub fn to_node_link_json(&self) -> Value {
156        let nodes: Vec<Value> = self
157            .graph
158            .node_indices()
159            .filter_map(|idx| {
160                let n = self.graph.node_weight(idx)?;
161                Some(serde_json::to_value(n).unwrap_or(Value::Null))
162            })
163            .collect();
164
165        let links: Vec<Value> = self
166            .graph
167            .edge_indices()
168            .filter_map(|idx| {
169                let e = self.graph.edge_weight(idx)?;
170                Some(serde_json::to_value(e).unwrap_or(Value::Null))
171            })
172            .collect();
173
174        json!({
175            "directed": false,
176            "multigraph": false,
177            "graph": {},
178            "nodes": nodes,
179            "links": links,
180        })
181    }
182
183    /// Deserialize from the NetworkX `node_link_data` JSON format.
184    pub fn from_node_link_json(value: &Value) -> Result<Self> {
185        let mut kg = Self::new();
186
187        // Nodes
188        if let Some(nodes) = value.get("nodes").and_then(|v| v.as_array()) {
189            for nv in nodes {
190                let node: GraphNode = serde_json::from_value(nv.clone())
191                    .map_err(GraphifyError::SerializationError)?;
192                if let Err(e) = kg.add_node(node) {
193                    warn!("skipping node during import: {e}");
194                }
195            }
196        }
197
198        // Edges (field name is "links" in node_link_data)
199        if let Some(links) = value.get("links").and_then(|v| v.as_array()) {
200            for lv in links {
201                let edge: GraphEdge = serde_json::from_value(lv.clone())
202                    .map_err(GraphifyError::SerializationError)?;
203                if let Err(e) = kg.add_edge(edge) {
204                    warn!("skipping edge during import: {e}");
205                }
206            }
207        }
208
209        Ok(kg)
210    }
211}
212
213// ---------------------------------------------------------------------------
214// Tests
215// ---------------------------------------------------------------------------
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::confidence::Confidence;
221    use crate::model::NodeType;
222
223    fn make_node(id: &str) -> GraphNode {
224        GraphNode {
225            id: id.into(),
226            label: id.into(),
227            source_file: "test.rs".into(),
228            source_location: None,
229            node_type: NodeType::Class,
230            community: None,
231            extra: HashMap::new(),
232        }
233    }
234
235    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
236        GraphEdge {
237            source: src.into(),
238            target: tgt.into(),
239            relation: "calls".into(),
240            confidence: Confidence::Extracted,
241            confidence_score: 1.0,
242            source_file: "test.rs".into(),
243            source_location: None,
244            weight: 1.0,
245            extra: HashMap::new(),
246        }
247    }
248
249    #[test]
250    fn add_and_get_node() {
251        let mut kg = KnowledgeGraph::new();
252        kg.add_node(make_node("a")).unwrap();
253        assert_eq!(kg.node_count(), 1);
254        assert!(kg.get_node("a").is_some());
255        assert!(kg.get_node("missing").is_none());
256    }
257
258    #[test]
259    fn duplicate_node_error() {
260        let mut kg = KnowledgeGraph::new();
261        kg.add_node(make_node("a")).unwrap();
262        let err = kg.add_node(make_node("a")).unwrap_err();
263        assert!(matches!(err, GraphifyError::DuplicateNode(_)));
264    }
265
266    #[test]
267    fn add_edge_and_neighbors() {
268        let mut kg = KnowledgeGraph::new();
269        kg.add_node(make_node("a")).unwrap();
270        kg.add_node(make_node("b")).unwrap();
271        kg.add_edge(make_edge("a", "b")).unwrap();
272
273        assert_eq!(kg.edge_count(), 1);
274        let neighbors = kg.get_neighbors("a");
275        assert_eq!(neighbors.len(), 1);
276        assert_eq!(neighbors[0].id, "b");
277    }
278
279    #[test]
280    fn edge_missing_node() {
281        let mut kg = KnowledgeGraph::new();
282        kg.add_node(make_node("a")).unwrap();
283        let err = kg.add_edge(make_edge("a", "missing")).unwrap_err();
284        assert!(matches!(err, GraphifyError::NodeNotFound(_)));
285    }
286
287    #[test]
288    fn node_link_roundtrip() {
289        let mut kg = KnowledgeGraph::new();
290        kg.add_node(make_node("x")).unwrap();
291        kg.add_node(make_node("y")).unwrap();
292        kg.add_edge(make_edge("x", "y")).unwrap();
293
294        let json = kg.to_node_link_json();
295        assert_eq!(json["directed"], false);
296        assert_eq!(json["multigraph"], false);
297        assert!(json["nodes"].as_array().unwrap().len() == 2);
298        assert!(json["links"].as_array().unwrap().len() == 1);
299
300        // Reconstruct
301        let kg2 = KnowledgeGraph::from_node_link_json(&json).unwrap();
302        assert_eq!(kg2.node_count(), 2);
303        assert_eq!(kg2.edge_count(), 1);
304        assert!(kg2.get_node("x").is_some());
305    }
306
307    #[test]
308    fn empty_graph_json() {
309        let kg = KnowledgeGraph::new();
310        let json = kg.to_node_link_json();
311        assert!(json["nodes"].as_array().unwrap().is_empty());
312        assert!(json["links"].as_array().unwrap().is_empty());
313    }
314
315    #[test]
316    fn get_neighbors_missing_node() {
317        let kg = KnowledgeGraph::new();
318        assert!(kg.get_neighbors("nope").is_empty());
319    }
320
321    #[test]
322    fn default_impl() {
323        let kg = KnowledgeGraph::default();
324        assert_eq!(kg.node_count(), 0);
325    }
326}