Skip to main content

engine/
hybrid.rs

1//! Hybrid search combining vector similarity and full-text search
2//!
3//! Provides a unified search experience by combining:
4//! - Vector similarity scores (cosine, euclidean, dot product)
5//! - Full-text BM25 scores
6//!
7//! Two fusion strategies are supported:
8//! - **MinMax** (default): weighted linear combination of min-max normalised scores
9//! - **RRF**: Reciprocal Rank Fusion — `score(d) = Σ 1/(k + rank_r(d))`, k=60
10
11use std::collections::HashMap;
12
13use common::FusionStrategy;
14
15use crate::fulltext::FullTextResult;
16
17/// RRF smoothing constant (Cormack et al., SIGIR 2009 — k=60 is the canonical default).
18const RRF_K: f32 = 60.0;
19
20/// A vector search result row: (id, score, optional metadata, optional vector).
21type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
22
23/// Configuration for hybrid search
24#[derive(Debug, Clone)]
25pub struct HybridConfig {
26    /// Weight for vector search (0.0 to 1.0) — used by MinMax strategy only.
27    pub vector_weight: f32,
28    /// Whether to require matches in both indices.
29    pub require_both: bool,
30    /// CE-14: Fusion strategy (default: MinMax since v0.11.2).
31    pub fusion_strategy: FusionStrategy,
32}
33
34impl Default for HybridConfig {
35    fn default() -> Self {
36        Self {
37            vector_weight: 0.5,
38            require_both: false,
39            fusion_strategy: FusionStrategy::MinMax,
40        }
41    }
42}
43
44/// Raw score from a single search type
45#[derive(Debug, Clone)]
46struct RawScore {
47    /// Original score before normalization
48    score: f32,
49    /// Additional data (metadata, vector)
50    metadata: Option<serde_json::Value>,
51    vector: Option<Vec<f32>>,
52}
53
54/// Result of hybrid search
55#[derive(Debug, Clone)]
56pub struct HybridResult {
57    /// Document/vector ID
58    pub id: String,
59    /// Combined score (weighted average of normalized scores)
60    pub combined_score: f32,
61    /// Normalized vector similarity score (0-1)
62    pub vector_score: f32,
63    /// Normalized text search score (0-1)
64    pub text_score: f32,
65    /// Optional metadata
66    pub metadata: Option<serde_json::Value>,
67    /// Optional vector values
68    pub vector: Option<Vec<f32>>,
69}
70
71/// Hybrid search engine that combines vector and text search
72pub struct HybridSearcher {
73    config: HybridConfig,
74}
75
76impl HybridSearcher {
77    pub fn new(config: HybridConfig) -> Self {
78        Self { config }
79    }
80
81    pub fn with_vector_weight(mut self, weight: f32) -> Self {
82        self.config.vector_weight = weight.clamp(0.0, 1.0);
83        self
84    }
85
86    /// CE-14: Override the fusion strategy.
87    pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
88        self.config.fusion_strategy = strategy;
89        self
90    }
91
92    /// Combine vector search results with full-text search results.
93    ///
94    /// Dispatches to [`Self::rrf_search`] (RRF, default) or [`Self::minmax_search`]
95    /// depending on `config.fusion_strategy`.
96    ///
97    /// # Arguments
98    /// * `vector_results` - Results from vector similarity search (id, score, metadata, vector)
99    /// * `text_results` - Results from full-text BM25 search
100    /// * `top_k` - Number of results to return
101    pub fn search(
102        &self,
103        vector_results: Vec<VectorResultRow>,
104        text_results: Vec<FullTextResult>,
105        top_k: usize,
106    ) -> Vec<HybridResult> {
107        match self.config.fusion_strategy {
108            FusionStrategy::Rrf => self.rrf_search(vector_results, text_results, top_k),
109            FusionStrategy::MinMax => self.minmax_search(vector_results, text_results, top_k),
110        }
111    }
112
113    /// Reciprocal Rank Fusion (Cormack et al., SIGIR 2009).
114    ///
115    /// Each document receives `score(d) = Σ_r 1 / (k + rank_r(d))` where k=60.
116    /// Documents appearing in only one result list receive 0 from the missing retriever.
117    fn rrf_search(
118        &self,
119        vector_results: Vec<VectorResultRow>,
120        text_results: Vec<FullTextResult>,
121        top_k: usize,
122    ) -> Vec<HybridResult> {
123        let mut vector_map: HashMap<String, RawScore> = HashMap::new();
124        let mut vector_ranks: HashMap<String, usize> = HashMap::new();
125        let mut text_ranks: HashMap<String, usize> = HashMap::new();
126
127        // Sort vector results by score descending, assign 1-based ranks.
128        let mut sorted_vec = vector_results;
129        sorted_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130        for (i, (id, score, metadata, vector)) in sorted_vec.into_iter().enumerate() {
131            vector_ranks.insert(id.clone(), i + 1);
132            vector_map.insert(
133                id,
134                RawScore {
135                    score,
136                    metadata,
137                    vector,
138                },
139            );
140        }
141
142        // Sort text results by score descending, assign 1-based ranks.
143        let mut sorted_text = text_results;
144        sorted_text.sort_by(|a, b| {
145            b.score
146                .partial_cmp(&a.score)
147                .unwrap_or(std::cmp::Ordering::Equal)
148        });
149        for (i, result) in sorted_text.into_iter().enumerate() {
150            text_ranks.insert(result.doc_id, i + 1);
151        }
152
153        // Union of all document IDs.
154        let mut all_ids: Vec<String> = vector_map
155            .keys()
156            .chain(text_ranks.keys())
157            .cloned()
158            .collect();
159        all_ids.sort();
160        all_ids.dedup();
161
162        let total = all_ids.len().max(1) as f32;
163        let mut results: Vec<HybridResult> = Vec::with_capacity(all_ids.len());
164
165        for id in all_ids {
166            let vec_rank = vector_ranks.get(&id).copied().unwrap_or(0);
167            let txt_rank = text_ranks.get(&id).copied().unwrap_or(0);
168
169            if self.config.require_both && (vec_rank == 0 || txt_rank == 0) {
170                continue;
171            }
172
173            let vec_rrf = if vec_rank > 0 {
174                1.0 / (RRF_K + vec_rank as f32)
175            } else {
176                0.0
177            };
178            let txt_rrf = if txt_rank > 0 {
179                1.0 / (RRF_K + txt_rank as f32)
180            } else {
181                0.0
182            };
183            let combined = vec_rrf + txt_rrf;
184
185            // Rank-normalised display scores (0=absent, 1=top-ranked).
186            let vector_score = if vec_rank > 0 {
187                1.0 - (vec_rank as f32 - 1.0) / total
188            } else {
189                0.0
190            };
191            let text_score = if txt_rank > 0 {
192                1.0 - (txt_rank as f32 - 1.0) / total
193            } else {
194                0.0
195            };
196
197            let raw = vector_map.get(&id);
198            results.push(HybridResult {
199                id,
200                combined_score: combined,
201                vector_score,
202                text_score,
203                metadata: raw.and_then(|r| r.metadata.clone()),
204                vector: raw.and_then(|r| r.vector.clone()),
205            });
206        }
207
208        results.sort_by(|a, b| {
209            b.combined_score
210                .partial_cmp(&a.combined_score)
211                .unwrap_or(std::cmp::Ordering::Equal)
212        });
213        results.truncate(top_k);
214        results
215    }
216
217    /// Weighted min-max normalization (legacy fusion strategy).
218    fn minmax_search(
219        &self,
220        vector_results: Vec<VectorResultRow>,
221        text_results: Vec<FullTextResult>,
222        top_k: usize,
223    ) -> Vec<HybridResult> {
224        let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
225        let mut text_scores: HashMap<String, f32> = HashMap::new();
226
227        let mut vector_min = f32::MAX;
228        let mut vector_max = f32::MIN;
229        let mut text_min = f32::MAX;
230        let mut text_max = f32::MIN;
231
232        for (id, score, metadata, vector) in vector_results {
233            vector_min = vector_min.min(score);
234            vector_max = vector_max.max(score);
235            vector_scores.insert(
236                id,
237                RawScore {
238                    score,
239                    metadata,
240                    vector,
241                },
242            );
243        }
244
245        for result in text_results {
246            text_min = text_min.min(result.score);
247            text_max = text_max.max(result.score);
248            text_scores.insert(result.doc_id, result.score);
249        }
250
251        let mut all_ids: Vec<String> = vector_scores
252            .keys()
253            .chain(text_scores.keys())
254            .cloned()
255            .collect();
256        all_ids.sort();
257        all_ids.dedup();
258
259        let mut results: Vec<HybridResult> = Vec::new();
260
261        for id in all_ids {
262            let vector_raw = vector_scores.get(&id);
263            let text_raw = text_scores.get(&id);
264
265            if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
266                continue;
267            }
268
269            let vector_normalized = if let Some(raw) = vector_raw {
270                normalize_score(raw.score, vector_min, vector_max)
271            } else {
272                0.0
273            };
274
275            let text_normalized = if let Some(&score) = text_raw {
276                normalize_score(score, text_min, text_max)
277            } else {
278                0.0
279            };
280
281            let combined = self.config.vector_weight * vector_normalized
282                + (1.0 - self.config.vector_weight) * text_normalized;
283
284            let (metadata, vector) = if let Some(raw) = vector_raw {
285                (raw.metadata.clone(), raw.vector.clone())
286            } else {
287                (None, None)
288            };
289
290            results.push(HybridResult {
291                id,
292                combined_score: combined,
293                vector_score: vector_normalized,
294                text_score: text_normalized,
295                metadata,
296                vector,
297            });
298        }
299
300        results.sort_by(|a, b| {
301            b.combined_score
302                .partial_cmp(&a.combined_score)
303                .unwrap_or(std::cmp::Ordering::Equal)
304        });
305        results.truncate(top_k);
306        results
307    }
308}
309
310impl Default for HybridSearcher {
311    fn default() -> Self {
312        Self::new(HybridConfig::default())
313    }
314}
315
316// ============================================================================
317// CE-12c: Adaptive hybrid weighting
318// ============================================================================
319
320/// Return an adaptive `vector_weight` (0.0–1.0) for a [`HybridSearcher`]
321/// based on the inferred [`QueryKind`].
322///
323/// | QueryKind | vector_weight | Rationale                                         |
324/// |-----------|---------------|---------------------------------------------------|
325/// | Keyword   | 0.25          | Exact-term signals dominate                        |
326/// | Hybrid    | 0.50          | Balanced blend                                     |
327/// | Semantic  | 0.75          | Embedding captures intent better                   |
328/// | Temporal  | 0.00          | Pure BM25 — any vector weight degrades temporal    |
329///
330/// CE-19: Temporal weight reduced from 0.20 to 0.00. LoCoMo CE-18 data showed
331/// Hybrid(80/20) scored 40.6% vs pure BM25 43.8% (−3.2pp). Date-prefixed memories
332/// have near-zero cosine similarity to "when did X happen?" queries; even 20% vector
333/// weight consistently pulls wrong candidates into the MinMax score. vector_weight=0.0
334/// makes the Hybrid path mathematically equivalent to pure BM25 scoring.
335pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
336    match kind {
337        crate::routing::QueryKind::Keyword => 0.25,
338        crate::routing::QueryKind::Hybrid => 0.50,
339        crate::routing::QueryKind::Semantic => 0.75,
340        crate::routing::QueryKind::Temporal => 0.00,
341    }
342}
343
344/// Normalize a score to 0-1 range using min-max normalization
345fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
346    if (max - min).abs() < f32::EPSILON {
347        // All scores are the same, return 1.0
348        1.0
349    } else {
350        (score - min) / (max - min)
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_hybrid_search_basic() {
360        let searcher = HybridSearcher::default();
361
362        let vector_results = vec![
363            ("doc1".to_string(), 0.9, None, None),
364            ("doc2".to_string(), 0.7, None, None),
365            ("doc3".to_string(), 0.5, None, None),
366        ];
367
368        let text_results = vec![
369            FullTextResult {
370                doc_id: "doc1".to_string(),
371                score: 3.0,
372            },
373            FullTextResult {
374                doc_id: "doc2".to_string(),
375                score: 4.0,
376            },
377            FullTextResult {
378                doc_id: "doc4".to_string(),
379                score: 2.0,
380            },
381        ];
382
383        let results = searcher.search(vector_results, text_results, 10);
384
385        // All 4 documents should be in results
386        assert_eq!(results.len(), 4);
387
388        // Check that doc1 and doc2 have both scores >= 0
389        // (normalized scores, min becomes 0.0)
390        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
391        assert!(doc1.vector_score > 0.0);
392        assert!(doc1.text_score >= 0.0);
393        assert!(doc1.combined_score > 0.0);
394
395        let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
396        assert!(doc2.vector_score > 0.0);
397        assert!(doc2.text_score > 0.0); // doc2 has highest text score, should be 1.0
398        assert!(doc2.combined_score > 0.0);
399
400        // doc2 should have the highest text score (normalized to 1.0)
401        assert_eq!(doc2.text_score, 1.0);
402    }
403
404    #[test]
405    fn test_hybrid_search_vector_only() {
406        // MinMax with vector_weight=1.0 → combined_score == vector_score
407        let searcher = HybridSearcher::new(HybridConfig {
408            vector_weight: 1.0,
409            require_both: false,
410            fusion_strategy: FusionStrategy::MinMax,
411        });
412
413        let vector_results = vec![
414            ("doc1".to_string(), 0.9, None, None),
415            ("doc2".to_string(), 0.5, None, None),
416        ];
417
418        let text_results = vec![FullTextResult {
419            doc_id: "doc1".to_string(),
420            score: 1.0,
421        }];
422
423        let results = searcher.search(vector_results, text_results, 10);
424
425        // doc1 should be first (highest vector score)
426        assert_eq!(results[0].id, "doc1");
427        assert_eq!(results[0].combined_score, results[0].vector_score);
428    }
429
430    #[test]
431    fn test_hybrid_search_text_only() {
432        // MinMax with vector_weight=0.0 → combined_score == text_score
433        let searcher = HybridSearcher::new(HybridConfig {
434            vector_weight: 0.0,
435            require_both: false,
436            fusion_strategy: FusionStrategy::MinMax,
437        });
438
439        let vector_results = vec![
440            ("doc1".to_string(), 0.9, None, None),
441            ("doc2".to_string(), 0.5, None, None),
442        ];
443
444        let text_results = vec![
445            FullTextResult {
446                doc_id: "doc1".to_string(),
447                score: 1.0,
448            },
449            FullTextResult {
450                doc_id: "doc2".to_string(),
451                score: 3.0,
452            },
453        ];
454
455        let results = searcher.search(vector_results, text_results, 10);
456
457        // doc2 should be first (highest text score)
458        assert_eq!(results[0].id, "doc2");
459        assert_eq!(results[0].combined_score, results[0].text_score);
460    }
461
462    #[test]
463    fn test_hybrid_search_require_both() {
464        let searcher = HybridSearcher::new(HybridConfig {
465            vector_weight: 0.5,
466            require_both: true,
467            ..Default::default()
468        });
469
470        let vector_results = vec![
471            ("doc1".to_string(), 0.9, None, None),
472            ("doc2".to_string(), 0.7, None, None),
473        ];
474
475        let text_results = vec![FullTextResult {
476            doc_id: "doc1".to_string(),
477            score: 2.0,
478        }];
479
480        let results = searcher.search(vector_results, text_results, 10);
481
482        // Only doc1 should be in results (only one with both scores)
483        assert_eq!(results.len(), 1);
484        assert_eq!(results[0].id, "doc1");
485    }
486
487    #[test]
488    fn test_hybrid_search_top_k() {
489        let searcher = HybridSearcher::default();
490
491        let vector_results = vec![
492            ("doc1".to_string(), 0.9, None, None),
493            ("doc2".to_string(), 0.8, None, None),
494            ("doc3".to_string(), 0.7, None, None),
495            ("doc4".to_string(), 0.6, None, None),
496            ("doc5".to_string(), 0.5, None, None),
497        ];
498
499        let text_results = vec![];
500
501        let results = searcher.search(vector_results, text_results, 3);
502
503        assert_eq!(results.len(), 3);
504    }
505
506    #[test]
507    fn test_hybrid_search_with_metadata() {
508        let searcher = HybridSearcher::default();
509
510        let metadata = serde_json::json!({"title": "Test Document"});
511        let vector = vec![1.0, 0.0, 0.0];
512
513        let vector_results = vec![(
514            "doc1".to_string(),
515            0.9,
516            Some(metadata.clone()),
517            Some(vector.clone()),
518        )];
519
520        let text_results = vec![FullTextResult {
521            doc_id: "doc1".to_string(),
522            score: 2.0,
523        }];
524
525        let results = searcher.search(vector_results, text_results, 10);
526
527        assert_eq!(results.len(), 1);
528        assert_eq!(results[0].metadata, Some(metadata));
529        assert_eq!(results[0].vector, Some(vector));
530    }
531
532    #[test]
533    fn test_normalize_score() {
534        // Normal case
535        assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
536        assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
537        assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
538
539        // All same scores
540        assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
541    }
542
543    #[test]
544    fn test_hybrid_searcher_builder() {
545        let searcher = HybridSearcher::default().with_vector_weight(0.7);
546
547        assert_eq!(searcher.config.vector_weight, 0.7);
548    }
549
550    #[test]
551    fn test_vector_weight_clamping() {
552        let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
553        assert_eq!(searcher1.config.vector_weight, 1.0);
554
555        let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
556        assert_eq!(searcher2.config.vector_weight, 0.0);
557    }
558
559    // --- CE-19: Temporal adaptive weight test ---
560
561    #[test]
562    fn test_adaptive_vector_weight_temporal() {
563        use crate::routing::QueryKind;
564        // CE-19: Temporal queries must get vector_weight=0.00 (pure BM25).
565        // CE-18 data: Hybrid(80/20)=40.6% vs pure BM25=43.8% (−3.2pp regression).
566        // Any vector weight contaminate temporal scores — set to 0.0.
567        assert_eq!(adaptive_vector_weight(QueryKind::Temporal), 0.00);
568        // Other kinds unchanged
569        assert_eq!(adaptive_vector_weight(QueryKind::Keyword), 0.25);
570        assert_eq!(adaptive_vector_weight(QueryKind::Hybrid), 0.50);
571        assert_eq!(adaptive_vector_weight(QueryKind::Semantic), 0.75);
572    }
573
574    // --- CE-14: RRF tests ---
575
576    #[test]
577    fn test_minmax_default_strategy() {
578        // Default HybridSearcher uses MinMax since v0.11.2 (A/B: +6.3pp overall, +13.5pp temporal)
579        let searcher = HybridSearcher::default();
580        assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
581    }
582
583    #[test]
584    fn test_rrf_ranks_correctly() {
585        // doc1: rank 1 in vector, rank 2 in text → 1/(60+1) + 1/(60+2) ≈ 0.03254
586        // doc2: rank 2 in vector, rank 1 in text → 1/(60+2) + 1/(60+1) ≈ 0.03254  (equal)
587        // doc3: rank 3 in vector, not in text    → 1/(60+3) ≈ 0.01587
588        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
589
590        let vector_results = vec![
591            ("doc1".to_string(), 0.9, None, None),
592            ("doc2".to_string(), 0.7, None, None),
593            ("doc3".to_string(), 0.5, None, None),
594        ];
595
596        let text_results = vec![
597            FullTextResult {
598                doc_id: "doc2".to_string(),
599                score: 5.0,
600            },
601            FullTextResult {
602                doc_id: "doc1".to_string(),
603                score: 3.0,
604            },
605        ];
606
607        let results = searcher.search(vector_results, text_results, 10);
608
609        assert_eq!(results.len(), 3);
610
611        // doc1 and doc2 both appear in both lists — they score higher than doc3 (vector-only)
612        let doc3 = results.iter().find(|r| r.id == "doc3").unwrap();
613        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
614        assert!(doc1.combined_score > doc3.combined_score);
615
616        // Scores are positive
617        for r in &results {
618            assert!(r.combined_score > 0.0);
619        }
620    }
621
622    #[test]
623    fn test_rrf_require_both() {
624        let searcher = HybridSearcher::new(HybridConfig {
625            require_both: true,
626            ..Default::default() // fusion_strategy: MinMax (now default since v0.11.2)
627        });
628
629        let vector_results = vec![
630            ("doc1".to_string(), 0.9, None, None),
631            ("doc2".to_string(), 0.7, None, None),
632        ];
633
634        let text_results = vec![FullTextResult {
635            doc_id: "doc1".to_string(),
636            score: 2.0,
637        }];
638
639        let results = searcher.search(vector_results, text_results, 10);
640
641        // Only doc1 appears in both lists
642        assert_eq!(results.len(), 1);
643        assert_eq!(results[0].id, "doc1");
644    }
645
646    #[test]
647    fn test_rrf_formula_k60() {
648        // Verify RRF formula: score = 1/(60 + rank), k=60
649        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
650
651        let vector_results = vec![("doc1".to_string(), 1.0, None, None)];
652        let text_results = vec![FullTextResult {
653            doc_id: "doc1".to_string(),
654            score: 1.0,
655        }];
656
657        let results = searcher.search(vector_results, text_results, 10);
658
659        assert_eq!(results.len(), 1);
660        // doc1 is rank 1 in both → score = 1/(60+1) + 1/(60+1) = 2/61 ≈ 0.032787
661        let expected = 2.0 / (RRF_K + 1.0);
662        assert!((results[0].combined_score - expected).abs() < 1e-5);
663    }
664
665    #[test]
666    fn test_with_fusion_strategy_builder() {
667        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::MinMax);
668        assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
669    }
670}