1use std::collections::HashMap;
12
13use common::FusionStrategy;
14
15use crate::fulltext::FullTextResult;
16
17const RRF_K: f32 = 60.0;
19
20type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
22
23#[derive(Debug, Clone)]
25pub struct HybridConfig {
26 pub vector_weight: f32,
28 pub require_both: bool,
30 pub fusion_strategy: FusionStrategy,
32}
33
34impl Default for HybridConfig {
35 fn default() -> Self {
36 Self {
37 vector_weight: 0.5,
38 require_both: false,
39 fusion_strategy: FusionStrategy::MinMax,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46struct RawScore {
47 score: f32,
49 metadata: Option<serde_json::Value>,
51 vector: Option<Vec<f32>>,
52}
53
54#[derive(Debug, Clone)]
56pub struct HybridResult {
57 pub id: String,
59 pub combined_score: f32,
61 pub vector_score: f32,
63 pub text_score: f32,
65 pub metadata: Option<serde_json::Value>,
67 pub vector: Option<Vec<f32>>,
69}
70
71pub struct HybridSearcher {
73 config: HybridConfig,
74}
75
76impl HybridSearcher {
77 pub fn new(config: HybridConfig) -> Self {
78 Self { config }
79 }
80
81 pub fn with_vector_weight(mut self, weight: f32) -> Self {
82 self.config.vector_weight = weight.clamp(0.0, 1.0);
83 self
84 }
85
86 pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
88 self.config.fusion_strategy = strategy;
89 self
90 }
91
92 pub fn search(
102 &self,
103 vector_results: Vec<VectorResultRow>,
104 text_results: Vec<FullTextResult>,
105 top_k: usize,
106 ) -> Vec<HybridResult> {
107 match self.config.fusion_strategy {
108 FusionStrategy::Rrf => self.rrf_search(vector_results, text_results, top_k),
109 FusionStrategy::MinMax => self.minmax_search(vector_results, text_results, top_k),
110 }
111 }
112
113 fn rrf_search(
118 &self,
119 vector_results: Vec<VectorResultRow>,
120 text_results: Vec<FullTextResult>,
121 top_k: usize,
122 ) -> Vec<HybridResult> {
123 let mut vector_map: HashMap<String, RawScore> = HashMap::new();
124 let mut vector_ranks: HashMap<String, usize> = HashMap::new();
125 let mut text_ranks: HashMap<String, usize> = HashMap::new();
126
127 let mut sorted_vec = vector_results;
129 sorted_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130 for (i, (id, score, metadata, vector)) in sorted_vec.into_iter().enumerate() {
131 vector_ranks.insert(id.clone(), i + 1);
132 vector_map.insert(
133 id,
134 RawScore {
135 score,
136 metadata,
137 vector,
138 },
139 );
140 }
141
142 let mut sorted_text = text_results;
144 sorted_text.sort_by(|a, b| {
145 b.score
146 .partial_cmp(&a.score)
147 .unwrap_or(std::cmp::Ordering::Equal)
148 });
149 for (i, result) in sorted_text.into_iter().enumerate() {
150 text_ranks.insert(result.doc_id, i + 1);
151 }
152
153 let mut all_ids: Vec<String> = vector_map
155 .keys()
156 .chain(text_ranks.keys())
157 .cloned()
158 .collect();
159 all_ids.sort();
160 all_ids.dedup();
161
162 let total = all_ids.len().max(1) as f32;
163 let mut results: Vec<HybridResult> = Vec::with_capacity(all_ids.len());
164
165 for id in all_ids {
166 let vec_rank = vector_ranks.get(&id).copied().unwrap_or(0);
167 let txt_rank = text_ranks.get(&id).copied().unwrap_or(0);
168
169 if self.config.require_both && (vec_rank == 0 || txt_rank == 0) {
170 continue;
171 }
172
173 let vec_rrf = if vec_rank > 0 {
174 1.0 / (RRF_K + vec_rank as f32)
175 } else {
176 0.0
177 };
178 let txt_rrf = if txt_rank > 0 {
179 1.0 / (RRF_K + txt_rank as f32)
180 } else {
181 0.0
182 };
183 let combined = vec_rrf + txt_rrf;
184
185 let vector_score = if vec_rank > 0 {
187 1.0 - (vec_rank as f32 - 1.0) / total
188 } else {
189 0.0
190 };
191 let text_score = if txt_rank > 0 {
192 1.0 - (txt_rank as f32 - 1.0) / total
193 } else {
194 0.0
195 };
196
197 let raw = vector_map.get(&id);
198 results.push(HybridResult {
199 id,
200 combined_score: combined,
201 vector_score,
202 text_score,
203 metadata: raw.and_then(|r| r.metadata.clone()),
204 vector: raw.and_then(|r| r.vector.clone()),
205 });
206 }
207
208 results.sort_by(|a, b| {
209 b.combined_score
210 .partial_cmp(&a.combined_score)
211 .unwrap_or(std::cmp::Ordering::Equal)
212 });
213 results.truncate(top_k);
214 results
215 }
216
217 fn minmax_search(
219 &self,
220 vector_results: Vec<VectorResultRow>,
221 text_results: Vec<FullTextResult>,
222 top_k: usize,
223 ) -> Vec<HybridResult> {
224 let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
225 let mut text_scores: HashMap<String, f32> = HashMap::new();
226
227 let mut vector_min = f32::MAX;
228 let mut vector_max = f32::MIN;
229 let mut text_min = f32::MAX;
230 let mut text_max = f32::MIN;
231
232 for (id, score, metadata, vector) in vector_results {
233 vector_min = vector_min.min(score);
234 vector_max = vector_max.max(score);
235 vector_scores.insert(
236 id,
237 RawScore {
238 score,
239 metadata,
240 vector,
241 },
242 );
243 }
244
245 for result in text_results {
246 text_min = text_min.min(result.score);
247 text_max = text_max.max(result.score);
248 text_scores.insert(result.doc_id, result.score);
249 }
250
251 let mut all_ids: Vec<String> = vector_scores
252 .keys()
253 .chain(text_scores.keys())
254 .cloned()
255 .collect();
256 all_ids.sort();
257 all_ids.dedup();
258
259 let mut results: Vec<HybridResult> = Vec::new();
260
261 for id in all_ids {
262 let vector_raw = vector_scores.get(&id);
263 let text_raw = text_scores.get(&id);
264
265 if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
266 continue;
267 }
268
269 let vector_normalized = if let Some(raw) = vector_raw {
270 normalize_score(raw.score, vector_min, vector_max)
271 } else {
272 0.0
273 };
274
275 let text_normalized = if let Some(&score) = text_raw {
276 normalize_score(score, text_min, text_max)
277 } else {
278 0.0
279 };
280
281 let combined = self.config.vector_weight * vector_normalized
282 + (1.0 - self.config.vector_weight) * text_normalized;
283
284 let (metadata, vector) = if let Some(raw) = vector_raw {
285 (raw.metadata.clone(), raw.vector.clone())
286 } else {
287 (None, None)
288 };
289
290 results.push(HybridResult {
291 id,
292 combined_score: combined,
293 vector_score: vector_normalized,
294 text_score: text_normalized,
295 metadata,
296 vector,
297 });
298 }
299
300 results.sort_by(|a, b| {
301 b.combined_score
302 .partial_cmp(&a.combined_score)
303 .unwrap_or(std::cmp::Ordering::Equal)
304 });
305 results.truncate(top_k);
306 results
307 }
308}
309
310impl Default for HybridSearcher {
311 fn default() -> Self {
312 Self::new(HybridConfig::default())
313 }
314}
315
316pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
336 match kind {
337 crate::routing::QueryKind::Keyword => 0.25,
338 crate::routing::QueryKind::Hybrid => 0.50,
339 crate::routing::QueryKind::Semantic => 0.75,
340 crate::routing::QueryKind::Temporal => 0.00,
341 crate::routing::QueryKind::MultiHop => 0.40,
345 }
346}
347
348fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
350 if (max - min).abs() < f32::EPSILON {
351 1.0
353 } else {
354 (score - min) / (max - min)
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_hybrid_search_basic() {
364 let searcher = HybridSearcher::default();
365
366 let vector_results = vec![
367 ("doc1".to_string(), 0.9, None, None),
368 ("doc2".to_string(), 0.7, None, None),
369 ("doc3".to_string(), 0.5, None, None),
370 ];
371
372 let text_results = vec![
373 FullTextResult {
374 doc_id: "doc1".to_string(),
375 score: 3.0,
376 },
377 FullTextResult {
378 doc_id: "doc2".to_string(),
379 score: 4.0,
380 },
381 FullTextResult {
382 doc_id: "doc4".to_string(),
383 score: 2.0,
384 },
385 ];
386
387 let results = searcher.search(vector_results, text_results, 10);
388
389 assert_eq!(results.len(), 4);
391
392 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
395 assert!(doc1.vector_score > 0.0);
396 assert!(doc1.text_score >= 0.0);
397 assert!(doc1.combined_score > 0.0);
398
399 let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
400 assert!(doc2.vector_score > 0.0);
401 assert!(doc2.text_score > 0.0); assert!(doc2.combined_score > 0.0);
403
404 assert_eq!(doc2.text_score, 1.0);
406 }
407
408 #[test]
409 fn test_hybrid_search_vector_only() {
410 let searcher = HybridSearcher::new(HybridConfig {
412 vector_weight: 1.0,
413 require_both: false,
414 fusion_strategy: FusionStrategy::MinMax,
415 });
416
417 let vector_results = vec![
418 ("doc1".to_string(), 0.9, None, None),
419 ("doc2".to_string(), 0.5, None, None),
420 ];
421
422 let text_results = vec![FullTextResult {
423 doc_id: "doc1".to_string(),
424 score: 1.0,
425 }];
426
427 let results = searcher.search(vector_results, text_results, 10);
428
429 assert_eq!(results[0].id, "doc1");
431 assert_eq!(results[0].combined_score, results[0].vector_score);
432 }
433
434 #[test]
435 fn test_hybrid_search_text_only() {
436 let searcher = HybridSearcher::new(HybridConfig {
438 vector_weight: 0.0,
439 require_both: false,
440 fusion_strategy: FusionStrategy::MinMax,
441 });
442
443 let vector_results = vec![
444 ("doc1".to_string(), 0.9, None, None),
445 ("doc2".to_string(), 0.5, None, None),
446 ];
447
448 let text_results = vec![
449 FullTextResult {
450 doc_id: "doc1".to_string(),
451 score: 1.0,
452 },
453 FullTextResult {
454 doc_id: "doc2".to_string(),
455 score: 3.0,
456 },
457 ];
458
459 let results = searcher.search(vector_results, text_results, 10);
460
461 assert_eq!(results[0].id, "doc2");
463 assert_eq!(results[0].combined_score, results[0].text_score);
464 }
465
466 #[test]
467 fn test_hybrid_search_require_both() {
468 let searcher = HybridSearcher::new(HybridConfig {
469 vector_weight: 0.5,
470 require_both: true,
471 ..Default::default()
472 });
473
474 let vector_results = vec![
475 ("doc1".to_string(), 0.9, None, None),
476 ("doc2".to_string(), 0.7, None, None),
477 ];
478
479 let text_results = vec![FullTextResult {
480 doc_id: "doc1".to_string(),
481 score: 2.0,
482 }];
483
484 let results = searcher.search(vector_results, text_results, 10);
485
486 assert_eq!(results.len(), 1);
488 assert_eq!(results[0].id, "doc1");
489 }
490
491 #[test]
492 fn test_hybrid_search_top_k() {
493 let searcher = HybridSearcher::default();
494
495 let vector_results = vec![
496 ("doc1".to_string(), 0.9, None, None),
497 ("doc2".to_string(), 0.8, None, None),
498 ("doc3".to_string(), 0.7, None, None),
499 ("doc4".to_string(), 0.6, None, None),
500 ("doc5".to_string(), 0.5, None, None),
501 ];
502
503 let text_results = vec![];
504
505 let results = searcher.search(vector_results, text_results, 3);
506
507 assert_eq!(results.len(), 3);
508 }
509
510 #[test]
511 fn test_hybrid_search_with_metadata() {
512 let searcher = HybridSearcher::default();
513
514 let metadata = serde_json::json!({"title": "Test Document"});
515 let vector = vec![1.0, 0.0, 0.0];
516
517 let vector_results = vec![(
518 "doc1".to_string(),
519 0.9,
520 Some(metadata.clone()),
521 Some(vector.clone()),
522 )];
523
524 let text_results = vec![FullTextResult {
525 doc_id: "doc1".to_string(),
526 score: 2.0,
527 }];
528
529 let results = searcher.search(vector_results, text_results, 10);
530
531 assert_eq!(results.len(), 1);
532 assert_eq!(results[0].metadata, Some(metadata));
533 assert_eq!(results[0].vector, Some(vector));
534 }
535
536 #[test]
537 fn test_normalize_score() {
538 assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
540 assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
541 assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
542
543 assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
545 }
546
547 #[test]
548 fn test_hybrid_searcher_builder() {
549 let searcher = HybridSearcher::default().with_vector_weight(0.7);
550
551 assert_eq!(searcher.config.vector_weight, 0.7);
552 }
553
554 #[test]
555 fn test_vector_weight_clamping() {
556 let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
557 assert_eq!(searcher1.config.vector_weight, 1.0);
558
559 let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
560 assert_eq!(searcher2.config.vector_weight, 0.0);
561 }
562
563 #[test]
566 fn test_adaptive_vector_weight_temporal() {
567 use crate::routing::QueryKind;
568 assert_eq!(adaptive_vector_weight(QueryKind::Temporal), 0.00);
572 assert_eq!(adaptive_vector_weight(QueryKind::Keyword), 0.25);
574 assert_eq!(adaptive_vector_weight(QueryKind::Hybrid), 0.50);
575 assert_eq!(adaptive_vector_weight(QueryKind::Semantic), 0.75);
576 assert_eq!(adaptive_vector_weight(QueryKind::MultiHop), 0.40);
578 }
579
580 #[test]
583 fn test_minmax_default_strategy() {
584 let searcher = HybridSearcher::default();
586 assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
587 }
588
589 #[test]
590 fn test_rrf_ranks_correctly() {
591 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
595
596 let vector_results = vec![
597 ("doc1".to_string(), 0.9, None, None),
598 ("doc2".to_string(), 0.7, None, None),
599 ("doc3".to_string(), 0.5, None, None),
600 ];
601
602 let text_results = vec![
603 FullTextResult {
604 doc_id: "doc2".to_string(),
605 score: 5.0,
606 },
607 FullTextResult {
608 doc_id: "doc1".to_string(),
609 score: 3.0,
610 },
611 ];
612
613 let results = searcher.search(vector_results, text_results, 10);
614
615 assert_eq!(results.len(), 3);
616
617 let doc3 = results.iter().find(|r| r.id == "doc3").unwrap();
619 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
620 assert!(doc1.combined_score > doc3.combined_score);
621
622 for r in &results {
624 assert!(r.combined_score > 0.0);
625 }
626 }
627
628 #[test]
629 fn test_rrf_require_both() {
630 let searcher = HybridSearcher::new(HybridConfig {
631 require_both: true,
632 ..Default::default() });
634
635 let vector_results = vec![
636 ("doc1".to_string(), 0.9, None, None),
637 ("doc2".to_string(), 0.7, None, None),
638 ];
639
640 let text_results = vec![FullTextResult {
641 doc_id: "doc1".to_string(),
642 score: 2.0,
643 }];
644
645 let results = searcher.search(vector_results, text_results, 10);
646
647 assert_eq!(results.len(), 1);
649 assert_eq!(results[0].id, "doc1");
650 }
651
652 #[test]
653 fn test_rrf_formula_k60() {
654 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
656
657 let vector_results = vec![("doc1".to_string(), 1.0, None, None)];
658 let text_results = vec![FullTextResult {
659 doc_id: "doc1".to_string(),
660 score: 1.0,
661 }];
662
663 let results = searcher.search(vector_results, text_results, 10);
664
665 assert_eq!(results.len(), 1);
666 let expected = 2.0 / (RRF_K + 1.0);
668 assert!((results[0].combined_score - expected).abs() < 1e-5);
669 }
670
671 #[test]
672 fn test_with_fusion_strategy_builder() {
673 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::MinMax);
674 assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
675 }
676}