Skip to main content

engine/
hybrid.rs

1//! Hybrid search combining vector similarity and full-text search
2//!
3//! Provides a unified search experience by combining:
4//! - Vector similarity scores (cosine, euclidean, dot product)
5//! - Full-text BM25 scores
6//!
7//! Two fusion strategies are supported:
8//! - **MinMax** (default): weighted linear combination of min-max normalised scores
9//! - **RRF**: Reciprocal Rank Fusion — `score(d) = Σ 1/(k + rank_r(d))`, k=60
10
11use std::collections::HashMap;
12
13use common::FusionStrategy;
14
15use crate::fulltext::FullTextResult;
16
17/// RRF smoothing constant (Cormack et al., SIGIR 2009 — k=60 is the canonical default).
18const RRF_K: f32 = 60.0;
19
20/// A vector search result row: (id, score, optional metadata, optional vector).
21type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
22
23/// Configuration for hybrid search
24#[derive(Debug, Clone)]
25pub struct HybridConfig {
26    /// Weight for vector search (0.0 to 1.0) — used by MinMax strategy only.
27    pub vector_weight: f32,
28    /// Whether to require matches in both indices.
29    pub require_both: bool,
30    /// CE-14: Fusion strategy (default: MinMax since v0.11.2).
31    pub fusion_strategy: FusionStrategy,
32}
33
34impl Default for HybridConfig {
35    fn default() -> Self {
36        Self {
37            vector_weight: 0.5,
38            require_both: false,
39            fusion_strategy: FusionStrategy::MinMax,
40        }
41    }
42}
43
44/// Raw score from a single search type
45#[derive(Debug, Clone)]
46struct RawScore {
47    /// Original score before normalization
48    score: f32,
49    /// Additional data (metadata, vector)
50    metadata: Option<serde_json::Value>,
51    vector: Option<Vec<f32>>,
52}
53
54/// Result of hybrid search
55#[derive(Debug, Clone)]
56pub struct HybridResult {
57    /// Document/vector ID
58    pub id: String,
59    /// Combined score (weighted average of normalized scores)
60    pub combined_score: f32,
61    /// Normalized vector similarity score (0-1)
62    pub vector_score: f32,
63    /// Normalized text search score (0-1)
64    pub text_score: f32,
65    /// Optional metadata
66    pub metadata: Option<serde_json::Value>,
67    /// Optional vector values
68    pub vector: Option<Vec<f32>>,
69}
70
71/// Hybrid search engine that combines vector and text search
72pub struct HybridSearcher {
73    config: HybridConfig,
74}
75
76impl HybridSearcher {
77    pub fn new(config: HybridConfig) -> Self {
78        Self { config }
79    }
80
81    pub fn with_vector_weight(mut self, weight: f32) -> Self {
82        self.config.vector_weight = weight.clamp(0.0, 1.0);
83        self
84    }
85
86    /// CE-14: Override the fusion strategy.
87    pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
88        self.config.fusion_strategy = strategy;
89        self
90    }
91
92    /// Combine vector search results with full-text search results.
93    ///
94    /// Dispatches to [`Self::rrf_search`] (RRF, default) or [`Self::minmax_search`]
95    /// depending on `config.fusion_strategy`.
96    ///
97    /// # Arguments
98    /// * `vector_results` - Results from vector similarity search (id, score, metadata, vector)
99    /// * `text_results` - Results from full-text BM25 search
100    /// * `top_k` - Number of results to return
101    pub fn search(
102        &self,
103        vector_results: Vec<VectorResultRow>,
104        text_results: Vec<FullTextResult>,
105        top_k: usize,
106    ) -> Vec<HybridResult> {
107        match self.config.fusion_strategy {
108            FusionStrategy::Rrf => self.rrf_search(vector_results, text_results, top_k),
109            FusionStrategy::MinMax => self.minmax_search(vector_results, text_results, top_k),
110        }
111    }
112
113    /// Reciprocal Rank Fusion (Cormack et al., SIGIR 2009).
114    ///
115    /// Each document receives `score(d) = Σ_r 1 / (k + rank_r(d))` where k=60.
116    /// Documents appearing in only one result list receive 0 from the missing retriever.
117    fn rrf_search(
118        &self,
119        vector_results: Vec<VectorResultRow>,
120        text_results: Vec<FullTextResult>,
121        top_k: usize,
122    ) -> Vec<HybridResult> {
123        let mut vector_map: HashMap<String, RawScore> = HashMap::new();
124        let mut vector_ranks: HashMap<String, usize> = HashMap::new();
125        let mut text_ranks: HashMap<String, usize> = HashMap::new();
126
127        // Sort vector results by score descending, assign 1-based ranks.
128        let mut sorted_vec = vector_results;
129        sorted_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
130        for (i, (id, score, metadata, vector)) in sorted_vec.into_iter().enumerate() {
131            vector_ranks.insert(id.clone(), i + 1);
132            vector_map.insert(
133                id,
134                RawScore {
135                    score,
136                    metadata,
137                    vector,
138                },
139            );
140        }
141
142        // Sort text results by score descending, assign 1-based ranks.
143        let mut sorted_text = text_results;
144        sorted_text.sort_by(|a, b| {
145            b.score
146                .partial_cmp(&a.score)
147                .unwrap_or(std::cmp::Ordering::Equal)
148        });
149        for (i, result) in sorted_text.into_iter().enumerate() {
150            text_ranks.insert(result.doc_id, i + 1);
151        }
152
153        // Union of all document IDs.
154        let mut all_ids: Vec<String> = vector_map
155            .keys()
156            .chain(text_ranks.keys())
157            .cloned()
158            .collect();
159        all_ids.sort();
160        all_ids.dedup();
161
162        let total = all_ids.len().max(1) as f32;
163        let mut results: Vec<HybridResult> = Vec::with_capacity(all_ids.len());
164
165        for id in all_ids {
166            let vec_rank = vector_ranks.get(&id).copied().unwrap_or(0);
167            let txt_rank = text_ranks.get(&id).copied().unwrap_or(0);
168
169            if self.config.require_both && (vec_rank == 0 || txt_rank == 0) {
170                continue;
171            }
172
173            let vec_rrf = if vec_rank > 0 {
174                1.0 / (RRF_K + vec_rank as f32)
175            } else {
176                0.0
177            };
178            let txt_rrf = if txt_rank > 0 {
179                1.0 / (RRF_K + txt_rank as f32)
180            } else {
181                0.0
182            };
183            let combined = vec_rrf + txt_rrf;
184
185            // Rank-normalised display scores (0=absent, 1=top-ranked).
186            let vector_score = if vec_rank > 0 {
187                1.0 - (vec_rank as f32 - 1.0) / total
188            } else {
189                0.0
190            };
191            let text_score = if txt_rank > 0 {
192                1.0 - (txt_rank as f32 - 1.0) / total
193            } else {
194                0.0
195            };
196
197            let raw = vector_map.get(&id);
198            results.push(HybridResult {
199                id,
200                combined_score: combined,
201                vector_score,
202                text_score,
203                metadata: raw.and_then(|r| r.metadata.clone()),
204                vector: raw.and_then(|r| r.vector.clone()),
205            });
206        }
207
208        results.sort_by(|a, b| {
209            b.combined_score
210                .partial_cmp(&a.combined_score)
211                .unwrap_or(std::cmp::Ordering::Equal)
212        });
213        results.truncate(top_k);
214        results
215    }
216
217    /// Weighted min-max normalization (legacy fusion strategy).
218    fn minmax_search(
219        &self,
220        vector_results: Vec<VectorResultRow>,
221        text_results: Vec<FullTextResult>,
222        top_k: usize,
223    ) -> Vec<HybridResult> {
224        let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
225        let mut text_scores: HashMap<String, f32> = HashMap::new();
226
227        let mut vector_min = f32::MAX;
228        let mut vector_max = f32::MIN;
229        let mut text_min = f32::MAX;
230        let mut text_max = f32::MIN;
231
232        for (id, score, metadata, vector) in vector_results {
233            vector_min = vector_min.min(score);
234            vector_max = vector_max.max(score);
235            vector_scores.insert(
236                id,
237                RawScore {
238                    score,
239                    metadata,
240                    vector,
241                },
242            );
243        }
244
245        for result in text_results {
246            text_min = text_min.min(result.score);
247            text_max = text_max.max(result.score);
248            text_scores.insert(result.doc_id, result.score);
249        }
250
251        let mut all_ids: Vec<String> = vector_scores
252            .keys()
253            .chain(text_scores.keys())
254            .cloned()
255            .collect();
256        all_ids.sort();
257        all_ids.dedup();
258
259        let mut results: Vec<HybridResult> = Vec::new();
260
261        for id in all_ids {
262            let vector_raw = vector_scores.get(&id);
263            let text_raw = text_scores.get(&id);
264
265            if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
266                continue;
267            }
268
269            let vector_normalized = if let Some(raw) = vector_raw {
270                normalize_score(raw.score, vector_min, vector_max)
271            } else {
272                0.0
273            };
274
275            let text_normalized = if let Some(&score) = text_raw {
276                normalize_score(score, text_min, text_max)
277            } else {
278                0.0
279            };
280
281            let combined = self.config.vector_weight * vector_normalized
282                + (1.0 - self.config.vector_weight) * text_normalized;
283
284            let (metadata, vector) = if let Some(raw) = vector_raw {
285                (raw.metadata.clone(), raw.vector.clone())
286            } else {
287                (None, None)
288            };
289
290            results.push(HybridResult {
291                id,
292                combined_score: combined,
293                vector_score: vector_normalized,
294                text_score: text_normalized,
295                metadata,
296                vector,
297            });
298        }
299
300        results.sort_by(|a, b| {
301            b.combined_score
302                .partial_cmp(&a.combined_score)
303                .unwrap_or(std::cmp::Ordering::Equal)
304        });
305        results.truncate(top_k);
306        results
307    }
308}
309
310impl Default for HybridSearcher {
311    fn default() -> Self {
312        Self::new(HybridConfig::default())
313    }
314}
315
316// ============================================================================
317// CE-12c: Adaptive hybrid weighting
318// ============================================================================
319
320/// Return an adaptive `vector_weight` (0.0–1.0) for a [`HybridSearcher`]
321/// based on the inferred [`QueryKind`].
322///
323/// | QueryKind | vector_weight | Rationale                                         |
324/// |-----------|---------------|---------------------------------------------------|
325/// | Keyword   | 0.25          | Exact-term signals dominate                        |
326/// | Hybrid    | 0.50          | Balanced blend                                     |
327/// | Semantic  | 0.75          | Embedding captures intent better                   |
328/// | Temporal  | 0.00          | Pure BM25 — any vector weight degrades temporal    |
329///
330/// CE-19: Temporal weight reduced from 0.20 to 0.00. LoCoMo CE-18 data showed
331/// Hybrid(80/20) scored 40.6% vs pure BM25 43.8% (−3.2pp). Date-prefixed memories
332/// have near-zero cosine similarity to "when did X happen?" queries; even 20% vector
333/// weight consistently pulls wrong candidates into the MinMax score. vector_weight=0.0
334/// makes the Hybrid path mathematically equivalent to pure BM25 scoring.
335pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
336    match kind {
337        crate::routing::QueryKind::Keyword => 0.25,
338        crate::routing::QueryKind::Hybrid => 0.50,
339        crate::routing::QueryKind::Semantic => 0.75,
340        crate::routing::QueryKind::Temporal => 0.00,
341        // CE-34 v2: BM25-tilted weight for multi-hop — entity co-occurrence
342        // in bridging memories is better captured by BM25 exact matching than
343        // by vector similarity. Slight tilt toward BM25 (0.10 shift from 0.50).
344        crate::routing::QueryKind::MultiHop => 0.40,
345    }
346}
347
348/// Normalize a score to 0-1 range using min-max normalization
349fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
350    if (max - min).abs() < f32::EPSILON {
351        // All scores are the same, return 1.0
352        1.0
353    } else {
354        (score - min) / (max - min)
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_hybrid_search_basic() {
364        let searcher = HybridSearcher::default();
365
366        let vector_results = vec![
367            ("doc1".to_string(), 0.9, None, None),
368            ("doc2".to_string(), 0.7, None, None),
369            ("doc3".to_string(), 0.5, None, None),
370        ];
371
372        let text_results = vec![
373            FullTextResult {
374                doc_id: "doc1".to_string(),
375                score: 3.0,
376            },
377            FullTextResult {
378                doc_id: "doc2".to_string(),
379                score: 4.0,
380            },
381            FullTextResult {
382                doc_id: "doc4".to_string(),
383                score: 2.0,
384            },
385        ];
386
387        let results = searcher.search(vector_results, text_results, 10);
388
389        // All 4 documents should be in results
390        assert_eq!(results.len(), 4);
391
392        // Check that doc1 and doc2 have both scores >= 0
393        // (normalized scores, min becomes 0.0)
394        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
395        assert!(doc1.vector_score > 0.0);
396        assert!(doc1.text_score >= 0.0);
397        assert!(doc1.combined_score > 0.0);
398
399        let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
400        assert!(doc2.vector_score > 0.0);
401        assert!(doc2.text_score > 0.0); // doc2 has highest text score, should be 1.0
402        assert!(doc2.combined_score > 0.0);
403
404        // doc2 should have the highest text score (normalized to 1.0)
405        assert_eq!(doc2.text_score, 1.0);
406    }
407
408    #[test]
409    fn test_hybrid_search_vector_only() {
410        // MinMax with vector_weight=1.0 → combined_score == vector_score
411        let searcher = HybridSearcher::new(HybridConfig {
412            vector_weight: 1.0,
413            require_both: false,
414            fusion_strategy: FusionStrategy::MinMax,
415        });
416
417        let vector_results = vec![
418            ("doc1".to_string(), 0.9, None, None),
419            ("doc2".to_string(), 0.5, None, None),
420        ];
421
422        let text_results = vec![FullTextResult {
423            doc_id: "doc1".to_string(),
424            score: 1.0,
425        }];
426
427        let results = searcher.search(vector_results, text_results, 10);
428
429        // doc1 should be first (highest vector score)
430        assert_eq!(results[0].id, "doc1");
431        assert_eq!(results[0].combined_score, results[0].vector_score);
432    }
433
434    #[test]
435    fn test_hybrid_search_text_only() {
436        // MinMax with vector_weight=0.0 → combined_score == text_score
437        let searcher = HybridSearcher::new(HybridConfig {
438            vector_weight: 0.0,
439            require_both: false,
440            fusion_strategy: FusionStrategy::MinMax,
441        });
442
443        let vector_results = vec![
444            ("doc1".to_string(), 0.9, None, None),
445            ("doc2".to_string(), 0.5, None, None),
446        ];
447
448        let text_results = vec![
449            FullTextResult {
450                doc_id: "doc1".to_string(),
451                score: 1.0,
452            },
453            FullTextResult {
454                doc_id: "doc2".to_string(),
455                score: 3.0,
456            },
457        ];
458
459        let results = searcher.search(vector_results, text_results, 10);
460
461        // doc2 should be first (highest text score)
462        assert_eq!(results[0].id, "doc2");
463        assert_eq!(results[0].combined_score, results[0].text_score);
464    }
465
466    #[test]
467    fn test_hybrid_search_require_both() {
468        let searcher = HybridSearcher::new(HybridConfig {
469            vector_weight: 0.5,
470            require_both: true,
471            ..Default::default()
472        });
473
474        let vector_results = vec![
475            ("doc1".to_string(), 0.9, None, None),
476            ("doc2".to_string(), 0.7, None, None),
477        ];
478
479        let text_results = vec![FullTextResult {
480            doc_id: "doc1".to_string(),
481            score: 2.0,
482        }];
483
484        let results = searcher.search(vector_results, text_results, 10);
485
486        // Only doc1 should be in results (only one with both scores)
487        assert_eq!(results.len(), 1);
488        assert_eq!(results[0].id, "doc1");
489    }
490
491    #[test]
492    fn test_hybrid_search_top_k() {
493        let searcher = HybridSearcher::default();
494
495        let vector_results = vec![
496            ("doc1".to_string(), 0.9, None, None),
497            ("doc2".to_string(), 0.8, None, None),
498            ("doc3".to_string(), 0.7, None, None),
499            ("doc4".to_string(), 0.6, None, None),
500            ("doc5".to_string(), 0.5, None, None),
501        ];
502
503        let text_results = vec![];
504
505        let results = searcher.search(vector_results, text_results, 3);
506
507        assert_eq!(results.len(), 3);
508    }
509
510    #[test]
511    fn test_hybrid_search_with_metadata() {
512        let searcher = HybridSearcher::default();
513
514        let metadata = serde_json::json!({"title": "Test Document"});
515        let vector = vec![1.0, 0.0, 0.0];
516
517        let vector_results = vec![(
518            "doc1".to_string(),
519            0.9,
520            Some(metadata.clone()),
521            Some(vector.clone()),
522        )];
523
524        let text_results = vec![FullTextResult {
525            doc_id: "doc1".to_string(),
526            score: 2.0,
527        }];
528
529        let results = searcher.search(vector_results, text_results, 10);
530
531        assert_eq!(results.len(), 1);
532        assert_eq!(results[0].metadata, Some(metadata));
533        assert_eq!(results[0].vector, Some(vector));
534    }
535
536    #[test]
537    fn test_normalize_score() {
538        // Normal case
539        assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
540        assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
541        assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
542
543        // All same scores
544        assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
545    }
546
547    #[test]
548    fn test_hybrid_searcher_builder() {
549        let searcher = HybridSearcher::default().with_vector_weight(0.7);
550
551        assert_eq!(searcher.config.vector_weight, 0.7);
552    }
553
554    #[test]
555    fn test_vector_weight_clamping() {
556        let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
557        assert_eq!(searcher1.config.vector_weight, 1.0);
558
559        let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
560        assert_eq!(searcher2.config.vector_weight, 0.0);
561    }
562
563    // --- CE-19: Temporal adaptive weight test ---
564
565    #[test]
566    fn test_adaptive_vector_weight_temporal() {
567        use crate::routing::QueryKind;
568        // CE-19: Temporal queries must get vector_weight=0.00 (pure BM25).
569        // CE-18 data: Hybrid(80/20)=40.6% vs pure BM25=43.8% (−3.2pp regression).
570        // Any vector weight contaminate temporal scores — set to 0.0.
571        assert_eq!(adaptive_vector_weight(QueryKind::Temporal), 0.00);
572        // Other kinds unchanged
573        assert_eq!(adaptive_vector_weight(QueryKind::Keyword), 0.25);
574        assert_eq!(adaptive_vector_weight(QueryKind::Hybrid), 0.50);
575        assert_eq!(adaptive_vector_weight(QueryKind::Semantic), 0.75);
576        // CE-34 v2: MultiHop gets BM25-tilted weight for entity co-occurrence bridging.
577        assert_eq!(adaptive_vector_weight(QueryKind::MultiHop), 0.40);
578    }
579
580    // --- CE-14: RRF tests ---
581
582    #[test]
583    fn test_minmax_default_strategy() {
584        // Default HybridSearcher uses MinMax since v0.11.2 (A/B: +6.3pp overall, +13.5pp temporal)
585        let searcher = HybridSearcher::default();
586        assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
587    }
588
589    #[test]
590    fn test_rrf_ranks_correctly() {
591        // doc1: rank 1 in vector, rank 2 in text → 1/(60+1) + 1/(60+2) ≈ 0.03254
592        // doc2: rank 2 in vector, rank 1 in text → 1/(60+2) + 1/(60+1) ≈ 0.03254  (equal)
593        // doc3: rank 3 in vector, not in text    → 1/(60+3) ≈ 0.01587
594        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
595
596        let vector_results = vec![
597            ("doc1".to_string(), 0.9, None, None),
598            ("doc2".to_string(), 0.7, None, None),
599            ("doc3".to_string(), 0.5, None, None),
600        ];
601
602        let text_results = vec![
603            FullTextResult {
604                doc_id: "doc2".to_string(),
605                score: 5.0,
606            },
607            FullTextResult {
608                doc_id: "doc1".to_string(),
609                score: 3.0,
610            },
611        ];
612
613        let results = searcher.search(vector_results, text_results, 10);
614
615        assert_eq!(results.len(), 3);
616
617        // doc1 and doc2 both appear in both lists — they score higher than doc3 (vector-only)
618        let doc3 = results.iter().find(|r| r.id == "doc3").unwrap();
619        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
620        assert!(doc1.combined_score > doc3.combined_score);
621
622        // Scores are positive
623        for r in &results {
624            assert!(r.combined_score > 0.0);
625        }
626    }
627
628    #[test]
629    fn test_rrf_require_both() {
630        let searcher = HybridSearcher::new(HybridConfig {
631            require_both: true,
632            ..Default::default() // fusion_strategy: MinMax (now default since v0.11.2)
633        });
634
635        let vector_results = vec![
636            ("doc1".to_string(), 0.9, None, None),
637            ("doc2".to_string(), 0.7, None, None),
638        ];
639
640        let text_results = vec![FullTextResult {
641            doc_id: "doc1".to_string(),
642            score: 2.0,
643        }];
644
645        let results = searcher.search(vector_results, text_results, 10);
646
647        // Only doc1 appears in both lists
648        assert_eq!(results.len(), 1);
649        assert_eq!(results[0].id, "doc1");
650    }
651
652    #[test]
653    fn test_rrf_formula_k60() {
654        // Verify RRF formula: score = 1/(60 + rank), k=60
655        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::Rrf);
656
657        let vector_results = vec![("doc1".to_string(), 1.0, None, None)];
658        let text_results = vec![FullTextResult {
659            doc_id: "doc1".to_string(),
660            score: 1.0,
661        }];
662
663        let results = searcher.search(vector_results, text_results, 10);
664
665        assert_eq!(results.len(), 1);
666        // doc1 is rank 1 in both → score = 1/(60+1) + 1/(60+1) = 2/61 ≈ 0.032787
667        let expected = 2.0 / (RRF_K + 1.0);
668        assert!((results[0].combined_score - expected).abs() < 1e-5);
669    }
670
671    #[test]
672    fn test_with_fusion_strategy_builder() {
673        let searcher = HybridSearcher::default().with_fusion_strategy(FusionStrategy::MinMax);
674        assert_eq!(searcher.config.fusion_strategy, FusionStrategy::MinMax);
675    }
676}