1use std::collections::HashMap;
2use std::io::Write;
3
4use petgraph::Undirected;
5use petgraph::stable_graph::{NodeIndex, StableGraph};
6use serde_json::{Value, json};
7use tracing::warn;
8
9use crate::error::{GraphifyError, Result};
10use crate::model::{CommunityInfo, GraphEdge, GraphNode, Hyperedge};
11
12#[derive(Debug)]
17pub struct KnowledgeGraph {
18 graph: StableGraph<GraphNode, GraphEdge, Undirected>,
19 index_map: HashMap<String, NodeIndex>,
20 pub communities: Vec<CommunityInfo>,
21 pub hyperedges: Vec<Hyperedge>,
22}
23
24impl Default for KnowledgeGraph {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl KnowledgeGraph {
31 pub fn new() -> Self {
32 Self {
33 graph: StableGraph::default(),
34 index_map: HashMap::new(),
35 communities: Vec::new(),
36 hyperedges: Vec::new(),
37 }
38 }
39
40 pub fn add_node(&mut self, node: GraphNode) -> Result<NodeIndex> {
42 if self.index_map.contains_key(&node.id) {
43 return Err(GraphifyError::DuplicateNode(node.id.clone()));
44 }
45 let id = node.id.clone();
46 let idx = self.graph.add_node(node);
47 self.index_map.insert(id, idx);
48 Ok(idx)
49 }
50
51 pub fn add_edge(&mut self, edge: GraphEdge) -> Result<()> {
53 let &src = self
54 .index_map
55 .get(&edge.source)
56 .ok_or_else(|| GraphifyError::NodeNotFound(edge.source.clone()))?;
57 let &tgt = self
58 .index_map
59 .get(&edge.target)
60 .ok_or_else(|| GraphifyError::NodeNotFound(edge.target.clone()))?;
61 self.graph.add_edge(src, tgt, edge);
62 Ok(())
63 }
64
65 pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
66 self.index_map
67 .get(id)
68 .and_then(|&idx| self.graph.node_weight(idx))
69 }
70
71 pub fn get_node_mut(&mut self, id: &str) -> Option<&mut GraphNode> {
73 self.index_map
74 .get(id)
75 .copied()
76 .and_then(|idx| self.graph.node_weight_mut(idx))
77 }
78
79 pub fn get_neighbors(&self, id: &str) -> Vec<&GraphNode> {
80 let Some(&idx) = self.index_map.get(id) else {
81 return Vec::new();
82 };
83 self.graph
84 .neighbors(idx)
85 .filter_map(|ni| self.graph.node_weight(ni))
86 .collect()
87 }
88
89 pub fn node_count(&self) -> usize {
90 self.graph.node_count()
91 }
92
93 pub fn edge_count(&self) -> usize {
94 self.graph.edge_count()
95 }
96
97 pub fn set_hyperedges(&mut self, h: Vec<Hyperedge>) {
99 self.hyperedges = h;
100 }
101
102 pub fn node_ids(&self) -> Vec<String> {
104 self.index_map.keys().cloned().collect()
105 }
106
107 pub fn degree(&self, id: &str) -> usize {
109 self.index_map
110 .get(id)
111 .map_or(0, |&idx| self.graph.edges(idx).count())
112 }
113
114 pub fn neighbor_ids(&self, id: &str) -> Vec<String> {
116 self.get_neighbors(id)
117 .iter()
118 .map(|n| n.id.clone())
119 .collect()
120 }
121
122 pub fn nodes(&self) -> Vec<&GraphNode> {
124 self.graph
125 .node_indices()
126 .filter_map(|idx| self.graph.node_weight(idx))
127 .collect()
128 }
129
130 pub fn edges_with_endpoints(&self) -> Vec<(&str, &str, &GraphEdge)> {
132 self.graph
133 .edge_indices()
134 .filter_map(|idx| {
135 let (a, b) = self.graph.edge_endpoints(idx)?;
136 let na = self.graph.node_weight(a)?;
137 let nb = self.graph.node_weight(b)?;
138 let e = self.graph.edge_weight(idx)?;
139 Some((na.id.as_str(), nb.id.as_str(), e))
140 })
141 .collect()
142 }
143
144 pub fn edges(&self) -> Vec<&GraphEdge> {
146 self.graph
147 .edge_indices()
148 .filter_map(|idx| self.graph.edge_weight(idx))
149 .collect()
150 }
151
152 pub fn to_node_link_json(&self) -> Value {
154 let nodes: Vec<Value> = self
155 .graph
156 .node_indices()
157 .filter_map(|idx| {
158 let n = self.graph.node_weight(idx)?;
159 Some(serde_json::to_value(n).unwrap_or(Value::Null))
160 })
161 .collect();
162
163 let links: Vec<Value> = self
164 .graph
165 .edge_indices()
166 .filter_map(|idx| {
167 let e = self.graph.edge_weight(idx)?;
168 Some(serde_json::to_value(e).unwrap_or(Value::Null))
169 })
170 .collect();
171
172 json!({
173 "directed": false,
174 "multigraph": false,
175 "graph": {},
176 "nodes": nodes,
177 "links": links,
178 })
179 }
180
181 pub fn write_node_link_json<W: Write>(&self, writer: W) -> serde_json::Result<()> {
188 use serde::ser::SerializeMap;
189 use serde_json::ser::{PrettyFormatter, Serializer};
190
191 let formatter = PrettyFormatter::with_indent(b" ");
192 let mut ser = Serializer::with_formatter(writer, formatter);
193 let mut map = serde::Serializer::serialize_map(&mut ser, Some(5))?;
194
195 map.serialize_entry("directed", &false)?;
196 map.serialize_entry("multigraph", &false)?;
197 map.serialize_entry("graph", &serde_json::Map::new())?;
198
199 let nodes: Vec<&GraphNode> = self
200 .graph
201 .node_indices()
202 .filter_map(|idx| self.graph.node_weight(idx))
203 .collect();
204 map.serialize_entry("nodes", &nodes)?;
205
206 let links: Vec<&GraphEdge> = self
207 .graph
208 .edge_indices()
209 .filter_map(|idx| self.graph.edge_weight(idx))
210 .collect();
211 map.serialize_entry("links", &links)?;
212
213 map.end()
214 }
215
216 pub fn from_node_link_json(value: &Value) -> Result<Self> {
218 let mut kg = Self::new();
219
220 if let Some(nodes) = value.get("nodes").and_then(|v| v.as_array()) {
221 for nv in nodes {
222 let node: GraphNode = serde_json::from_value(nv.clone())
223 .map_err(GraphifyError::SerializationError)?;
224 if let Err(e) = kg.add_node(node) {
225 warn!("skipping node during import: {e}");
226 }
227 }
228 }
229
230 if let Some(links) = value.get("links").and_then(|v| v.as_array()) {
231 for lv in links {
232 let edge: GraphEdge = serde_json::from_value(lv.clone())
233 .map_err(GraphifyError::SerializationError)?;
234 if let Err(e) = kg.add_edge(edge) {
235 warn!("skipping edge during import: {e}");
236 }
237 }
238 }
239
240 Ok(kg)
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::confidence::Confidence;
248 use crate::model::NodeType;
249
250 fn make_node(id: &str) -> GraphNode {
251 GraphNode {
252 id: id.into(),
253 label: id.into(),
254 source_file: "test.rs".into(),
255 source_location: None,
256 node_type: NodeType::Class,
257 community: None,
258 extra: HashMap::new(),
259 }
260 }
261
262 fn make_edge(src: &str, tgt: &str) -> GraphEdge {
263 GraphEdge {
264 source: src.into(),
265 target: tgt.into(),
266 relation: "calls".into(),
267 confidence: Confidence::Extracted,
268 confidence_score: 1.0,
269 source_file: "test.rs".into(),
270 source_location: None,
271 weight: 1.0,
272 extra: HashMap::new(),
273 }
274 }
275
276 #[test]
277 fn add_and_get_node() {
278 let mut kg = KnowledgeGraph::new();
279 kg.add_node(make_node("a")).unwrap();
280 assert_eq!(kg.node_count(), 1);
281 assert!(kg.get_node("a").is_some());
282 assert!(kg.get_node("missing").is_none());
283 }
284
285 #[test]
286 fn duplicate_node_error() {
287 let mut kg = KnowledgeGraph::new();
288 kg.add_node(make_node("a")).unwrap();
289 let err = kg.add_node(make_node("a")).unwrap_err();
290 assert!(matches!(err, GraphifyError::DuplicateNode(_)));
291 }
292
293 #[test]
294 fn add_edge_and_neighbors() {
295 let mut kg = KnowledgeGraph::new();
296 kg.add_node(make_node("a")).unwrap();
297 kg.add_node(make_node("b")).unwrap();
298 kg.add_edge(make_edge("a", "b")).unwrap();
299
300 assert_eq!(kg.edge_count(), 1);
301 let neighbors = kg.get_neighbors("a");
302 assert_eq!(neighbors.len(), 1);
303 assert_eq!(neighbors[0].id, "b");
304 }
305
306 #[test]
307 fn edge_missing_node() {
308 let mut kg = KnowledgeGraph::new();
309 kg.add_node(make_node("a")).unwrap();
310 let err = kg.add_edge(make_edge("a", "missing")).unwrap_err();
311 assert!(matches!(err, GraphifyError::NodeNotFound(_)));
312 }
313
314 #[test]
315 fn node_link_roundtrip() {
316 let mut kg = KnowledgeGraph::new();
317 kg.add_node(make_node("x")).unwrap();
318 kg.add_node(make_node("y")).unwrap();
319 kg.add_edge(make_edge("x", "y")).unwrap();
320
321 let json = kg.to_node_link_json();
322 assert_eq!(json["directed"], false);
323 assert_eq!(json["multigraph"], false);
324 assert!(json["nodes"].as_array().unwrap().len() == 2);
325 assert!(json["links"].as_array().unwrap().len() == 1);
326
327 let kg2 = KnowledgeGraph::from_node_link_json(&json).unwrap();
328 assert_eq!(kg2.node_count(), 2);
329 assert_eq!(kg2.edge_count(), 1);
330 assert!(kg2.get_node("x").is_some());
331 }
332
333 #[test]
334 fn empty_graph_json() {
335 let kg = KnowledgeGraph::new();
336 let json = kg.to_node_link_json();
337 assert!(json["nodes"].as_array().unwrap().is_empty());
338 assert!(json["links"].as_array().unwrap().is_empty());
339 }
340
341 #[test]
342 fn get_neighbors_missing_node() {
343 let kg = KnowledgeGraph::new();
344 assert!(kg.get_neighbors("nope").is_empty());
345 }
346
347 #[test]
348 fn default_impl() {
349 let kg = KnowledgeGraph::default();
350 assert_eq!(kg.node_count(), 0);
351 }
352
353 #[test]
354 fn get_node_mut_updates_community() {
355 let mut kg = KnowledgeGraph::new();
356 kg.add_node(make_node("a")).unwrap();
357 assert!(kg.get_node("a").unwrap().community.is_none());
358
359 kg.get_node_mut("a").unwrap().community = Some(42);
360 assert_eq!(kg.get_node("a").unwrap().community, Some(42));
361 }
362
363 #[test]
364 fn get_node_mut_missing_returns_none() {
365 let mut kg = KnowledgeGraph::new();
366 assert!(kg.get_node_mut("nope").is_none());
367 }
368
369 #[test]
370 fn write_node_link_json_matches_to_node_link_json() {
371 let mut kg = KnowledgeGraph::new();
372 kg.add_node(make_node("a")).unwrap();
373 kg.add_node(make_node("b")).unwrap();
374 kg.add_edge(make_edge("a", "b")).unwrap();
375
376 let mut buf = Vec::new();
377 kg.write_node_link_json(&mut buf).unwrap();
378 let streamed: serde_json::Value = serde_json::from_slice(&buf).unwrap();
379
380 let in_mem = kg.to_node_link_json();
381
382 assert_eq!(streamed["directed"], in_mem["directed"]);
383 assert_eq!(streamed["multigraph"], in_mem["multigraph"]);
384 assert_eq!(
385 streamed["nodes"].as_array().unwrap().len(),
386 in_mem["nodes"].as_array().unwrap().len()
387 );
388 assert_eq!(
389 streamed["links"].as_array().unwrap().len(),
390 in_mem["links"].as_array().unwrap().len()
391 );
392 }
393
394 #[test]
395 fn write_node_link_json_empty_graph() {
396 let kg = KnowledgeGraph::new();
397 let mut buf = Vec::new();
398 kg.write_node_link_json(&mut buf).unwrap();
399 let json: serde_json::Value = serde_json::from_slice(&buf).unwrap();
400 assert!(json["nodes"].as_array().unwrap().is_empty());
401 assert!(json["links"].as_array().unwrap().is_empty());
402 }
403}