rexis_rag/retrieval/
hybrid.rs

1//! # Hybrid Retrieval System
2//!
3//! Combines semantic and keyword-based retrieval for optimal performance.
4//! Implements multiple fusion strategies and adaptive weighting.
5
6use super::{
7    BM25Config, BM25Retriever, RankFusion, ReciprocalRankFusion, SemanticConfig, SemanticRetriever,
8    WeightedFusion,
9};
10use crate::{Document, EmbeddingProvider, RragResult, SearchResult};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::RwLock;
15
16/// Hybrid retriever configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HybridConfig {
19    /// BM25 configuration
20    pub bm25_config: BM25Config,
21
22    /// Semantic search configuration
23    pub semantic_config: SemanticConfig,
24
25    /// Fusion strategy to use
26    pub fusion_strategy: FusionStrategy,
27
28    /// Whether to use adaptive weighting
29    pub adaptive_weights: bool,
30
31    /// Initial weight for semantic search (0.0 to 1.0)
32    pub semantic_weight: f32,
33
34    /// Whether to run retrievers in parallel
35    pub parallel_retrieval: bool,
36
37    /// Minimum confidence score to include results
38    pub min_confidence: f32,
39
40    /// Enable query analysis for better routing
41    pub enable_query_analysis: bool,
42}
43
44impl Default for HybridConfig {
45    fn default() -> Self {
46        Self {
47            bm25_config: BM25Config::default(),
48            semantic_config: SemanticConfig::default(),
49            fusion_strategy: FusionStrategy::ReciprocalRankFusion,
50            adaptive_weights: true,
51            semantic_weight: 0.6,
52            parallel_retrieval: true,
53            min_confidence: 0.0,
54            enable_query_analysis: true,
55        }
56    }
57}
58
59/// Fusion strategies for combining results
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub enum FusionStrategy {
62    /// Reciprocal Rank Fusion
63    ReciprocalRankFusion,
64
65    /// Weighted linear combination
66    WeightedCombination,
67
68    /// Learned fusion with ML model
69    LearnedFusion,
70
71    /// Custom fusion function
72    Custom,
73}
74
75/// Query characteristics for adaptive routing
76#[derive(Debug, Clone)]
77#[allow(dead_code)]
78struct QueryCharacteristics {
79    /// Number of tokens in query
80    num_tokens: usize,
81
82    /// Contains named entities
83    has_entities: bool,
84
85    /// Is a question
86    is_question: bool,
87
88    /// Contains technical terms
89    has_technical_terms: bool,
90
91    /// Query complexity score
92    complexity: f32,
93}
94
95/// Hybrid retriever combining multiple strategies
96pub struct HybridRetriever {
97    /// Configuration
98    config: Arc<HybridConfig>,
99
100    /// BM25 keyword retriever
101    bm25_retriever: Arc<BM25Retriever>,
102
103    /// Semantic vector retriever
104    semantic_retriever: Arc<SemanticRetriever>,
105
106    /// Fusion algorithm
107    fusion: Arc<dyn RankFusion>,
108
109    /// Adaptive weight history
110    weight_history: Arc<RwLock<Vec<(f32, f32)>>>, // (semantic_weight, performance_score)
111
112    /// Query performance metrics
113    query_metrics: Arc<RwLock<Vec<QueryMetrics>>>,
114}
115
116/// Query performance metrics for learning
117#[derive(Debug, Clone)]
118#[allow(dead_code)]
119struct QueryMetrics {
120    query: String,
121    characteristics: QueryCharacteristics,
122    semantic_weight_used: f32,
123    response_time_ms: u64,
124    user_satisfaction: Option<f32>, // Optional user feedback
125}
126
127impl HybridRetriever {
128    /// Create a new hybrid retriever
129    pub fn new(config: HybridConfig, embedding_service: Arc<dyn EmbeddingProvider>) -> Self {
130        let bm25_retriever = Arc::new(BM25Retriever::new(config.bm25_config.clone()));
131        let semantic_retriever = Arc::new(SemanticRetriever::new(
132            config.semantic_config.clone(),
133            embedding_service,
134        ));
135
136        let fusion: Arc<dyn RankFusion> = match &config.fusion_strategy {
137            FusionStrategy::ReciprocalRankFusion => Arc::new(ReciprocalRankFusion::default()),
138            FusionStrategy::WeightedCombination => Arc::new(WeightedFusion::new(vec![
139                1.0 - config.semantic_weight,
140                config.semantic_weight,
141            ])),
142            _ => Arc::new(ReciprocalRankFusion::default()),
143        };
144
145        Self {
146            config: Arc::new(config),
147            bm25_retriever,
148            semantic_retriever,
149            fusion,
150            weight_history: Arc::new(RwLock::new(Vec::new())),
151            query_metrics: Arc::new(RwLock::new(Vec::new())),
152        }
153    }
154
155    /// Index a document in both retrievers
156    pub async fn index_document(&self, doc: &Document) -> RragResult<()> {
157        if self.config.parallel_retrieval {
158            // Index in parallel
159            let (bm25_result, semantic_result) = tokio::join!(
160                self.bm25_retriever.index_document(doc),
161                self.semantic_retriever.index_document(doc)
162            );
163
164            bm25_result?;
165            semantic_result?;
166        } else {
167            // Index sequentially
168            self.bm25_retriever.index_document(doc).await?;
169            self.semantic_retriever.index_document(doc).await?;
170        }
171
172        Ok(())
173    }
174
175    /// Batch index multiple documents
176    pub async fn index_batch(&self, documents: Vec<Document>) -> RragResult<()> {
177        if self.config.parallel_retrieval {
178            let (bm25_result, semantic_result) = tokio::join!(
179                self.bm25_retriever.index_batch(documents.clone()),
180                self.semantic_retriever.index_batch(documents)
181            );
182
183            bm25_result?;
184            semantic_result?;
185        } else {
186            self.bm25_retriever.index_batch(documents.clone()).await?;
187            self.semantic_retriever.index_batch(documents).await?;
188        }
189
190        Ok(())
191    }
192
193    /// Perform hybrid search
194    pub async fn search(&self, query: &str, limit: usize) -> RragResult<Vec<SearchResult>> {
195        let start_time = Instant::now();
196
197        // Analyze query characteristics
198        let characteristics = if self.config.enable_query_analysis {
199            self.analyze_query(query)
200        } else {
201            self.simple_query_analysis(query)
202        };
203
204        // Determine weights based on query characteristics and history
205        let semantic_weight = if self.config.adaptive_weights {
206            self.calculate_adaptive_weight(&characteristics).await
207        } else {
208            self.config.semantic_weight
209        };
210
211        // Perform searches
212        let (bm25_results, semantic_results) = if self.config.parallel_retrieval {
213            tokio::join!(
214                self.bm25_retriever.search(query, limit * 2),
215                self.semantic_retriever
216                    .search(query, limit * 2, Some(self.config.min_confidence))
217            )
218        } else {
219            let bm25 = self.bm25_retriever.search(query, limit * 2).await;
220            let semantic = self
221                .semantic_retriever
222                .search(query, limit * 2, Some(self.config.min_confidence))
223                .await;
224            (bm25, semantic)
225        };
226
227        let bm25_results = bm25_results?;
228        let semantic_results = semantic_results?;
229
230        // Combine results using fusion strategy
231        let fused_results = match self.config.fusion_strategy {
232            FusionStrategy::WeightedCombination => {
233                let fusion = WeightedFusion::new(vec![1.0 - semantic_weight, semantic_weight]);
234                fusion.fuse(vec![bm25_results, semantic_results], limit)?
235            }
236            _ => self
237                .fusion
238                .fuse(vec![bm25_results, semantic_results], limit)?,
239        };
240
241        // Record metrics
242        let elapsed = start_time.elapsed().as_millis() as u64;
243        let metrics = QueryMetrics {
244            query: query.to_string(),
245            characteristics,
246            semantic_weight_used: semantic_weight,
247            response_time_ms: elapsed,
248            user_satisfaction: None,
249        };
250
251        let mut query_metrics = self.query_metrics.write().await;
252        query_metrics.push(metrics);
253
254        Ok(fused_results)
255    }
256
257    /// Advanced search with multiple strategies
258    pub async fn advanced_search(
259        &self,
260        query: &str,
261        limit: usize,
262        strategies: Vec<SearchStrategy>,
263    ) -> RragResult<Vec<SearchResult>> {
264        let mut all_results = Vec::new();
265
266        for strategy in strategies {
267            let results = match strategy {
268                SearchStrategy::ExactMatch => {
269                    // Boost BM25 for exact matches
270                    self.bm25_retriever.search(query, limit).await?
271                }
272                SearchStrategy::Semantic => {
273                    // Pure semantic search
274                    self.semantic_retriever.search(query, limit, None).await?
275                }
276                SearchStrategy::Hybrid => {
277                    // Standard hybrid search
278                    self.search(query, limit).await?
279                }
280                SearchStrategy::QueryExpansion => {
281                    // Expand query with synonyms and search
282                    let expanded = self.expand_query(query);
283                    self.search(&expanded, limit).await?
284                }
285            };
286
287            all_results.push(results);
288        }
289
290        // Fuse all strategy results
291        self.fusion.fuse(all_results, limit)
292    }
293
294    /// Analyze query characteristics
295    fn analyze_query(&self, query: &str) -> QueryCharacteristics {
296        let tokens: Vec<&str> = query.split_whitespace().collect();
297        let num_tokens = tokens.len();
298
299        // Check if it's a question
300        let is_question = query.contains('?')
301            || query.starts_with("what")
302            || query.starts_with("how")
303            || query.starts_with("why")
304            || query.starts_with("when")
305            || query.starts_with("where")
306            || query.starts_with("who");
307
308        // Simple entity detection (could use NER model)
309        let has_entities = tokens
310            .iter()
311            .any(|t| t.chars().next().map_or(false, |c| c.is_uppercase()));
312
313        // Technical term detection (simplified)
314        let technical_terms = [
315            "algorithm",
316            "function",
317            "method",
318            "system",
319            "protocol",
320            "framework",
321        ];
322        let has_technical_terms = tokens
323            .iter()
324            .any(|t| technical_terms.contains(&t.to_lowercase().as_str()));
325
326        // Calculate complexity
327        let complexity = (num_tokens as f32 / 10.0).min(1.0);
328
329        QueryCharacteristics {
330            num_tokens,
331            has_entities,
332            is_question,
333            has_technical_terms,
334            complexity,
335        }
336    }
337
338    /// Simple query analysis without NLP
339    fn simple_query_analysis(&self, query: &str) -> QueryCharacteristics {
340        let num_tokens = query.split_whitespace().count();
341
342        QueryCharacteristics {
343            num_tokens,
344            has_entities: false,
345            is_question: query.contains('?'),
346            has_technical_terms: false,
347            complexity: (num_tokens as f32 / 10.0).min(1.0),
348        }
349    }
350
351    /// Calculate adaptive weight based on query characteristics and history
352    async fn calculate_adaptive_weight(&self, characteristics: &QueryCharacteristics) -> f32 {
353        let mut base_weight = self.config.semantic_weight;
354
355        // Adjust based on query characteristics
356        if characteristics.is_question {
357            base_weight += 0.1; // Questions benefit from semantic understanding
358        }
359
360        if characteristics.has_entities {
361            base_weight -= 0.1; // Named entities benefit from exact matching
362        }
363
364        if characteristics.has_technical_terms {
365            base_weight -= 0.05; // Technical terms often need exact matches
366        }
367
368        // Adjust based on query complexity
369        base_weight += characteristics.complexity * 0.1;
370
371        // Learn from history if available
372        let history = self.weight_history.read().await;
373        if history.len() > 10 {
374            // Simple moving average of successful weights
375            let recent_weights: Vec<f32> = history
376                .iter()
377                .rev()
378                .take(10)
379                .filter(|(_, score)| *score > 0.7)
380                .map(|(weight, _)| *weight)
381                .collect();
382
383            if !recent_weights.is_empty() {
384                let avg_weight: f32 =
385                    recent_weights.iter().sum::<f32>() / recent_weights.len() as f32;
386                base_weight = 0.7 * base_weight + 0.3 * avg_weight;
387            }
388        }
389
390        // Clamp to valid range
391        base_weight.max(0.0).min(1.0)
392    }
393
394    /// Expand query with synonyms and related terms
395    fn expand_query(&self, query: &str) -> String {
396        // Simple query expansion (in production, use WordNet or embeddings)
397        let expansions = vec![
398            ("ML", "machine learning"),
399            ("AI", "artificial intelligence"),
400            ("NLP", "natural language processing"),
401            ("DB", "database"),
402        ];
403
404        let mut expanded = query.to_string();
405        for (abbr, full) in expansions {
406            if query.contains(abbr) && !query.contains(full) {
407                expanded.push_str(&format!(" {}", full));
408            }
409        }
410
411        expanded
412    }
413
414    /// Record user feedback for learning
415    pub async fn record_feedback(&self, query: &str, satisfaction: f32) -> RragResult<()> {
416        let mut metrics = self.query_metrics.write().await;
417
418        // Find the most recent query matching this text
419        if let Some(metric) = metrics.iter_mut().rev().find(|m| m.query == query) {
420            metric.user_satisfaction = Some(satisfaction);
421
422            // Update weight history if satisfied
423            if satisfaction > 0.7 {
424                let mut history = self.weight_history.write().await;
425                history.push((metric.semantic_weight_used, satisfaction));
426
427                // Keep only recent history
428                if history.len() > 100 {
429                    history.drain(0..50);
430                }
431            }
432        }
433
434        Ok(())
435    }
436
437    /// Get retrieval statistics
438    pub async fn stats(&self) -> HybridStats {
439        let bm25_stats = self.bm25_retriever.stats().await;
440        let semantic_stats = self.semantic_retriever.stats().await;
441        let metrics = self.query_metrics.read().await;
442
443        let avg_response_time = if metrics.is_empty() {
444            0
445        } else {
446            metrics.iter().map(|m| m.response_time_ms).sum::<u64>() / metrics.len() as u64
447        };
448
449        HybridStats {
450            bm25_stats,
451            semantic_stats,
452            total_queries: metrics.len(),
453            avg_response_time_ms: avg_response_time,
454            fusion_strategy: format!("{:?}", self.config.fusion_strategy),
455        }
456    }
457}
458
459/// Search strategies for advanced search
460#[derive(Debug, Clone)]
461pub enum SearchStrategy {
462    /// Exact keyword matching
463    ExactMatch,
464    /// Pure semantic search
465    Semantic,
466    /// Hybrid search
467    Hybrid,
468    /// Query expansion with synonyms
469    QueryExpansion,
470}
471
472/// Hybrid retriever statistics
473#[derive(Debug, Serialize)]
474pub struct HybridStats {
475    pub bm25_stats: std::collections::HashMap<String, serde_json::Value>,
476    pub semantic_stats: std::collections::HashMap<String, serde_json::Value>,
477    pub total_queries: usize,
478    pub avg_response_time_ms: u64,
479    pub fusion_strategy: String,
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use crate::embeddings::MockEmbeddingService;
486
487    #[tokio::test]
488    async fn test_hybrid_search() {
489        let config = HybridConfig::default();
490        let embedding_service = Arc::new(MockEmbeddingService::new());
491        let retriever = HybridRetriever::new(config, embedding_service);
492
493        let docs = vec![
494            Document::with_id("1", "The quick brown fox jumps over the lazy dog"),
495            Document::with_id(
496                "2",
497                "Machine learning is a subset of artificial intelligence",
498            ),
499            Document::with_id(
500                "3",
501                "Natural language processing enables text understanding",
502            ),
503        ];
504
505        retriever.index_batch(docs).await.unwrap();
506
507        let results = retriever.search("machine learning AI", 2).await.unwrap();
508        assert!(!results.is_empty());
509
510        // Test adaptive weighting
511        retriever
512            .record_feedback("machine learning AI", 0.9)
513            .await
514            .unwrap();
515    }
516}