1use crate::index::registry::MultiIndexResults;
34use ahash::AHashMap;
35use ordered_float::OrderedFloat;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub enum FusionStrategy {
40 #[default]
42 RRF,
43 CombSUM,
45 CombMNZ,
47 CombMAX,
49 CombMIN,
51}
52
53#[derive(Debug, Clone)]
55pub struct FusedResult {
56 pub id: String,
58 pub fused_score: f32,
60 pub sources: Vec<String>,
62 pub source_scores: AHashMap<String, f32>,
64}
65
66impl FusedResult {
67 #[must_use]
69 pub fn new(id: String, fused_score: f32) -> Self {
70 Self {
71 id,
72 fused_score,
73 sources: Vec::new(),
74 source_scores: AHashMap::new(),
75 }
76 }
77
78 pub fn add_source(&mut self, index_name: String, score: f32) {
80 self.sources.push(index_name.clone());
81 self.source_scores.insert(index_name, score);
82 }
83
84 #[must_use]
86 pub fn source_count(&self) -> usize {
87 self.sources.len()
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct FusionConfig {
94 pub strategy: FusionStrategy,
96 pub rrf_k: usize,
98 pub weights: Option<AHashMap<String, f32>>,
100 pub normalize_scores: bool,
102}
103
104impl Default for FusionConfig {
105 fn default() -> Self {
106 Self {
107 strategy: FusionStrategy::RRF,
108 rrf_k: 60,
109 weights: None,
110 normalize_scores: true,
111 }
112 }
113}
114
115impl FusionConfig {
116 #[must_use]
118 pub fn new() -> Self {
119 Self::default()
120 }
121
122 #[must_use]
124 pub const fn with_rrf(mut self, k: usize) -> Self {
125 self.strategy = FusionStrategy::RRF;
126 self.rrf_k = k;
127 self
128 }
129
130 #[must_use]
132 pub const fn with_comb_sum(mut self) -> Self {
133 self.strategy = FusionStrategy::CombSUM;
134 self
135 }
136
137 #[must_use]
139 pub const fn with_comb_mnz(mut self) -> Self {
140 self.strategy = FusionStrategy::CombMNZ;
141 self
142 }
143
144 #[must_use]
146 pub fn with_weights(mut self, weights: AHashMap<String, f32>) -> Self {
147 self.weights = Some(weights);
148 self
149 }
150
151 #[must_use]
153 pub const fn with_normalize(mut self, normalize: bool) -> Self {
154 self.normalize_scores = normalize;
155 self
156 }
157}
158
159pub struct ScoreFusion {
161 config: FusionConfig,
162}
163
164impl ScoreFusion {
165 #[must_use]
167 pub fn new() -> Self {
168 Self {
169 config: FusionConfig::default(),
170 }
171 }
172
173 #[must_use]
175 pub fn with_config(config: FusionConfig) -> Self {
176 Self { config }
177 }
178
179 #[must_use]
181 pub fn rrf() -> Self {
182 Self::new()
183 }
184
185 #[must_use]
187 pub fn rrf_with_k(k: usize) -> Self {
188 Self::with_config(FusionConfig::new().with_rrf(k))
189 }
190
191 #[must_use]
193 pub fn fuse(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
194 match self.config.strategy {
195 FusionStrategy::RRF => self.fuse_rrf(results),
196 FusionStrategy::CombSUM => self.fuse_comb_sum(results),
197 FusionStrategy::CombMNZ => self.fuse_comb_mnz(results),
198 FusionStrategy::CombMAX => self.fuse_comb_max(results),
199 FusionStrategy::CombMIN => self.fuse_comb_min(results),
200 }
201 }
202
203 #[must_use]
205 pub fn fuse_top_k(&self, results: &MultiIndexResults, k: usize) -> Vec<FusedResult> {
206 let mut fused = self.fuse(results);
207 fused.truncate(k);
208 fused
209 }
210
211 fn fuse_rrf(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
216 let k = self.config.rrf_k as f32;
217 let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
218
219 for idx_result in &results.by_index {
220 let index_name = &idx_result.index_name;
221 let weight = self.get_weight(index_name);
222
223 for (rank, result) in idx_result.results.iter().enumerate() {
224 let rrf_score = weight / (k + (rank + 1) as f32);
225
226 let fused = scores.entry(result.id.clone()).or_insert_with(|| {
227 FusedResult::new(result.id.clone(), 0.0)
228 });
229
230 fused.fused_score += rrf_score;
231 fused.add_source(index_name.clone(), result.score);
232 }
233 }
234
235 self.sort_results(scores)
236 }
237
238 fn fuse_comb_sum(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
240 let normalized = if self.config.normalize_scores {
241 self.normalize_per_index(results)
242 } else {
243 self.collect_scores(results)
244 };
245
246 let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
247
248 for (id, index_scores) in normalized {
249 let mut fused = FusedResult::new(id.clone(), 0.0);
250
251 for (index_name, score) in index_scores {
252 let weight = self.get_weight(&index_name);
253 fused.fused_score += weight * score;
254 fused.add_source(index_name, score);
255 }
256
257 scores.insert(id, fused);
258 }
259
260 self.sort_results(scores)
261 }
262
263 fn fuse_comb_mnz(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
265 let normalized = if self.config.normalize_scores {
266 self.normalize_per_index(results)
267 } else {
268 self.collect_scores(results)
269 };
270
271 let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
272
273 for (id, index_scores) in normalized {
274 let mut fused = FusedResult::new(id.clone(), 0.0);
275 let mut sum = 0.0;
276
277 for (index_name, score) in index_scores {
278 let weight = self.get_weight(&index_name);
279 sum += weight * score;
280 fused.add_source(index_name, score);
281 }
282
283 fused.fused_score = sum * fused.source_count() as f32;
285 scores.insert(id, fused);
286 }
287
288 self.sort_results(scores)
289 }
290
291 fn fuse_comb_max(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
293 let normalized = if self.config.normalize_scores {
294 self.normalize_per_index(results)
295 } else {
296 self.collect_scores(results)
297 };
298
299 let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
300
301 for (id, index_scores) in normalized {
302 let mut fused = FusedResult::new(id.clone(), 0.0);
303 let mut max_score: f32 = 0.0;
304
305 for (index_name, score) in index_scores {
306 let weight = self.get_weight(&index_name);
307 let weighted = weight * score;
308 max_score = max_score.max(weighted);
309 fused.add_source(index_name, score);
310 }
311
312 fused.fused_score = max_score;
313 scores.insert(id, fused);
314 }
315
316 self.sort_results(scores)
317 }
318
319 fn fuse_comb_min(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
321 let normalized = if self.config.normalize_scores {
322 self.normalize_per_index(results)
323 } else {
324 self.collect_scores(results)
325 };
326
327 let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
328
329 for (id, index_scores) in normalized {
330 let mut fused = FusedResult::new(id.clone(), 0.0);
331 let mut min_score: f32 = f32::MAX;
332
333 for (index_name, score) in index_scores {
334 let weight = self.get_weight(&index_name);
335 let weighted = weight * score;
336 min_score = min_score.min(weighted);
337 fused.add_source(index_name, score);
338 }
339
340 fused.fused_score = if min_score == f32::MAX { 0.0 } else { min_score };
341 scores.insert(id, fused);
342 }
343
344 self.sort_results(scores)
345 }
346
347 fn get_weight(&self, index_name: &str) -> f32 {
349 self.config
350 .weights
351 .as_ref()
352 .and_then(|w| w.get(index_name))
353 .copied()
354 .unwrap_or(1.0)
355 }
356
357 fn collect_scores(&self, results: &MultiIndexResults) -> AHashMap<String, Vec<(String, f32)>> {
359 let mut collected: AHashMap<String, Vec<(String, f32)>> = AHashMap::new();
360
361 for idx_result in &results.by_index {
362 for result in &idx_result.results {
363 collected
364 .entry(result.id.clone())
365 .or_default()
366 .push((idx_result.index_name.clone(), result.score));
367 }
368 }
369
370 collected
371 }
372
373 fn normalize_per_index(
375 &self,
376 results: &MultiIndexResults,
377 ) -> AHashMap<String, Vec<(String, f32)>> {
378 let mut collected: AHashMap<String, Vec<(String, f32)>> = AHashMap::new();
379
380 for idx_result in &results.by_index {
381 let scores: Vec<f32> = idx_result.results.iter().map(|r| r.score).collect();
383 let min_score = scores.iter().cloned().fold(f32::INFINITY, f32::min);
384 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
385 let range = max_score - min_score;
386
387 for result in &idx_result.results {
388 let normalized = if range > f32::EPSILON {
389 (result.score - min_score) / range
390 } else {
391 1.0 };
393
394 collected
395 .entry(result.id.clone())
396 .or_default()
397 .push((idx_result.index_name.clone(), normalized));
398 }
399 }
400
401 collected
402 }
403
404 fn sort_results(&self, scores: AHashMap<String, FusedResult>) -> Vec<FusedResult> {
406 let mut sorted: Vec<FusedResult> = scores.into_values().collect();
407 sorted.sort_by(|a, b| {
408 OrderedFloat(b.fused_score).cmp(&OrderedFloat(a.fused_score))
409 });
410 sorted
411 }
412}
413
414impl Default for ScoreFusion {
415 fn default() -> Self {
416 Self::new()
417 }
418}
419
420#[must_use]
422pub fn rrf_fuse(results: &MultiIndexResults) -> Vec<FusedResult> {
423 ScoreFusion::rrf().fuse(results)
424}
425
426#[must_use]
428pub fn rrf_fuse_top_k(results: &MultiIndexResults, k: usize) -> Vec<FusedResult> {
429 ScoreFusion::rrf().fuse_top_k(results, k)
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::index::registry::MultiIndexResult;
436 use crate::SearchResult;
437
438 fn make_result(id: &str, score: f32) -> SearchResult {
439 SearchResult {
440 id: id.to_string(),
441 distance: 1.0 - score, score,
443 }
444 }
445
446 fn make_multi_results() -> MultiIndexResults {
447 MultiIndexResults {
448 by_index: vec![
449 MultiIndexResult {
450 index_name: "idx1".to_string(),
451 results: vec![
452 make_result("a", 0.9),
453 make_result("b", 0.8),
454 make_result("c", 0.7),
455 ],
456 },
457 MultiIndexResult {
458 index_name: "idx2".to_string(),
459 results: vec![
460 make_result("b", 0.95), make_result("a", 0.85),
462 make_result("d", 0.75),
463 ],
464 },
465 ],
466 total_count: 6,
467 }
468 }
469
470 #[test]
471 fn test_rrf_fusion() {
472 let results = make_multi_results();
473 let fused = ScoreFusion::rrf().fuse(&results);
474
475 assert_eq!(fused.len(), 4);
477
478 assert!(fused[0].id == "a" || fused[0].id == "b");
481 assert_eq!(fused[0].source_count(), 2);
482
483 assert!(fused[1].id == "a" || fused[1].id == "b");
484 assert_eq!(fused[1].source_count(), 2);
485 assert_ne!(fused[0].id, fused[1].id); assert!(fused[2].id == "c" || fused[2].id == "d");
489 assert!(fused[3].id == "c" || fused[3].id == "d");
490 }
491
492 #[test]
493 fn test_rrf_scores() {
494 let results = make_multi_results();
495 let fusion = ScoreFusion::rrf_with_k(60);
496 let fused = fusion.fuse(&results);
497
498 let b = fused.iter().find(|r| r.id == "b").unwrap();
503 let expected = 1.0 / 62.0 + 1.0 / 61.0;
504 assert!((b.fused_score - expected).abs() < 0.0001);
505 }
506
507 #[test]
508 fn test_comb_sum() {
509 let results = make_multi_results();
510 let fusion = ScoreFusion::with_config(FusionConfig::new().with_comb_sum());
511 let fused = fusion.fuse(&results);
512
513 assert!(fused[0].id == "a" || fused[0].id == "b");
516 assert!(fused[1].id == "a" || fused[1].id == "b");
517 assert_ne!(fused[0].id, fused[1].id);
518 }
519
520 #[test]
521 fn test_comb_mnz() {
522 let results = make_multi_results();
523 let fusion = ScoreFusion::with_config(FusionConfig::new().with_comb_mnz());
524 let fused = fusion.fuse(&results);
525
526 let b = fused.iter().find(|r| r.id == "b").unwrap();
528 let c = fused.iter().find(|r| r.id == "c").unwrap();
529
530 assert_eq!(b.source_count(), 2);
532 assert_eq!(c.source_count(), 1);
533 }
534
535 #[test]
536 fn test_weighted_fusion() {
537 let results = make_multi_results();
538
539 let mut weights = AHashMap::new();
540 weights.insert("idx1".to_string(), 2.0);
541 weights.insert("idx2".to_string(), 1.0);
542
543 let fusion = ScoreFusion::with_config(FusionConfig::new().with_weights(weights));
544 let fused = fusion.fuse(&results);
545
546 assert_eq!(fused[0].id, "a");
549 }
550
551 #[test]
552 fn test_top_k() {
553 let results = make_multi_results();
554 let fused = ScoreFusion::rrf().fuse_top_k(&results, 2);
555
556 assert_eq!(fused.len(), 2);
557 }
558
559 #[test]
560 fn test_convenience_functions() {
561 let results = make_multi_results();
562
563 let fused1 = rrf_fuse(&results);
564 let fused2 = rrf_fuse_top_k(&results, 2);
565
566 assert_eq!(fused1.len(), 4);
567 assert_eq!(fused2.len(), 2);
568 }
569
570 #[test]
571 fn test_empty_results() {
572 let results = MultiIndexResults::default();
573 let fused = ScoreFusion::rrf().fuse(&results);
574 assert!(fused.is_empty());
575 }
576
577 #[test]
578 fn test_single_index() {
579 let results = MultiIndexResults {
580 by_index: vec![MultiIndexResult {
581 index_name: "only".to_string(),
582 results: vec![make_result("a", 0.9), make_result("b", 0.8)],
583 }],
584 total_count: 2,
585 };
586
587 let fused = ScoreFusion::rrf().fuse(&results);
588
589 assert_eq!(fused.len(), 2);
590 assert_eq!(fused[0].id, "a");
591 assert_eq!(fused[1].id, "b");
592 }
593
594 #[test]
595 fn test_fused_result_sources() {
596 let results = make_multi_results();
597 let fused = ScoreFusion::rrf().fuse(&results);
598
599 let b = fused.iter().find(|r| r.id == "b").unwrap();
600 assert!(b.sources.contains(&"idx1".to_string()));
601 assert!(b.sources.contains(&"idx2".to_string()));
602 assert!(b.source_scores.contains_key("idx1"));
603 assert!(b.source_scores.contains_key("idx2"));
604 }
605
606 #[test]
607 fn test_comb_max() {
608 let results = MultiIndexResults {
609 by_index: vec![
610 MultiIndexResult {
611 index_name: "idx1".to_string(),
612 results: vec![make_result("a", 0.5), make_result("b", 0.9)],
613 },
614 MultiIndexResult {
615 index_name: "idx2".to_string(),
616 results: vec![make_result("a", 0.8), make_result("b", 0.3)],
617 },
618 ],
619 total_count: 4,
620 };
621
622 let fusion = ScoreFusion::with_config(FusionConfig {
623 strategy: FusionStrategy::CombMAX,
624 normalize_scores: false,
625 ..Default::default()
626 });
627 let fused = fusion.fuse(&results);
628
629 assert_eq!(fused[0].id, "b");
632 assert!((fused[0].fused_score - 0.9).abs() < 0.001);
633 }
634
635 #[test]
636 fn test_comb_min() {
637 let results = MultiIndexResults {
638 by_index: vec![
639 MultiIndexResult {
640 index_name: "idx1".to_string(),
641 results: vec![make_result("a", 0.5), make_result("b", 0.9)],
642 },
643 MultiIndexResult {
644 index_name: "idx2".to_string(),
645 results: vec![make_result("a", 0.8), make_result("b", 0.3)],
646 },
647 ],
648 total_count: 4,
649 };
650
651 let fusion = ScoreFusion::with_config(FusionConfig {
652 strategy: FusionStrategy::CombMIN,
653 normalize_scores: false,
654 ..Default::default()
655 });
656 let fused = fusion.fuse(&results);
657
658 assert_eq!(fused[0].id, "a");
662 assert!((fused[0].fused_score - 0.5).abs() < 0.001);
663 }
664}