Skip to main content

engine/
hybrid.rs

1//! Hybrid search combining vector similarity and full-text search
2//!
3//! Provides a unified search experience by combining:
4//! - Vector similarity scores (cosine, euclidean, dot product)
5//! - Full-text BM25 scores
6//!
7//! Two fusion strategies are supported:
8//! - **MinMax** (default): weighted linear combination of min-max normalised scores
9//! - **RRF**: Reciprocal Rank Fusion — `score(d) = Σ 1/(k + rank_r(d))`, k=60
10
11use std::collections::HashMap;
12
13use common::FusionStrategy;
14
15use crate::fulltext::FullTextResult;
16
17/// RRF smoothing constant (Cormack et al., SIGIR 2009 — k=60 is the canonical default).
18const RRF_K: f32 = 60.0;
19
20/// A vector search result row: (id, score, optional metadata, optional vector).
21type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
22
23/// Configuration for hybrid search
24#[derive(Debug, Clone)]
25pub struct HybridConfig {
26    /// Weight for vector search (0.0 to 1.0) — used by MinMax strategy only.
27    pub vector_weight: f32,
28    /// Whether to require matches in both indices.
29    pub require_both: bool,
30    /// CE-14: Fusion strategy (default: MinMax since v0.11.2).
31    pub fusion_strategy: FusionStrategy,
32}
33
34impl Default for HybridConfig {
35    fn default() -> Self {
36        Self {
37            vector_weight: 0.5,
38            require_both: false,
39            fusion_strategy: FusionStrategy::MinMax,
40        }
41    }
42}
43
44/// Raw score from a single search type
45#[derive(Debug, Clone)]
46struct RawScore {
47    /// Original score before normalization
48    score: f32,
49    /// Additional data (metadata, vector)
50    metadata: Option<serde_json::Value>,
51    vector: Option<Vec<f32>>,
52}
53
54/// Result of hybrid search
55#[derive(Debug, Clone)]
56pub struct HybridResult {
57    /// Document/vector ID
58    pub id: String,
59    /// Combined score (weighted average of normalized scores)
60    pub combined_score: f32,
61    /// Normalized vector similarity score (0-1)
62    pub vector_score: f32,
63    /// Normalized text search score (0-1)
64    pub text_score: f32,
65    /// Optional metadata
66    pub metadata: Option<serde_json::Value>,
67    /// Optional vector values
68    pub vector: Option<Vec<f32>>,
69}
70
71/// Hybrid search engine that combines vector and text search
72pub struct HybridSearcher {
73    config: HybridConfig,
74}
75
76impl HybridSearcher {
77    pub fn new(config: HybridConfig) -> Self {
78        Self { config }
79    }
80
81    pub fn with_vector_weight(mut self, weight: f32) -> Self {
82        self.config.vector_weight = weight.clamp(0.0, 1.0);
83        self
84    }
85
86    /// CE-14: Override the fusion strategy.
87    pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
88        self.config.fusion_strategy = strategy;
89        self
90    }
91
92    /// Combine vector search results with full-text search results.
93    ///
94    /// Dispatches to [`Self::rrf_search`] (RRF, default) or [`Self::minmax_search`]
95    /// depending on `config.fusion_strategy`.
96    ///
97    /// # Arguments
98    /// * `vector_results` - Results from vector similarity search (id, score, metadata, vector)
99    /// * `text_results` - Results from full-text BM25 search
100    /// * `top_k` - Number of results to return
101    pub fn search(
102        &self,
103        vector_results: Vec<VectorResultRow>,
104        text_results: Vec<FullTextResult>,
105        top_k: usize,
106    ) -> Vec<HybridResult> {
107        match self.config.fusion_strategy {
108            FusionStrategy::Rrf => self.rrf_search(vector_results, text_results, top_k),
109            FusionStrategy::MinMax => self.minmax_search(vector_results, text_results, top_k),
110        }
111    }
112
113    /// Reciprocal Rank Fusion (Cormack et al., SIGIR 2009).
114    ///
115    /// Each document receives `score(d) = Σ_r 1 / (k + rank_r(d))` where k=60.
116    /// Documents appearing in only one result list receive 0 from the missing retriever.
117    fn rrf_search(
118        &self,
119        vector_results: Vec<VectorResultRow>,
120        text_results: Vec<FullTextResult>,
121        top_k: usize,
122    ) -> Vec<HybridResult> {
123        let mut vector_map: HashMap<String, RawScore> = HashMap::new();
124        let mut vector_ranks: HashMap<String, usize> = HashMap::new();
125        let mut text_ranks: HashMap<String, usize> = HashMap::new();
126
127        // Sort vector results by score descending, assign 1-based ranks.
128        let mut sorted_vec = vector_results;
129        sorted_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130        for (i, (id, score, metadata, vector)) in sorted_vec.into_iter().enumerate() {
131            vector_ranks.insert(id.clone(), i + 1);
132            vector_map.insert(
133                id,
134                RawScore {
135                    score,
136                    metadata,
137                    vector,
138                },
139            );
140        }
141
142        // Sort text results by score descending, assign 1-based ranks.
143        let mut sorted_text = text_results;
144        sorted_text.sort_by(|a, b| {
145            b.score
146                .partial_cmp(&a.score)
147                .unwrap_or(std::cmp::Ordering::Equal)
148        });
149        for (i, result) in sorted_text.into_iter().enumerate() {
150            text_ranks.insert(result.doc_id, i + 1);
151        }
152
153        // Union of all document IDs.
154        let mut all_ids: Vec<String> = vector_map
155            .keys()
156            .chain(text_ranks.keys())
157            .cloned()
158            .collect();
159        all_ids.sort();
160        all_ids.dedup();
161
162        let total = all_ids.len().max(1) as f32;
163        let mut results: Vec<HybridResult> = Vec::with_capacity(all_ids.len());
164
165        for id in all_ids {
166            let vec_rank = vector_ranks.get(&id).copied().unwrap_or(0);
167            let txt_rank = text_ranks.get(&id).copied().unwrap_or(0);
168
169            if self.config.require_both && (vec_rank == 0 || txt_rank == 0) {
170                continue;
171            }
172
173            let vec_rrf = if vec_rank > 0 {
174                1.0 / (RRF_K + vec_rank as f32)
175            } else {
176                0.0
177            };
178            let txt_rrf = if txt_rank > 0 {
179                1.0 / (RRF_K + txt_rank as f32)
180            } else {
181                0.0
182            };
183            let combined = vec_rrf + txt_rrf;
184
185            // Rank-normalised display scores (0=absent, 1=top-ranked).
186            let vector_score = if vec_rank > 0 {
187                1.0 - (vec_rank as f32 - 1.0) / total
188            } else {
189                0.0
190            };
191            let text_score = if txt_rank > 0 {
192                1.0 - (txt_rank as f32 - 1.0) / total
193            } else {
194                0.0
195            };
196
197            let raw = vector_map.get(&id);
198            results.push(HybridResult {
199                id,
200                combined_score: combined,
201                vector_score,
202                text_score,
203                metadata: raw.and_then(|r| r.metadata.clone()),
204                vector: raw.and_then(|r| r.vector.clone()),
205            });
206        }
207
208        results.sort_by(|a, b| {
209            b.combined_score
210                .partial_cmp(&a.combined_score)
211                .unwrap_or(std::cmp::Ordering::Equal)
212        });
213        results.truncate(top_k);
214        results
215    }
216
217    /// Weighted min-max normalization (legacy fusion strategy).
218    fn minmax_search(
219        &self,
220        vector_results: Vec<VectorResultRow>,
221        text_results: Vec<FullTextResult>,
222        top_k: usize,
223    ) -> Vec<HybridResult> {
224        let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
225        let mut text_scores: HashMap<String, f32> = HashMap::new();
226
227        let mut vector_min = f32::MAX;
228        let mut vector_max = f32::MIN;
229        let mut text_min = f32::MAX;
230        let mut text_max = f32::MIN;
231
232        for (id, score, metadata, vector) in vector_results {
233            vector_min = vector_min.min(score);
234            vector_max = vector_max.max(score);
235            vector_scores.insert(
236                id,
237                RawScore {
238                    score,
239                    metadata,
240                    vector,
241                },
242            );
243        }
244
245        for result in text_results {
246            text_min = text_min.min(result.score);
247            text_max = text_max.max(result.score);
248            text_scores.insert(result.doc_id, result.score);
249        }
250
251        let mut all_ids: Vec<String> = vector_scores
252            .keys()
253            .chain(text_scores.keys())
254            .cloned()
255            .collect();
256        all_ids.sort();
257        all_ids.dedup();
258
259        let mut results: Vec<HybridResult> = Vec::new();
260
261        for id in all_ids {
262            let vector_raw = vector_scores.get(&id);
263            let text_raw = text_scores.get(&id);
264
265            if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
266                continue;
267            }
268
269            let vector_normalized = if let Some(raw) = vector_raw {
270                normalize_score(raw.score, vector_min, vector_max)
271            } else {
272                0.0
273            };
274
275            let text_normalized = if let Some(&score) = text_raw {
276                normalize_score(score, text_min, text_max)
277            } else {
278                0.0
279            };
280
281            let combined = self.config.vector_weight * vector_normalized
282                + (1.0 - self.config.vector_weight) * text_normalized;
283
284            let (metadata, vector) = if let Some(raw) = vector_raw {
285                (raw.metadata.clone(), raw.vector.clone())
286            } else {
287                (None, None)
288            };
289
290            results.push(HybridResult {
291                id,
292                combined_score: combined,
293                vector_score: vector_normalized,
294                text_score: text_normalized,
295                metadata,
296                vector,
297            });
298        }
299
300        results.sort_by(|a, b| {
301            b.combined_score
302                .partial_cmp(&a.combined_score)
303                .unwrap_or(std::cmp::Ordering::Equal)
304        });
305        results.truncate(top_k);
306        results
307    }
308}
309
310impl Default for HybridSearcher {
311    fn default() -> Self {
312        Self::new(HybridConfig::default())
313    }
314}
315
316// ============================================================================
317// CE-12c: Adaptive hybrid weighting
318// ============================================================================
319
320/// Return an adaptive `vector_weight` (0.0–1.0) for a [`HybridSearcher`]
321/// based on the inferred [`QueryKind`].
322///
323/// | QueryKind | vector_weight | Rationale                                         |
324/// |-----------|---------------|---------------------------------------------------|
325/// | Keyword   | 0.25          | Exact-term signals dominate                        |
326/// | Hybrid    | 0.50          | Balanced blend                                     |
327/// | Semantic  | 0.75          | Embedding captures intent better                   |
328/// | Temporal  | 0.20          | BM25 at 0.80 — date tokens dominate over semantics |
329pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
330    match kind {
331        crate::routing::QueryKind::Keyword => 0.25,
332        crate::routing::QueryKind::Hybrid => 0.50,
333        crate::routing::QueryKind::Semantic => 0.75,
334        crate::routing::QueryKind::Temporal => 0.20,
335    }
336}
337
338/// Normalize a score to 0-1 range using min-max normalization
339fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
340    if (max - min).abs() < f32::EPSILON {
341        // All scores are the same, return 1.0
342        1.0
343    } else {
344        (score - min) / (max - min)
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_hybrid_search_basic() {
354        let searcher = HybridSearcher::default();
355
356        let vector_results = vec![
357            ("doc1".to_string(), 0.9, None, None),
358            ("doc2".to_string(), 0.7, None, None),
359            ("doc3".to_string(), 0.5, None, None),
360        ];
361
362        let text_results = vec![
363            FullTextResult {
364                doc_id: "doc1".to_string(),
365                score: 3.0,
366            },
367            FullTextResult {
368                doc_id: "doc2".to_string(),
369                score: 4.0,
370            },
371            FullTextResult {
372                doc_id: "doc4".to_string(),
373                score: 2.0,
374            },
375        ];
376
377        let results = searcher.search(vector_results, text_results, 10);
378
379        // All 4 documents should be in results
380        assert_eq!(results.len(), 4);
381
382        // Check that doc1 and doc2 have both scores >= 0
383        // (normalized scores, min becomes 0.0)
384        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
385        assert!(doc1.vector_score > 0.0);
386        assert!(doc1.text_score >= 0.0);
387        assert!(doc1.combined_score > 0.0);
388
389        let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
390        assert!(doc2.vector_score > 0.0);
391        assert!(doc2.text_score > 0.0); // doc2 has highest text score, should be 1.0
392        assert!(doc2.combined_score > 0.0);
393
394        // doc2 should have the highest text score (normalized to 1.0)
395        assert_eq!(doc2.text_score, 1.0);
396    }
397
398    #[test]
399    fn test_hybrid_search_vector_only() {
400        // MinMax with vector_weight=1.0 → combined_score == vector_score
401        let searcher = HybridSearcher::new(HybridConfig {
402            vector_weight: 1.0,
403            require_both: false,
404            fusion_strategy: FusionStrategy::MinMax,
405        });
406
407        let vector_results = vec![
408            ("doc1".to_string(), 0.9, None, None),
409            ("doc2".to_string(), 0.5, None, None),
410        ];
411
412        let text_results = vec![FullTextResult {
413            doc_id: "doc1".to_string(),
414            score: 1.0,
415        }];
416
417        let results = searcher.search(vector_results, text_results, 10);
418
419        // doc1 should be first (highest vector score)
420        assert_eq!(results[0].id, "doc1");
421        assert_eq!(results[0].combined_score, results[0].vector_score);
422    }
423
424    #[test]
425    fn test_hybrid_search_text_only() {
426        // MinMax with vector_weight=0.0 → combined_score == text_score
427        let searcher = HybridSearcher::new(HybridConfig {
428            vector_weight: 0.0,
429            require_both: false,
430            fusion_strategy: FusionStrategy::MinMax,
431        });
432
433        let vector_results = vec![
434            ("doc1".to_string(), 0.9, None, None),
435            ("doc2".to_string(), 0.5, None, None),
436        ];
437
438        let text_results = vec![
439            FullTextResult {
440                doc_id: "doc1".to_string(),
441                score: 1.0,
442            },
443            FullTextResult {
444                doc_id: "doc2".to_string(),
445                score: 3.0,
446            },
447        ];
448
449        let results = searcher.search(vector_results, text_results, 10);
450
451        // doc2 should be first (highest text score)
452        assert_eq!(results[0].id, "doc2");
453        assert_eq!(results[0].combined_score, results[0].text_score);
454    }
455
456    #[test]
457    fn test_hybrid_search_require_both() {
458        let searcher = HybridSearcher::new(HybridConfig {
459            vector_weight: 0.5,
460            require_both: true,
461            ..Default::default()
462        });
463
464        let vector_results = vec![
465            ("doc1".to_string(), 0.9, None, None),
466            ("doc2".to_string(), 0.7, None, None),
467        ];
468
469        let text_results = vec![FullTextResult {
470            doc_id: "doc1".to_string(),
471            score: 2.0,
472        }];
473
474        let results = searcher.search(vector_results, text_results, 10);
475
476        // Only doc1 should be in results (only one with both scores)
477        assert_eq!(results.len(), 1);
478        assert_eq!(results[0].id, "doc1");
479    }
480
481    #[test]
482    fn test_hybrid_search_top_k() {
483        let searcher = HybridSearcher::default();
484
485        let vector_results = vec![
486            ("doc1".to_string(), 0.9, None, None),
487            ("doc2".to_string(), 0.8, None, None),
488            ("doc3".to_string(), 0.7, None, None),
489            ("doc4".to_string(), 0.6, None, None),
490            ("doc5".to_string(), 0.5, None, None),
491        ];
492
493        let text_results = vec![];
494
495        let results = searcher.search(vector_results, text_results, 3);
496
497        assert_eq!(results.len(), 3);
498    }
499
500    #[test]
501    fn test_hybrid_search_with_metadata() {
502        let searcher = HybridSearcher::default();
503
504        let metadata = serde_json::json!({"title": "Test Document"});
505        let vector = vec![1.0, 0.0, 0.0];
506
507        let vector_results = vec![(
508            "doc1".to_string(),
509            0.9,
510            Some(metadata.clone()),
511            Some(vector.clone()),
512        )];
513
514        let text_results = vec![FullTextResult {
515            doc_id: "doc1".to_string(),
516            score: 2.0,
517        }];
518
519        let results = searcher.search(vector_results, text_results, 10);
520
521        assert_eq!(results.len(), 1);
522        assert_eq!(results[0].metadata, Some(metadata));
523        assert_eq!(results[0].vector, Some(vector));
524    }
525
526    #[test]
527    fn test_normalize_score() {
528        // Normal case
529        assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
530        assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
531        assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
532
533        // All same scores
534        assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
535    }
536
537    #[test]
538    fn test_hybrid_searcher_builder() {
539        let searcher = HybridSearcher::default().with_vector_weight(0.7);
540
541        assert_eq!(searcher.config.vector_weight, 0.7);
542    }
543
544    #[test]
545    fn test_vector_weight_clamping() {
546        let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
547        assert_eq!(searcher1.config.vector_weight, 1.0);
548
549        let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
550        assert_eq!(searcher2.config.vector_weight, 0.0);
551    }
552
553    // --- CE-15: Temporal adaptive weight test ---
554
555    #[test]
556    fn test_adaptive_vector_weight_temporal() {
557        use crate::routing::QueryKind;
558        // Temporal queries must get vector_weight=0.20 (BM25 at 0.80).
559        // This is the key fix for Cat 3 LoCoMo temporal recall (42.7% → target 65%+).
560        assert_eq!(adaptive_vector_weight(QueryKind::Temporal), 0.20);
561        // Other kinds unchanged
562        assert_eq!(adaptive_vector_weight(QueryKind::Keyword), 0.25);
563        assert_eq!(adaptive_vector_weight(QueryKind::Hybrid), 0.50);
564        assert_eq!(adaptive_vector_weight(QueryKind::Semantic), 0.75);
565    }
566
567    // --- CE-14: RRF tests ---
568
569    #[test]
570    fn test_minmax_default_strategy() {
571        // Default HybridSearcher uses MinMax since v0.11.2 (A/B: +6.3pp overall, +13.5pp temporal)
572        let searcher = HybridSearcher::default();
573        assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
574    }
575
576    #[test]
577    fn test_rrf_ranks_correctly() {
578        // doc1: rank 1 in vector, rank 2 in text → 1/(60+1) + 1/(60+2) ≈ 0.03254
579        // doc2: rank 2 in vector, rank 1 in text → 1/(60+2) + 1/(60+1) ≈ 0.03254  (equal)
580        // doc3: rank 3 in vector, not in text    → 1/(60+3) ≈ 0.01587
581        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
582
583        let vector_results = vec![
584            ("doc1".to_string(), 0.9, None, None),
585            ("doc2".to_string(), 0.7, None, None),
586            ("doc3".to_string(), 0.5, None, None),
587        ];
588
589        let text_results = vec![
590            FullTextResult {
591                doc_id: "doc2".to_string(),
592                score: 5.0,
593            },
594            FullTextResult {
595                doc_id: "doc1".to_string(),
596                score: 3.0,
597            },
598        ];
599
600        let results = searcher.search(vector_results, text_results, 10);
601
602        assert_eq!(results.len(), 3);
603
604        // doc1 and doc2 both appear in both lists — they score higher than doc3 (vector-only)
605        let doc3 = results.iter().find(|r| r.id == "doc3").unwrap();
606        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
607        assert!(doc1.combined_score > doc3.combined_score);
608
609        // Scores are positive
610        for r in &results {
611            assert!(r.combined_score > 0.0);
612        }
613    }
614
615    #[test]
616    fn test_rrf_require_both() {
617        let searcher = HybridSearcher::new(HybridConfig {
618            require_both: true,
619            ..Default::default() // fusion_strategy: MinMax (now default since v0.11.2)
620        });
621
622        let vector_results = vec![
623            ("doc1".to_string(), 0.9, None, None),
624            ("doc2".to_string(), 0.7, None, None),
625        ];
626
627        let text_results = vec![FullTextResult {
628            doc_id: "doc1".to_string(),
629            score: 2.0,
630        }];
631
632        let results = searcher.search(vector_results, text_results, 10);
633
634        // Only doc1 appears in both lists
635        assert_eq!(results.len(), 1);
636        assert_eq!(results[0].id, "doc1");
637    }
638
639    #[test]
640    fn test_rrf_formula_k60() {
641        // Verify RRF formula: score = 1/(60 + rank), k=60
642        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
643
644        let vector_results = vec![("doc1".to_string(), 1.0, None, None)];
645        let text_results = vec![FullTextResult {
646            doc_id: "doc1".to_string(),
647            score: 1.0,
648        }];
649
650        let results = searcher.search(vector_results, text_results, 10);
651
652        assert_eq!(results.len(), 1);
653        // doc1 is rank 1 in both → score = 1/(60+1) + 1/(60+1) = 2/61 ≈ 0.032787
654        let expected = 2.0 / (RRF_K + 1.0);
655        assert!((results[0].combined_score - expected).abs() < 1e-5);
656    }
657
658    #[test]
659    fn test_with_fusion_strategy_builder() {
660        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::MinMax);
661        assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
662    }
663}