1use std::collections::HashMap;
12
13use common::FusionStrategy;
14
15use crate::fulltext::FullTextResult;
16
17const RRF_K: f32 = 60.0;
19
20type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
22
23#[derive(Debug, Clone)]
25pub struct HybridConfig {
26 pub vector_weight: f32,
28 pub require_both: bool,
30 pub fusion_strategy: FusionStrategy,
32}
33
34impl Default for HybridConfig {
35 fn default() -> Self {
36 Self {
37 vector_weight: 0.5,
38 require_both: false,
39 fusion_strategy: FusionStrategy::Rrf,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46struct RawScore {
47 score: f32,
49 metadata: Option<serde_json::Value>,
51 vector: Option<Vec<f32>>,
52}
53
54#[derive(Debug, Clone)]
56pub struct HybridResult {
57 pub id: String,
59 pub combined_score: f32,
61 pub vector_score: f32,
63 pub text_score: f32,
65 pub metadata: Option<serde_json::Value>,
67 pub vector: Option<Vec<f32>>,
69}
70
71pub struct HybridSearcher {
73 config: HybridConfig,
74}
75
76impl HybridSearcher {
77 pub fn new(config: HybridConfig) -> Self {
78 Self { config }
79 }
80
81 pub fn with_vector_weight(mut self, weight: f32) -> Self {
82 self.config.vector_weight = weight.clamp(0.0, 1.0);
83 self
84 }
85
86 pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
88 self.config.fusion_strategy = strategy;
89 self
90 }
91
92 pub fn search(
102 &self,
103 vector_results: Vec<VectorResultRow>,
104 text_results: Vec<FullTextResult>,
105 top_k: usize,
106 ) -> Vec<HybridResult> {
107 match self.config.fusion_strategy {
108 FusionStrategy::Rrf => self.rrf_search(vector_results, text_results, top_k),
109 FusionStrategy::MinMax => self.minmax_search(vector_results, text_results, top_k),
110 }
111 }
112
113 fn rrf_search(
118 &self,
119 vector_results: Vec<VectorResultRow>,
120 text_results: Vec<FullTextResult>,
121 top_k: usize,
122 ) -> Vec<HybridResult> {
123 let mut vector_map: HashMap<String, RawScore> = HashMap::new();
124 let mut vector_ranks: HashMap<String, usize> = HashMap::new();
125 let mut text_ranks: HashMap<String, usize> = HashMap::new();
126
127 let mut sorted_vec = vector_results;
129 sorted_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130 for (i, (id, score, metadata, vector)) in sorted_vec.into_iter().enumerate() {
131 vector_ranks.insert(id.clone(), i + 1);
132 vector_map.insert(
133 id,
134 RawScore {
135 score,
136 metadata,
137 vector,
138 },
139 );
140 }
141
142 let mut sorted_text = text_results;
144 sorted_text.sort_by(|a, b| {
145 b.score
146 .partial_cmp(&a.score)
147 .unwrap_or(std::cmp::Ordering::Equal)
148 });
149 for (i, result) in sorted_text.into_iter().enumerate() {
150 text_ranks.insert(result.doc_id, i + 1);
151 }
152
153 let mut all_ids: Vec<String> = vector_map
155 .keys()
156 .chain(text_ranks.keys())
157 .cloned()
158 .collect();
159 all_ids.sort();
160 all_ids.dedup();
161
162 let total = all_ids.len().max(1) as f32;
163 let mut results: Vec<HybridResult> = Vec::with_capacity(all_ids.len());
164
165 for id in all_ids {
166 let vec_rank = vector_ranks.get(&id).copied().unwrap_or(0);
167 let txt_rank = text_ranks.get(&id).copied().unwrap_or(0);
168
169 if self.config.require_both && (vec_rank == 0 || txt_rank == 0) {
170 continue;
171 }
172
173 let vec_rrf = if vec_rank > 0 {
174 1.0 / (RRF_K + vec_rank as f32)
175 } else {
176 0.0
177 };
178 let txt_rrf = if txt_rank > 0 {
179 1.0 / (RRF_K + txt_rank as f32)
180 } else {
181 0.0
182 };
183 let combined = vec_rrf + txt_rrf;
184
185 let vector_score = if vec_rank > 0 {
187 1.0 - (vec_rank as f32 - 1.0) / total
188 } else {
189 0.0
190 };
191 let text_score = if txt_rank > 0 {
192 1.0 - (txt_rank as f32 - 1.0) / total
193 } else {
194 0.0
195 };
196
197 let raw = vector_map.get(&id);
198 results.push(HybridResult {
199 id,
200 combined_score: combined,
201 vector_score,
202 text_score,
203 metadata: raw.and_then(|r| r.metadata.clone()),
204 vector: raw.and_then(|r| r.vector.clone()),
205 });
206 }
207
208 results.sort_by(|a, b| {
209 b.combined_score
210 .partial_cmp(&a.combined_score)
211 .unwrap_or(std::cmp::Ordering::Equal)
212 });
213 results.truncate(top_k);
214 results
215 }
216
217 fn minmax_search(
219 &self,
220 vector_results: Vec<VectorResultRow>,
221 text_results: Vec<FullTextResult>,
222 top_k: usize,
223 ) -> Vec<HybridResult> {
224 let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
225 let mut text_scores: HashMap<String, f32> = HashMap::new();
226
227 let mut vector_min = f32::MAX;
228 let mut vector_max = f32::MIN;
229 let mut text_min = f32::MAX;
230 let mut text_max = f32::MIN;
231
232 for (id, score, metadata, vector) in vector_results {
233 vector_min = vector_min.min(score);
234 vector_max = vector_max.max(score);
235 vector_scores.insert(
236 id,
237 RawScore {
238 score,
239 metadata,
240 vector,
241 },
242 );
243 }
244
245 for result in text_results {
246 text_min = text_min.min(result.score);
247 text_max = text_max.max(result.score);
248 text_scores.insert(result.doc_id, result.score);
249 }
250
251 let mut all_ids: Vec<String> = vector_scores
252 .keys()
253 .chain(text_scores.keys())
254 .cloned()
255 .collect();
256 all_ids.sort();
257 all_ids.dedup();
258
259 let mut results: Vec<HybridResult> = Vec::new();
260
261 for id in all_ids {
262 let vector_raw = vector_scores.get(&id);
263 let text_raw = text_scores.get(&id);
264
265 if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
266 continue;
267 }
268
269 let vector_normalized = if let Some(raw) = vector_raw {
270 normalize_score(raw.score, vector_min, vector_max)
271 } else {
272 0.0
273 };
274
275 let text_normalized = if let Some(&score) = text_raw {
276 normalize_score(score, text_min, text_max)
277 } else {
278 0.0
279 };
280
281 let combined = self.config.vector_weight * vector_normalized
282 + (1.0 - self.config.vector_weight) * text_normalized;
283
284 let (metadata, vector) = if let Some(raw) = vector_raw {
285 (raw.metadata.clone(), raw.vector.clone())
286 } else {
287 (None, None)
288 };
289
290 results.push(HybridResult {
291 id,
292 combined_score: combined,
293 vector_score: vector_normalized,
294 text_score: text_normalized,
295 metadata,
296 vector,
297 });
298 }
299
300 results.sort_by(|a, b| {
301 b.combined_score
302 .partial_cmp(&a.combined_score)
303 .unwrap_or(std::cmp::Ordering::Equal)
304 });
305 results.truncate(top_k);
306 results
307 }
308}
309
310impl Default for HybridSearcher {
311 fn default() -> Self {
312 Self::new(HybridConfig::default())
313 }
314}
315
316pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
329 match kind {
330 crate::routing::QueryKind::Keyword => 0.25,
331 crate::routing::QueryKind::Hybrid => 0.50,
332 crate::routing::QueryKind::Semantic => 0.75,
333 }
334}
335
336fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
338 if (max - min).abs() < f32::EPSILON {
339 1.0
341 } else {
342 (score - min) / (max - min)
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_hybrid_search_basic() {
352 let searcher = HybridSearcher::default();
353
354 let vector_results = vec![
355 ("doc1".to_string(), 0.9, None, None),
356 ("doc2".to_string(), 0.7, None, None),
357 ("doc3".to_string(), 0.5, None, None),
358 ];
359
360 let text_results = vec![
361 FullTextResult {
362 doc_id: "doc1".to_string(),
363 score: 3.0,
364 },
365 FullTextResult {
366 doc_id: "doc2".to_string(),
367 score: 4.0,
368 },
369 FullTextResult {
370 doc_id: "doc4".to_string(),
371 score: 2.0,
372 },
373 ];
374
375 let results = searcher.search(vector_results, text_results, 10);
376
377 assert_eq!(results.len(), 4);
379
380 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
383 assert!(doc1.vector_score > 0.0);
384 assert!(doc1.text_score >= 0.0);
385 assert!(doc1.combined_score > 0.0);
386
387 let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
388 assert!(doc2.vector_score > 0.0);
389 assert!(doc2.text_score > 0.0); assert!(doc2.combined_score > 0.0);
391
392 assert_eq!(doc2.text_score, 1.0);
394 }
395
396 #[test]
397 fn test_hybrid_search_vector_only() {
398 let searcher = HybridSearcher::new(HybridConfig {
400 vector_weight: 1.0,
401 require_both: false,
402 fusion_strategy: FusionStrategy::MinMax,
403 });
404
405 let vector_results = vec![
406 ("doc1".to_string(), 0.9, None, None),
407 ("doc2".to_string(), 0.5, None, None),
408 ];
409
410 let text_results = vec![FullTextResult {
411 doc_id: "doc1".to_string(),
412 score: 1.0,
413 }];
414
415 let results = searcher.search(vector_results, text_results, 10);
416
417 assert_eq!(results[0].id, "doc1");
419 assert_eq!(results[0].combined_score, results[0].vector_score);
420 }
421
422 #[test]
423 fn test_hybrid_search_text_only() {
424 let searcher = HybridSearcher::new(HybridConfig {
426 vector_weight: 0.0,
427 require_both: false,
428 fusion_strategy: FusionStrategy::MinMax,
429 });
430
431 let vector_results = vec![
432 ("doc1".to_string(), 0.9, None, None),
433 ("doc2".to_string(), 0.5, None, None),
434 ];
435
436 let text_results = vec![
437 FullTextResult {
438 doc_id: "doc1".to_string(),
439 score: 1.0,
440 },
441 FullTextResult {
442 doc_id: "doc2".to_string(),
443 score: 3.0,
444 },
445 ];
446
447 let results = searcher.search(vector_results, text_results, 10);
448
449 assert_eq!(results[0].id, "doc2");
451 assert_eq!(results[0].combined_score, results[0].text_score);
452 }
453
454 #[test]
455 fn test_hybrid_search_require_both() {
456 let searcher = HybridSearcher::new(HybridConfig {
457 vector_weight: 0.5,
458 require_both: true,
459 ..Default::default()
460 });
461
462 let vector_results = vec![
463 ("doc1".to_string(), 0.9, None, None),
464 ("doc2".to_string(), 0.7, None, None),
465 ];
466
467 let text_results = vec![FullTextResult {
468 doc_id: "doc1".to_string(),
469 score: 2.0,
470 }];
471
472 let results = searcher.search(vector_results, text_results, 10);
473
474 assert_eq!(results.len(), 1);
476 assert_eq!(results[0].id, "doc1");
477 }
478
479 #[test]
480 fn test_hybrid_search_top_k() {
481 let searcher = HybridSearcher::default();
482
483 let vector_results = vec![
484 ("doc1".to_string(), 0.9, None, None),
485 ("doc2".to_string(), 0.8, None, None),
486 ("doc3".to_string(), 0.7, None, None),
487 ("doc4".to_string(), 0.6, None, None),
488 ("doc5".to_string(), 0.5, None, None),
489 ];
490
491 let text_results = vec![];
492
493 let results = searcher.search(vector_results, text_results, 3);
494
495 assert_eq!(results.len(), 3);
496 }
497
498 #[test]
499 fn test_hybrid_search_with_metadata() {
500 let searcher = HybridSearcher::default();
501
502 let metadata = serde_json::json!({"title": "Test Document"});
503 let vector = vec![1.0, 0.0, 0.0];
504
505 let vector_results = vec![(
506 "doc1".to_string(),
507 0.9,
508 Some(metadata.clone()),
509 Some(vector.clone()),
510 )];
511
512 let text_results = vec![FullTextResult {
513 doc_id: "doc1".to_string(),
514 score: 2.0,
515 }];
516
517 let results = searcher.search(vector_results, text_results, 10);
518
519 assert_eq!(results.len(), 1);
520 assert_eq!(results[0].metadata, Some(metadata));
521 assert_eq!(results[0].vector, Some(vector));
522 }
523
524 #[test]
525 fn test_normalize_score() {
526 assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
528 assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
529 assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
530
531 assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
533 }
534
535 #[test]
536 fn test_hybrid_searcher_builder() {
537 let searcher = HybridSearcher::default().with_vector_weight(0.7);
538
539 assert_eq!(searcher.config.vector_weight, 0.7);
540 }
541
542 #[test]
543 fn test_vector_weight_clamping() {
544 let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
545 assert_eq!(searcher1.config.vector_weight, 1.0);
546
547 let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
548 assert_eq!(searcher2.config.vector_weight, 0.0);
549 }
550
551 #[test]
554 fn test_rrf_default_strategy() {
555 let searcher = HybridSearcher::default();
557 assert_eq!(searcher.config.fusion_strategy, FusionStrategy::Rrf);
558 }
559
560 #[test]
561 fn test_rrf_ranks_correctly() {
562 let searcher = HybridSearcher::default(); let vector_results = vec![
568 ("doc1".to_string(), 0.9, None, None),
569 ("doc2".to_string(), 0.7, None, None),
570 ("doc3".to_string(), 0.5, None, None),
571 ];
572
573 let text_results = vec![
574 FullTextResult {
575 doc_id: "doc2".to_string(),
576 score: 5.0,
577 },
578 FullTextResult {
579 doc_id: "doc1".to_string(),
580 score: 3.0,
581 },
582 ];
583
584 let results = searcher.search(vector_results, text_results, 10);
585
586 assert_eq!(results.len(), 3);
587
588 let doc3 = results.iter().find(|r| r.id == "doc3").unwrap();
590 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
591 assert!(doc1.combined_score > doc3.combined_score);
592
593 for r in &results {
595 assert!(r.combined_score > 0.0);
596 }
597 }
598
599 #[test]
600 fn test_rrf_require_both() {
601 let searcher = HybridSearcher::new(HybridConfig {
602 require_both: true,
603 ..Default::default() });
605
606 let vector_results = vec![
607 ("doc1".to_string(), 0.9, None, None),
608 ("doc2".to_string(), 0.7, None, None),
609 ];
610
611 let text_results = vec![FullTextResult {
612 doc_id: "doc1".to_string(),
613 score: 2.0,
614 }];
615
616 let results = searcher.search(vector_results, text_results, 10);
617
618 assert_eq!(results.len(), 1);
620 assert_eq!(results[0].id, "doc1");
621 }
622
623 #[test]
624 fn test_rrf_formula_k60() {
625 let searcher = HybridSearcher::default();
627
628 let vector_results = vec![("doc1".to_string(), 1.0, None, None)];
629 let text_results = vec![FullTextResult {
630 doc_id: "doc1".to_string(),
631 score: 1.0,
632 }];
633
634 let results = searcher.search(vector_results, text_results, 10);
635
636 assert_eq!(results.len(), 1);
637 let expected = 2.0 / (RRF_K + 1.0);
639 assert!((results[0].combined_score - expected).abs() < 1e-5);
640 }
641
642 #[test]
643 fn test_with_fusion_strategy_builder() {
644 let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::MinMax);
645 assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
646 }
647}