1use std::collections::HashMap;
12
13use common::FusionStrategy;
14
15use crate::fulltext::FullTextResult;
16
17const RRF_K: f32 = 60.0;
19
20type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
22
23#[derive(Debug, Clone)]
25pub struct HybridConfig {
26 pub vector_weight: f32,
28 pub require_both: bool,
30 pub fusion_strategy: FusionStrategy,
32}
33
34impl Default for HybridConfig {
35 fn default() -> Self {
36 Self {
37 vector_weight: 0.5,
38 require_both: false,
39 fusion_strategy: FusionStrategy::MinMax,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46struct RawScore {
47 score: f32,
49 metadata: Option<serde_json::Value>,
51 vector: Option<Vec<f32>>,
52}
53
54#[derive(Debug, Clone)]
56pub struct HybridResult {
57 pub id: String,
59 pub combined_score: f32,
61 pub vector_score: f32,
63 pub text_score: f32,
65 pub metadata: Option<serde_json::Value>,
67 pub vector: Option<Vec<f32>>,
69}
70
71pub struct HybridSearcher {
73 config: HybridConfig,
74}
75
76impl HybridSearcher {
77 pub fn new(config: HybridConfig) -> Self {
78 Self { config }
79 }
80
81 pub fn with_vector_weight(mut self, weight: f32) -> Self {
82 self.config.vector_weight = weight.clamp(0.0, 1.0);
83 self
84 }
85
86 pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
88 self.config.fusion_strategy = strategy;
89 self
90 }
91
92 pub fn search(
102 &self,
103 vector_results: Vec<VectorResultRow>,
104 text_results: Vec<FullTextResult>,
105 top_k: usize,
106 ) -> Vec<HybridResult> {
107 match self.config.fusion_strategy {
108 FusionStrategy::Rrf => self.rrf_search(vector_results, text_results, top_k),
109 FusionStrategy::MinMax => self.minmax_search(vector_results, text_results, top_k),
110 }
111 }
112
113 fn rrf_search(
118 &self,
119 vector_results: Vec<VectorResultRow>,
120 text_results: Vec<FullTextResult>,
121 top_k: usize,
122 ) -> Vec<HybridResult> {
123 let mut vector_map: HashMap<String, RawScore> = HashMap::new();
124 let mut vector_ranks: HashMap<String, usize> = HashMap::new();
125 let mut text_ranks: HashMap<String, usize> = HashMap::new();
126
127 let mut sorted_vec = vector_results;
129 sorted_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130 for (i, (id, score, metadata, vector)) in sorted_vec.into_iter().enumerate() {
131 vector_ranks.insert(id.clone(), i + 1);
132 vector_map.insert(
133 id,
134 RawScore {
135 score,
136 metadata,
137 vector,
138 },
139 );
140 }
141
142 let mut sorted_text = text_results;
144 sorted_text.sort_by(|a, b| {
145 b.score
146 .partial_cmp(&a.score)
147 .unwrap_or(std::cmp::Ordering::Equal)
148 });
149 for (i, result) in sorted_text.into_iter().enumerate() {
150 text_ranks.insert(result.doc_id, i + 1);
151 }
152
153 let mut all_ids: Vec<String> = vector_map
155 .keys()
156 .chain(text_ranks.keys())
157 .cloned()
158 .collect();
159 all_ids.sort();
160 all_ids.dedup();
161
162 let total = all_ids.len().max(1) as f32;
163 let mut results: Vec<HybridResult> = Vec::with_capacity(all_ids.len());
164
165 for id in all_ids {
166 let vec_rank = vector_ranks.get(&id).copied().unwrap_or(0);
167 let txt_rank = text_ranks.get(&id).copied().unwrap_or(0);
168
169 if self.config.require_both && (vec_rank == 0 || txt_rank == 0) {
170 continue;
171 }
172
173 let vec_rrf = if vec_rank > 0 {
174 1.0 / (RRF_K + vec_rank as f32)
175 } else {
176 0.0
177 };
178 let txt_rrf = if txt_rank > 0 {
179 1.0 / (RRF_K + txt_rank as f32)
180 } else {
181 0.0
182 };
183 let combined = vec_rrf + txt_rrf;
184
185 let vector_score = if vec_rank > 0 {
187 1.0 - (vec_rank as f32 - 1.0) / total
188 } else {
189 0.0
190 };
191 let text_score = if txt_rank > 0 {
192 1.0 - (txt_rank as f32 - 1.0) / total
193 } else {
194 0.0
195 };
196
197 let raw = vector_map.get(&id);
198 results.push(HybridResult {
199 id,
200 combined_score: combined,
201 vector_score,
202 text_score,
203 metadata: raw.and_then(|r| r.metadata.clone()),
204 vector: raw.and_then(|r| r.vector.clone()),
205 });
206 }
207
208 results.sort_by(|a, b| {
209 b.combined_score
210 .partial_cmp(&a.combined_score)
211 .unwrap_or(std::cmp::Ordering::Equal)
212 });
213 results.truncate(top_k);
214 results
215 }
216
217 fn minmax_search(
219 &self,
220 vector_results: Vec<VectorResultRow>,
221 text_results: Vec<FullTextResult>,
222 top_k: usize,
223 ) -> Vec<HybridResult> {
224 let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
225 let mut text_scores: HashMap<String, f32> = HashMap::new();
226
227 let mut vector_min = f32::MAX;
228 let mut vector_max = f32::MIN;
229 let mut text_min = f32::MAX;
230 let mut text_max = f32::MIN;
231
232 for (id, score, metadata, vector) in vector_results {
233 vector_min = vector_min.min(score);
234 vector_max = vector_max.max(score);
235 vector_scores.insert(
236 id,
237 RawScore {
238 score,
239 metadata,
240 vector,
241 },
242 );
243 }
244
245 for result in text_results {
246 text_min = text_min.min(result.score);
247 text_max = text_max.max(result.score);
248 text_scores.insert(result.doc_id, result.score);
249 }
250
251 let mut all_ids: Vec<String> = vector_scores
252 .keys()
253 .chain(text_scores.keys())
254 .cloned()
255 .collect();
256 all_ids.sort();
257 all_ids.dedup();
258
259 let mut results: Vec<HybridResult> = Vec::new();
260
261 for id in all_ids {
262 let vector_raw = vector_scores.get(&id);
263 let text_raw = text_scores.get(&id);
264
265 if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
266 continue;
267 }
268
269 let vector_normalized = if let Some(raw) = vector_raw {
270 normalize_score(raw.score, vector_min, vector_max)
271 } else {
272 0.0
273 };
274
275 let text_normalized = if let Some(&score) = text_raw {
276 normalize_score(score, text_min, text_max)
277 } else {
278 0.0
279 };
280
281 let combined = self.config.vector_weight * vector_normalized
282 + (1.0 - self.config.vector_weight) * text_normalized;
283
284 let (metadata, vector) = if let Some(raw) = vector_raw {
285 (raw.metadata.clone(), raw.vector.clone())
286 } else {
287 (None, None)
288 };
289
290 results.push(HybridResult {
291 id,
292 combined_score: combined,
293 vector_score: vector_normalized,
294 text_score: text_normalized,
295 metadata,
296 vector,
297 });
298 }
299
300 results.sort_by(|a, b| {
301 b.combined_score
302 .partial_cmp(&a.combined_score)
303 .unwrap_or(std::cmp::Ordering::Equal)
304 });
305 results.truncate(top_k);
306 results
307 }
308}
309
310impl Default for HybridSearcher {
311 fn default() -> Self {
312 Self::new(HybridConfig::default())
313 }
314}
315
316pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
336 match kind {
337 crate::routing::QueryKind::Keyword => 0.25,
338 crate::routing::QueryKind::Hybrid => 0.50,
339 crate::routing::QueryKind::Semantic => 0.75,
340 crate::routing::QueryKind::Temporal => 0.00,
341 }
342}
343
344fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
346 if (max - min).abs() < f32::EPSILON {
347 1.0
349 } else {
350 (score - min) / (max - min)
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_hybrid_search_basic() {
360 let searcher = HybridSearcher::default();
361
362 let vector_results = vec![
363 ("doc1".to_string(), 0.9, None, None),
364 ("doc2".to_string(), 0.7, None, None),
365 ("doc3".to_string(), 0.5, None, None),
366 ];
367
368 let text_results = vec![
369 FullTextResult {
370 doc_id: "doc1".to_string(),
371 score: 3.0,
372 },
373 FullTextResult {
374 doc_id: "doc2".to_string(),
375 score: 4.0,
376 },
377 FullTextResult {
378 doc_id: "doc4".to_string(),
379 score: 2.0,
380 },
381 ];
382
383 let results = searcher.search(vector_results, text_results, 10);
384
385 assert_eq!(results.len(), 4);
387
388 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
391 assert!(doc1.vector_score > 0.0);
392 assert!(doc1.text_score >= 0.0);
393 assert!(doc1.combined_score > 0.0);
394
395 let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
396 assert!(doc2.vector_score > 0.0);
397 assert!(doc2.text_score > 0.0); assert!(doc2.combined_score > 0.0);
399
400 assert_eq!(doc2.text_score, 1.0);
402 }
403
404 #[test]
405 fn test_hybrid_search_vector_only() {
406 let searcher = HybridSearcher::new(HybridConfig {
408 vector_weight: 1.0,
409 require_both: false,
410 fusion_strategy: FusionStrategy::MinMax,
411 });
412
413 let vector_results = vec![
414 ("doc1".to_string(), 0.9, None, None),
415 ("doc2".to_string(), 0.5, None, None),
416 ];
417
418 let text_results = vec![FullTextResult {
419 doc_id: "doc1".to_string(),
420 score: 1.0,
421 }];
422
423 let results = searcher.search(vector_results, text_results, 10);
424
425 assert_eq!(results[0].id, "doc1");
427 assert_eq!(results[0].combined_score, results[0].vector_score);
428 }
429
430 #[test]
431 fn test_hybrid_search_text_only() {
432 let searcher = HybridSearcher::new(HybridConfig {
434 vector_weight: 0.0,
435 require_both: false,
436 fusion_strategy: FusionStrategy::MinMax,
437 });
438
439 let vector_results = vec![
440 ("doc1".to_string(), 0.9, None, None),
441 ("doc2".to_string(), 0.5, None, None),
442 ];
443
444 let text_results = vec![
445 FullTextResult {
446 doc_id: "doc1".to_string(),
447 score: 1.0,
448 },
449 FullTextResult {
450 doc_id: "doc2".to_string(),
451 score: 3.0,
452 },
453 ];
454
455 let results = searcher.search(vector_results, text_results, 10);
456
457 assert_eq!(results[0].id, "doc2");
459 assert_eq!(results[0].combined_score, results[0].text_score);
460 }
461
462 #[test]
463 fn test_hybrid_search_require_both() {
464 let searcher = HybridSearcher::new(HybridConfig {
465 vector_weight: 0.5,
466 require_both: true,
467 ..Default::default()
468 });
469
470 let vector_results = vec![
471 ("doc1".to_string(), 0.9, None, None),
472 ("doc2".to_string(), 0.7, None, None),
473 ];
474
475 let text_results = vec![FullTextResult {
476 doc_id: "doc1".to_string(),
477 score: 2.0,
478 }];
479
480 let results = searcher.search(vector_results, text_results, 10);
481
482 assert_eq!(results.len(), 1);
484 assert_eq!(results[0].id, "doc1");
485 }
486
487 #[test]
488 fn test_hybrid_search_top_k() {
489 let searcher = HybridSearcher::default();
490
491 let vector_results = vec![
492 ("doc1".to_string(), 0.9, None, None),
493 ("doc2".to_string(), 0.8, None, None),
494 ("doc3".to_string(), 0.7, None, None),
495 ("doc4".to_string(), 0.6, None, None),
496 ("doc5".to_string(), 0.5, None, None),
497 ];
498
499 let text_results = vec![];
500
501 let results = searcher.search(vector_results, text_results, 3);
502
503 assert_eq!(results.len(), 3);
504 }
505
506 #[test]
507 fn test_hybrid_search_with_metadata() {
508 let searcher = HybridSearcher::default();
509
510 let metadata = serde_json::json!({"title": "Test Document"});
511 let vector = vec![1.0, 0.0, 0.0];
512
513 let vector_results = vec![(
514 "doc1".to_string(),
515 0.9,
516 Some(metadata.clone()),
517 Some(vector.clone()),
518 )];
519
520 let text_results = vec![FullTextResult {
521 doc_id: "doc1".to_string(),
522 score: 2.0,
523 }];
524
525 let results = searcher.search(vector_results, text_results, 10);
526
527 assert_eq!(results.len(), 1);
528 assert_eq!(results[0].metadata, Some(metadata));
529 assert_eq!(results[0].vector, Some(vector));
530 }
531
532 #[test]
533 fn test_normalize_score() {
534 assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
536 assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
537 assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
538
539 assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
541 }
542
543 #[test]
544 fn test_hybrid_searcher_builder() {
545 let searcher = HybridSearcher::default().with_vector_weight(0.7);
546
547 assert_eq!(searcher.config.vector_weight, 0.7);
548 }
549
550 #[test]
551 fn test_vector_weight_clamping() {
552 let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
553 assert_eq!(searcher1.config.vector_weight, 1.0);
554
555 let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
556 assert_eq!(searcher2.config.vector_weight, 0.0);
557 }
558
559 #[test]
562 fn test_adaptive_vector_weight_temporal() {
563 use crate::routing::QueryKind;
564 assert_eq!(adaptive_vector_weight(QueryKind::Temporal), 0.00);
568 assert_eq!(adaptive_vector_weight(QueryKind::Keyword), 0.25);
570 assert_eq!(adaptive_vector_weight(QueryKind::Hybrid), 0.50);
571 assert_eq!(adaptive_vector_weight(QueryKind::Semantic), 0.75);
572 }
573
574 #[test]
577 fn test_minmax_default_strategy() {
578 let searcher = HybridSearcher::default();
580 assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
581 }
582
583 #[test]
584 fn test_rrf_ranks_correctly() {
585 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
589
590 let vector_results = vec![
591 ("doc1".to_string(), 0.9, None, None),
592 ("doc2".to_string(), 0.7, None, None),
593 ("doc3".to_string(), 0.5, None, None),
594 ];
595
596 let text_results = vec![
597 FullTextResult {
598 doc_id: "doc2".to_string(),
599 score: 5.0,
600 },
601 FullTextResult {
602 doc_id: "doc1".to_string(),
603 score: 3.0,
604 },
605 ];
606
607 let results = searcher.search(vector_results, text_results, 10);
608
609 assert_eq!(results.len(), 3);
610
611 let doc3 = results.iter().find(|r| r.id == "doc3").unwrap();
613 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
614 assert!(doc1.combined_score > doc3.combined_score);
615
616 for r in &results {
618 assert!(r.combined_score > 0.0);
619 }
620 }
621
622 #[test]
623 fn test_rrf_require_both() {
624 let searcher = HybridSearcher::new(HybridConfig {
625 require_both: true,
626 ..Default::default() });
628
629 let vector_results = vec![
630 ("doc1".to_string(), 0.9, None, None),
631 ("doc2".to_string(), 0.7, None, None),
632 ];
633
634 let text_results = vec![FullTextResult {
635 doc_id: "doc1".to_string(),
636 score: 2.0,
637 }];
638
639 let results = searcher.search(vector_results, text_results, 10);
640
641 assert_eq!(results.len(), 1);
643 assert_eq!(results[0].id, "doc1");
644 }
645
646 #[test]
647 fn test_rrf_formula_k60() {
648 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
650
651 let vector_results = vec![("doc1".to_string(), 1.0, None, None)];
652 let text_results = vec![FullTextResult {
653 doc_id: "doc1".to_string(),
654 score: 1.0,
655 }];
656
657 let results = searcher.search(vector_results, text_results, 10);
658
659 assert_eq!(results.len(), 1);
660 let expected = 2.0 / (RRF_K + 1.0);
662 assert!((results[0].combined_score - expected).abs() < 1e-5);
663 }
664
665 #[test]
666 fn test_with_fusion_strategy_builder() {
667 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::MinMax);
668 assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
669 }
670}