1use std::collections::HashMap;
12
13use common::FusionStrategy;
14
15use crate::fulltext::FullTextResult;
16
17const RRF_K: f32 = 60.0;
19
20type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
22
23#[derive(Debug, Clone)]
25pub struct HybridConfig {
26 pub vector_weight: f32,
28 pub require_both: bool,
30 pub fusion_strategy: FusionStrategy,
32}
33
34impl Default for HybridConfig {
35 fn default() -> Self {
36 Self {
37 vector_weight: 0.5,
38 require_both: false,
39 fusion_strategy: FusionStrategy::MinMax,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46struct RawScore {
47 score: f32,
49 metadata: Option<serde_json::Value>,
51 vector: Option<Vec<f32>>,
52}
53
54#[derive(Debug, Clone)]
56pub struct HybridResult {
57 pub id: String,
59 pub combined_score: f32,
61 pub vector_score: f32,
63 pub text_score: f32,
65 pub metadata: Option<serde_json::Value>,
67 pub vector: Option<Vec<f32>>,
69}
70
71pub struct HybridSearcher {
73 config: HybridConfig,
74}
75
76impl HybridSearcher {
77 pub fn new(config: HybridConfig) -> Self {
78 Self { config }
79 }
80
81 pub fn with_vector_weight(mut self, weight: f32) -> Self {
82 self.config.vector_weight = weight.clamp(0.0, 1.0);
83 self
84 }
85
86 pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
88 self.config.fusion_strategy = strategy;
89 self
90 }
91
92 pub fn search(
102 &self,
103 vector_results: Vec<VectorResultRow>,
104 text_results: Vec<FullTextResult>,
105 top_k: usize,
106 ) -> Vec<HybridResult> {
107 match self.config.fusion_strategy {
108 FusionStrategy::Rrf => self.rrf_search(vector_results, text_results, top_k),
109 FusionStrategy::MinMax => self.minmax_search(vector_results, text_results, top_k),
110 }
111 }
112
113 fn rrf_search(
118 &self,
119 vector_results: Vec<VectorResultRow>,
120 text_results: Vec<FullTextResult>,
121 top_k: usize,
122 ) -> Vec<HybridResult> {
123 let mut vector_map: HashMap<String, RawScore> = HashMap::new();
124 let mut vector_ranks: HashMap<String, usize> = HashMap::new();
125 let mut text_ranks: HashMap<String, usize> = HashMap::new();
126
127 let mut sorted_vec = vector_results;
129 sorted_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130 for (i, (id, score, metadata, vector)) in sorted_vec.into_iter().enumerate() {
131 vector_ranks.insert(id.clone(), i + 1);
132 vector_map.insert(
133 id,
134 RawScore {
135 score,
136 metadata,
137 vector,
138 },
139 );
140 }
141
142 let mut sorted_text = text_results;
144 sorted_text.sort_by(|a, b| {
145 b.score
146 .partial_cmp(&a.score)
147 .unwrap_or(std::cmp::Ordering::Equal)
148 });
149 for (i, result) in sorted_text.into_iter().enumerate() {
150 text_ranks.insert(result.doc_id, i + 1);
151 }
152
153 let mut all_ids: Vec<String> = vector_map
155 .keys()
156 .chain(text_ranks.keys())
157 .cloned()
158 .collect();
159 all_ids.sort();
160 all_ids.dedup();
161
162 let total = all_ids.len().max(1) as f32;
163 let mut results: Vec<HybridResult> = Vec::with_capacity(all_ids.len());
164
165 for id in all_ids {
166 let vec_rank = vector_ranks.get(&id).copied().unwrap_or(0);
167 let txt_rank = text_ranks.get(&id).copied().unwrap_or(0);
168
169 if self.config.require_both && (vec_rank == 0 || txt_rank == 0) {
170 continue;
171 }
172
173 let vec_rrf = if vec_rank > 0 {
174 1.0 / (RRF_K + vec_rank as f32)
175 } else {
176 0.0
177 };
178 let txt_rrf = if txt_rank > 0 {
179 1.0 / (RRF_K + txt_rank as f32)
180 } else {
181 0.0
182 };
183 let combined = vec_rrf + txt_rrf;
184
185 let vector_score = if vec_rank > 0 {
187 1.0 - (vec_rank as f32 - 1.0) / total
188 } else {
189 0.0
190 };
191 let text_score = if txt_rank > 0 {
192 1.0 - (txt_rank as f32 - 1.0) / total
193 } else {
194 0.0
195 };
196
197 let raw = vector_map.get(&id);
198 results.push(HybridResult {
199 id,
200 combined_score: combined,
201 vector_score,
202 text_score,
203 metadata: raw.and_then(|r| r.metadata.clone()),
204 vector: raw.and_then(|r| r.vector.clone()),
205 });
206 }
207
208 results.sort_by(|a, b| {
209 b.combined_score
210 .partial_cmp(&a.combined_score)
211 .unwrap_or(std::cmp::Ordering::Equal)
212 });
213 results.truncate(top_k);
214 results
215 }
216
217 fn minmax_search(
219 &self,
220 vector_results: Vec<VectorResultRow>,
221 text_results: Vec<FullTextResult>,
222 top_k: usize,
223 ) -> Vec<HybridResult> {
224 let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
225 let mut text_scores: HashMap<String, f32> = HashMap::new();
226
227 let mut vector_min = f32::MAX;
228 let mut vector_max = f32::MIN;
229 let mut text_min = f32::MAX;
230 let mut text_max = f32::MIN;
231
232 for (id, score, metadata, vector) in vector_results {
233 vector_min = vector_min.min(score);
234 vector_max = vector_max.max(score);
235 vector_scores.insert(
236 id,
237 RawScore {
238 score,
239 metadata,
240 vector,
241 },
242 );
243 }
244
245 for result in text_results {
246 text_min = text_min.min(result.score);
247 text_max = text_max.max(result.score);
248 text_scores.insert(result.doc_id, result.score);
249 }
250
251 let mut all_ids: Vec<String> = vector_scores
252 .keys()
253 .chain(text_scores.keys())
254 .cloned()
255 .collect();
256 all_ids.sort();
257 all_ids.dedup();
258
259 let mut results: Vec<HybridResult> = Vec::new();
260
261 for id in all_ids {
262 let vector_raw = vector_scores.get(&id);
263 let text_raw = text_scores.get(&id);
264
265 if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
266 continue;
267 }
268
269 let vector_normalized = if let Some(raw) = vector_raw {
270 normalize_score(raw.score, vector_min, vector_max)
271 } else {
272 0.0
273 };
274
275 let text_normalized = if let Some(&score) = text_raw {
276 normalize_score(score, text_min, text_max)
277 } else {
278 0.0
279 };
280
281 let combined = self.config.vector_weight * vector_normalized
282 + (1.0 - self.config.vector_weight) * text_normalized;
283
284 let (metadata, vector) = if let Some(raw) = vector_raw {
285 (raw.metadata.clone(), raw.vector.clone())
286 } else {
287 (None, None)
288 };
289
290 results.push(HybridResult {
291 id,
292 combined_score: combined,
293 vector_score: vector_normalized,
294 text_score: text_normalized,
295 metadata,
296 vector,
297 });
298 }
299
300 results.sort_by(|a, b| {
301 b.combined_score
302 .partial_cmp(&a.combined_score)
303 .unwrap_or(std::cmp::Ordering::Equal)
304 });
305 results.truncate(top_k);
306 results
307 }
308}
309
310impl Default for HybridSearcher {
311 fn default() -> Self {
312 Self::new(HybridConfig::default())
313 }
314}
315
316pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
330 match kind {
331 crate::routing::QueryKind::Keyword => 0.25,
332 crate::routing::QueryKind::Hybrid => 0.50,
333 crate::routing::QueryKind::Semantic => 0.75,
334 crate::routing::QueryKind::Temporal => 0.20,
335 }
336}
337
338fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
340 if (max - min).abs() < f32::EPSILON {
341 1.0
343 } else {
344 (score - min) / (max - min)
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_hybrid_search_basic() {
354 let searcher = HybridSearcher::default();
355
356 let vector_results = vec![
357 ("doc1".to_string(), 0.9, None, None),
358 ("doc2".to_string(), 0.7, None, None),
359 ("doc3".to_string(), 0.5, None, None),
360 ];
361
362 let text_results = vec![
363 FullTextResult {
364 doc_id: "doc1".to_string(),
365 score: 3.0,
366 },
367 FullTextResult {
368 doc_id: "doc2".to_string(),
369 score: 4.0,
370 },
371 FullTextResult {
372 doc_id: "doc4".to_string(),
373 score: 2.0,
374 },
375 ];
376
377 let results = searcher.search(vector_results, text_results, 10);
378
379 assert_eq!(results.len(), 4);
381
382 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
385 assert!(doc1.vector_score > 0.0);
386 assert!(doc1.text_score >= 0.0);
387 assert!(doc1.combined_score > 0.0);
388
389 let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
390 assert!(doc2.vector_score > 0.0);
391 assert!(doc2.text_score > 0.0); assert!(doc2.combined_score > 0.0);
393
394 assert_eq!(doc2.text_score, 1.0);
396 }
397
398 #[test]
399 fn test_hybrid_search_vector_only() {
400 let searcher = HybridSearcher::new(HybridConfig {
402 vector_weight: 1.0,
403 require_both: false,
404 fusion_strategy: FusionStrategy::MinMax,
405 });
406
407 let vector_results = vec![
408 ("doc1".to_string(), 0.9, None, None),
409 ("doc2".to_string(), 0.5, None, None),
410 ];
411
412 let text_results = vec![FullTextResult {
413 doc_id: "doc1".to_string(),
414 score: 1.0,
415 }];
416
417 let results = searcher.search(vector_results, text_results, 10);
418
419 assert_eq!(results[0].id, "doc1");
421 assert_eq!(results[0].combined_score, results[0].vector_score);
422 }
423
424 #[test]
425 fn test_hybrid_search_text_only() {
426 let searcher = HybridSearcher::new(HybridConfig {
428 vector_weight: 0.0,
429 require_both: false,
430 fusion_strategy: FusionStrategy::MinMax,
431 });
432
433 let vector_results = vec![
434 ("doc1".to_string(), 0.9, None, None),
435 ("doc2".to_string(), 0.5, None, None),
436 ];
437
438 let text_results = vec![
439 FullTextResult {
440 doc_id: "doc1".to_string(),
441 score: 1.0,
442 },
443 FullTextResult {
444 doc_id: "doc2".to_string(),
445 score: 3.0,
446 },
447 ];
448
449 let results = searcher.search(vector_results, text_results, 10);
450
451 assert_eq!(results[0].id, "doc2");
453 assert_eq!(results[0].combined_score, results[0].text_score);
454 }
455
456 #[test]
457 fn test_hybrid_search_require_both() {
458 let searcher = HybridSearcher::new(HybridConfig {
459 vector_weight: 0.5,
460 require_both: true,
461 ..Default::default()
462 });
463
464 let vector_results = vec![
465 ("doc1".to_string(), 0.9, None, None),
466 ("doc2".to_string(), 0.7, None, None),
467 ];
468
469 let text_results = vec![FullTextResult {
470 doc_id: "doc1".to_string(),
471 score: 2.0,
472 }];
473
474 let results = searcher.search(vector_results, text_results, 10);
475
476 assert_eq!(results.len(), 1);
478 assert_eq!(results[0].id, "doc1");
479 }
480
481 #[test]
482 fn test_hybrid_search_top_k() {
483 let searcher = HybridSearcher::default();
484
485 let vector_results = vec![
486 ("doc1".to_string(), 0.9, None, None),
487 ("doc2".to_string(), 0.8, None, None),
488 ("doc3".to_string(), 0.7, None, None),
489 ("doc4".to_string(), 0.6, None, None),
490 ("doc5".to_string(), 0.5, None, None),
491 ];
492
493 let text_results = vec![];
494
495 let results = searcher.search(vector_results, text_results, 3);
496
497 assert_eq!(results.len(), 3);
498 }
499
500 #[test]
501 fn test_hybrid_search_with_metadata() {
502 let searcher = HybridSearcher::default();
503
504 let metadata = serde_json::json!({"title": "Test Document"});
505 let vector = vec![1.0, 0.0, 0.0];
506
507 let vector_results = vec![(
508 "doc1".to_string(),
509 0.9,
510 Some(metadata.clone()),
511 Some(vector.clone()),
512 )];
513
514 let text_results = vec![FullTextResult {
515 doc_id: "doc1".to_string(),
516 score: 2.0,
517 }];
518
519 let results = searcher.search(vector_results, text_results, 10);
520
521 assert_eq!(results.len(), 1);
522 assert_eq!(results[0].metadata, Some(metadata));
523 assert_eq!(results[0].vector, Some(vector));
524 }
525
526 #[test]
527 fn test_normalize_score() {
528 assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
530 assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
531 assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
532
533 assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
535 }
536
537 #[test]
538 fn test_hybrid_searcher_builder() {
539 let searcher = HybridSearcher::default().with_vector_weight(0.7);
540
541 assert_eq!(searcher.config.vector_weight, 0.7);
542 }
543
544 #[test]
545 fn test_vector_weight_clamping() {
546 let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
547 assert_eq!(searcher1.config.vector_weight, 1.0);
548
549 let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
550 assert_eq!(searcher2.config.vector_weight, 0.0);
551 }
552
553 #[test]
556 fn test_adaptive_vector_weight_temporal() {
557 use crate::routing::QueryKind;
558 assert_eq!(adaptive_vector_weight(QueryKind::Temporal), 0.20);
561 assert_eq!(adaptive_vector_weight(QueryKind::Keyword), 0.25);
563 assert_eq!(adaptive_vector_weight(QueryKind::Hybrid), 0.50);
564 assert_eq!(adaptive_vector_weight(QueryKind::Semantic), 0.75);
565 }
566
567 #[test]
570 fn test_minmax_default_strategy() {
571 let searcher = HybridSearcher::default();
573 assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
574 }
575
576 #[test]
577 fn test_rrf_ranks_correctly() {
578 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
582
583 let vector_results = vec![
584 ("doc1".to_string(), 0.9, None, None),
585 ("doc2".to_string(), 0.7, None, None),
586 ("doc3".to_string(), 0.5, None, None),
587 ];
588
589 let text_results = vec![
590 FullTextResult {
591 doc_id: "doc2".to_string(),
592 score: 5.0,
593 },
594 FullTextResult {
595 doc_id: "doc1".to_string(),
596 score: 3.0,
597 },
598 ];
599
600 let results = searcher.search(vector_results, text_results, 10);
601
602 assert_eq!(results.len(), 3);
603
604 let doc3 = results.iter().find(|r| r.id == "doc3").unwrap();
606 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
607 assert!(doc1.combined_score > doc3.combined_score);
608
609 for r in &results {
611 assert!(r.combined_score > 0.0);
612 }
613 }
614
615 #[test]
616 fn test_rrf_require_both() {
617 let searcher = HybridSearcher::new(HybridConfig {
618 require_both: true,
619 ..Default::default() });
621
622 let vector_results = vec![
623 ("doc1".to_string(), 0.9, None, None),
624 ("doc2".to_string(), 0.7, None, None),
625 ];
626
627 let text_results = vec![FullTextResult {
628 doc_id: "doc1".to_string(),
629 score: 2.0,
630 }];
631
632 let results = searcher.search(vector_results, text_results, 10);
633
634 assert_eq!(results.len(), 1);
636 assert_eq!(results[0].id, "doc1");
637 }
638
639 #[test]
640 fn test_rrf_formula_k60() {
641 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
643
644 let vector_results = vec![("doc1".to_string(), 1.0, None, None)];
645 let text_results = vec![FullTextResult {
646 doc_id: "doc1".to_string(),
647 score: 1.0,
648 }];
649
650 let results = searcher.search(vector_results, text_results, 10);
651
652 assert_eq!(results.len(), 1);
653 let expected = 2.0 / (RRF_K + 1.0);
655 assert!((results[0].combined_score - expected).abs() < 1e-5);
656 }
657
658 #[test]
659 fn test_with_fusion_strategy_builder() {
660 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::MinMax);
661 assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
662 }
663}