1use std::collections::HashMap;
2use std::io::Write;
3
4use petgraph::Undirected;
5use petgraph::stable_graph::{NodeIndex, StableGraph};
6use serde_json::{Value, json};
7use tracing::warn;
8
9use crate::error::{GraphifyError, Result};
10use crate::model::{CommunityInfo, GraphEdge, GraphNode, Hyperedge};
11
12#[derive(Debug)]
21pub struct KnowledgeGraph {
22 graph: StableGraph<GraphNode, GraphEdge, Undirected>,
23 index_map: HashMap<String, NodeIndex>,
24 pub communities: Vec<CommunityInfo>,
25 pub hyperedges: Vec<Hyperedge>,
26}
27
28impl Default for KnowledgeGraph {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl KnowledgeGraph {
35 pub fn new() -> Self {
36 Self {
37 graph: StableGraph::default(),
38 index_map: HashMap::new(),
39 communities: Vec::new(),
40 hyperedges: Vec::new(),
41 }
42 }
43
44 pub fn add_node(&mut self, node: GraphNode) -> Result<NodeIndex> {
48 if self.index_map.contains_key(&node.id) {
49 return Err(GraphifyError::DuplicateNode(node.id.clone()));
50 }
51 let id = node.id.clone();
52 let idx = self.graph.add_node(node);
53 self.index_map.insert(id, idx);
54 Ok(idx)
55 }
56
57 pub fn add_edge(&mut self, edge: GraphEdge) -> Result<()> {
59 let &src = self
60 .index_map
61 .get(&edge.source)
62 .ok_or_else(|| GraphifyError::NodeNotFound(edge.source.clone()))?;
63 let &tgt = self
64 .index_map
65 .get(&edge.target)
66 .ok_or_else(|| GraphifyError::NodeNotFound(edge.target.clone()))?;
67 self.graph.add_edge(src, tgt, edge);
68 Ok(())
69 }
70
71 pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
74 self.index_map
75 .get(id)
76 .and_then(|&idx| self.graph.node_weight(idx))
77 }
78
79 pub fn get_node_mut(&mut self, id: &str) -> Option<&mut GraphNode> {
81 self.index_map
82 .get(id)
83 .copied()
84 .and_then(|idx| self.graph.node_weight_mut(idx))
85 }
86
87 pub fn get_neighbors(&self, id: &str) -> Vec<&GraphNode> {
88 let Some(&idx) = self.index_map.get(id) else {
89 return Vec::new();
90 };
91 self.graph
92 .neighbors(idx)
93 .filter_map(|ni| self.graph.node_weight(ni))
94 .collect()
95 }
96
97 pub fn node_count(&self) -> usize {
98 self.graph.node_count()
99 }
100
101 pub fn edge_count(&self) -> usize {
102 self.graph.edge_count()
103 }
104
105 pub fn set_hyperedges(&mut self, h: Vec<Hyperedge>) {
107 self.hyperedges = h;
108 }
109
110 pub fn node_ids(&self) -> Vec<String> {
112 self.index_map.keys().cloned().collect()
113 }
114
115 pub fn degree(&self, id: &str) -> usize {
117 self.index_map
118 .get(id)
119 .map(|&idx| self.graph.edges(idx).count())
120 .unwrap_or(0)
121 }
122
123 pub fn neighbor_ids(&self, id: &str) -> Vec<String> {
125 self.get_neighbors(id)
126 .iter()
127 .map(|n| n.id.clone())
128 .collect()
129 }
130
131 pub fn nodes(&self) -> Vec<&GraphNode> {
133 self.graph
134 .node_indices()
135 .filter_map(|idx| self.graph.node_weight(idx))
136 .collect()
137 }
138
139 pub fn edges_with_endpoints(&self) -> Vec<(&str, &str, &GraphEdge)> {
141 self.graph
142 .edge_indices()
143 .filter_map(|idx| {
144 let (a, b) = self.graph.edge_endpoints(idx)?;
145 let na = self.graph.node_weight(a)?;
146 let nb = self.graph.node_weight(b)?;
147 let e = self.graph.edge_weight(idx)?;
148 Some((na.id.as_str(), nb.id.as_str(), e))
149 })
150 .collect()
151 }
152
153 pub fn edges(&self) -> Vec<&GraphEdge> {
155 self.graph
156 .edge_indices()
157 .filter_map(|idx| self.graph.edge_weight(idx))
158 .collect()
159 }
160
161 pub fn to_node_link_json(&self) -> Value {
165 let nodes: Vec<Value> = self
166 .graph
167 .node_indices()
168 .filter_map(|idx| {
169 let n = self.graph.node_weight(idx)?;
170 Some(serde_json::to_value(n).unwrap_or(Value::Null))
171 })
172 .collect();
173
174 let links: Vec<Value> = self
175 .graph
176 .edge_indices()
177 .filter_map(|idx| {
178 let e = self.graph.edge_weight(idx)?;
179 Some(serde_json::to_value(e).unwrap_or(Value::Null))
180 })
181 .collect();
182
183 json!({
184 "directed": false,
185 "multigraph": false,
186 "graph": {},
187 "nodes": nodes,
188 "links": links,
189 })
190 }
191
192 pub fn write_node_link_json<W: Write>(&self, writer: W) -> serde_json::Result<()> {
198 use serde::ser::SerializeMap;
199 use serde_json::ser::{PrettyFormatter, Serializer};
200
201 let formatter = PrettyFormatter::with_indent(b" ");
202 let mut ser = Serializer::with_formatter(writer, formatter);
203 let mut map = serde::Serializer::serialize_map(&mut ser, Some(5))?;
204
205 map.serialize_entry("directed", &false)?;
206 map.serialize_entry("multigraph", &false)?;
207 map.serialize_entry("graph", &serde_json::Map::new())?;
208
209 let nodes: Vec<&GraphNode> = self
211 .graph
212 .node_indices()
213 .filter_map(|idx| self.graph.node_weight(idx))
214 .collect();
215 map.serialize_entry("nodes", &nodes)?;
216
217 let links: Vec<&GraphEdge> = self
219 .graph
220 .edge_indices()
221 .filter_map(|idx| self.graph.edge_weight(idx))
222 .collect();
223 map.serialize_entry("links", &links)?;
224
225 map.end()
226 }
227
228 pub fn from_node_link_json(value: &Value) -> Result<Self> {
230 let mut kg = Self::new();
231
232 if let Some(nodes) = value.get("nodes").and_then(|v| v.as_array()) {
234 for nv in nodes {
235 let node: GraphNode = serde_json::from_value(nv.clone())
236 .map_err(GraphifyError::SerializationError)?;
237 if let Err(e) = kg.add_node(node) {
238 warn!("skipping node during import: {e}");
239 }
240 }
241 }
242
243 if let Some(links) = value.get("links").and_then(|v| v.as_array()) {
245 for lv in links {
246 let edge: GraphEdge = serde_json::from_value(lv.clone())
247 .map_err(GraphifyError::SerializationError)?;
248 if let Err(e) = kg.add_edge(edge) {
249 warn!("skipping edge during import: {e}");
250 }
251 }
252 }
253
254 Ok(kg)
255 }
256}
257
258#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::confidence::Confidence;
266 use crate::model::NodeType;
267
268 fn make_node(id: &str) -> GraphNode {
269 GraphNode {
270 id: id.into(),
271 label: id.into(),
272 source_file: "test.rs".into(),
273 source_location: None,
274 node_type: NodeType::Class,
275 community: None,
276 extra: HashMap::new(),
277 }
278 }
279
280 fn make_edge(src: &str, tgt: &str) -> GraphEdge {
281 GraphEdge {
282 source: src.into(),
283 target: tgt.into(),
284 relation: "calls".into(),
285 confidence: Confidence::Extracted,
286 confidence_score: 1.0,
287 source_file: "test.rs".into(),
288 source_location: None,
289 weight: 1.0,
290 extra: HashMap::new(),
291 }
292 }
293
294 #[test]
295 fn add_and_get_node() {
296 let mut kg = KnowledgeGraph::new();
297 kg.add_node(make_node("a")).unwrap();
298 assert_eq!(kg.node_count(), 1);
299 assert!(kg.get_node("a").is_some());
300 assert!(kg.get_node("missing").is_none());
301 }
302
303 #[test]
304 fn duplicate_node_error() {
305 let mut kg = KnowledgeGraph::new();
306 kg.add_node(make_node("a")).unwrap();
307 let err = kg.add_node(make_node("a")).unwrap_err();
308 assert!(matches!(err, GraphifyError::DuplicateNode(_)));
309 }
310
311 #[test]
312 fn add_edge_and_neighbors() {
313 let mut kg = KnowledgeGraph::new();
314 kg.add_node(make_node("a")).unwrap();
315 kg.add_node(make_node("b")).unwrap();
316 kg.add_edge(make_edge("a", "b")).unwrap();
317
318 assert_eq!(kg.edge_count(), 1);
319 let neighbors = kg.get_neighbors("a");
320 assert_eq!(neighbors.len(), 1);
321 assert_eq!(neighbors[0].id, "b");
322 }
323
324 #[test]
325 fn edge_missing_node() {
326 let mut kg = KnowledgeGraph::new();
327 kg.add_node(make_node("a")).unwrap();
328 let err = kg.add_edge(make_edge("a", "missing")).unwrap_err();
329 assert!(matches!(err, GraphifyError::NodeNotFound(_)));
330 }
331
332 #[test]
333 fn node_link_roundtrip() {
334 let mut kg = KnowledgeGraph::new();
335 kg.add_node(make_node("x")).unwrap();
336 kg.add_node(make_node("y")).unwrap();
337 kg.add_edge(make_edge("x", "y")).unwrap();
338
339 let json = kg.to_node_link_json();
340 assert_eq!(json["directed"], false);
341 assert_eq!(json["multigraph"], false);
342 assert!(json["nodes"].as_array().unwrap().len() == 2);
343 assert!(json["links"].as_array().unwrap().len() == 1);
344
345 let kg2 = KnowledgeGraph::from_node_link_json(&json).unwrap();
347 assert_eq!(kg2.node_count(), 2);
348 assert_eq!(kg2.edge_count(), 1);
349 assert!(kg2.get_node("x").is_some());
350 }
351
352 #[test]
353 fn empty_graph_json() {
354 let kg = KnowledgeGraph::new();
355 let json = kg.to_node_link_json();
356 assert!(json["nodes"].as_array().unwrap().is_empty());
357 assert!(json["links"].as_array().unwrap().is_empty());
358 }
359
360 #[test]
361 fn get_neighbors_missing_node() {
362 let kg = KnowledgeGraph::new();
363 assert!(kg.get_neighbors("nope").is_empty());
364 }
365
366 #[test]
367 fn default_impl() {
368 let kg = KnowledgeGraph::default();
369 assert_eq!(kg.node_count(), 0);
370 }
371
372 #[test]
373 fn get_node_mut_updates_community() {
374 let mut kg = KnowledgeGraph::new();
375 kg.add_node(make_node("a")).unwrap();
376 assert!(kg.get_node("a").unwrap().community.is_none());
377
378 kg.get_node_mut("a").unwrap().community = Some(42);
379 assert_eq!(kg.get_node("a").unwrap().community, Some(42));
380 }
381
382 #[test]
383 fn get_node_mut_missing_returns_none() {
384 let mut kg = KnowledgeGraph::new();
385 assert!(kg.get_node_mut("nope").is_none());
386 }
387
388 #[test]
389 fn write_node_link_json_matches_to_node_link_json() {
390 let mut kg = KnowledgeGraph::new();
391 kg.add_node(make_node("a")).unwrap();
392 kg.add_node(make_node("b")).unwrap();
393 kg.add_edge(make_edge("a", "b")).unwrap();
394
395 let mut buf = Vec::new();
397 kg.write_node_link_json(&mut buf).unwrap();
398 let streamed: serde_json::Value = serde_json::from_slice(&buf).unwrap();
399
400 let in_mem = kg.to_node_link_json();
402
403 assert_eq!(streamed["directed"], in_mem["directed"]);
404 assert_eq!(streamed["multigraph"], in_mem["multigraph"]);
405 assert_eq!(
406 streamed["nodes"].as_array().unwrap().len(),
407 in_mem["nodes"].as_array().unwrap().len()
408 );
409 assert_eq!(
410 streamed["links"].as_array().unwrap().len(),
411 in_mem["links"].as_array().unwrap().len()
412 );
413 }
414
415 #[test]
416 fn write_node_link_json_empty_graph() {
417 let kg = KnowledgeGraph::new();
418 let mut buf = Vec::new();
419 kg.write_node_link_json(&mut buf).unwrap();
420 let json: serde_json::Value = serde_json::from_slice(&buf).unwrap();
421 assert!(json["nodes"].as_array().unwrap().is_empty());
422 assert!(json["links"].as_array().unwrap().is_empty());
423 }
424}