rexis_rag/graph_retrieval/
graph.rs

1//! # Knowledge Graph Core
2//!
3//! Core graph structures and operations for knowledge graph construction and management.
4
5use crate::{Embedding, RragResult};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use uuid::Uuid;
9
10/// Core knowledge graph structure
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct KnowledgeGraph {
13    /// Graph nodes indexed by ID
14    pub nodes: HashMap<String, GraphNode>,
15
16    /// Graph edges indexed by ID
17    pub edges: HashMap<String, GraphEdge>,
18
19    /// Node adjacency list for efficient traversal
20    pub adjacency_list: HashMap<String, HashSet<String>>,
21
22    /// Reverse adjacency list for incoming edges
23    pub reverse_adjacency_list: HashMap<String, HashSet<String>>,
24
25    /// Graph metadata
26    pub metadata: HashMap<String, serde_json::Value>,
27
28    /// Graph creation timestamp
29    pub created_at: chrono::DateTime<chrono::Utc>,
30
31    /// Last update timestamp
32    pub updated_at: chrono::DateTime<chrono::Utc>,
33}
34
35/// Graph node representing entities and concepts
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct GraphNode {
38    /// Unique node identifier
39    pub id: String,
40
41    /// Node label/name
42    pub label: String,
43
44    /// Node type
45    pub node_type: NodeType,
46
47    /// Node attributes
48    pub attributes: HashMap<String, serde_json::Value>,
49
50    /// Associated embedding for semantic operations
51    pub embedding: Option<Embedding>,
52
53    /// Source document references
54    pub source_documents: HashSet<String>,
55
56    /// Node confidence score
57    pub confidence: f32,
58
59    /// PageRank score
60    pub pagerank_score: Option<f32>,
61
62    /// Creation timestamp
63    pub created_at: chrono::DateTime<chrono::Utc>,
64}
65
66/// Graph edge representing relationships
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct GraphEdge {
69    /// Unique edge identifier
70    pub id: String,
71
72    /// Source node ID
73    pub source_id: String,
74
75    /// Target node ID
76    pub target_id: String,
77
78    /// Edge label/relationship type
79    pub label: String,
80
81    /// Edge type
82    pub edge_type: EdgeType,
83
84    /// Edge attributes
85    pub attributes: HashMap<String, serde_json::Value>,
86
87    /// Edge weight/strength
88    pub weight: f32,
89
90    /// Edge confidence score
91    pub confidence: f32,
92
93    /// Source document references
94    pub source_documents: HashSet<String>,
95
96    /// Creation timestamp
97    pub created_at: chrono::DateTime<chrono::Utc>,
98}
99
100/// Node types in the knowledge graph
101#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
102pub enum NodeType {
103    /// Named entity (person, organization, location, etc.)
104    Entity(String),
105
106    /// Concept or topic
107    Concept,
108
109    /// Document node
110    Document,
111
112    /// Document chunk/segment
113    DocumentChunk,
114
115    /// Keyword or term
116    Keyword,
117
118    /// Custom node type
119    Custom(String),
120}
121
122/// Edge types in the knowledge graph
123#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
124pub enum EdgeType {
125    /// Semantic relationship
126    Semantic(String),
127
128    /// Hierarchical relationship
129    Hierarchical,
130
131    /// Document containment relationship
132    Contains,
133
134    /// Reference/citation relationship
135    References,
136
137    /// Co-occurrence relationship
138    CoOccurs,
139
140    /// Similarity relationship
141    Similar,
142
143    /// Custom edge type
144    Custom(String),
145}
146
147/// Graph metrics and statistics
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct GraphMetrics {
150    /// Number of nodes
151    pub node_count: usize,
152
153    /// Number of edges
154    pub edge_count: usize,
155
156    /// Number of connected components
157    pub connected_components: usize,
158
159    /// Graph density
160    pub density: f32,
161
162    /// Average degree
163    pub average_degree: f32,
164
165    /// Maximum degree
166    pub max_degree: usize,
167
168    /// Clustering coefficient
169    pub clustering_coefficient: f32,
170
171    /// Diameter (longest shortest path)
172    pub diameter: Option<usize>,
173
174    /// Node type distribution
175    pub node_type_distribution: HashMap<String, usize>,
176
177    /// Edge type distribution
178    pub edge_type_distribution: HashMap<String, usize>,
179}
180
181impl KnowledgeGraph {
182    /// Create a new empty knowledge graph
183    pub fn new() -> Self {
184        let now = chrono::Utc::now();
185        Self {
186            nodes: HashMap::new(),
187            edges: HashMap::new(),
188            adjacency_list: HashMap::new(),
189            reverse_adjacency_list: HashMap::new(),
190            metadata: HashMap::new(),
191            created_at: now,
192            updated_at: now,
193        }
194    }
195
196    /// Add a node to the graph
197    pub fn add_node(&mut self, node: GraphNode) -> RragResult<()> {
198        let node_id = node.id.clone();
199
200        // Initialize adjacency lists
201        self.adjacency_list
202            .entry(node_id.clone())
203            .or_insert_with(HashSet::new);
204        self.reverse_adjacency_list
205            .entry(node_id.clone())
206            .or_insert_with(HashSet::new);
207
208        // Insert node
209        self.nodes.insert(node_id, node);
210        self.updated_at = chrono::Utc::now();
211
212        Ok(())
213    }
214
215    /// Add an edge to the graph
216    pub fn add_edge(&mut self, edge: GraphEdge) -> RragResult<()> {
217        let edge_id = edge.id.clone();
218        let source_id = edge.source_id.clone();
219        let target_id = edge.target_id.clone();
220
221        // Verify nodes exist
222        if !self.nodes.contains_key(&source_id) {
223            return Err(crate::RragError::retrieval(format!(
224                "Source node {} not found",
225                source_id
226            )));
227        }
228
229        if !self.nodes.contains_key(&target_id) {
230            return Err(crate::RragError::retrieval(format!(
231                "Target node {} not found",
232                target_id
233            )));
234        }
235
236        // Update adjacency lists
237        self.adjacency_list
238            .entry(source_id.clone())
239            .or_insert_with(HashSet::new)
240            .insert(target_id.clone());
241
242        self.reverse_adjacency_list
243            .entry(target_id.clone())
244            .or_insert_with(HashSet::new)
245            .insert(source_id.clone());
246
247        // Insert edge
248        self.edges.insert(edge_id, edge);
249        self.updated_at = chrono::Utc::now();
250
251        Ok(())
252    }
253
254    /// Get node by ID
255    pub fn get_node(&self, node_id: &str) -> Option<&GraphNode> {
256        self.nodes.get(node_id)
257    }
258
259    /// Get edge by ID
260    pub fn get_edge(&self, edge_id: &str) -> Option<&GraphEdge> {
261        self.edges.get(edge_id)
262    }
263
264    /// Get neighbors of a node
265    pub fn get_neighbors(&self, node_id: &str) -> Vec<&GraphNode> {
266        self.adjacency_list
267            .get(node_id)
268            .map(|neighbors| {
269                neighbors
270                    .iter()
271                    .filter_map(|neighbor_id| self.nodes.get(neighbor_id))
272                    .collect()
273            })
274            .unwrap_or_default()
275    }
276
277    /// Get incoming neighbors of a node
278    pub fn get_incoming_neighbors(&self, node_id: &str) -> Vec<&GraphNode> {
279        self.reverse_adjacency_list
280            .get(node_id)
281            .map(|neighbors| {
282                neighbors
283                    .iter()
284                    .filter_map(|neighbor_id| self.nodes.get(neighbor_id))
285                    .collect()
286            })
287            .unwrap_or_default()
288    }
289
290    /// Get edges connected to a node
291    pub fn get_node_edges(&self, node_id: &str) -> Vec<&GraphEdge> {
292        self.edges
293            .values()
294            .filter(|edge| edge.source_id == node_id || edge.target_id == node_id)
295            .collect()
296    }
297
298    /// Remove a node and its connected edges
299    pub fn remove_node(&mut self, node_id: &str) -> RragResult<()> {
300        // Remove from adjacency lists
301        if let Some(neighbors) = self.adjacency_list.remove(node_id) {
302            for neighbor in neighbors {
303                if let Some(reverse_neighbors) = self.reverse_adjacency_list.get_mut(&neighbor) {
304                    reverse_neighbors.remove(node_id);
305                }
306            }
307        }
308
309        if let Some(incoming_neighbors) = self.reverse_adjacency_list.remove(node_id) {
310            for neighbor in incoming_neighbors {
311                if let Some(outgoing_neighbors) = self.adjacency_list.get_mut(&neighbor) {
312                    outgoing_neighbors.remove(node_id);
313                }
314            }
315        }
316
317        // Remove connected edges
318        let edges_to_remove: Vec<String> = self
319            .edges
320            .iter()
321            .filter(|(_, edge)| edge.source_id == node_id || edge.target_id == node_id)
322            .map(|(edge_id, _)| edge_id.clone())
323            .collect();
324
325        for edge_id in edges_to_remove {
326            self.edges.remove(&edge_id);
327        }
328
329        // Remove node
330        self.nodes.remove(node_id);
331        self.updated_at = chrono::Utc::now();
332
333        Ok(())
334    }
335
336    /// Find nodes by type
337    pub fn find_nodes_by_type(&self, node_type: &NodeType) -> Vec<&GraphNode> {
338        self.nodes
339            .values()
340            .filter(|node| &node.node_type == node_type)
341            .collect()
342    }
343
344    /// Find edges by type
345    pub fn find_edges_by_type(&self, edge_type: &EdgeType) -> Vec<&GraphEdge> {
346        self.edges
347            .values()
348            .filter(|edge| &edge.edge_type == edge_type)
349            .collect()
350    }
351
352    /// Calculate graph metrics
353    pub fn calculate_metrics(&self) -> GraphMetrics {
354        let node_count = self.nodes.len();
355        let edge_count = self.edges.len();
356
357        // Calculate density
358        let max_edges = if node_count > 1 {
359            node_count * (node_count - 1)
360        } else {
361            0
362        };
363        let density = if max_edges > 0 {
364            edge_count as f32 / max_edges as f32
365        } else {
366            0.0
367        };
368
369        // Calculate degree statistics
370        let degrees: Vec<usize> = self
371            .adjacency_list
372            .values()
373            .map(|neighbors| neighbors.len())
374            .collect();
375
376        let average_degree = if !degrees.is_empty() {
377            degrees.iter().sum::<usize>() as f32 / degrees.len() as f32
378        } else {
379            0.0
380        };
381
382        let max_degree = degrees.iter().max().copied().unwrap_or(0);
383
384        // Calculate connected components
385        let connected_components = self.count_connected_components();
386
387        // Calculate clustering coefficient
388        let clustering_coefficient = self.calculate_clustering_coefficient();
389
390        // Node type distribution
391        let mut node_type_distribution = HashMap::new();
392        for node in self.nodes.values() {
393            let type_key = self.node_type_key(&node.node_type);
394            *node_type_distribution.entry(type_key).or_insert(0) += 1;
395        }
396
397        // Edge type distribution
398        let mut edge_type_distribution = HashMap::new();
399        for edge in self.edges.values() {
400            let type_key = self.edge_type_key(&edge.edge_type);
401            *edge_type_distribution.entry(type_key).or_insert(0) += 1;
402        }
403
404        GraphMetrics {
405            node_count,
406            edge_count,
407            connected_components,
408            density,
409            average_degree,
410            max_degree,
411            clustering_coefficient,
412            diameter: None, // Expensive to calculate, could be computed on demand
413            node_type_distribution,
414            edge_type_distribution,
415        }
416    }
417
418    /// Count connected components using DFS
419    fn count_connected_components(&self) -> usize {
420        let mut visited = HashSet::new();
421        let mut components = 0;
422
423        for node_id in self.nodes.keys() {
424            if !visited.contains(node_id) {
425                self.dfs_component(node_id, &mut visited);
426                components += 1;
427            }
428        }
429
430        components
431    }
432
433    /// DFS helper for connected components
434    fn dfs_component(&self, node_id: &str, visited: &mut HashSet<String>) {
435        visited.insert(node_id.to_string());
436
437        if let Some(neighbors) = self.adjacency_list.get(node_id) {
438            for neighbor in neighbors {
439                if !visited.contains(neighbor) {
440                    self.dfs_component(neighbor, visited);
441                }
442            }
443        }
444
445        if let Some(incoming_neighbors) = self.reverse_adjacency_list.get(node_id) {
446            for neighbor in incoming_neighbors {
447                if !visited.contains(neighbor) {
448                    self.dfs_component(neighbor, visited);
449                }
450            }
451        }
452    }
453
454    /// Calculate clustering coefficient
455    fn calculate_clustering_coefficient(&self) -> f32 {
456        let mut total_coefficient = 0.0;
457        let mut nodes_with_neighbors = 0;
458
459        for (_node_id, neighbors) in &self.adjacency_list {
460            if neighbors.len() < 2 {
461                continue;
462            }
463
464            let neighbor_count = neighbors.len();
465            let possible_edges = neighbor_count * (neighbor_count - 1) / 2;
466
467            // Count actual edges between neighbors
468            let mut actual_edges = 0;
469            let neighbor_vec: Vec<_> = neighbors.iter().collect();
470
471            for i in 0..neighbor_vec.len() {
472                for j in (i + 1)..neighbor_vec.len() {
473                    let neighbor1 = neighbor_vec[i];
474                    let neighbor2 = neighbor_vec[j];
475
476                    if let Some(neighbor1_neighbors) = self.adjacency_list.get(neighbor1) {
477                        if neighbor1_neighbors.contains(neighbor2) {
478                            actual_edges += 1;
479                        }
480                    }
481                }
482            }
483
484            if possible_edges > 0 {
485                total_coefficient += actual_edges as f32 / possible_edges as f32;
486                nodes_with_neighbors += 1;
487            }
488        }
489
490        if nodes_with_neighbors > 0 {
491            total_coefficient / nodes_with_neighbors as f32
492        } else {
493            0.0
494        }
495    }
496
497    /// Get string representation of node type
498    fn node_type_key(&self, node_type: &NodeType) -> String {
499        match node_type {
500            NodeType::Entity(entity_type) => format!("Entity({})", entity_type),
501            NodeType::Concept => "Concept".to_string(),
502            NodeType::Document => "Document".to_string(),
503            NodeType::DocumentChunk => "DocumentChunk".to_string(),
504            NodeType::Keyword => "Keyword".to_string(),
505            NodeType::Custom(custom_type) => format!("Custom({})", custom_type),
506        }
507    }
508
509    /// Get string representation of edge type
510    fn edge_type_key(&self, edge_type: &EdgeType) -> String {
511        match edge_type {
512            EdgeType::Semantic(relation) => format!("Semantic({})", relation),
513            EdgeType::Hierarchical => "Hierarchical".to_string(),
514            EdgeType::Contains => "Contains".to_string(),
515            EdgeType::References => "References".to_string(),
516            EdgeType::CoOccurs => "CoOccurs".to_string(),
517            EdgeType::Similar => "Similar".to_string(),
518            EdgeType::Custom(custom_type) => format!("Custom({})", custom_type),
519        }
520    }
521
522    /// Merge another graph into this one
523    pub fn merge(&mut self, other: &KnowledgeGraph) -> RragResult<()> {
524        // Add nodes
525        for (_, node) in &other.nodes {
526            if !self.nodes.contains_key(&node.id) {
527                self.add_node(node.clone())?;
528            }
529        }
530
531        // Add edges
532        for (_, edge) in &other.edges {
533            if !self.edges.contains_key(&edge.id) {
534                self.add_edge(edge.clone())?;
535            }
536        }
537
538        Ok(())
539    }
540
541    /// Clear the entire graph
542    pub fn clear(&mut self) {
543        self.nodes.clear();
544        self.edges.clear();
545        self.adjacency_list.clear();
546        self.reverse_adjacency_list.clear();
547        self.updated_at = chrono::Utc::now();
548    }
549}
550
551impl Default for KnowledgeGraph {
552    fn default() -> Self {
553        Self::new()
554    }
555}
556
557impl GraphNode {
558    /// Create a new graph node
559    pub fn new(label: impl Into<String>, node_type: NodeType) -> Self {
560        Self {
561            id: Uuid::new_v4().to_string(),
562            label: label.into(),
563            node_type,
564            attributes: HashMap::new(),
565            embedding: None,
566            source_documents: HashSet::new(),
567            confidence: 1.0,
568            pagerank_score: None,
569            created_at: chrono::Utc::now(),
570        }
571    }
572
573    /// Create node with specific ID
574    pub fn with_id(id: impl Into<String>, label: impl Into<String>, node_type: NodeType) -> Self {
575        Self {
576            id: id.into(),
577            label: label.into(),
578            node_type,
579            attributes: HashMap::new(),
580            embedding: None,
581            source_documents: HashSet::new(),
582            confidence: 1.0,
583            pagerank_score: None,
584            created_at: chrono::Utc::now(),
585        }
586    }
587
588    /// Add attribute using builder pattern
589    pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
590        self.attributes.insert(key.into(), value);
591        self
592    }
593
594    /// Set embedding
595    pub fn with_embedding(mut self, embedding: Embedding) -> Self {
596        self.embedding = Some(embedding);
597        self
598    }
599
600    /// Set confidence score
601    pub fn with_confidence(mut self, confidence: f32) -> Self {
602        self.confidence = confidence.clamp(0.0, 1.0);
603        self
604    }
605
606    /// Add source document
607    pub fn with_source_document(mut self, document_id: impl Into<String>) -> Self {
608        self.source_documents.insert(document_id.into());
609        self
610    }
611}
612
613impl GraphEdge {
614    /// Create a new graph edge
615    pub fn new(
616        source_id: impl Into<String>,
617        target_id: impl Into<String>,
618        label: impl Into<String>,
619        edge_type: EdgeType,
620    ) -> Self {
621        Self {
622            id: Uuid::new_v4().to_string(),
623            source_id: source_id.into(),
624            target_id: target_id.into(),
625            label: label.into(),
626            edge_type,
627            attributes: HashMap::new(),
628            weight: 1.0,
629            confidence: 1.0,
630            source_documents: HashSet::new(),
631            created_at: chrono::Utc::now(),
632        }
633    }
634
635    /// Create edge with specific ID
636    pub fn with_id(
637        id: impl Into<String>,
638        source_id: impl Into<String>,
639        target_id: impl Into<String>,
640        label: impl Into<String>,
641        edge_type: EdgeType,
642    ) -> Self {
643        Self {
644            id: id.into(),
645            source_id: source_id.into(),
646            target_id: target_id.into(),
647            label: label.into(),
648            edge_type,
649            attributes: HashMap::new(),
650            weight: 1.0,
651            confidence: 1.0,
652            source_documents: HashSet::new(),
653            created_at: chrono::Utc::now(),
654        }
655    }
656
657    /// Add attribute using builder pattern
658    pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
659        self.attributes.insert(key.into(), value);
660        self
661    }
662
663    /// Set edge weight
664    pub fn with_weight(mut self, weight: f32) -> Self {
665        self.weight = weight.max(0.0);
666        self
667    }
668
669    /// Set confidence score
670    pub fn with_confidence(mut self, confidence: f32) -> Self {
671        self.confidence = confidence.clamp(0.0, 1.0);
672        self
673    }
674
675    /// Add source document
676    pub fn with_source_document(mut self, document_id: impl Into<String>) -> Self {
677        self.source_documents.insert(document_id.into());
678        self
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685
686    #[test]
687    fn test_knowledge_graph_creation() {
688        let graph = KnowledgeGraph::new();
689        assert!(graph.nodes.is_empty());
690        assert!(graph.edges.is_empty());
691    }
692
693    #[test]
694    fn test_add_node() {
695        let mut graph = KnowledgeGraph::new();
696        let node = GraphNode::new("test_entity", NodeType::Entity("Person".to_string()));
697        let node_id = node.id.clone();
698
699        graph.add_node(node).unwrap();
700        assert!(graph.nodes.contains_key(&node_id));
701        assert!(graph.adjacency_list.contains_key(&node_id));
702    }
703
704    #[test]
705    fn test_add_edge() {
706        let mut graph = KnowledgeGraph::new();
707
708        let node1 = GraphNode::new("person1", NodeType::Entity("Person".to_string()));
709        let node2 = GraphNode::new("person2", NodeType::Entity("Person".to_string()));
710        let node1_id = node1.id.clone();
711        let node2_id = node2.id.clone();
712
713        graph.add_node(node1).unwrap();
714        graph.add_node(node2).unwrap();
715
716        let edge = GraphEdge::new(
717            node1_id.clone(),
718            node2_id.clone(),
719            "knows",
720            EdgeType::Semantic("knows".to_string()),
721        );
722
723        graph.add_edge(edge).unwrap();
724
725        assert_eq!(graph.edges.len(), 1);
726        assert!(graph.adjacency_list[&node1_id].contains(&node2_id));
727        assert!(graph.reverse_adjacency_list[&node2_id].contains(&node1_id));
728    }
729
730    #[test]
731    fn test_get_neighbors() {
732        let mut graph = KnowledgeGraph::new();
733
734        let node1 = GraphNode::new("node1", NodeType::Concept);
735        let node2 = GraphNode::new("node2", NodeType::Concept);
736        let node3 = GraphNode::new("node3", NodeType::Concept);
737
738        let node1_id = node1.id.clone();
739        let node2_id = node2.id.clone();
740        let node3_id = node3.id.clone();
741
742        graph.add_node(node1).unwrap();
743        graph.add_node(node2).unwrap();
744        graph.add_node(node3).unwrap();
745
746        graph
747            .add_edge(GraphEdge::new(
748                node1_id.clone(),
749                node2_id.clone(),
750                "connected",
751                EdgeType::Similar,
752            ))
753            .unwrap();
754
755        graph
756            .add_edge(GraphEdge::new(
757                node1_id.clone(),
758                node3_id.clone(),
759                "connected",
760                EdgeType::Similar,
761            ))
762            .unwrap();
763
764        let neighbors = graph.get_neighbors(&node1_id);
765        assert_eq!(neighbors.len(), 2);
766    }
767
768    #[test]
769    fn test_graph_metrics() {
770        let mut graph = KnowledgeGraph::new();
771
772        // Create a simple graph with 3 nodes and 2 edges
773        let node1 = GraphNode::new("node1", NodeType::Concept);
774        let node2 = GraphNode::new("node2", NodeType::Concept);
775        let node3 = GraphNode::new("node3", NodeType::Concept);
776
777        let node1_id = node1.id.clone();
778        let node2_id = node2.id.clone();
779        let node3_id = node3.id.clone();
780
781        graph.add_node(node1).unwrap();
782        graph.add_node(node2).unwrap();
783        graph.add_node(node3).unwrap();
784
785        graph
786            .add_edge(GraphEdge::new(
787                node1_id.clone(),
788                node2_id.clone(),
789                "edge1",
790                EdgeType::Similar,
791            ))
792            .unwrap();
793
794        graph
795            .add_edge(GraphEdge::new(
796                node2_id.clone(),
797                node3_id.clone(),
798                "edge2",
799                EdgeType::Similar,
800            ))
801            .unwrap();
802
803        let metrics = graph.calculate_metrics();
804        assert_eq!(metrics.node_count, 3);
805        assert_eq!(metrics.edge_count, 2);
806        assert_eq!(metrics.connected_components, 1);
807    }
808}