1use super::{
7 BM25Config, BM25Retriever, RankFusion, ReciprocalRankFusion, SemanticConfig, SemanticRetriever,
8 WeightedFusion,
9};
10use crate::{Document, EmbeddingProvider, RragResult, SearchResult};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HybridConfig {
19 pub bm25_config: BM25Config,
21
22 pub semantic_config: SemanticConfig,
24
25 pub fusion_strategy: FusionStrategy,
27
28 pub adaptive_weights: bool,
30
31 pub semantic_weight: f32,
33
34 pub parallel_retrieval: bool,
36
37 pub min_confidence: f32,
39
40 pub enable_query_analysis: bool,
42}
43
44impl Default for HybridConfig {
45 fn default() -> Self {
46 Self {
47 bm25_config: BM25Config::default(),
48 semantic_config: SemanticConfig::default(),
49 fusion_strategy: FusionStrategy::ReciprocalRankFusion,
50 adaptive_weights: true,
51 semantic_weight: 0.6,
52 parallel_retrieval: true,
53 min_confidence: 0.0,
54 enable_query_analysis: true,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub enum FusionStrategy {
62 ReciprocalRankFusion,
64
65 WeightedCombination,
67
68 LearnedFusion,
70
71 Custom,
73}
74
75#[derive(Debug, Clone)]
77#[allow(dead_code)]
78struct QueryCharacteristics {
79 num_tokens: usize,
81
82 has_entities: bool,
84
85 is_question: bool,
87
88 has_technical_terms: bool,
90
91 complexity: f32,
93}
94
95pub struct HybridRetriever {
97 config: Arc<HybridConfig>,
99
100 bm25_retriever: Arc<BM25Retriever>,
102
103 semantic_retriever: Arc<SemanticRetriever>,
105
106 fusion: Arc<dyn RankFusion>,
108
109 weight_history: Arc<RwLock<Vec<(f32, f32)>>>, query_metrics: Arc<RwLock<Vec<QueryMetrics>>>,
114}
115
116#[derive(Debug, Clone)]
118#[allow(dead_code)]
119struct QueryMetrics {
120 query: String,
121 characteristics: QueryCharacteristics,
122 semantic_weight_used: f32,
123 response_time_ms: u64,
124 user_satisfaction: Option<f32>, }
126
127impl HybridRetriever {
128 pub fn new(config: HybridConfig, embedding_service: Arc<dyn EmbeddingProvider>) -> Self {
130 let bm25_retriever = Arc::new(BM25Retriever::new(config.bm25_config.clone()));
131 let semantic_retriever = Arc::new(SemanticRetriever::new(
132 config.semantic_config.clone(),
133 embedding_service,
134 ));
135
136 let fusion: Arc<dyn RankFusion> = match &config.fusion_strategy {
137 FusionStrategy::ReciprocalRankFusion => Arc::new(ReciprocalRankFusion::default()),
138 FusionStrategy::WeightedCombination => Arc::new(WeightedFusion::new(vec![
139 1.0 - config.semantic_weight,
140 config.semantic_weight,
141 ])),
142 _ => Arc::new(ReciprocalRankFusion::default()),
143 };
144
145 Self {
146 config: Arc::new(config),
147 bm25_retriever,
148 semantic_retriever,
149 fusion,
150 weight_history: Arc::new(RwLock::new(Vec::new())),
151 query_metrics: Arc::new(RwLock::new(Vec::new())),
152 }
153 }
154
155 pub async fn index_document(&self, doc: &Document) -> RragResult<()> {
157 if self.config.parallel_retrieval {
158 let (bm25_result, semantic_result) = tokio::join!(
160 self.bm25_retriever.index_document(doc),
161 self.semantic_retriever.index_document(doc)
162 );
163
164 bm25_result?;
165 semantic_result?;
166 } else {
167 self.bm25_retriever.index_document(doc).await?;
169 self.semantic_retriever.index_document(doc).await?;
170 }
171
172 Ok(())
173 }
174
175 pub async fn index_batch(&self, documents: Vec<Document>) -> RragResult<()> {
177 if self.config.parallel_retrieval {
178 let (bm25_result, semantic_result) = tokio::join!(
179 self.bm25_retriever.index_batch(documents.clone()),
180 self.semantic_retriever.index_batch(documents)
181 );
182
183 bm25_result?;
184 semantic_result?;
185 } else {
186 self.bm25_retriever.index_batch(documents.clone()).await?;
187 self.semantic_retriever.index_batch(documents).await?;
188 }
189
190 Ok(())
191 }
192
193 pub async fn search(&self, query: &str, limit: usize) -> RragResult<Vec<SearchResult>> {
195 let start_time = Instant::now();
196
197 let characteristics = if self.config.enable_query_analysis {
199 self.analyze_query(query)
200 } else {
201 self.simple_query_analysis(query)
202 };
203
204 let semantic_weight = if self.config.adaptive_weights {
206 self.calculate_adaptive_weight(&characteristics).await
207 } else {
208 self.config.semantic_weight
209 };
210
211 let (bm25_results, semantic_results) = if self.config.parallel_retrieval {
213 tokio::join!(
214 self.bm25_retriever.search(query, limit * 2),
215 self.semantic_retriever
216 .search(query, limit * 2, Some(self.config.min_confidence))
217 )
218 } else {
219 let bm25 = self.bm25_retriever.search(query, limit * 2).await;
220 let semantic = self
221 .semantic_retriever
222 .search(query, limit * 2, Some(self.config.min_confidence))
223 .await;
224 (bm25, semantic)
225 };
226
227 let bm25_results = bm25_results?;
228 let semantic_results = semantic_results?;
229
230 let fused_results = match self.config.fusion_strategy {
232 FusionStrategy::WeightedCombination => {
233 let fusion = WeightedFusion::new(vec![1.0 - semantic_weight, semantic_weight]);
234 fusion.fuse(vec![bm25_results, semantic_results], limit)?
235 }
236 _ => self
237 .fusion
238 .fuse(vec![bm25_results, semantic_results], limit)?,
239 };
240
241 let elapsed = start_time.elapsed().as_millis() as u64;
243 let metrics = QueryMetrics {
244 query: query.to_string(),
245 characteristics,
246 semantic_weight_used: semantic_weight,
247 response_time_ms: elapsed,
248 user_satisfaction: None,
249 };
250
251 let mut query_metrics = self.query_metrics.write().await;
252 query_metrics.push(metrics);
253
254 Ok(fused_results)
255 }
256
257 pub async fn advanced_search(
259 &self,
260 query: &str,
261 limit: usize,
262 strategies: Vec<SearchStrategy>,
263 ) -> RragResult<Vec<SearchResult>> {
264 let mut all_results = Vec::new();
265
266 for strategy in strategies {
267 let results = match strategy {
268 SearchStrategy::ExactMatch => {
269 self.bm25_retriever.search(query, limit).await?
271 }
272 SearchStrategy::Semantic => {
273 self.semantic_retriever.search(query, limit, None).await?
275 }
276 SearchStrategy::Hybrid => {
277 self.search(query, limit).await?
279 }
280 SearchStrategy::QueryExpansion => {
281 let expanded = self.expand_query(query);
283 self.search(&expanded, limit).await?
284 }
285 };
286
287 all_results.push(results);
288 }
289
290 self.fusion.fuse(all_results, limit)
292 }
293
294 fn analyze_query(&self, query: &str) -> QueryCharacteristics {
296 let tokens: Vec<&str> = query.split_whitespace().collect();
297 let num_tokens = tokens.len();
298
299 let is_question = query.contains('?')
301 || query.starts_with("what")
302 || query.starts_with("how")
303 || query.starts_with("why")
304 || query.starts_with("when")
305 || query.starts_with("where")
306 || query.starts_with("who");
307
308 let has_entities = tokens
310 .iter()
311 .any(|t| t.chars().next().map_or(false, |c| c.is_uppercase()));
312
313 let technical_terms = [
315 "algorithm",
316 "function",
317 "method",
318 "system",
319 "protocol",
320 "framework",
321 ];
322 let has_technical_terms = tokens
323 .iter()
324 .any(|t| technical_terms.contains(&t.to_lowercase().as_str()));
325
326 let complexity = (num_tokens as f32 / 10.0).min(1.0);
328
329 QueryCharacteristics {
330 num_tokens,
331 has_entities,
332 is_question,
333 has_technical_terms,
334 complexity,
335 }
336 }
337
338 fn simple_query_analysis(&self, query: &str) -> QueryCharacteristics {
340 let num_tokens = query.split_whitespace().count();
341
342 QueryCharacteristics {
343 num_tokens,
344 has_entities: false,
345 is_question: query.contains('?'),
346 has_technical_terms: false,
347 complexity: (num_tokens as f32 / 10.0).min(1.0),
348 }
349 }
350
351 async fn calculate_adaptive_weight(&self, characteristics: &QueryCharacteristics) -> f32 {
353 let mut base_weight = self.config.semantic_weight;
354
355 if characteristics.is_question {
357 base_weight += 0.1; }
359
360 if characteristics.has_entities {
361 base_weight -= 0.1; }
363
364 if characteristics.has_technical_terms {
365 base_weight -= 0.05; }
367
368 base_weight += characteristics.complexity * 0.1;
370
371 let history = self.weight_history.read().await;
373 if history.len() > 10 {
374 let recent_weights: Vec<f32> = history
376 .iter()
377 .rev()
378 .take(10)
379 .filter(|(_, score)| *score > 0.7)
380 .map(|(weight, _)| *weight)
381 .collect();
382
383 if !recent_weights.is_empty() {
384 let avg_weight: f32 =
385 recent_weights.iter().sum::<f32>() / recent_weights.len() as f32;
386 base_weight = 0.7 * base_weight + 0.3 * avg_weight;
387 }
388 }
389
390 base_weight.max(0.0).min(1.0)
392 }
393
394 fn expand_query(&self, query: &str) -> String {
396 let expansions = vec![
398 ("ML", "machine learning"),
399 ("AI", "artificial intelligence"),
400 ("NLP", "natural language processing"),
401 ("DB", "database"),
402 ];
403
404 let mut expanded = query.to_string();
405 for (abbr, full) in expansions {
406 if query.contains(abbr) && !query.contains(full) {
407 expanded.push_str(&format!(" {}", full));
408 }
409 }
410
411 expanded
412 }
413
414 pub async fn record_feedback(&self, query: &str, satisfaction: f32) -> RragResult<()> {
416 let mut metrics = self.query_metrics.write().await;
417
418 if let Some(metric) = metrics.iter_mut().rev().find(|m| m.query == query) {
420 metric.user_satisfaction = Some(satisfaction);
421
422 if satisfaction > 0.7 {
424 let mut history = self.weight_history.write().await;
425 history.push((metric.semantic_weight_used, satisfaction));
426
427 if history.len() > 100 {
429 history.drain(0..50);
430 }
431 }
432 }
433
434 Ok(())
435 }
436
437 pub async fn stats(&self) -> HybridStats {
439 let bm25_stats = self.bm25_retriever.stats().await;
440 let semantic_stats = self.semantic_retriever.stats().await;
441 let metrics = self.query_metrics.read().await;
442
443 let avg_response_time = if metrics.is_empty() {
444 0
445 } else {
446 metrics.iter().map(|m| m.response_time_ms).sum::<u64>() / metrics.len() as u64
447 };
448
449 HybridStats {
450 bm25_stats,
451 semantic_stats,
452 total_queries: metrics.len(),
453 avg_response_time_ms: avg_response_time,
454 fusion_strategy: format!("{:?}", self.config.fusion_strategy),
455 }
456 }
457}
458
459#[derive(Debug, Clone)]
461pub enum SearchStrategy {
462 ExactMatch,
464 Semantic,
466 Hybrid,
468 QueryExpansion,
470}
471
472#[derive(Debug, Serialize)]
474pub struct HybridStats {
475 pub bm25_stats: std::collections::HashMap<String, serde_json::Value>,
476 pub semantic_stats: std::collections::HashMap<String, serde_json::Value>,
477 pub total_queries: usize,
478 pub avg_response_time_ms: u64,
479 pub fusion_strategy: String,
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::embeddings::MockEmbeddingService;
486
487 #[tokio::test]
488 async fn test_hybrid_search() {
489 let config = HybridConfig::default();
490 let embedding_service = Arc::new(MockEmbeddingService::new());
491 let retriever = HybridRetriever::new(config, embedding_service);
492
493 let docs = vec![
494 Document::with_id("1", "The quick brown fox jumps over the lazy dog"),
495 Document::with_id(
496 "2",
497 "Machine learning is a subset of artificial intelligence",
498 ),
499 Document::with_id(
500 "3",
501 "Natural language processing enables text understanding",
502 ),
503 ];
504
505 retriever.index_batch(docs).await.unwrap();
506
507 let results = retriever.search("machine learning AI", 2).await.unwrap();
508 assert!(!results.is_empty());
509
510 retriever
512 .record_feedback("machine learning AI", 0.9)
513 .await
514 .unwrap();
515 }
516}