rag_plusplus_core/index/
fusion.rs

1//! Score Fusion
2//!
3//! Algorithms for merging results from multiple indexes or retrieval methods.
4//!
5//! # Overview
6//!
7//! When searching multiple indexes with potentially different distance metrics
8//! or scoring scales, raw scores cannot be directly compared. Score fusion
9//! provides principled methods to combine rankings.
10//!
11//! # Algorithms
12//!
13//! - **RRF (Reciprocal Rank Fusion)**: Rank-based fusion that ignores score magnitudes
14//! - **CombSUM**: Sum of normalized scores across sources
15//! - **CombMNZ**: CombSUM weighted by number of sources containing the result
16//! - **Weighted**: User-specified weights per index
17//!
18//! # Architecture
19//!
20//! ```text
21//! ┌────────────────────────────────────────────────────────────┐
22//! │                    ScoreFusion                              │
23//! ├────────────────────────────────────────────────────────────┤
24//! │  strategy: FusionStrategy                                   │
25//! │  k: usize (RRF constant, default 60)                        │
26//! │  weights: Option<HashMap<String, f32>>                      │
27//! ├────────────────────────────────────────────────────────────┤
28//! │  + fuse(MultiIndexResults) -> Vec<FusedResult>              │
29//! │  + fuse_top_k(MultiIndexResults, k) -> Vec<FusedResult>     │
30//! └────────────────────────────────────────────────────────────┘
31//! ```
32
33use crate::index::registry::MultiIndexResults;
34use ahash::AHashMap;
35use ordered_float::OrderedFloat;
36
37/// Fusion strategy to use.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum FusionStrategy {
40    /// Reciprocal Rank Fusion - rank-based, ignores score magnitudes
41    #[default]
42    RRF,
43    /// Sum of normalized scores
44    CombSUM,
45    /// CombSUM weighted by occurrence count
46    CombMNZ,
47    /// Maximum score across sources
48    CombMAX,
49    /// Minimum score across sources
50    CombMIN,
51}
52
53/// Result after score fusion.
54#[derive(Debug, Clone)]
55pub struct FusedResult {
56    /// Record ID
57    pub id: String,
58    /// Fused score (higher = more relevant)
59    pub fused_score: f32,
60    /// Source indexes that contributed to this result
61    pub sources: Vec<String>,
62    /// Original scores from each source (index_name -> score)
63    pub source_scores: AHashMap<String, f32>,
64}
65
66impl FusedResult {
67    /// Create a new fused result.
68    #[must_use]
69    pub fn new(id: String, fused_score: f32) -> Self {
70        Self {
71            id,
72            fused_score,
73            sources: Vec::new(),
74            source_scores: AHashMap::new(),
75        }
76    }
77
78    /// Add a source contribution.
79    pub fn add_source(&mut self, index_name: String, score: f32) {
80        self.sources.push(index_name.clone());
81        self.source_scores.insert(index_name, score);
82    }
83
84    /// Number of sources contributing to this result.
85    #[must_use]
86    pub fn source_count(&self) -> usize {
87        self.sources.len()
88    }
89}
90
91/// Configuration for score fusion.
92#[derive(Debug, Clone)]
93pub struct FusionConfig {
94    /// Fusion strategy
95    pub strategy: FusionStrategy,
96    /// RRF constant k (default 60, higher = more weight to lower ranks)
97    pub rrf_k: usize,
98    /// Per-index weights for weighted fusion (None = equal weights)
99    pub weights: Option<AHashMap<String, f32>>,
100    /// Whether to normalize scores before fusion (for Comb methods)
101    pub normalize_scores: bool,
102}
103
104impl Default for FusionConfig {
105    fn default() -> Self {
106        Self {
107            strategy: FusionStrategy::RRF,
108            rrf_k: 60,
109            weights: None,
110            normalize_scores: true,
111        }
112    }
113}
114
115impl FusionConfig {
116    /// Create new config with default RRF settings.
117    #[must_use]
118    pub fn new() -> Self {
119        Self::default()
120    }
121
122    /// Use RRF fusion strategy.
123    #[must_use]
124    pub const fn with_rrf(mut self, k: usize) -> Self {
125        self.strategy = FusionStrategy::RRF;
126        self.rrf_k = k;
127        self
128    }
129
130    /// Use CombSUM fusion strategy.
131    #[must_use]
132    pub const fn with_comb_sum(mut self) -> Self {
133        self.strategy = FusionStrategy::CombSUM;
134        self
135    }
136
137    /// Use CombMNZ fusion strategy.
138    #[must_use]
139    pub const fn with_comb_mnz(mut self) -> Self {
140        self.strategy = FusionStrategy::CombMNZ;
141        self
142    }
143
144    /// Set per-index weights.
145    #[must_use]
146    pub fn with_weights(mut self, weights: AHashMap<String, f32>) -> Self {
147        self.weights = Some(weights);
148        self
149    }
150
151    /// Set whether to normalize scores.
152    #[must_use]
153    pub const fn with_normalize(mut self, normalize: bool) -> Self {
154        self.normalize_scores = normalize;
155        self
156    }
157}
158
159/// Score fusion engine.
160pub struct ScoreFusion {
161    config: FusionConfig,
162}
163
164impl ScoreFusion {
165    /// Create a new score fusion engine with default config.
166    #[must_use]
167    pub fn new() -> Self {
168        Self {
169            config: FusionConfig::default(),
170        }
171    }
172
173    /// Create with custom configuration.
174    #[must_use]
175    pub fn with_config(config: FusionConfig) -> Self {
176        Self { config }
177    }
178
179    /// Create RRF fusion engine with default k=60.
180    #[must_use]
181    pub fn rrf() -> Self {
182        Self::new()
183    }
184
185    /// Create RRF fusion engine with custom k.
186    #[must_use]
187    pub fn rrf_with_k(k: usize) -> Self {
188        Self::with_config(FusionConfig::new().with_rrf(k))
189    }
190
191    /// Fuse results from multiple indexes.
192    #[must_use]
193    pub fn fuse(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
194        match self.config.strategy {
195            FusionStrategy::RRF => self.fuse_rrf(results),
196            FusionStrategy::CombSUM => self.fuse_comb_sum(results),
197            FusionStrategy::CombMNZ => self.fuse_comb_mnz(results),
198            FusionStrategy::CombMAX => self.fuse_comb_max(results),
199            FusionStrategy::CombMIN => self.fuse_comb_min(results),
200        }
201    }
202
203    /// Fuse and return only top-k results.
204    #[must_use]
205    pub fn fuse_top_k(&self, results: &MultiIndexResults, k: usize) -> Vec<FusedResult> {
206        let mut fused = self.fuse(results);
207        fused.truncate(k);
208        fused
209    }
210
211    /// Reciprocal Rank Fusion.
212    ///
213    /// Score = sum over sources of 1 / (k + rank)
214    /// where rank is 1-indexed position in that source.
215    fn fuse_rrf(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
216        let k = self.config.rrf_k as f32;
217        let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
218
219        for idx_result in &results.by_index {
220            let index_name = &idx_result.index_name;
221            let weight = self.get_weight(index_name);
222
223            for (rank, result) in idx_result.results.iter().enumerate() {
224                let rrf_score = weight / (k + (rank + 1) as f32);
225
226                let fused = scores.entry(result.id.clone()).or_insert_with(|| {
227                    FusedResult::new(result.id.clone(), 0.0)
228                });
229
230                fused.fused_score += rrf_score;
231                fused.add_source(index_name.clone(), result.score);
232            }
233        }
234
235        self.sort_results(scores)
236    }
237
238    /// CombSUM: Sum of (normalized) scores.
239    fn fuse_comb_sum(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
240        let normalized = if self.config.normalize_scores {
241            self.normalize_per_index(results)
242        } else {
243            self.collect_scores(results)
244        };
245
246        let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
247
248        for (id, index_scores) in normalized {
249            let mut fused = FusedResult::new(id.clone(), 0.0);
250
251            for (index_name, score) in index_scores {
252                let weight = self.get_weight(&index_name);
253                fused.fused_score += weight * score;
254                fused.add_source(index_name, score);
255            }
256
257            scores.insert(id, fused);
258        }
259
260        self.sort_results(scores)
261    }
262
263    /// CombMNZ: CombSUM weighted by number of sources.
264    fn fuse_comb_mnz(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
265        let normalized = if self.config.normalize_scores {
266            self.normalize_per_index(results)
267        } else {
268            self.collect_scores(results)
269        };
270
271        let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
272
273        for (id, index_scores) in normalized {
274            let mut fused = FusedResult::new(id.clone(), 0.0);
275            let mut sum = 0.0;
276
277            for (index_name, score) in index_scores {
278                let weight = self.get_weight(&index_name);
279                sum += weight * score;
280                fused.add_source(index_name, score);
281            }
282
283            // Multiply by number of sources (MNZ = "multiply by non-zero")
284            fused.fused_score = sum * fused.source_count() as f32;
285            scores.insert(id, fused);
286        }
287
288        self.sort_results(scores)
289    }
290
291    /// CombMAX: Maximum score across sources.
292    fn fuse_comb_max(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
293        let normalized = if self.config.normalize_scores {
294            self.normalize_per_index(results)
295        } else {
296            self.collect_scores(results)
297        };
298
299        let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
300
301        for (id, index_scores) in normalized {
302            let mut fused = FusedResult::new(id.clone(), 0.0);
303            let mut max_score: f32 = 0.0;
304
305            for (index_name, score) in index_scores {
306                let weight = self.get_weight(&index_name);
307                let weighted = weight * score;
308                max_score = max_score.max(weighted);
309                fused.add_source(index_name, score);
310            }
311
312            fused.fused_score = max_score;
313            scores.insert(id, fused);
314        }
315
316        self.sort_results(scores)
317    }
318
319    /// CombMIN: Minimum score across sources.
320    fn fuse_comb_min(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
321        let normalized = if self.config.normalize_scores {
322            self.normalize_per_index(results)
323        } else {
324            self.collect_scores(results)
325        };
326
327        let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
328
329        for (id, index_scores) in normalized {
330            let mut fused = FusedResult::new(id.clone(), 0.0);
331            let mut min_score: f32 = f32::MAX;
332
333            for (index_name, score) in index_scores {
334                let weight = self.get_weight(&index_name);
335                let weighted = weight * score;
336                min_score = min_score.min(weighted);
337                fused.add_source(index_name, score);
338            }
339
340            fused.fused_score = if min_score == f32::MAX { 0.0 } else { min_score };
341            scores.insert(id, fused);
342        }
343
344        self.sort_results(scores)
345    }
346
347    /// Get weight for an index (default 1.0).
348    fn get_weight(&self, index_name: &str) -> f32 {
349        self.config
350            .weights
351            .as_ref()
352            .and_then(|w| w.get(index_name))
353            .copied()
354            .unwrap_or(1.0)
355    }
356
357    /// Collect scores without normalization.
358    fn collect_scores(&self, results: &MultiIndexResults) -> AHashMap<String, Vec<(String, f32)>> {
359        let mut collected: AHashMap<String, Vec<(String, f32)>> = AHashMap::new();
360
361        for idx_result in &results.by_index {
362            for result in &idx_result.results {
363                collected
364                    .entry(result.id.clone())
365                    .or_default()
366                    .push((idx_result.index_name.clone(), result.score));
367            }
368        }
369
370        collected
371    }
372
373    /// Normalize scores per index to [0, 1] range.
374    fn normalize_per_index(
375        &self,
376        results: &MultiIndexResults,
377    ) -> AHashMap<String, Vec<(String, f32)>> {
378        let mut collected: AHashMap<String, Vec<(String, f32)>> = AHashMap::new();
379
380        for idx_result in &results.by_index {
381            // Find min/max for this index
382            let scores: Vec<f32> = idx_result.results.iter().map(|r| r.score).collect();
383            let min_score = scores.iter().cloned().fold(f32::INFINITY, f32::min);
384            let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
385            let range = max_score - min_score;
386
387            for result in &idx_result.results {
388                let normalized = if range > f32::EPSILON {
389                    (result.score - min_score) / range
390                } else {
391                    1.0 // All scores equal
392                };
393
394                collected
395                    .entry(result.id.clone())
396                    .or_default()
397                    .push((idx_result.index_name.clone(), normalized));
398            }
399        }
400
401        collected
402    }
403
404    /// Sort results by fused score (descending).
405    fn sort_results(&self, scores: AHashMap<String, FusedResult>) -> Vec<FusedResult> {
406        let mut sorted: Vec<FusedResult> = scores.into_values().collect();
407        sorted.sort_by(|a, b| {
408            OrderedFloat(b.fused_score).cmp(&OrderedFloat(a.fused_score))
409        });
410        sorted
411    }
412}
413
414impl Default for ScoreFusion {
415    fn default() -> Self {
416        Self::new()
417    }
418}
419
420/// Convenience function for RRF fusion.
421#[must_use]
422pub fn rrf_fuse(results: &MultiIndexResults) -> Vec<FusedResult> {
423    ScoreFusion::rrf().fuse(results)
424}
425
426/// Convenience function for RRF fusion with top-k.
427#[must_use]
428pub fn rrf_fuse_top_k(results: &MultiIndexResults, k: usize) -> Vec<FusedResult> {
429    ScoreFusion::rrf().fuse_top_k(results, k)
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::index::registry::MultiIndexResult;
436    use crate::SearchResult;
437
438    fn make_result(id: &str, score: f32) -> SearchResult {
439        SearchResult {
440            id: id.to_string(),
441            distance: 1.0 - score, // Fake distance
442            score,
443        }
444    }
445
446    fn make_multi_results() -> MultiIndexResults {
447        MultiIndexResults {
448            by_index: vec![
449                MultiIndexResult {
450                    index_name: "idx1".to_string(),
451                    results: vec![
452                        make_result("a", 0.9),
453                        make_result("b", 0.8),
454                        make_result("c", 0.7),
455                    ],
456                },
457                MultiIndexResult {
458                    index_name: "idx2".to_string(),
459                    results: vec![
460                        make_result("b", 0.95), // b is top in idx2
461                        make_result("a", 0.85),
462                        make_result("d", 0.75),
463                    ],
464                },
465            ],
466            total_count: 6,
467        }
468    }
469
470    #[test]
471    fn test_rrf_fusion() {
472        let results = make_multi_results();
473        let fused = ScoreFusion::rrf().fuse(&results);
474
475        // Should have 4 unique IDs: a, b, c, d
476        assert_eq!(fused.len(), 4);
477
478        // a and b appear in both indexes with symmetric ranks, so they have equal RRF scores
479        // Top two should be a and b (either order is valid for tied scores)
480        assert!(fused[0].id == "a" || fused[0].id == "b");
481        assert_eq!(fused[0].source_count(), 2);
482
483        assert!(fused[1].id == "a" || fused[1].id == "b");
484        assert_eq!(fused[1].source_count(), 2);
485        assert_ne!(fused[0].id, fused[1].id); // Both should be present
486
487        // c and d only appear in one index each
488        assert!(fused[2].id == "c" || fused[2].id == "d");
489        assert!(fused[3].id == "c" || fused[3].id == "d");
490    }
491
492    #[test]
493    fn test_rrf_scores() {
494        let results = make_multi_results();
495        let fusion = ScoreFusion::rrf_with_k(60);
496        let fused = fusion.fuse(&results);
497
498        // For item b:
499        // idx1: rank 2 -> 1/(60+2) = 1/62
500        // idx2: rank 1 -> 1/(60+1) = 1/61
501        // Total: 1/62 + 1/61 ≈ 0.01613 + 0.01639 ≈ 0.03252
502        let b = fused.iter().find(|r| r.id == "b").unwrap();
503        let expected = 1.0 / 62.0 + 1.0 / 61.0;
504        assert!((b.fused_score - expected).abs() < 0.0001);
505    }
506
507    #[test]
508    fn test_comb_sum() {
509        let results = make_multi_results();
510        let fusion = ScoreFusion::with_config(FusionConfig::new().with_comb_sum());
511        let fused = fusion.fuse(&results);
512
513        // a and b appear in both indexes with equal combined scores (0.9+0.85 = 0.8+0.95 = 1.75)
514        // Either order is valid for tied scores
515        assert!(fused[0].id == "a" || fused[0].id == "b");
516        assert!(fused[1].id == "a" || fused[1].id == "b");
517        assert_ne!(fused[0].id, fused[1].id);
518    }
519
520    #[test]
521    fn test_comb_mnz() {
522        let results = make_multi_results();
523        let fusion = ScoreFusion::with_config(FusionConfig::new().with_comb_mnz());
524        let fused = fusion.fuse(&results);
525
526        // Items appearing in both indexes should be boosted
527        let b = fused.iter().find(|r| r.id == "b").unwrap();
528        let c = fused.iter().find(|r| r.id == "c").unwrap();
529
530        // b appears in 2 sources, c in 1
531        assert_eq!(b.source_count(), 2);
532        assert_eq!(c.source_count(), 1);
533    }
534
535    #[test]
536    fn test_weighted_fusion() {
537        let results = make_multi_results();
538
539        let mut weights = AHashMap::new();
540        weights.insert("idx1".to_string(), 2.0);
541        weights.insert("idx2".to_string(), 1.0);
542
543        let fusion = ScoreFusion::with_config(FusionConfig::new().with_weights(weights));
544        let fused = fusion.fuse(&results);
545
546        // 'a' is ranked #1 in idx1 (weight 2.0) vs 'b' ranked #2
547        // With double weight on idx1, 'a' should score higher
548        assert_eq!(fused[0].id, "a");
549    }
550
551    #[test]
552    fn test_top_k() {
553        let results = make_multi_results();
554        let fused = ScoreFusion::rrf().fuse_top_k(&results, 2);
555
556        assert_eq!(fused.len(), 2);
557    }
558
559    #[test]
560    fn test_convenience_functions() {
561        let results = make_multi_results();
562
563        let fused1 = rrf_fuse(&results);
564        let fused2 = rrf_fuse_top_k(&results, 2);
565
566        assert_eq!(fused1.len(), 4);
567        assert_eq!(fused2.len(), 2);
568    }
569
570    #[test]
571    fn test_empty_results() {
572        let results = MultiIndexResults::default();
573        let fused = ScoreFusion::rrf().fuse(&results);
574        assert!(fused.is_empty());
575    }
576
577    #[test]
578    fn test_single_index() {
579        let results = MultiIndexResults {
580            by_index: vec![MultiIndexResult {
581                index_name: "only".to_string(),
582                results: vec![make_result("a", 0.9), make_result("b", 0.8)],
583            }],
584            total_count: 2,
585        };
586
587        let fused = ScoreFusion::rrf().fuse(&results);
588
589        assert_eq!(fused.len(), 2);
590        assert_eq!(fused[0].id, "a");
591        assert_eq!(fused[1].id, "b");
592    }
593
594    #[test]
595    fn test_fused_result_sources() {
596        let results = make_multi_results();
597        let fused = ScoreFusion::rrf().fuse(&results);
598
599        let b = fused.iter().find(|r| r.id == "b").unwrap();
600        assert!(b.sources.contains(&"idx1".to_string()));
601        assert!(b.sources.contains(&"idx2".to_string()));
602        assert!(b.source_scores.contains_key("idx1"));
603        assert!(b.source_scores.contains_key("idx2"));
604    }
605
606    #[test]
607    fn test_comb_max() {
608        let results = MultiIndexResults {
609            by_index: vec![
610                MultiIndexResult {
611                    index_name: "idx1".to_string(),
612                    results: vec![make_result("a", 0.5), make_result("b", 0.9)],
613                },
614                MultiIndexResult {
615                    index_name: "idx2".to_string(),
616                    results: vec![make_result("a", 0.8), make_result("b", 0.3)],
617                },
618            ],
619            total_count: 4,
620        };
621
622        let fusion = ScoreFusion::with_config(FusionConfig {
623            strategy: FusionStrategy::CombMAX,
624            normalize_scores: false,
625            ..Default::default()
626        });
627        let fused = fusion.fuse(&results);
628
629        // a: max(0.5, 0.8) = 0.8
630        // b: max(0.9, 0.3) = 0.9
631        assert_eq!(fused[0].id, "b");
632        assert!((fused[0].fused_score - 0.9).abs() < 0.001);
633    }
634
635    #[test]
636    fn test_comb_min() {
637        let results = MultiIndexResults {
638            by_index: vec![
639                MultiIndexResult {
640                    index_name: "idx1".to_string(),
641                    results: vec![make_result("a", 0.5), make_result("b", 0.9)],
642                },
643                MultiIndexResult {
644                    index_name: "idx2".to_string(),
645                    results: vec![make_result("a", 0.8), make_result("b", 0.3)],
646                },
647            ],
648            total_count: 4,
649        };
650
651        let fusion = ScoreFusion::with_config(FusionConfig {
652            strategy: FusionStrategy::CombMIN,
653            normalize_scores: false,
654            ..Default::default()
655        });
656        let fused = fusion.fuse(&results);
657
658        // a: min(0.5, 0.8) = 0.5
659        // b: min(0.9, 0.3) = 0.3
660        // a should be ranked higher with min strategy
661        assert_eq!(fused[0].id, "a");
662        assert!((fused[0].fused_score - 0.5).abs() < 0.001);
663    }
664}