rexis_rag/graph_retrieval/
retriever.rs

1//! # Graph-Based Retriever
2//!
3//! Main retriever implementation that integrates graph-based algorithms with traditional retrieval methods.
4
5use super::{
6    algorithms::GraphAlgorithms,
7    query_expansion::{ExpansionOptions, ExpansionStrategy, GraphQueryExpander, QueryExpander},
8    storage::GraphStorage,
9    GraphNode, KnowledgeGraph,
10};
11use crate::{
12    retrieval_core::{IndexStats, QueryType},
13    Document, DocumentChunk, Embedding, Retriever, RragResult, SearchQuery, SearchResult,
14};
15use async_trait::async_trait;
16use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, HashSet};
18
19/// Graph-based retriever that combines traditional and graph-based search
20pub struct GraphRetriever {
21    /// Knowledge graph (using RwLock for interior mutability)
22    graph: tokio::sync::RwLock<KnowledgeGraph>,
23
24    /// Graph storage backend
25    storage: tokio::sync::RwLock<Box<dyn GraphStorage>>,
26
27    /// Query expander (using RwLock for interior mutability)
28    query_expander: tokio::sync::RwLock<GraphQueryExpander>,
29
30    /// Configuration
31    config: GraphRetrievalConfig,
32
33    /// PageRank scores cache
34    pagerank_cache: tokio::sync::RwLock<Option<HashMap<String, f32>>>,
35
36    /// Entity to document mapping (using RwLock for interior mutability)
37    entity_document_map: tokio::sync::RwLock<HashMap<String, HashSet<String>>>,
38}
39
40/// Graph retrieval configuration
41#[derive(Debug, Clone)]
42pub struct GraphRetrievalConfig {
43    /// Enable query expansion
44    pub enable_query_expansion: bool,
45
46    /// Enable PageRank scoring
47    pub enable_pagerank_scoring: bool,
48
49    /// Enable path-based retrieval
50    pub enable_path_based_retrieval: bool,
51
52    /// Weight for graph-based scores vs traditional similarity
53    pub graph_weight: f32,
54
55    /// Weight for traditional similarity scores
56    pub similarity_weight: f32,
57
58    /// Maximum number of graph hops for retrieval
59    pub max_graph_hops: usize,
60
61    /// Minimum graph score threshold
62    pub min_graph_score: f32,
63
64    /// Query expansion configuration
65    pub expansion_options: ExpansionOptions,
66
67    /// PageRank configuration
68    pub pagerank_config: super::algorithms::PageRankConfig,
69
70    /// Enable result diversification
71    pub enable_diversification: bool,
72
73    /// Diversification factor (0.0 to 1.0)
74    pub diversification_factor: f32,
75}
76
77impl Default for GraphRetrievalConfig {
78    fn default() -> Self {
79        Self {
80            enable_query_expansion: true,
81            enable_pagerank_scoring: true,
82            enable_path_based_retrieval: true,
83            graph_weight: 0.4,
84            similarity_weight: 0.6,
85            max_graph_hops: 3,
86            min_graph_score: 0.1,
87            expansion_options: ExpansionOptions {
88                strategies: vec![
89                    ExpansionStrategy::Semantic,
90                    ExpansionStrategy::Similarity,
91                    ExpansionStrategy::CoOccurrence,
92                ],
93                max_terms: Some(10),
94                min_confidence: 0.3,
95                ..Default::default()
96            },
97            pagerank_config: super::algorithms::PageRankConfig::default(),
98            enable_diversification: true,
99            diversification_factor: 0.3,
100        }
101    }
102}
103
104/// Graph search result with additional graph-specific information
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct GraphSearchResult {
107    /// Base search result
108    pub search_result: SearchResult,
109
110    /// Graph-based score
111    pub graph_score: f32,
112
113    /// PageRank score of associated entities
114    pub pagerank_score: f32,
115
116    /// Related entities found in the content
117    pub related_entities: Vec<String>,
118
119    /// Graph paths that led to this result
120    pub graph_paths: Vec<GraphPath>,
121
122    /// Expanded query terms that matched
123    pub matched_expansions: Vec<String>,
124}
125
126/// Graph path information
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct GraphPath {
129    /// Node IDs in the path
130    pub nodes: Vec<String>,
131
132    /// Path score
133    pub score: f32,
134
135    /// Path type/description
136    pub path_type: String,
137
138    /// Path length
139    pub length: usize,
140}
141
142impl GraphRetriever {
143    /// Create a new graph retriever
144    pub fn new(
145        graph: KnowledgeGraph,
146        storage: Box<dyn GraphStorage>,
147        config: GraphRetrievalConfig,
148    ) -> RragResult<Self> {
149        let query_expander = GraphQueryExpander::new(
150            graph.clone(),
151            super::query_expansion::ExpansionConfig::default(),
152        );
153
154        let mut entity_document_map = HashMap::new();
155
156        // Build entity-document mapping
157        for (_, node) in &graph.nodes {
158            for doc_id in &node.source_documents {
159                entity_document_map
160                    .entry(node.id.clone())
161                    .or_insert_with(HashSet::new)
162                    .insert(doc_id.clone());
163            }
164        }
165
166        let retriever = Self {
167            graph: tokio::sync::RwLock::new(graph),
168            storage: tokio::sync::RwLock::new(storage),
169            query_expander: tokio::sync::RwLock::new(query_expander),
170            config,
171            pagerank_cache: tokio::sync::RwLock::new(None),
172            entity_document_map: tokio::sync::RwLock::new(entity_document_map),
173        };
174
175        Ok(retriever)
176    }
177
178    /// Update the knowledge graph
179    pub async fn update_graph(&self, graph: KnowledgeGraph) -> RragResult<()> {
180        *self.graph.write().await = graph.clone();
181        self.query_expander
182            .write()
183            .await
184            .update_graph(graph.clone())
185            .await;
186
187        // Rebuild entity-document mapping
188        let mut entity_map = self.entity_document_map.write().await;
189        entity_map.clear();
190        for (_, node) in &graph.nodes {
191            for doc_id in &node.source_documents {
192                entity_map
193                    .entry(node.id.clone())
194                    .or_insert_with(HashSet::new)
195                    .insert(doc_id.clone());
196            }
197        }
198
199        // Clear PageRank cache
200        *self.pagerank_cache.write().await = None;
201
202        // Update storage
203        self.storage.write().await.store_graph(&graph).await?;
204
205        Ok(())
206    }
207
208    /// Get or compute PageRank scores
209    async fn get_pagerank_scores(&self) -> RragResult<HashMap<String, f32>> {
210        let mut cache = self.pagerank_cache.write().await;
211
212        if cache.is_none() {
213            let graph = self.graph.read().await;
214            let scores = GraphAlgorithms::pagerank(&*graph, &self.config.pagerank_config)?;
215            *cache = Some(scores);
216        }
217
218        Ok(cache.clone().unwrap())
219    }
220
221    /// Expand query using graph structure
222    async fn expand_query(&self, query: &str) -> RragResult<Vec<String>> {
223        if !self.config.enable_query_expansion {
224            return Ok(vec![query.to_string()]);
225        }
226
227        let expansion_result = self
228            .query_expander
229            .read()
230            .await
231            .expand_query(query, &self.config.expansion_options)
232            .await?;
233
234        let mut terms = vec![query.to_string()];
235        terms.extend(expansion_result.expanded_terms.into_iter().map(|t| t.term));
236
237        Ok(terms)
238    }
239
240    /// Find entities related to query
241    async fn find_query_entities(&self, query: &str) -> Vec<String> {
242        let query_lower = query.to_lowercase();
243        let mut entities = Vec::new();
244
245        let graph = self.graph.read().await;
246
247        // Find entities that match query terms
248        for (entity_id, node) in &graph.nodes {
249            let label_lower = node.label.to_lowercase();
250            if query_lower.contains(&label_lower) || label_lower.contains(&query_lower) {
251                entities.push(entity_id.clone());
252            }
253        }
254
255        entities
256    }
257
258    /// Add documents and their entities to the graph
259    pub async fn add_document_with_entities(
260        &self,
261        document: &Document,
262        entities: Vec<GraphNode>,
263        relationships: Vec<super::GraphEdge>,
264    ) -> RragResult<()> {
265        let mut graph = self.graph.write().await;
266
267        // Add document node
268        let doc_node = GraphNode::new(format!("doc_{}", document.id), super::NodeType::Document)
269            .with_source_document(document.id.clone())
270            .with_attribute(
271                "title",
272                serde_json::Value::String(
273                    document
274                        .metadata
275                        .get("title")
276                        .and_then(|v| v.as_str())
277                        .unwrap_or(&document.id)
278                        .to_string(),
279                ),
280            );
281
282        graph.add_node(doc_node.clone())?;
283
284        // Add entities and connect them to the document
285        for entity in entities {
286            let entity_id = entity.id.clone();
287            graph.add_node(entity)?;
288
289            // Create containment edge from document to entity
290            let containment_edge = super::GraphEdge::new(
291                doc_node.id.clone(),
292                entity_id.clone(),
293                "contains",
294                super::EdgeType::Contains,
295            );
296            graph.add_edge(containment_edge)?;
297
298            // Update entity-document mapping
299            self.entity_document_map
300                .write()
301                .await
302                .entry(entity_id)
303                .or_insert_with(HashSet::new)
304                .insert(document.id.clone());
305        }
306
307        // Add relationships
308        for relationship in relationships {
309            graph.add_edge(relationship)?;
310        }
311
312        // Clear PageRank cache
313        *self.pagerank_cache.write().await = None;
314
315        // Update storage
316        self.storage.write().await.store_graph(&*graph).await?;
317
318        Ok(())
319    }
320}
321
322#[async_trait]
323impl Retriever for GraphRetriever {
324    fn name(&self) -> &str {
325        "graph_retriever"
326    }
327
328    async fn search(&self, query: &SearchQuery) -> RragResult<Vec<SearchResult>> {
329        let query_text = match &query.query {
330            QueryType::Text(text) => text,
331            QueryType::Embedding(_) => {
332                // For embedding queries, we can't do text-based entity extraction
333                // Fall back to basic similarity search (would need embedding-based entity matching)
334                return Ok(Vec::new());
335            }
336        };
337
338        // Expand query if enabled
339        let expanded_terms = self.expand_query(query_text).await?;
340        let expanded_query = expanded_terms.join(" ");
341
342        // Find entities in the (expanded) query
343        let query_entities = self.find_query_entities(&expanded_query).await;
344
345        // For simplicity, return basic results based on entity matching
346        let mut results = Vec::new();
347
348        let entity_map = self.entity_document_map.read().await;
349        let pagerank_scores = if self.config.enable_pagerank_scoring {
350            self.get_pagerank_scores().await?
351        } else {
352            HashMap::new()
353        };
354
355        // Find documents connected to query entities
356        let mut candidate_docs = HashSet::new();
357        for entity_id in &query_entities {
358            if let Some(doc_ids) = entity_map.get(entity_id) {
359                candidate_docs.extend(doc_ids.clone());
360            }
361        }
362
363        // Create search results for candidate documents
364        for (rank, doc_id) in candidate_docs.iter().enumerate() {
365            // Calculate graph-based score
366            let mut graph_score = 0.5; // Base score
367
368            // Add PageRank contribution
369            for entity_id in &query_entities {
370                if let Some(doc_ids) = entity_map.get(entity_id) {
371                    if doc_ids.contains(doc_id) {
372                        let pagerank_score = pagerank_scores.get(entity_id).copied().unwrap_or(0.0);
373                        graph_score += pagerank_score * 0.3;
374                    }
375                }
376            }
377
378            if graph_score >= self.config.min_graph_score {
379                let result = SearchResult {
380                    id: doc_id.clone(),
381                    content: format!("Document {}", doc_id), // Placeholder
382                    score: graph_score,
383                    rank,
384                    metadata: {
385                        let mut metadata = HashMap::new();
386                        metadata.insert("graph_score".to_string(), serde_json::json!(graph_score));
387                        metadata
388                    },
389                    embedding: None,
390                };
391
392                results.push(result);
393            }
394        }
395
396        // Sort by score and apply limits
397        results.sort_by(|a, b| {
398            b.score
399                .partial_cmp(&a.score)
400                .unwrap_or(std::cmp::Ordering::Equal)
401        });
402        results.retain(|result| result.score >= query.min_score);
403        results.truncate(query.limit);
404
405        // Update ranks
406        for (i, result) in results.iter_mut().enumerate() {
407            result.rank = i;
408        }
409
410        Ok(results)
411    }
412
413    async fn add_documents(&self, documents: &[(Document, Embedding)]) -> RragResult<()> {
414        // This would typically involve:
415        // 1. Extracting entities and relationships from documents
416        // 2. Adding them to the graph
417        // 3. Updating storage
418        // For now, just add document nodes
419
420        let mut graph = self.graph.write().await;
421        let mut nodes = Vec::new();
422
423        for (document, _embedding) in documents {
424            let doc_node =
425                GraphNode::new(format!("doc_{}", document.id), super::NodeType::Document)
426                    .with_source_document(document.id.clone());
427
428            nodes.push(doc_node.clone());
429            graph.add_node(doc_node)?;
430        }
431
432        self.storage.write().await.store_nodes(&nodes).await?;
433
434        Ok(())
435    }
436
437    async fn add_chunks(&self, chunks: &[(DocumentChunk, Embedding)]) -> RragResult<()> {
438        // Similar to add_documents but for chunks
439        let mut graph = self.graph.write().await;
440        let mut nodes = Vec::new();
441
442        for (chunk, _embedding) in chunks {
443            let chunk_node = GraphNode::new(
444                format!("chunk_{}_{}", chunk.document_id, chunk.chunk_index),
445                super::NodeType::DocumentChunk,
446            )
447            .with_source_document(chunk.document_id.clone());
448
449            nodes.push(chunk_node.clone());
450            graph.add_node(chunk_node)?;
451        }
452
453        self.storage.write().await.store_nodes(&nodes).await?;
454
455        Ok(())
456    }
457
458    async fn remove_documents(&self, document_ids: &[String]) -> RragResult<()> {
459        let mut graph = self.graph.write().await;
460
461        // Remove document nodes and update entity mappings
462        let doc_node_ids: Vec<_> = document_ids
463            .iter()
464            .map(|doc_id| format!("doc_{}", doc_id))
465            .collect();
466
467        for node_id in &doc_node_ids {
468            graph.remove_node(node_id)?;
469        }
470
471        // Update entity-document mapping
472        let mut entity_map = self.entity_document_map.write().await;
473        for doc_id in document_ids {
474            for entity_docs in entity_map.values_mut() {
475                entity_docs.remove(doc_id);
476            }
477        }
478
479        self.storage
480            .write()
481            .await
482            .delete_nodes(&doc_node_ids)
483            .await?;
484
485        Ok(())
486    }
487
488    async fn clear(&self) -> RragResult<()> {
489        *self.graph.write().await = KnowledgeGraph::new();
490        self.entity_document_map.write().await.clear();
491        *self.pagerank_cache.write().await = None;
492        self.storage.write().await.clear().await?;
493        Ok(())
494    }
495
496    async fn stats(&self) -> RragResult<IndexStats> {
497        let storage_stats = self.storage.read().await.get_stats().await?;
498        let graph = self.graph.read().await;
499        let _graph_metrics = graph.calculate_metrics();
500
501        Ok(IndexStats {
502            total_items: storage_stats.total_nodes,
503            size_bytes: storage_stats.storage_size_bytes,
504            dimensions: 0, // Graph doesn't have fixed dimensions
505            index_type: "graph_based".to_string(),
506            last_updated: storage_stats.last_updated,
507        })
508    }
509
510    async fn health_check(&self) -> RragResult<bool> {
511        // Check graph consistency and storage health
512        let graph = self.graph.read().await;
513        Ok(!graph.nodes.is_empty() || self.storage.read().await.get_stats().await.is_ok())
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use crate::graph_retrieval::{storage::InMemoryGraphStorage, EdgeType, GraphEdge, NodeType};
521
522    #[tokio::test]
523    async fn test_graph_retriever_creation() {
524        let graph = KnowledgeGraph::new();
525        let storage = Box::new(InMemoryGraphStorage::new());
526        let config = GraphRetrievalConfig::default();
527
528        let retriever = GraphRetriever::new(graph, storage, config).unwrap();
529        assert_eq!(retriever.name(), "graph_retriever");
530    }
531
532    #[tokio::test]
533    async fn test_query_expansion() {
534        let mut graph = KnowledgeGraph::new();
535
536        // Create test graph with related entities
537        let node1 = GraphNode::new("machine learning", NodeType::Concept);
538        let node2 = GraphNode::new("artificial intelligence", NodeType::Concept);
539        let node1_id = node1.id.clone();
540        let node2_id = node2.id.clone();
541
542        graph.add_node(node1).unwrap();
543        graph.add_node(node2).unwrap();
544
545        graph
546            .add_edge(
547                GraphEdge::new(
548                    node1_id.clone(),
549                    node2_id.clone(),
550                    "part_of",
551                    EdgeType::Semantic("part_of".to_string()),
552                )
553                .with_confidence(0.8),
554            )
555            .unwrap();
556
557        let storage = Box::new(InMemoryGraphStorage::new());
558        let config = GraphRetrievalConfig::default();
559
560        let retriever = GraphRetriever::new(graph, storage, config).unwrap();
561
562        // Test query expansion
563        let expanded = retriever.expand_query("machine learning").await.unwrap();
564        assert!(!expanded.is_empty());
565        assert!(expanded.contains(&"machine learning".to_string()));
566    }
567
568    #[tokio::test]
569    async fn test_find_query_entities() {
570        let mut graph = KnowledgeGraph::new();
571
572        let node = GraphNode::new("neural networks", NodeType::Concept);
573        let node_id = node.id.clone();
574        graph.add_node(node).unwrap();
575
576        let storage = Box::new(InMemoryGraphStorage::new());
577        let config = GraphRetrievalConfig::default();
578
579        let retriever = GraphRetriever::new(graph, storage, config).unwrap();
580
581        let entities = retriever
582            .find_query_entities("neural networks deep learning")
583            .await;
584        assert!(!entities.is_empty());
585        assert!(entities.contains(&node_id));
586    }
587
588    #[tokio::test]
589    async fn test_health_check() {
590        let graph = KnowledgeGraph::new();
591        let storage = Box::new(InMemoryGraphStorage::new());
592        let config = GraphRetrievalConfig::default();
593
594        let retriever = GraphRetriever::new(graph, storage, config).unwrap();
595        let is_healthy = retriever.health_check().await.unwrap();
596        assert!(is_healthy);
597    }
598}