Skip to main content

engine/
hybrid.rs

1//! Hybrid search combining vector similarity and full-text search
2//!
3//! Provides a unified search experience by combining:
4//! - Vector similarity scores (cosine, euclidean, dot product)
5//! - Full-text BM25 scores
6//!
7//! Two fusion strategies are supported:
8//! - **RRF** (default): Reciprocal Rank Fusion — `score(d) = Σ 1/(k + rank_r(d))`, k=60
9//! - **MinMax**: weighted linear combination of min-max normalised scores
10
11use std::collections::HashMap;
12
13use common::FusionStrategy;
14
15use crate::fulltext::FullTextResult;
16
17/// RRF smoothing constant (Cormack et al., SIGIR 2009 — k=60 is the canonical default).
18const RRF_K: f32 = 60.0;
19
20/// A vector search result row: (id, score, optional metadata, optional vector).
21type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
22
23/// Configuration for hybrid search
24#[derive(Debug, Clone)]
25pub struct HybridConfig {
26    /// Weight for vector search (0.0 to 1.0) — used by MinMax strategy only.
27    pub vector_weight: f32,
28    /// Whether to require matches in both indices.
29    pub require_both: bool,
30    /// CE-14: Fusion strategy (default: RRF).
31    pub fusion_strategy: FusionStrategy,
32}
33
34impl Default for HybridConfig {
35    fn default() -> Self {
36        Self {
37            vector_weight: 0.5,
38            require_both: false,
39            fusion_strategy: FusionStrategy::Rrf,
40        }
41    }
42}
43
44/// Raw score from a single search type
45#[derive(Debug, Clone)]
46struct RawScore {
47    /// Original score before normalization
48    score: f32,
49    /// Additional data (metadata, vector)
50    metadata: Option<serde_json::Value>,
51    vector: Option<Vec<f32>>,
52}
53
54/// Result of hybrid search
55#[derive(Debug, Clone)]
56pub struct HybridResult {
57    /// Document/vector ID
58    pub id: String,
59    /// Combined score (weighted average of normalized scores)
60    pub combined_score: f32,
61    /// Normalized vector similarity score (0-1)
62    pub vector_score: f32,
63    /// Normalized text search score (0-1)
64    pub text_score: f32,
65    /// Optional metadata
66    pub metadata: Option<serde_json::Value>,
67    /// Optional vector values
68    pub vector: Option<Vec<f32>>,
69}
70
71/// Hybrid search engine that combines vector and text search
72pub struct HybridSearcher {
73    config: HybridConfig,
74}
75
76impl HybridSearcher {
77    pub fn new(config: HybridConfig) -> Self {
78        Self { config }
79    }
80
81    pub fn with_vector_weight(mut self, weight: f32) -> Self {
82        self.config.vector_weight = weight.clamp(0.0, 1.0);
83        self
84    }
85
86    /// CE-14: Override the fusion strategy.
87    pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
88        self.config.fusion_strategy = strategy;
89        self
90    }
91
92    /// Combine vector search results with full-text search results.
93    ///
94    /// Dispatches to [`Self::rrf_search`] (RRF, default) or [`Self::minmax_search`]
95    /// depending on `config.fusion_strategy`.
96    ///
97    /// # Arguments
98    /// * `vector_results` - Results from vector similarity search (id, score, metadata, vector)
99    /// * `text_results` - Results from full-text BM25 search
100    /// * `top_k` - Number of results to return
101    pub fn search(
102        &self,
103        vector_results: Vec<VectorResultRow>,
104        text_results: Vec<FullTextResult>,
105        top_k: usize,
106    ) -> Vec<HybridResult> {
107        match self.config.fusion_strategy {
108            FusionStrategy::Rrf => self.rrf_search(vector_results, text_results, top_k),
109            FusionStrategy::MinMax => self.minmax_search(vector_results, text_results, top_k),
110        }
111    }
112
113    /// Reciprocal Rank Fusion (Cormack et al., SIGIR 2009).
114    ///
115    /// Each document receives `score(d) = Σ_r 1 / (k + rank_r(d))` where k=60.
116    /// Documents appearing in only one result list receive 0 from the missing retriever.
117    fn rrf_search(
118        &self,
119        vector_results: Vec<VectorResultRow>,
120        text_results: Vec<FullTextResult>,
121        top_k: usize,
122    ) -> Vec<HybridResult> {
123        let mut vector_map: HashMap<String, RawScore> = HashMap::new();
124        let mut vector_ranks: HashMap<String, usize> = HashMap::new();
125        let mut text_ranks: HashMap<String, usize> = HashMap::new();
126
127        // Sort vector results by score descending, assign 1-based ranks.
128        let mut sorted_vec = vector_results;
129        sorted_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130        for (i, (id, score, metadata, vector)) in sorted_vec.into_iter().enumerate() {
131            vector_ranks.insert(id.clone(), i + 1);
132            vector_map.insert(
133                id,
134                RawScore {
135                    score,
136                    metadata,
137                    vector,
138                },
139            );
140        }
141
142        // Sort text results by score descending, assign 1-based ranks.
143        let mut sorted_text = text_results;
144        sorted_text.sort_by(|a, b| {
145            b.score
146                .partial_cmp(&a.score)
147                .unwrap_or(std::cmp::Ordering::Equal)
148        });
149        for (i, result) in sorted_text.into_iter().enumerate() {
150            text_ranks.insert(result.doc_id, i + 1);
151        }
152
153        // Union of all document IDs.
154        let mut all_ids: Vec<String> = vector_map
155            .keys()
156            .chain(text_ranks.keys())
157            .cloned()
158            .collect();
159        all_ids.sort();
160        all_ids.dedup();
161
162        let total = all_ids.len().max(1) as f32;
163        let mut results: Vec<HybridResult> = Vec::with_capacity(all_ids.len());
164
165        for id in all_ids {
166            let vec_rank = vector_ranks.get(&id).copied().unwrap_or(0);
167            let txt_rank = text_ranks.get(&id).copied().unwrap_or(0);
168
169            if self.config.require_both && (vec_rank == 0 || txt_rank == 0) {
170                continue;
171            }
172
173            let vec_rrf = if vec_rank > 0 {
174                1.0 / (RRF_K + vec_rank as f32)
175            } else {
176                0.0
177            };
178            let txt_rrf = if txt_rank > 0 {
179                1.0 / (RRF_K + txt_rank as f32)
180            } else {
181                0.0
182            };
183            let combined = vec_rrf + txt_rrf;
184
185            // Rank-normalised display scores (0=absent, 1=top-ranked).
186            let vector_score = if vec_rank > 0 {
187                1.0 - (vec_rank as f32 - 1.0) / total
188            } else {
189                0.0
190            };
191            let text_score = if txt_rank > 0 {
192                1.0 - (txt_rank as f32 - 1.0) / total
193            } else {
194                0.0
195            };
196
197            let raw = vector_map.get(&id);
198            results.push(HybridResult {
199                id,
200                combined_score: combined,
201                vector_score,
202                text_score,
203                metadata: raw.and_then(|r| r.metadata.clone()),
204                vector: raw.and_then(|r| r.vector.clone()),
205            });
206        }
207
208        results.sort_by(|a, b| {
209            b.combined_score
210                .partial_cmp(&a.combined_score)
211                .unwrap_or(std::cmp::Ordering::Equal)
212        });
213        results.truncate(top_k);
214        results
215    }
216
217    /// Weighted min-max normalization (legacy fusion strategy).
218    fn minmax_search(
219        &self,
220        vector_results: Vec<VectorResultRow>,
221        text_results: Vec<FullTextResult>,
222        top_k: usize,
223    ) -> Vec<HybridResult> {
224        let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
225        let mut text_scores: HashMap<String, f32> = HashMap::new();
226
227        let mut vector_min = f32::MAX;
228        let mut vector_max = f32::MIN;
229        let mut text_min = f32::MAX;
230        let mut text_max = f32::MIN;
231
232        for (id, score, metadata, vector) in vector_results {
233            vector_min = vector_min.min(score);
234            vector_max = vector_max.max(score);
235            vector_scores.insert(
236                id,
237                RawScore {
238                    score,
239                    metadata,
240                    vector,
241                },
242            );
243        }
244
245        for result in text_results {
246            text_min = text_min.min(result.score);
247            text_max = text_max.max(result.score);
248            text_scores.insert(result.doc_id, result.score);
249        }
250
251        let mut all_ids: Vec<String> = vector_scores
252            .keys()
253            .chain(text_scores.keys())
254            .cloned()
255            .collect();
256        all_ids.sort();
257        all_ids.dedup();
258
259        let mut results: Vec<HybridResult> = Vec::new();
260
261        for id in all_ids {
262            let vector_raw = vector_scores.get(&id);
263            let text_raw = text_scores.get(&id);
264
265            if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
266                continue;
267            }
268
269            let vector_normalized = if let Some(raw) = vector_raw {
270                normalize_score(raw.score, vector_min, vector_max)
271            } else {
272                0.0
273            };
274
275            let text_normalized = if let Some(&score) = text_raw {
276                normalize_score(score, text_min, text_max)
277            } else {
278                0.0
279            };
280
281            let combined = self.config.vector_weight * vector_normalized
282                + (1.0 - self.config.vector_weight) * text_normalized;
283
284            let (metadata, vector) = if let Some(raw) = vector_raw {
285                (raw.metadata.clone(), raw.vector.clone())
286            } else {
287                (None, None)
288            };
289
290            results.push(HybridResult {
291                id,
292                combined_score: combined,
293                vector_score: vector_normalized,
294                text_score: text_normalized,
295                metadata,
296                vector,
297            });
298        }
299
300        results.sort_by(|a, b| {
301            b.combined_score
302                .partial_cmp(&a.combined_score)
303                .unwrap_or(std::cmp::Ordering::Equal)
304        });
305        results.truncate(top_k);
306        results
307    }
308}
309
310impl Default for HybridSearcher {
311    fn default() -> Self {
312        Self::new(HybridConfig::default())
313    }
314}
315
316// ============================================================================
317// CE-12c: Adaptive hybrid weighting
318// ============================================================================
319
320/// Return an adaptive `vector_weight` (0.0–1.0) for a [`HybridSearcher`]
321/// based on the inferred [`QueryKind`].
322///
323/// | QueryKind | vector_weight | Rationale                          |
324/// |-----------|---------------|------------------------------------|
325/// | Keyword   | 0.25          | Exact-term signals dominate         |
326/// | Hybrid    | 0.50          | Balanced blend                      |
327/// | Semantic  | 0.75          | Embedding captures intent better    |
328pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
329    match kind {
330        crate::routing::QueryKind::Keyword => 0.25,
331        crate::routing::QueryKind::Hybrid => 0.50,
332        crate::routing::QueryKind::Semantic => 0.75,
333    }
334}
335
336/// Normalize a score to 0-1 range using min-max normalization
337fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
338    if (max - min).abs() < f32::EPSILON {
339        // All scores are the same, return 1.0
340        1.0
341    } else {
342        (score - min) / (max - min)
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_hybrid_search_basic() {
352        let searcher = HybridSearcher::default();
353
354        let vector_results = vec![
355            ("doc1".to_string(), 0.9, None, None),
356            ("doc2".to_string(), 0.7, None, None),
357            ("doc3".to_string(), 0.5, None, None),
358        ];
359
360        let text_results = vec![
361            FullTextResult {
362                doc_id: "doc1".to_string(),
363                score: 3.0,
364            },
365            FullTextResult {
366                doc_id: "doc2".to_string(),
367                score: 4.0,
368            },
369            FullTextResult {
370                doc_id: "doc4".to_string(),
371                score: 2.0,
372            },
373        ];
374
375        let results = searcher.search(vector_results, text_results, 10);
376
377        // All 4 documents should be in results
378        assert_eq!(results.len(), 4);
379
380        // Check that doc1 and doc2 have both scores >= 0
381        // (normalized scores, min becomes 0.0)
382        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
383        assert!(doc1.vector_score > 0.0);
384        assert!(doc1.text_score >= 0.0);
385        assert!(doc1.combined_score > 0.0);
386
387        let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
388        assert!(doc2.vector_score > 0.0);
389        assert!(doc2.text_score > 0.0); // doc2 has highest text score, should be 1.0
390        assert!(doc2.combined_score > 0.0);
391
392        // doc2 should have the highest text score (normalized to 1.0)
393        assert_eq!(doc2.text_score, 1.0);
394    }
395
396    #[test]
397    fn test_hybrid_search_vector_only() {
398        // MinMax with vector_weight=1.0 → combined_score == vector_score
399        let searcher = HybridSearcher::new(HybridConfig {
400            vector_weight: 1.0,
401            require_both: false,
402            fusion_strategy: FusionStrategy::MinMax,
403        });
404
405        let vector_results = vec![
406            ("doc1".to_string(), 0.9, None, None),
407            ("doc2".to_string(), 0.5, None, None),
408        ];
409
410        let text_results = vec![FullTextResult {
411            doc_id: "doc1".to_string(),
412            score: 1.0,
413        }];
414
415        let results = searcher.search(vector_results, text_results, 10);
416
417        // doc1 should be first (highest vector score)
418        assert_eq!(results[0].id, "doc1");
419        assert_eq!(results[0].combined_score, results[0].vector_score);
420    }
421
422    #[test]
423    fn test_hybrid_search_text_only() {
424        // MinMax with vector_weight=0.0 → combined_score == text_score
425        let searcher = HybridSearcher::new(HybridConfig {
426            vector_weight: 0.0,
427            require_both: false,
428            fusion_strategy: FusionStrategy::MinMax,
429        });
430
431        let vector_results = vec![
432            ("doc1".to_string(), 0.9, None, None),
433            ("doc2".to_string(), 0.5, None, None),
434        ];
435
436        let text_results = vec![
437            FullTextResult {
438                doc_id: "doc1".to_string(),
439                score: 1.0,
440            },
441            FullTextResult {
442                doc_id: "doc2".to_string(),
443                score: 3.0,
444            },
445        ];
446
447        let results = searcher.search(vector_results, text_results, 10);
448
449        // doc2 should be first (highest text score)
450        assert_eq!(results[0].id, "doc2");
451        assert_eq!(results[0].combined_score, results[0].text_score);
452    }
453
454    #[test]
455    fn test_hybrid_search_require_both() {
456        let searcher = HybridSearcher::new(HybridConfig {
457            vector_weight: 0.5,
458            require_both: true,
459            ..Default::default()
460        });
461
462        let vector_results = vec![
463            ("doc1".to_string(), 0.9, None, None),
464            ("doc2".to_string(), 0.7, None, None),
465        ];
466
467        let text_results = vec![FullTextResult {
468            doc_id: "doc1".to_string(),
469            score: 2.0,
470        }];
471
472        let results = searcher.search(vector_results, text_results, 10);
473
474        // Only doc1 should be in results (only one with both scores)
475        assert_eq!(results.len(), 1);
476        assert_eq!(results[0].id, "doc1");
477    }
478
479    #[test]
480    fn test_hybrid_search_top_k() {
481        let searcher = HybridSearcher::default();
482
483        let vector_results = vec![
484            ("doc1".to_string(), 0.9, None, None),
485            ("doc2".to_string(), 0.8, None, None),
486            ("doc3".to_string(), 0.7, None, None),
487            ("doc4".to_string(), 0.6, None, None),
488            ("doc5".to_string(), 0.5, None, None),
489        ];
490
491        let text_results = vec![];
492
493        let results = searcher.search(vector_results, text_results, 3);
494
495        assert_eq!(results.len(), 3);
496    }
497
498    #[test]
499    fn test_hybrid_search_with_metadata() {
500        let searcher = HybridSearcher::default();
501
502        let metadata = serde_json::json!({"title": "Test Document"});
503        let vector = vec![1.0, 0.0, 0.0];
504
505        let vector_results = vec![(
506            "doc1".to_string(),
507            0.9,
508            Some(metadata.clone()),
509            Some(vector.clone()),
510        )];
511
512        let text_results = vec![FullTextResult {
513            doc_id: "doc1".to_string(),
514            score: 2.0,
515        }];
516
517        let results = searcher.search(vector_results, text_results, 10);
518
519        assert_eq!(results.len(), 1);
520        assert_eq!(results[0].metadata, Some(metadata));
521        assert_eq!(results[0].vector, Some(vector));
522    }
523
524    #[test]
525    fn test_normalize_score() {
526        // Normal case
527        assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
528        assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
529        assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
530
531        // All same scores
532        assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
533    }
534
535    #[test]
536    fn test_hybrid_searcher_builder() {
537        let searcher = HybridSearcher::default().with_vector_weight(0.7);
538
539        assert_eq!(searcher.config.vector_weight, 0.7);
540    }
541
542    #[test]
543    fn test_vector_weight_clamping() {
544        let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
545        assert_eq!(searcher1.config.vector_weight, 1.0);
546
547        let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
548        assert_eq!(searcher2.config.vector_weight, 0.0);
549    }
550
551    // --- CE-14: RRF tests ---
552
553    #[test]
554    fn test_rrf_default_strategy() {
555        // Default HybridSearcher uses RRF
556        let searcher = HybridSearcher::default();
557        assert_eq!(searcher.config.fusion_strategy, FusionStrategy::Rrf);
558    }
559
560    #[test]
561    fn test_rrf_ranks_correctly() {
562        // doc1: rank 1 in vector, rank 2 in text → 1/(60+1) + 1/(60+2) ≈ 0.03254
563        // doc2: rank 2 in vector, rank 1 in text → 1/(60+2) + 1/(60+1) ≈ 0.03254  (equal)
564        // doc3: rank 3 in vector, not in text    → 1/(60+3) ≈ 0.01587
565        let searcher = HybridSearcher::default(); // RRF
566
567        let vector_results = vec![
568            ("doc1".to_string(), 0.9, None, None),
569            ("doc2".to_string(), 0.7, None, None),
570            ("doc3".to_string(), 0.5, None, None),
571        ];
572
573        let text_results = vec![
574            FullTextResult {
575                doc_id: "doc2".to_string(),
576                score: 5.0,
577            },
578            FullTextResult {
579                doc_id: "doc1".to_string(),
580                score: 3.0,
581            },
582        ];
583
584        let results = searcher.search(vector_results, text_results, 10);
585
586        assert_eq!(results.len(), 3);
587
588        // doc1 and doc2 both appear in both lists — they score higher than doc3 (vector-only)
589        let doc3 = results.iter().find(|r| r.id == "doc3").unwrap();
590        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
591        assert!(doc1.combined_score > doc3.combined_score);
592
593        // Scores are positive
594        for r in &results {
595            assert!(r.combined_score > 0.0);
596        }
597    }
598
599    #[test]
600    fn test_rrf_require_both() {
601        let searcher = HybridSearcher::new(HybridConfig {
602            require_both: true,
603            ..Default::default() // fusion_strategy: Rrf
604        });
605
606        let vector_results = vec![
607            ("doc1".to_string(), 0.9, None, None),
608            ("doc2".to_string(), 0.7, None, None),
609        ];
610
611        let text_results = vec![FullTextResult {
612            doc_id: "doc1".to_string(),
613            score: 2.0,
614        }];
615
616        let results = searcher.search(vector_results, text_results, 10);
617
618        // Only doc1 appears in both lists
619        assert_eq!(results.len(), 1);
620        assert_eq!(results[0].id, "doc1");
621    }
622
623    #[test]
624    fn test_rrf_formula_k60() {
625        // Verify RRF formula: score = 1/(60 + rank), k=60
626        let searcher = HybridSearcher::default();
627
628        let vector_results = vec![("doc1".to_string(), 1.0, None, None)];
629        let text_results = vec![FullTextResult {
630            doc_id: "doc1".to_string(),
631            score: 1.0,
632        }];
633
634        let results = searcher.search(vector_results, text_results, 10);
635
636        assert_eq!(results.len(), 1);
637        // doc1 is rank 1 in both → score = 1/(60+1) + 1/(60+1) = 2/61 ≈ 0.032787
638        let expected = 2.0 / (RRF_K + 1.0);
639        assert!((results[0].combined_score - expected).abs() < 1e-5);
640    }
641
642    #[test]
643    fn test_with_fusion_strategy_builder() {
644        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::MinMax);
645        assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
646    }
647}