rexis_rag/graph_retrieval/
retriever.rs1use super::{
6 algorithms::GraphAlgorithms,
7 query_expansion::{ExpansionOptions, ExpansionStrategy, GraphQueryExpander, QueryExpander},
8 storage::GraphStorage,
9 GraphNode, KnowledgeGraph,
10};
11use crate::{
12 retrieval_core::{IndexStats, QueryType},
13 Document, DocumentChunk, Embedding, Retriever, RragResult, SearchQuery, SearchResult,
14};
15use async_trait::async_trait;
16use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, HashSet};
18
19pub struct GraphRetriever {
21 graph: tokio::sync::RwLock<KnowledgeGraph>,
23
24 storage: tokio::sync::RwLock<Box<dyn GraphStorage>>,
26
27 query_expander: tokio::sync::RwLock<GraphQueryExpander>,
29
30 config: GraphRetrievalConfig,
32
33 pagerank_cache: tokio::sync::RwLock<Option<HashMap<String, f32>>>,
35
36 entity_document_map: tokio::sync::RwLock<HashMap<String, HashSet<String>>>,
38}
39
40#[derive(Debug, Clone)]
42pub struct GraphRetrievalConfig {
43 pub enable_query_expansion: bool,
45
46 pub enable_pagerank_scoring: bool,
48
49 pub enable_path_based_retrieval: bool,
51
52 pub graph_weight: f32,
54
55 pub similarity_weight: f32,
57
58 pub max_graph_hops: usize,
60
61 pub min_graph_score: f32,
63
64 pub expansion_options: ExpansionOptions,
66
67 pub pagerank_config: super::algorithms::PageRankConfig,
69
70 pub enable_diversification: bool,
72
73 pub diversification_factor: f32,
75}
76
77impl Default for GraphRetrievalConfig {
78 fn default() -> Self {
79 Self {
80 enable_query_expansion: true,
81 enable_pagerank_scoring: true,
82 enable_path_based_retrieval: true,
83 graph_weight: 0.4,
84 similarity_weight: 0.6,
85 max_graph_hops: 3,
86 min_graph_score: 0.1,
87 expansion_options: ExpansionOptions {
88 strategies: vec![
89 ExpansionStrategy::Semantic,
90 ExpansionStrategy::Similarity,
91 ExpansionStrategy::CoOccurrence,
92 ],
93 max_terms: Some(10),
94 min_confidence: 0.3,
95 ..Default::default()
96 },
97 pagerank_config: super::algorithms::PageRankConfig::default(),
98 enable_diversification: true,
99 diversification_factor: 0.3,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct GraphSearchResult {
107 pub search_result: SearchResult,
109
110 pub graph_score: f32,
112
113 pub pagerank_score: f32,
115
116 pub related_entities: Vec<String>,
118
119 pub graph_paths: Vec<GraphPath>,
121
122 pub matched_expansions: Vec<String>,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct GraphPath {
129 pub nodes: Vec<String>,
131
132 pub score: f32,
134
135 pub path_type: String,
137
138 pub length: usize,
140}
141
142impl GraphRetriever {
143 pub fn new(
145 graph: KnowledgeGraph,
146 storage: Box<dyn GraphStorage>,
147 config: GraphRetrievalConfig,
148 ) -> RragResult<Self> {
149 let query_expander = GraphQueryExpander::new(
150 graph.clone(),
151 super::query_expansion::ExpansionConfig::default(),
152 );
153
154 let mut entity_document_map = HashMap::new();
155
156 for (_, node) in &graph.nodes {
158 for doc_id in &node.source_documents {
159 entity_document_map
160 .entry(node.id.clone())
161 .or_insert_with(HashSet::new)
162 .insert(doc_id.clone());
163 }
164 }
165
166 let retriever = Self {
167 graph: tokio::sync::RwLock::new(graph),
168 storage: tokio::sync::RwLock::new(storage),
169 query_expander: tokio::sync::RwLock::new(query_expander),
170 config,
171 pagerank_cache: tokio::sync::RwLock::new(None),
172 entity_document_map: tokio::sync::RwLock::new(entity_document_map),
173 };
174
175 Ok(retriever)
176 }
177
178 pub async fn update_graph(&self, graph: KnowledgeGraph) -> RragResult<()> {
180 *self.graph.write().await = graph.clone();
181 self.query_expander
182 .write()
183 .await
184 .update_graph(graph.clone())
185 .await;
186
187 let mut entity_map = self.entity_document_map.write().await;
189 entity_map.clear();
190 for (_, node) in &graph.nodes {
191 for doc_id in &node.source_documents {
192 entity_map
193 .entry(node.id.clone())
194 .or_insert_with(HashSet::new)
195 .insert(doc_id.clone());
196 }
197 }
198
199 *self.pagerank_cache.write().await = None;
201
202 self.storage.write().await.store_graph(&graph).await?;
204
205 Ok(())
206 }
207
208 async fn get_pagerank_scores(&self) -> RragResult<HashMap<String, f32>> {
210 let mut cache = self.pagerank_cache.write().await;
211
212 if cache.is_none() {
213 let graph = self.graph.read().await;
214 let scores = GraphAlgorithms::pagerank(&*graph, &self.config.pagerank_config)?;
215 *cache = Some(scores);
216 }
217
218 Ok(cache.clone().unwrap())
219 }
220
221 async fn expand_query(&self, query: &str) -> RragResult<Vec<String>> {
223 if !self.config.enable_query_expansion {
224 return Ok(vec![query.to_string()]);
225 }
226
227 let expansion_result = self
228 .query_expander
229 .read()
230 .await
231 .expand_query(query, &self.config.expansion_options)
232 .await?;
233
234 let mut terms = vec![query.to_string()];
235 terms.extend(expansion_result.expanded_terms.into_iter().map(|t| t.term));
236
237 Ok(terms)
238 }
239
240 async fn find_query_entities(&self, query: &str) -> Vec<String> {
242 let query_lower = query.to_lowercase();
243 let mut entities = Vec::new();
244
245 let graph = self.graph.read().await;
246
247 for (entity_id, node) in &graph.nodes {
249 let label_lower = node.label.to_lowercase();
250 if query_lower.contains(&label_lower) || label_lower.contains(&query_lower) {
251 entities.push(entity_id.clone());
252 }
253 }
254
255 entities
256 }
257
258 pub async fn add_document_with_entities(
260 &self,
261 document: &Document,
262 entities: Vec<GraphNode>,
263 relationships: Vec<super::GraphEdge>,
264 ) -> RragResult<()> {
265 let mut graph = self.graph.write().await;
266
267 let doc_node = GraphNode::new(format!("doc_{}", document.id), super::NodeType::Document)
269 .with_source_document(document.id.clone())
270 .with_attribute(
271 "title",
272 serde_json::Value::String(
273 document
274 .metadata
275 .get("title")
276 .and_then(|v| v.as_str())
277 .unwrap_or(&document.id)
278 .to_string(),
279 ),
280 );
281
282 graph.add_node(doc_node.clone())?;
283
284 for entity in entities {
286 let entity_id = entity.id.clone();
287 graph.add_node(entity)?;
288
289 let containment_edge = super::GraphEdge::new(
291 doc_node.id.clone(),
292 entity_id.clone(),
293 "contains",
294 super::EdgeType::Contains,
295 );
296 graph.add_edge(containment_edge)?;
297
298 self.entity_document_map
300 .write()
301 .await
302 .entry(entity_id)
303 .or_insert_with(HashSet::new)
304 .insert(document.id.clone());
305 }
306
307 for relationship in relationships {
309 graph.add_edge(relationship)?;
310 }
311
312 *self.pagerank_cache.write().await = None;
314
315 self.storage.write().await.store_graph(&*graph).await?;
317
318 Ok(())
319 }
320}
321
322#[async_trait]
323impl Retriever for GraphRetriever {
324 fn name(&self) -> &str {
325 "graph_retriever"
326 }
327
328 async fn search(&self, query: &SearchQuery) -> RragResult<Vec<SearchResult>> {
329 let query_text = match &query.query {
330 QueryType::Text(text) => text,
331 QueryType::Embedding(_) => {
332 return Ok(Vec::new());
335 }
336 };
337
338 let expanded_terms = self.expand_query(query_text).await?;
340 let expanded_query = expanded_terms.join(" ");
341
342 let query_entities = self.find_query_entities(&expanded_query).await;
344
345 let mut results = Vec::new();
347
348 let entity_map = self.entity_document_map.read().await;
349 let pagerank_scores = if self.config.enable_pagerank_scoring {
350 self.get_pagerank_scores().await?
351 } else {
352 HashMap::new()
353 };
354
355 let mut candidate_docs = HashSet::new();
357 for entity_id in &query_entities {
358 if let Some(doc_ids) = entity_map.get(entity_id) {
359 candidate_docs.extend(doc_ids.clone());
360 }
361 }
362
363 for (rank, doc_id) in candidate_docs.iter().enumerate() {
365 let mut graph_score = 0.5; for entity_id in &query_entities {
370 if let Some(doc_ids) = entity_map.get(entity_id) {
371 if doc_ids.contains(doc_id) {
372 let pagerank_score = pagerank_scores.get(entity_id).copied().unwrap_or(0.0);
373 graph_score += pagerank_score * 0.3;
374 }
375 }
376 }
377
378 if graph_score >= self.config.min_graph_score {
379 let result = SearchResult {
380 id: doc_id.clone(),
381 content: format!("Document {}", doc_id), score: graph_score,
383 rank,
384 metadata: {
385 let mut metadata = HashMap::new();
386 metadata.insert("graph_score".to_string(), serde_json::json!(graph_score));
387 metadata
388 },
389 embedding: None,
390 };
391
392 results.push(result);
393 }
394 }
395
396 results.sort_by(|a, b| {
398 b.score
399 .partial_cmp(&a.score)
400 .unwrap_or(std::cmp::Ordering::Equal)
401 });
402 results.retain(|result| result.score >= query.min_score);
403 results.truncate(query.limit);
404
405 for (i, result) in results.iter_mut().enumerate() {
407 result.rank = i;
408 }
409
410 Ok(results)
411 }
412
413 async fn add_documents(&self, documents: &[(Document, Embedding)]) -> RragResult<()> {
414 let mut graph = self.graph.write().await;
421 let mut nodes = Vec::new();
422
423 for (document, _embedding) in documents {
424 let doc_node =
425 GraphNode::new(format!("doc_{}", document.id), super::NodeType::Document)
426 .with_source_document(document.id.clone());
427
428 nodes.push(doc_node.clone());
429 graph.add_node(doc_node)?;
430 }
431
432 self.storage.write().await.store_nodes(&nodes).await?;
433
434 Ok(())
435 }
436
437 async fn add_chunks(&self, chunks: &[(DocumentChunk, Embedding)]) -> RragResult<()> {
438 let mut graph = self.graph.write().await;
440 let mut nodes = Vec::new();
441
442 for (chunk, _embedding) in chunks {
443 let chunk_node = GraphNode::new(
444 format!("chunk_{}_{}", chunk.document_id, chunk.chunk_index),
445 super::NodeType::DocumentChunk,
446 )
447 .with_source_document(chunk.document_id.clone());
448
449 nodes.push(chunk_node.clone());
450 graph.add_node(chunk_node)?;
451 }
452
453 self.storage.write().await.store_nodes(&nodes).await?;
454
455 Ok(())
456 }
457
458 async fn remove_documents(&self, document_ids: &[String]) -> RragResult<()> {
459 let mut graph = self.graph.write().await;
460
461 let doc_node_ids: Vec<_> = document_ids
463 .iter()
464 .map(|doc_id| format!("doc_{}", doc_id))
465 .collect();
466
467 for node_id in &doc_node_ids {
468 graph.remove_node(node_id)?;
469 }
470
471 let mut entity_map = self.entity_document_map.write().await;
473 for doc_id in document_ids {
474 for entity_docs in entity_map.values_mut() {
475 entity_docs.remove(doc_id);
476 }
477 }
478
479 self.storage
480 .write()
481 .await
482 .delete_nodes(&doc_node_ids)
483 .await?;
484
485 Ok(())
486 }
487
488 async fn clear(&self) -> RragResult<()> {
489 *self.graph.write().await = KnowledgeGraph::new();
490 self.entity_document_map.write().await.clear();
491 *self.pagerank_cache.write().await = None;
492 self.storage.write().await.clear().await?;
493 Ok(())
494 }
495
496 async fn stats(&self) -> RragResult<IndexStats> {
497 let storage_stats = self.storage.read().await.get_stats().await?;
498 let graph = self.graph.read().await;
499 let _graph_metrics = graph.calculate_metrics();
500
501 Ok(IndexStats {
502 total_items: storage_stats.total_nodes,
503 size_bytes: storage_stats.storage_size_bytes,
504 dimensions: 0, index_type: "graph_based".to_string(),
506 last_updated: storage_stats.last_updated,
507 })
508 }
509
510 async fn health_check(&self) -> RragResult<bool> {
511 let graph = self.graph.read().await;
513 Ok(!graph.nodes.is_empty() || self.storage.read().await.get_stats().await.is_ok())
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520 use crate::graph_retrieval::{storage::InMemoryGraphStorage, EdgeType, GraphEdge, NodeType};
521
522 #[tokio::test]
523 async fn test_graph_retriever_creation() {
524 let graph = KnowledgeGraph::new();
525 let storage = Box::new(InMemoryGraphStorage::new());
526 let config = GraphRetrievalConfig::default();
527
528 let retriever = GraphRetriever::new(graph, storage, config).unwrap();
529 assert_eq!(retriever.name(), "graph_retriever");
530 }
531
532 #[tokio::test]
533 async fn test_query_expansion() {
534 let mut graph = KnowledgeGraph::new();
535
536 let node1 = GraphNode::new("machine learning", NodeType::Concept);
538 let node2 = GraphNode::new("artificial intelligence", NodeType::Concept);
539 let node1_id = node1.id.clone();
540 let node2_id = node2.id.clone();
541
542 graph.add_node(node1).unwrap();
543 graph.add_node(node2).unwrap();
544
545 graph
546 .add_edge(
547 GraphEdge::new(
548 node1_id.clone(),
549 node2_id.clone(),
550 "part_of",
551 EdgeType::Semantic("part_of".to_string()),
552 )
553 .with_confidence(0.8),
554 )
555 .unwrap();
556
557 let storage = Box::new(InMemoryGraphStorage::new());
558 let config = GraphRetrievalConfig::default();
559
560 let retriever = GraphRetriever::new(graph, storage, config).unwrap();
561
562 let expanded = retriever.expand_query("machine learning").await.unwrap();
564 assert!(!expanded.is_empty());
565 assert!(expanded.contains(&"machine learning".to_string()));
566 }
567
568 #[tokio::test]
569 async fn test_find_query_entities() {
570 let mut graph = KnowledgeGraph::new();
571
572 let node = GraphNode::new("neural networks", NodeType::Concept);
573 let node_id = node.id.clone();
574 graph.add_node(node).unwrap();
575
576 let storage = Box::new(InMemoryGraphStorage::new());
577 let config = GraphRetrievalConfig::default();
578
579 let retriever = GraphRetriever::new(graph, storage, config).unwrap();
580
581 let entities = retriever
582 .find_query_entities("neural networks deep learning")
583 .await;
584 assert!(!entities.is_empty());
585 assert!(entities.contains(&node_id));
586 }
587
588 #[tokio::test]
589 async fn test_health_check() {
590 let graph = KnowledgeGraph::new();
591 let storage = Box::new(InMemoryGraphStorage::new());
592 let config = GraphRetrievalConfig::default();
593
594 let retriever = GraphRetriever::new(graph, storage, config).unwrap();
595 let is_healthy = retriever.health_check().await.unwrap();
596 assert!(is_healthy);
597 }
598}