Skip to main content

graphrag_core/retrieval/
hybrid.rs

1use crate::{
2    core::KnowledgeGraph,
3    retrieval::{
4        bm25::{BM25Result, BM25Retriever},
5        ResultType,
6    },
7    vector::{EmbeddingGenerator, VectorIndex},
8    GraphRAGError, Result,
9};
10use std::collections::HashMap;
11
12/// Hybrid search result combining multiple retrieval strategies
13#[derive(Debug, Clone)]
14pub struct HybridSearchResult {
15    /// Unique identifier for this result
16    pub id: String,
17    /// Text content of the result
18    pub content: String,
19    /// Combined final score from fusion
20    pub score: f32,
21    /// Score from semantic similarity search
22    pub semantic_score: f32,
23    /// Score from keyword-based search
24    pub keyword_score: f32,
25    /// Type of result (entity, chunk, hybrid)
26    pub result_type: ResultType,
27    /// Entities mentioned in this result
28    pub entities: Vec<String>,
29    /// Source chunk identifiers for this result
30    pub source_chunks: Vec<String>,
31    /// Fusion method used to combine scores
32    pub fusion_method: FusionMethod,
33}
34
35/// Method used to combine scores
36#[derive(Debug, Clone, PartialEq)]
37#[allow(clippy::upper_case_acronyms)]
38pub enum FusionMethod {
39    /// Reciprocal Rank Fusion
40    RRF,
41    /// Weighted combination
42    Weighted,
43    /// CombSUM scoring
44    CombSum,
45    /// Maximum score
46    MaxScore,
47}
48
49/// Configuration for hybrid retrieval
50#[derive(Debug, Clone)]
51pub struct HybridConfig {
52    /// Weight for semantic search results (0.0 to 1.0)
53    pub semantic_weight: f32,
54    /// Weight for keyword search results (0.0 to 1.0)
55    pub keyword_weight: f32,
56    /// Fusion method to combine results
57    pub fusion_method: FusionMethod,
58    /// RRF parameter (used when fusion_method is RRF)
59    pub rrf_k: f32,
60    /// Maximum results to retrieve from each method before fusion
61    pub max_candidates: usize,
62    /// Minimum score threshold for final results
63    pub min_score_threshold: f32,
64}
65
66impl Default for HybridConfig {
67    fn default() -> Self {
68        Self {
69            semantic_weight: 0.7,
70            keyword_weight: 0.3,
71            fusion_method: FusionMethod::RRF,
72            rrf_k: 60.0,
73            max_candidates: 100,
74            min_score_threshold: 0.1,
75        }
76    }
77}
78
79/// Hybrid retriever that combines semantic and keyword search
80pub struct HybridRetriever {
81    /// Vector-based retrieval system
82    vector_index: VectorIndex,
83    /// Embedding generator
84    embedding_generator: EmbeddingGenerator,
85    /// BM25-based keyword retrieval
86    bm25_retriever: BM25Retriever,
87    /// Configuration for hybrid retrieval
88    config: HybridConfig,
89    /// Flag indicating whether the system is initialized
90    initialized: bool,
91}
92
93impl HybridRetriever {
94    /// Create a new hybrid retriever with default configuration
95    pub fn new() -> Self {
96        Self {
97            vector_index: VectorIndex::new(),
98            embedding_generator: EmbeddingGenerator::new(128),
99            bm25_retriever: BM25Retriever::new(),
100            config: HybridConfig::default(),
101            initialized: false,
102        }
103    }
104
105    /// Create a new hybrid retriever with custom configuration
106    pub fn with_config(config: HybridConfig) -> Self {
107        Self {
108            vector_index: VectorIndex::new(),
109            embedding_generator: EmbeddingGenerator::new(128),
110            bm25_retriever: BM25Retriever::new(),
111            config,
112            initialized: false,
113        }
114    }
115
116    /// Initialize the hybrid retriever with a knowledge graph
117    pub fn initialize_with_graph(&mut self, graph: &KnowledgeGraph) -> Result<()> {
118        // Index entities and chunks for vector search
119        for entity in graph.entities() {
120            if let Some(embedding) = &entity.embedding {
121                let id = format!("entity:{}", entity.id);
122                self.vector_index.add_vector(id, embedding.clone())?;
123            }
124        }
125
126        for chunk in graph.chunks() {
127            if let Some(embedding) = &chunk.embedding {
128                let id = format!("chunk:{}", chunk.id);
129                self.vector_index.add_vector(id, embedding.clone())?;
130            }
131        }
132
133        // Build vector index
134        if !self.vector_index.is_empty() {
135            self.vector_index.build_index()?;
136        }
137
138        // Index documents for BM25 search
139        let mut bm25_documents = Vec::new();
140
141        // Add entities as documents
142        for entity in graph.entities() {
143            let doc = crate::retrieval::bm25::Document {
144                id: format!("entity:{}", entity.id),
145                content: format!("{} {}", entity.name, entity.entity_type),
146                metadata: HashMap::new(),
147            };
148            bm25_documents.push(doc);
149        }
150
151        // Add chunks as documents
152        for chunk in graph.chunks() {
153            let doc = crate::retrieval::bm25::Document {
154                id: format!("chunk:{}", chunk.id),
155                content: chunk.content.clone(),
156                metadata: HashMap::new(),
157            };
158            bm25_documents.push(doc);
159        }
160
161        self.bm25_retriever.index_documents(&bm25_documents)?;
162        self.initialized = true;
163
164        Ok(())
165    }
166
167    /// Perform hybrid search combining semantic and keyword retrieval
168    pub fn search(&mut self, query: &str, limit: usize) -> Result<Vec<HybridSearchResult>> {
169        if !self.initialized {
170            return Err(GraphRAGError::Retrieval {
171                message: "Hybrid retriever not initialized. Call initialize_with_graph() first."
172                    .to_string(),
173            });
174        }
175
176        // Get semantic results
177        let semantic_results = self.semantic_search(query, self.config.max_candidates)?;
178
179        // Get keyword results
180        let keyword_results = self.keyword_search(query, self.config.max_candidates);
181
182        // Combine results using configured fusion method
183        let combined_results = self.combine_results(semantic_results, keyword_results, limit)?;
184
185        Ok(combined_results)
186    }
187
188    /// Perform semantic search using vector similarity
189    fn semantic_search(&mut self, query: &str, limit: usize) -> Result<Vec<(String, f32, String)>> {
190        let query_embedding = self.embedding_generator.generate_embedding(query);
191        let similar_vectors = self.vector_index.search(&query_embedding, limit)?;
192
193        let mut results = Vec::new();
194        for (id, score) in similar_vectors {
195            // For now, use the ID as content - in a real implementation,
196            // you would fetch the actual content from the knowledge graph
197            results.push((id.clone(), score, id));
198        }
199
200        Ok(results)
201    }
202
203    /// Perform keyword search using BM25
204    fn keyword_search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
205        self.bm25_retriever.search(query, limit)
206    }
207
208    /// Combine semantic and keyword results using the configured fusion method
209    fn combine_results(
210        &mut self,
211        semantic_results: Vec<(String, f32, String)>,
212        keyword_results: Vec<BM25Result>,
213        limit: usize,
214    ) -> Result<Vec<HybridSearchResult>> {
215        match self.config.fusion_method {
216            FusionMethod::RRF => {
217                self.reciprocal_rank_fusion(semantic_results, keyword_results, limit)
218            }
219            FusionMethod::Weighted => {
220                self.weighted_combination(semantic_results, keyword_results, limit)
221            }
222            FusionMethod::CombSum => self.comb_sum_fusion(semantic_results, keyword_results, limit),
223            FusionMethod::MaxScore => {
224                self.max_score_fusion(semantic_results, keyword_results, limit)
225            }
226        }
227    }
228
229    /// Reciprocal Rank Fusion (RRF)
230    fn reciprocal_rank_fusion(
231        &mut self,
232        semantic_results: Vec<(String, f32, String)>,
233        keyword_results: Vec<BM25Result>,
234        limit: usize,
235    ) -> Result<Vec<HybridSearchResult>> {
236        let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
237        let mut content_map: HashMap<String, String> = HashMap::new();
238
239        // Process semantic results
240        for (rank, (id, score, content)) in semantic_results.iter().enumerate() {
241            let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
242            combined_scores.insert(
243                id.clone(),
244                (rrf_score * self.config.semantic_weight, *score, 0.0),
245            );
246            content_map.insert(id.clone(), content.clone());
247        }
248
249        // Process keyword results
250        for (rank, result) in keyword_results.iter().enumerate() {
251            let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
252            let entry = combined_scores
253                .entry(result.doc_id.clone())
254                .or_insert((0.0, 0.0, 0.0));
255            entry.0 += rrf_score * self.config.keyword_weight;
256            entry.2 = result.score;
257            content_map.insert(result.doc_id.clone(), result.content.clone());
258        }
259
260        self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::RRF)
261    }
262
263    /// Weighted combination of scores
264    fn weighted_combination(
265        &mut self,
266        semantic_results: Vec<(String, f32, String)>,
267        keyword_results: Vec<BM25Result>,
268        limit: usize,
269    ) -> Result<Vec<HybridSearchResult>> {
270        let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
271        let mut content_map: HashMap<String, String> = HashMap::new();
272
273        // Normalize semantic scores
274        let max_semantic = semantic_results
275            .iter()
276            .map(|(_, score, _)| *score)
277            .fold(f32::NEG_INFINITY, f32::max);
278
279        for (id, score, content) in semantic_results {
280            let normalized_score = if max_semantic > 0.0 {
281                score / max_semantic
282            } else {
283                0.0
284            };
285            combined_scores.insert(
286                id.clone(),
287                (normalized_score * self.config.semantic_weight, score, 0.0),
288            );
289            content_map.insert(id, content);
290        }
291
292        // Normalize keyword scores
293        let max_keyword = keyword_results
294            .iter()
295            .map(|r| r.score)
296            .fold(f32::NEG_INFINITY, f32::max);
297
298        for result in keyword_results {
299            let normalized_score = if max_keyword > 0.0 {
300                result.score / max_keyword
301            } else {
302                0.0
303            };
304            let entry = combined_scores
305                .entry(result.doc_id.clone())
306                .or_insert((0.0, 0.0, 0.0));
307            entry.0 += normalized_score * self.config.keyword_weight;
308            entry.2 = result.score;
309            content_map.insert(result.doc_id.clone(), result.content.clone());
310        }
311
312        self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::Weighted)
313    }
314
315    /// CombSUM fusion (simple addition of normalized scores)
316    fn comb_sum_fusion(
317        &mut self,
318        semantic_results: Vec<(String, f32, String)>,
319        keyword_results: Vec<BM25Result>,
320        limit: usize,
321    ) -> Result<Vec<HybridSearchResult>> {
322        let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
323        let mut content_map: HashMap<String, String> = HashMap::new();
324
325        // Process semantic results
326        for (id, score, content) in semantic_results {
327            combined_scores.insert(id.clone(), (score, score, 0.0));
328            content_map.insert(id, content);
329        }
330
331        // Process keyword results
332        for result in keyword_results {
333            let entry = combined_scores
334                .entry(result.doc_id.clone())
335                .or_insert((0.0, 0.0, 0.0));
336            entry.0 += result.score;
337            entry.2 = result.score;
338            content_map.insert(result.doc_id.clone(), result.content.clone());
339        }
340
341        self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::CombSum)
342    }
343
344    /// Maximum score fusion
345    fn max_score_fusion(
346        &mut self,
347        semantic_results: Vec<(String, f32, String)>,
348        keyword_results: Vec<BM25Result>,
349        limit: usize,
350    ) -> Result<Vec<HybridSearchResult>> {
351        let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
352        let mut content_map: HashMap<String, String> = HashMap::new();
353
354        // Process semantic results
355        for (id, score, content) in semantic_results {
356            combined_scores.insert(id.clone(), (score, score, 0.0));
357            content_map.insert(id, content);
358        }
359
360        // Process keyword results
361        for result in keyword_results {
362            let entry = combined_scores
363                .entry(result.doc_id.clone())
364                .or_insert((0.0, 0.0, 0.0));
365            entry.0 = entry.0.max(result.score);
366            entry.2 = result.score;
367            content_map.insert(result.doc_id.clone(), result.content.clone());
368        }
369
370        self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::MaxScore)
371    }
372
373    /// Create hybrid results from combined scores
374    fn create_hybrid_results(
375        &self,
376        combined_scores: HashMap<String, (f32, f32, f32)>,
377        content_map: HashMap<String, String>,
378        limit: usize,
379        fusion_method: FusionMethod,
380    ) -> Result<Vec<HybridSearchResult>> {
381        let mut results: Vec<HybridSearchResult> = combined_scores
382            .into_iter()
383            .filter_map(|(id, (combined_score, semantic_score, keyword_score))| {
384                if combined_score >= self.config.min_score_threshold {
385                    let content = content_map.get(&id).cloned().unwrap_or_else(|| id.clone());
386
387                    // Determine result type based on ID prefix
388                    let result_type = if id.starts_with("entity:") {
389                        ResultType::Entity
390                    } else if id.starts_with("chunk:") {
391                        ResultType::Chunk
392                    } else {
393                        ResultType::Hybrid
394                    };
395
396                    // Extract entities (simplified)
397                    let entities = if result_type == ResultType::Entity {
398                        vec![content.clone()]
399                    } else {
400                        Vec::new()
401                    };
402
403                    Some(HybridSearchResult {
404                        id: id.clone(),
405                        content,
406                        score: combined_score,
407                        semantic_score,
408                        keyword_score,
409                        result_type,
410                        entities,
411                        source_chunks: vec![id],
412                        fusion_method: fusion_method.clone(),
413                    })
414                } else {
415                    None
416                }
417            })
418            .collect();
419
420        // Sort by combined score
421        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
422        results.truncate(limit);
423
424        Ok(results)
425    }
426
427    /// Get configuration
428    pub fn get_config(&self) -> &HybridConfig {
429        &self.config
430    }
431
432    /// Update configuration
433    pub fn set_config(&mut self, config: HybridConfig) {
434        self.config = config;
435    }
436
437    /// Check if the retriever is initialized
438    pub fn is_initialized(&self) -> bool {
439        self.initialized
440    }
441
442    /// Get statistics about the hybrid retriever
443    pub fn get_statistics(&self) -> HybridStatistics {
444        let vector_stats = self.vector_index.statistics();
445        let bm25_stats = self.bm25_retriever.get_statistics();
446
447        HybridStatistics {
448            vector_count: vector_stats.vector_count,
449            bm25_document_count: bm25_stats.total_documents,
450            bm25_term_count: bm25_stats.total_terms,
451            config: self.config.clone(),
452            initialized: self.initialized,
453        }
454    }
455
456    /// Clear all indexed data
457    pub fn clear(&mut self) {
458        self.vector_index = VectorIndex::new();
459        self.bm25_retriever.clear();
460        self.initialized = false;
461    }
462}
463
464impl Default for HybridRetriever {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470/// Statistics about the hybrid retriever
471#[derive(Debug, Clone)]
472pub struct HybridStatistics {
473    /// Number of vectors in index
474    pub vector_count: usize,
475    /// Number of documents in BM25 index
476    pub bm25_document_count: usize,
477    /// Number of unique terms in BM25 index
478    pub bm25_term_count: usize,
479    /// Configuration settings for hybrid retrieval
480    pub config: HybridConfig,
481    /// Whether the retriever has been initialized
482    pub initialized: bool,
483}
484
485impl HybridStatistics {
486    /// Print statistics
487    pub fn print(&self) {
488        println!("Hybrid Retriever Statistics:");
489        println!("  Initialized: {}", self.initialized);
490        println!("  Vector index: {} vectors", self.vector_count);
491        println!(
492            "  BM25 index: {} documents, {} terms",
493            self.bm25_document_count, self.bm25_term_count
494        );
495        println!("  Fusion method: {:?}", self.config.fusion_method);
496        println!(
497            "  Weights: semantic={:.2}, keyword={:.2}",
498            self.config.semantic_weight, self.config.keyword_weight
499        );
500        println!("  Score threshold: {:.3}", self.config.min_score_threshold);
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use crate::core::KnowledgeGraph;
508
509    #[test]
510    fn test_hybrid_retriever_creation() {
511        let retriever = HybridRetriever::new();
512        assert!(!retriever.is_initialized());
513    }
514
515    #[test]
516    fn test_hybrid_config_default() {
517        let config = HybridConfig::default();
518        assert_eq!(config.semantic_weight, 0.7);
519        assert_eq!(config.keyword_weight, 0.3);
520        assert_eq!(config.fusion_method, FusionMethod::RRF);
521    }
522
523    #[test]
524    fn test_fusion_method_variants() {
525        assert_eq!(FusionMethod::RRF, FusionMethod::RRF);
526        assert_ne!(FusionMethod::RRF, FusionMethod::Weighted);
527    }
528
529    #[test]
530    fn test_hybrid_retriever_with_empty_graph() {
531        let mut retriever = HybridRetriever::new();
532        let graph = KnowledgeGraph::new();
533
534        let result = retriever.initialize_with_graph(&graph);
535        assert!(result.is_ok());
536        assert!(retriever.is_initialized());
537    }
538
539    #[test]
540    fn test_search_without_initialization() {
541        let mut retriever = HybridRetriever::new();
542        let result = retriever.search("test", 10);
543        assert!(result.is_err());
544    }
545}