use crate::index::registry::MultiIndexResults;
use ahash::AHashMap;
use ordered_float::OrderedFloat;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FusionStrategy {
#[default]
RRF,
CombSUM,
CombMNZ,
CombMAX,
CombMIN,
}
#[derive(Debug, Clone)]
pub struct FusedResult {
pub id: String,
pub fused_score: f32,
pub sources: Vec<String>,
pub source_scores: AHashMap<String, f32>,
}
impl FusedResult {
#[must_use]
pub fn new(id: String, fused_score: f32) -> Self {
Self {
id,
fused_score,
sources: Vec::new(),
source_scores: AHashMap::new(),
}
}
pub fn add_source(&mut self, index_name: String, score: f32) {
self.sources.push(index_name.clone());
self.source_scores.insert(index_name, score);
}
#[must_use]
pub fn source_count(&self) -> usize {
self.sources.len()
}
}
#[derive(Debug, Clone)]
pub struct FusionConfig {
pub strategy: FusionStrategy,
pub rrf_k: usize,
pub weights: Option<AHashMap<String, f32>>,
pub normalize_scores: bool,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
strategy: FusionStrategy::RRF,
rrf_k: 60,
weights: None,
normalize_scores: true,
}
}
}
impl FusionConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_rrf(mut self, k: usize) -> Self {
self.strategy = FusionStrategy::RRF;
self.rrf_k = k;
self
}
#[must_use]
pub const fn with_comb_sum(mut self) -> Self {
self.strategy = FusionStrategy::CombSUM;
self
}
#[must_use]
pub const fn with_comb_mnz(mut self) -> Self {
self.strategy = FusionStrategy::CombMNZ;
self
}
#[must_use]
pub fn with_weights(mut self, weights: AHashMap<String, f32>) -> Self {
self.weights = Some(weights);
self
}
#[must_use]
pub const fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize_scores = normalize;
self
}
}
pub struct ScoreFusion {
config: FusionConfig,
}
impl ScoreFusion {
#[must_use]
pub fn new() -> Self {
Self {
config: FusionConfig::default(),
}
}
#[must_use]
pub fn with_config(config: FusionConfig) -> Self {
Self { config }
}
#[must_use]
pub fn rrf() -> Self {
Self::new()
}
#[must_use]
pub fn rrf_with_k(k: usize) -> Self {
Self::with_config(FusionConfig::new().with_rrf(k))
}
#[must_use]
pub fn fuse(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
match self.config.strategy {
FusionStrategy::RRF => self.fuse_rrf(results),
FusionStrategy::CombSUM => self.fuse_comb_sum(results),
FusionStrategy::CombMNZ => self.fuse_comb_mnz(results),
FusionStrategy::CombMAX => self.fuse_comb_max(results),
FusionStrategy::CombMIN => self.fuse_comb_min(results),
}
}
#[must_use]
pub fn fuse_top_k(&self, results: &MultiIndexResults, k: usize) -> Vec<FusedResult> {
let mut fused = self.fuse(results);
fused.truncate(k);
fused
}
fn fuse_rrf(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
let k = self.config.rrf_k as f32;
let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
for idx_result in &results.by_index {
let index_name = &idx_result.index_name;
let weight = self.get_weight(index_name);
for (rank, result) in idx_result.results.iter().enumerate() {
let rrf_score = weight / (k + (rank + 1) as f32);
let fused = scores.entry(result.id.clone()).or_insert_with(|| {
FusedResult::new(result.id.clone(), 0.0)
});
fused.fused_score += rrf_score;
fused.add_source(index_name.clone(), result.score);
}
}
self.sort_results(scores)
}
fn fuse_comb_sum(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
let normalized = if self.config.normalize_scores {
self.normalize_per_index(results)
} else {
self.collect_scores(results)
};
let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
for (id, index_scores) in normalized {
let mut fused = FusedResult::new(id.clone(), 0.0);
for (index_name, score) in index_scores {
let weight = self.get_weight(&index_name);
fused.fused_score += weight * score;
fused.add_source(index_name, score);
}
scores.insert(id, fused);
}
self.sort_results(scores)
}
fn fuse_comb_mnz(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
let normalized = if self.config.normalize_scores {
self.normalize_per_index(results)
} else {
self.collect_scores(results)
};
let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
for (id, index_scores) in normalized {
let mut fused = FusedResult::new(id.clone(), 0.0);
let mut sum = 0.0;
for (index_name, score) in index_scores {
let weight = self.get_weight(&index_name);
sum += weight * score;
fused.add_source(index_name, score);
}
fused.fused_score = sum * fused.source_count() as f32;
scores.insert(id, fused);
}
self.sort_results(scores)
}
fn fuse_comb_max(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
let normalized = if self.config.normalize_scores {
self.normalize_per_index(results)
} else {
self.collect_scores(results)
};
let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
for (id, index_scores) in normalized {
let mut fused = FusedResult::new(id.clone(), 0.0);
let mut max_score: f32 = 0.0;
for (index_name, score) in index_scores {
let weight = self.get_weight(&index_name);
let weighted = weight * score;
max_score = max_score.max(weighted);
fused.add_source(index_name, score);
}
fused.fused_score = max_score;
scores.insert(id, fused);
}
self.sort_results(scores)
}
fn fuse_comb_min(&self, results: &MultiIndexResults) -> Vec<FusedResult> {
let normalized = if self.config.normalize_scores {
self.normalize_per_index(results)
} else {
self.collect_scores(results)
};
let mut scores: AHashMap<String, FusedResult> = AHashMap::new();
for (id, index_scores) in normalized {
let mut fused = FusedResult::new(id.clone(), 0.0);
let mut min_score: f32 = f32::MAX;
for (index_name, score) in index_scores {
let weight = self.get_weight(&index_name);
let weighted = weight * score;
min_score = min_score.min(weighted);
fused.add_source(index_name, score);
}
fused.fused_score = if min_score == f32::MAX { 0.0 } else { min_score };
scores.insert(id, fused);
}
self.sort_results(scores)
}
fn get_weight(&self, index_name: &str) -> f32 {
self.config
.weights
.as_ref()
.and_then(|w| w.get(index_name))
.copied()
.unwrap_or(1.0)
}
fn collect_scores(&self, results: &MultiIndexResults) -> AHashMap<String, Vec<(String, f32)>> {
let mut collected: AHashMap<String, Vec<(String, f32)>> = AHashMap::new();
for idx_result in &results.by_index {
for result in &idx_result.results {
collected
.entry(result.id.clone())
.or_default()
.push((idx_result.index_name.clone(), result.score));
}
}
collected
}
fn normalize_per_index(
&self,
results: &MultiIndexResults,
) -> AHashMap<String, Vec<(String, f32)>> {
let mut collected: AHashMap<String, Vec<(String, f32)>> = AHashMap::new();
for idx_result in &results.by_index {
let scores: Vec<f32> = idx_result.results.iter().map(|r| r.score).collect();
let min_score = scores.iter().cloned().fold(f32::INFINITY, f32::min);
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let range = max_score - min_score;
for result in &idx_result.results {
let normalized = if range > f32::EPSILON {
(result.score - min_score) / range
} else {
1.0 };
collected
.entry(result.id.clone())
.or_default()
.push((idx_result.index_name.clone(), normalized));
}
}
collected
}
fn sort_results(&self, scores: AHashMap<String, FusedResult>) -> Vec<FusedResult> {
let mut sorted: Vec<FusedResult> = scores.into_values().collect();
sorted.sort_by(|a, b| {
OrderedFloat(b.fused_score).cmp(&OrderedFloat(a.fused_score))
});
sorted
}
}
impl Default for ScoreFusion {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn rrf_fuse(results: &MultiIndexResults) -> Vec<FusedResult> {
ScoreFusion::rrf().fuse(results)
}
#[must_use]
pub fn rrf_fuse_top_k(results: &MultiIndexResults, k: usize) -> Vec<FusedResult> {
ScoreFusion::rrf().fuse_top_k(results, k)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::registry::MultiIndexResult;
use crate::SearchResult;
fn make_result(id: &str, score: f32) -> SearchResult {
SearchResult {
id: id.to_string(),
distance: 1.0 - score, score,
}
}
fn make_multi_results() -> MultiIndexResults {
MultiIndexResults {
by_index: vec![
MultiIndexResult {
index_name: "idx1".to_string(),
results: vec![
make_result("a", 0.9),
make_result("b", 0.8),
make_result("c", 0.7),
],
},
MultiIndexResult {
index_name: "idx2".to_string(),
results: vec![
make_result("b", 0.95), make_result("a", 0.85),
make_result("d", 0.75),
],
},
],
total_count: 6,
}
}
#[test]
fn test_rrf_fusion() {
let results = make_multi_results();
let fused = ScoreFusion::rrf().fuse(&results);
assert_eq!(fused.len(), 4);
assert!(fused[0].id == "a" || fused[0].id == "b");
assert_eq!(fused[0].source_count(), 2);
assert!(fused[1].id == "a" || fused[1].id == "b");
assert_eq!(fused[1].source_count(), 2);
assert_ne!(fused[0].id, fused[1].id);
assert!(fused[2].id == "c" || fused[2].id == "d");
assert!(fused[3].id == "c" || fused[3].id == "d");
}
#[test]
fn test_rrf_scores() {
let results = make_multi_results();
let fusion = ScoreFusion::rrf_with_k(60);
let fused = fusion.fuse(&results);
let b = fused.iter().find(|r| r.id == "b").unwrap();
let expected = 1.0 / 62.0 + 1.0 / 61.0;
assert!((b.fused_score - expected).abs() < 0.0001);
}
#[test]
fn test_comb_sum() {
let results = make_multi_results();
let fusion = ScoreFusion::with_config(FusionConfig::new().with_comb_sum());
let fused = fusion.fuse(&results);
assert!(fused[0].id == "a" || fused[0].id == "b");
assert!(fused[1].id == "a" || fused[1].id == "b");
assert_ne!(fused[0].id, fused[1].id);
}
#[test]
fn test_comb_mnz() {
let results = make_multi_results();
let fusion = ScoreFusion::with_config(FusionConfig::new().with_comb_mnz());
let fused = fusion.fuse(&results);
let b = fused.iter().find(|r| r.id == "b").unwrap();
let c = fused.iter().find(|r| r.id == "c").unwrap();
assert_eq!(b.source_count(), 2);
assert_eq!(c.source_count(), 1);
}
#[test]
fn test_weighted_fusion() {
let results = make_multi_results();
let mut weights = AHashMap::new();
weights.insert("idx1".to_string(), 2.0);
weights.insert("idx2".to_string(), 1.0);
let fusion = ScoreFusion::with_config(FusionConfig::new().with_weights(weights));
let fused = fusion.fuse(&results);
assert_eq!(fused[0].id, "a");
}
#[test]
fn test_top_k() {
let results = make_multi_results();
let fused = ScoreFusion::rrf().fuse_top_k(&results, 2);
assert_eq!(fused.len(), 2);
}
#[test]
fn test_convenience_functions() {
let results = make_multi_results();
let fused1 = rrf_fuse(&results);
let fused2 = rrf_fuse_top_k(&results, 2);
assert_eq!(fused1.len(), 4);
assert_eq!(fused2.len(), 2);
}
#[test]
fn test_empty_results() {
let results = MultiIndexResults::default();
let fused = ScoreFusion::rrf().fuse(&results);
assert!(fused.is_empty());
}
#[test]
fn test_single_index() {
let results = MultiIndexResults {
by_index: vec![MultiIndexResult {
index_name: "only".to_string(),
results: vec![make_result("a", 0.9), make_result("b", 0.8)],
}],
total_count: 2,
};
let fused = ScoreFusion::rrf().fuse(&results);
assert_eq!(fused.len(), 2);
assert_eq!(fused[0].id, "a");
assert_eq!(fused[1].id, "b");
}
#[test]
fn test_fused_result_sources() {
let results = make_multi_results();
let fused = ScoreFusion::rrf().fuse(&results);
let b = fused.iter().find(|r| r.id == "b").unwrap();
assert!(b.sources.contains(&"idx1".to_string()));
assert!(b.sources.contains(&"idx2".to_string()));
assert!(b.source_scores.contains_key("idx1"));
assert!(b.source_scores.contains_key("idx2"));
}
#[test]
fn test_comb_max() {
let results = MultiIndexResults {
by_index: vec![
MultiIndexResult {
index_name: "idx1".to_string(),
results: vec![make_result("a", 0.5), make_result("b", 0.9)],
},
MultiIndexResult {
index_name: "idx2".to_string(),
results: vec![make_result("a", 0.8), make_result("b", 0.3)],
},
],
total_count: 4,
};
let fusion = ScoreFusion::with_config(FusionConfig {
strategy: FusionStrategy::CombMAX,
normalize_scores: false,
..Default::default()
});
let fused = fusion.fuse(&results);
assert_eq!(fused[0].id, "b");
assert!((fused[0].fused_score - 0.9).abs() < 0.001);
}
#[test]
fn test_comb_min() {
let results = MultiIndexResults {
by_index: vec![
MultiIndexResult {
index_name: "idx1".to_string(),
results: vec![make_result("a", 0.5), make_result("b", 0.9)],
},
MultiIndexResult {
index_name: "idx2".to_string(),
results: vec![make_result("a", 0.8), make_result("b", 0.3)],
},
],
total_count: 4,
};
let fusion = ScoreFusion::with_config(FusionConfig {
strategy: FusionStrategy::CombMIN,
normalize_scores: false,
..Default::default()
});
let fused = fusion.fuse(&results);
assert_eq!(fused[0].id, "a");
assert!((fused[0].fused_score - 0.5).abs() < 0.001);
}
}