1use crate::{Embedding, RragResult};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use uuid::Uuid;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct KnowledgeGraph {
13 pub nodes: HashMap<String, GraphNode>,
15
16 pub edges: HashMap<String, GraphEdge>,
18
19 pub adjacency_list: HashMap<String, HashSet<String>>,
21
22 pub reverse_adjacency_list: HashMap<String, HashSet<String>>,
24
25 pub metadata: HashMap<String, serde_json::Value>,
27
28 pub created_at: chrono::DateTime<chrono::Utc>,
30
31 pub updated_at: chrono::DateTime<chrono::Utc>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct GraphNode {
38 pub id: String,
40
41 pub label: String,
43
44 pub node_type: NodeType,
46
47 pub attributes: HashMap<String, serde_json::Value>,
49
50 pub embedding: Option<Embedding>,
52
53 pub source_documents: HashSet<String>,
55
56 pub confidence: f32,
58
59 pub pagerank_score: Option<f32>,
61
62 pub created_at: chrono::DateTime<chrono::Utc>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct GraphEdge {
69 pub id: String,
71
72 pub source_id: String,
74
75 pub target_id: String,
77
78 pub label: String,
80
81 pub edge_type: EdgeType,
83
84 pub attributes: HashMap<String, serde_json::Value>,
86
87 pub weight: f32,
89
90 pub confidence: f32,
92
93 pub source_documents: HashSet<String>,
95
96 pub created_at: chrono::DateTime<chrono::Utc>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
102pub enum NodeType {
103 Entity(String),
105
106 Concept,
108
109 Document,
111
112 DocumentChunk,
114
115 Keyword,
117
118 Custom(String),
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
124pub enum EdgeType {
125 Semantic(String),
127
128 Hierarchical,
130
131 Contains,
133
134 References,
136
137 CoOccurs,
139
140 Similar,
142
143 Custom(String),
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct GraphMetrics {
150 pub node_count: usize,
152
153 pub edge_count: usize,
155
156 pub connected_components: usize,
158
159 pub density: f32,
161
162 pub average_degree: f32,
164
165 pub max_degree: usize,
167
168 pub clustering_coefficient: f32,
170
171 pub diameter: Option<usize>,
173
174 pub node_type_distribution: HashMap<String, usize>,
176
177 pub edge_type_distribution: HashMap<String, usize>,
179}
180
181impl KnowledgeGraph {
182 pub fn new() -> Self {
184 let now = chrono::Utc::now();
185 Self {
186 nodes: HashMap::new(),
187 edges: HashMap::new(),
188 adjacency_list: HashMap::new(),
189 reverse_adjacency_list: HashMap::new(),
190 metadata: HashMap::new(),
191 created_at: now,
192 updated_at: now,
193 }
194 }
195
196 pub fn add_node(&mut self, node: GraphNode) -> RragResult<()> {
198 let node_id = node.id.clone();
199
200 self.adjacency_list
202 .entry(node_id.clone())
203 .or_insert_with(HashSet::new);
204 self.reverse_adjacency_list
205 .entry(node_id.clone())
206 .or_insert_with(HashSet::new);
207
208 self.nodes.insert(node_id, node);
210 self.updated_at = chrono::Utc::now();
211
212 Ok(())
213 }
214
215 pub fn add_edge(&mut self, edge: GraphEdge) -> RragResult<()> {
217 let edge_id = edge.id.clone();
218 let source_id = edge.source_id.clone();
219 let target_id = edge.target_id.clone();
220
221 if !self.nodes.contains_key(&source_id) {
223 return Err(crate::RragError::retrieval(format!(
224 "Source node {} not found",
225 source_id
226 )));
227 }
228
229 if !self.nodes.contains_key(&target_id) {
230 return Err(crate::RragError::retrieval(format!(
231 "Target node {} not found",
232 target_id
233 )));
234 }
235
236 self.adjacency_list
238 .entry(source_id.clone())
239 .or_insert_with(HashSet::new)
240 .insert(target_id.clone());
241
242 self.reverse_adjacency_list
243 .entry(target_id.clone())
244 .or_insert_with(HashSet::new)
245 .insert(source_id.clone());
246
247 self.edges.insert(edge_id, edge);
249 self.updated_at = chrono::Utc::now();
250
251 Ok(())
252 }
253
254 pub fn get_node(&self, node_id: &str) -> Option<&GraphNode> {
256 self.nodes.get(node_id)
257 }
258
259 pub fn get_edge(&self, edge_id: &str) -> Option<&GraphEdge> {
261 self.edges.get(edge_id)
262 }
263
264 pub fn get_neighbors(&self, node_id: &str) -> Vec<&GraphNode> {
266 self.adjacency_list
267 .get(node_id)
268 .map(|neighbors| {
269 neighbors
270 .iter()
271 .filter_map(|neighbor_id| self.nodes.get(neighbor_id))
272 .collect()
273 })
274 .unwrap_or_default()
275 }
276
277 pub fn get_incoming_neighbors(&self, node_id: &str) -> Vec<&GraphNode> {
279 self.reverse_adjacency_list
280 .get(node_id)
281 .map(|neighbors| {
282 neighbors
283 .iter()
284 .filter_map(|neighbor_id| self.nodes.get(neighbor_id))
285 .collect()
286 })
287 .unwrap_or_default()
288 }
289
290 pub fn get_node_edges(&self, node_id: &str) -> Vec<&GraphEdge> {
292 self.edges
293 .values()
294 .filter(|edge| edge.source_id == node_id || edge.target_id == node_id)
295 .collect()
296 }
297
298 pub fn remove_node(&mut self, node_id: &str) -> RragResult<()> {
300 if let Some(neighbors) = self.adjacency_list.remove(node_id) {
302 for neighbor in neighbors {
303 if let Some(reverse_neighbors) = self.reverse_adjacency_list.get_mut(&neighbor) {
304 reverse_neighbors.remove(node_id);
305 }
306 }
307 }
308
309 if let Some(incoming_neighbors) = self.reverse_adjacency_list.remove(node_id) {
310 for neighbor in incoming_neighbors {
311 if let Some(outgoing_neighbors) = self.adjacency_list.get_mut(&neighbor) {
312 outgoing_neighbors.remove(node_id);
313 }
314 }
315 }
316
317 let edges_to_remove: Vec<String> = self
319 .edges
320 .iter()
321 .filter(|(_, edge)| edge.source_id == node_id || edge.target_id == node_id)
322 .map(|(edge_id, _)| edge_id.clone())
323 .collect();
324
325 for edge_id in edges_to_remove {
326 self.edges.remove(&edge_id);
327 }
328
329 self.nodes.remove(node_id);
331 self.updated_at = chrono::Utc::now();
332
333 Ok(())
334 }
335
336 pub fn find_nodes_by_type(&self, node_type: &NodeType) -> Vec<&GraphNode> {
338 self.nodes
339 .values()
340 .filter(|node| &node.node_type == node_type)
341 .collect()
342 }
343
344 pub fn find_edges_by_type(&self, edge_type: &EdgeType) -> Vec<&GraphEdge> {
346 self.edges
347 .values()
348 .filter(|edge| &edge.edge_type == edge_type)
349 .collect()
350 }
351
352 pub fn calculate_metrics(&self) -> GraphMetrics {
354 let node_count = self.nodes.len();
355 let edge_count = self.edges.len();
356
357 let max_edges = if node_count > 1 {
359 node_count * (node_count - 1)
360 } else {
361 0
362 };
363 let density = if max_edges > 0 {
364 edge_count as f32 / max_edges as f32
365 } else {
366 0.0
367 };
368
369 let degrees: Vec<usize> = self
371 .adjacency_list
372 .values()
373 .map(|neighbors| neighbors.len())
374 .collect();
375
376 let average_degree = if !degrees.is_empty() {
377 degrees.iter().sum::<usize>() as f32 / degrees.len() as f32
378 } else {
379 0.0
380 };
381
382 let max_degree = degrees.iter().max().copied().unwrap_or(0);
383
384 let connected_components = self.count_connected_components();
386
387 let clustering_coefficient = self.calculate_clustering_coefficient();
389
390 let mut node_type_distribution = HashMap::new();
392 for node in self.nodes.values() {
393 let type_key = self.node_type_key(&node.node_type);
394 *node_type_distribution.entry(type_key).or_insert(0) += 1;
395 }
396
397 let mut edge_type_distribution = HashMap::new();
399 for edge in self.edges.values() {
400 let type_key = self.edge_type_key(&edge.edge_type);
401 *edge_type_distribution.entry(type_key).or_insert(0) += 1;
402 }
403
404 GraphMetrics {
405 node_count,
406 edge_count,
407 connected_components,
408 density,
409 average_degree,
410 max_degree,
411 clustering_coefficient,
412 diameter: None, node_type_distribution,
414 edge_type_distribution,
415 }
416 }
417
418 fn count_connected_components(&self) -> usize {
420 let mut visited = HashSet::new();
421 let mut components = 0;
422
423 for node_id in self.nodes.keys() {
424 if !visited.contains(node_id) {
425 self.dfs_component(node_id, &mut visited);
426 components += 1;
427 }
428 }
429
430 components
431 }
432
433 fn dfs_component(&self, node_id: &str, visited: &mut HashSet<String>) {
435 visited.insert(node_id.to_string());
436
437 if let Some(neighbors) = self.adjacency_list.get(node_id) {
438 for neighbor in neighbors {
439 if !visited.contains(neighbor) {
440 self.dfs_component(neighbor, visited);
441 }
442 }
443 }
444
445 if let Some(incoming_neighbors) = self.reverse_adjacency_list.get(node_id) {
446 for neighbor in incoming_neighbors {
447 if !visited.contains(neighbor) {
448 self.dfs_component(neighbor, visited);
449 }
450 }
451 }
452 }
453
454 fn calculate_clustering_coefficient(&self) -> f32 {
456 let mut total_coefficient = 0.0;
457 let mut nodes_with_neighbors = 0;
458
459 for (_node_id, neighbors) in &self.adjacency_list {
460 if neighbors.len() < 2 {
461 continue;
462 }
463
464 let neighbor_count = neighbors.len();
465 let possible_edges = neighbor_count * (neighbor_count - 1) / 2;
466
467 let mut actual_edges = 0;
469 let neighbor_vec: Vec<_> = neighbors.iter().collect();
470
471 for i in 0..neighbor_vec.len() {
472 for j in (i + 1)..neighbor_vec.len() {
473 let neighbor1 = neighbor_vec[i];
474 let neighbor2 = neighbor_vec[j];
475
476 if let Some(neighbor1_neighbors) = self.adjacency_list.get(neighbor1) {
477 if neighbor1_neighbors.contains(neighbor2) {
478 actual_edges += 1;
479 }
480 }
481 }
482 }
483
484 if possible_edges > 0 {
485 total_coefficient += actual_edges as f32 / possible_edges as f32;
486 nodes_with_neighbors += 1;
487 }
488 }
489
490 if nodes_with_neighbors > 0 {
491 total_coefficient / nodes_with_neighbors as f32
492 } else {
493 0.0
494 }
495 }
496
497 fn node_type_key(&self, node_type: &NodeType) -> String {
499 match node_type {
500 NodeType::Entity(entity_type) => format!("Entity({})", entity_type),
501 NodeType::Concept => "Concept".to_string(),
502 NodeType::Document => "Document".to_string(),
503 NodeType::DocumentChunk => "DocumentChunk".to_string(),
504 NodeType::Keyword => "Keyword".to_string(),
505 NodeType::Custom(custom_type) => format!("Custom({})", custom_type),
506 }
507 }
508
509 fn edge_type_key(&self, edge_type: &EdgeType) -> String {
511 match edge_type {
512 EdgeType::Semantic(relation) => format!("Semantic({})", relation),
513 EdgeType::Hierarchical => "Hierarchical".to_string(),
514 EdgeType::Contains => "Contains".to_string(),
515 EdgeType::References => "References".to_string(),
516 EdgeType::CoOccurs => "CoOccurs".to_string(),
517 EdgeType::Similar => "Similar".to_string(),
518 EdgeType::Custom(custom_type) => format!("Custom({})", custom_type),
519 }
520 }
521
522 pub fn merge(&mut self, other: &KnowledgeGraph) -> RragResult<()> {
524 for (_, node) in &other.nodes {
526 if !self.nodes.contains_key(&node.id) {
527 self.add_node(node.clone())?;
528 }
529 }
530
531 for (_, edge) in &other.edges {
533 if !self.edges.contains_key(&edge.id) {
534 self.add_edge(edge.clone())?;
535 }
536 }
537
538 Ok(())
539 }
540
541 pub fn clear(&mut self) {
543 self.nodes.clear();
544 self.edges.clear();
545 self.adjacency_list.clear();
546 self.reverse_adjacency_list.clear();
547 self.updated_at = chrono::Utc::now();
548 }
549}
550
551impl Default for KnowledgeGraph {
552 fn default() -> Self {
553 Self::new()
554 }
555}
556
557impl GraphNode {
558 pub fn new(label: impl Into<String>, node_type: NodeType) -> Self {
560 Self {
561 id: Uuid::new_v4().to_string(),
562 label: label.into(),
563 node_type,
564 attributes: HashMap::new(),
565 embedding: None,
566 source_documents: HashSet::new(),
567 confidence: 1.0,
568 pagerank_score: None,
569 created_at: chrono::Utc::now(),
570 }
571 }
572
573 pub fn with_id(id: impl Into<String>, label: impl Into<String>, node_type: NodeType) -> Self {
575 Self {
576 id: id.into(),
577 label: label.into(),
578 node_type,
579 attributes: HashMap::new(),
580 embedding: None,
581 source_documents: HashSet::new(),
582 confidence: 1.0,
583 pagerank_score: None,
584 created_at: chrono::Utc::now(),
585 }
586 }
587
588 pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
590 self.attributes.insert(key.into(), value);
591 self
592 }
593
594 pub fn with_embedding(mut self, embedding: Embedding) -> Self {
596 self.embedding = Some(embedding);
597 self
598 }
599
600 pub fn with_confidence(mut self, confidence: f32) -> Self {
602 self.confidence = confidence.clamp(0.0, 1.0);
603 self
604 }
605
606 pub fn with_source_document(mut self, document_id: impl Into<String>) -> Self {
608 self.source_documents.insert(document_id.into());
609 self
610 }
611}
612
613impl GraphEdge {
614 pub fn new(
616 source_id: impl Into<String>,
617 target_id: impl Into<String>,
618 label: impl Into<String>,
619 edge_type: EdgeType,
620 ) -> Self {
621 Self {
622 id: Uuid::new_v4().to_string(),
623 source_id: source_id.into(),
624 target_id: target_id.into(),
625 label: label.into(),
626 edge_type,
627 attributes: HashMap::new(),
628 weight: 1.0,
629 confidence: 1.0,
630 source_documents: HashSet::new(),
631 created_at: chrono::Utc::now(),
632 }
633 }
634
635 pub fn with_id(
637 id: impl Into<String>,
638 source_id: impl Into<String>,
639 target_id: impl Into<String>,
640 label: impl Into<String>,
641 edge_type: EdgeType,
642 ) -> Self {
643 Self {
644 id: id.into(),
645 source_id: source_id.into(),
646 target_id: target_id.into(),
647 label: label.into(),
648 edge_type,
649 attributes: HashMap::new(),
650 weight: 1.0,
651 confidence: 1.0,
652 source_documents: HashSet::new(),
653 created_at: chrono::Utc::now(),
654 }
655 }
656
657 pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
659 self.attributes.insert(key.into(), value);
660 self
661 }
662
663 pub fn with_weight(mut self, weight: f32) -> Self {
665 self.weight = weight.max(0.0);
666 self
667 }
668
669 pub fn with_confidence(mut self, confidence: f32) -> Self {
671 self.confidence = confidence.clamp(0.0, 1.0);
672 self
673 }
674
675 pub fn with_source_document(mut self, document_id: impl Into<String>) -> Self {
677 self.source_documents.insert(document_id.into());
678 self
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685
686 #[test]
687 fn test_knowledge_graph_creation() {
688 let graph = KnowledgeGraph::new();
689 assert!(graph.nodes.is_empty());
690 assert!(graph.edges.is_empty());
691 }
692
693 #[test]
694 fn test_add_node() {
695 let mut graph = KnowledgeGraph::new();
696 let node = GraphNode::new("test_entity", NodeType::Entity("Person".to_string()));
697 let node_id = node.id.clone();
698
699 graph.add_node(node).unwrap();
700 assert!(graph.nodes.contains_key(&node_id));
701 assert!(graph.adjacency_list.contains_key(&node_id));
702 }
703
704 #[test]
705 fn test_add_edge() {
706 let mut graph = KnowledgeGraph::new();
707
708 let node1 = GraphNode::new("person1", NodeType::Entity("Person".to_string()));
709 let node2 = GraphNode::new("person2", NodeType::Entity("Person".to_string()));
710 let node1_id = node1.id.clone();
711 let node2_id = node2.id.clone();
712
713 graph.add_node(node1).unwrap();
714 graph.add_node(node2).unwrap();
715
716 let edge = GraphEdge::new(
717 node1_id.clone(),
718 node2_id.clone(),
719 "knows",
720 EdgeType::Semantic("knows".to_string()),
721 );
722
723 graph.add_edge(edge).unwrap();
724
725 assert_eq!(graph.edges.len(), 1);
726 assert!(graph.adjacency_list[&node1_id].contains(&node2_id));
727 assert!(graph.reverse_adjacency_list[&node2_id].contains(&node1_id));
728 }
729
730 #[test]
731 fn test_get_neighbors() {
732 let mut graph = KnowledgeGraph::new();
733
734 let node1 = GraphNode::new("node1", NodeType::Concept);
735 let node2 = GraphNode::new("node2", NodeType::Concept);
736 let node3 = GraphNode::new("node3", NodeType::Concept);
737
738 let node1_id = node1.id.clone();
739 let node2_id = node2.id.clone();
740 let node3_id = node3.id.clone();
741
742 graph.add_node(node1).unwrap();
743 graph.add_node(node2).unwrap();
744 graph.add_node(node3).unwrap();
745
746 graph
747 .add_edge(GraphEdge::new(
748 node1_id.clone(),
749 node2_id.clone(),
750 "connected",
751 EdgeType::Similar,
752 ))
753 .unwrap();
754
755 graph
756 .add_edge(GraphEdge::new(
757 node1_id.clone(),
758 node3_id.clone(),
759 "connected",
760 EdgeType::Similar,
761 ))
762 .unwrap();
763
764 let neighbors = graph.get_neighbors(&node1_id);
765 assert_eq!(neighbors.len(), 2);
766 }
767
768 #[test]
769 fn test_graph_metrics() {
770 let mut graph = KnowledgeGraph::new();
771
772 let node1 = GraphNode::new("node1", NodeType::Concept);
774 let node2 = GraphNode::new("node2", NodeType::Concept);
775 let node3 = GraphNode::new("node3", NodeType::Concept);
776
777 let node1_id = node1.id.clone();
778 let node2_id = node2.id.clone();
779 let node3_id = node3.id.clone();
780
781 graph.add_node(node1).unwrap();
782 graph.add_node(node2).unwrap();
783 graph.add_node(node3).unwrap();
784
785 graph
786 .add_edge(GraphEdge::new(
787 node1_id.clone(),
788 node2_id.clone(),
789 "edge1",
790 EdgeType::Similar,
791 ))
792 .unwrap();
793
794 graph
795 .add_edge(GraphEdge::new(
796 node2_id.clone(),
797 node3_id.clone(),
798 "edge2",
799 EdgeType::Similar,
800 ))
801 .unwrap();
802
803 let metrics = graph.calculate_metrics();
804 assert_eq!(metrics.node_count, 3);
805 assert_eq!(metrics.edge_count, 2);
806 assert_eq!(metrics.connected_components, 1);
807 }
808}