#![allow(clippy::unnecessary_wraps)]
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum FusionError {
InvalidWeightSum {
sum: f32,
},
NegativeWeight {
weight: f32,
},
}
impl std::fmt::Display for FusionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidWeightSum { sum } => {
write!(f, "Weights must sum to 1.0, got {sum:.4}")
}
Self::NegativeWeight { weight } => {
write!(f, "Weights must be non-negative, got {weight:.4}")
}
}
}
}
impl std::error::Error for FusionError {}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum FusionStrategy {
Average,
Maximum,
RRF {
k: u32,
},
Weighted {
avg_weight: f32,
max_weight: f32,
hit_weight: f32,
},
RelativeScore {
dense_weight: f32,
sparse_weight: f32,
},
}
impl FusionStrategy {
#[must_use]
pub fn rrf_default() -> Self {
Self::RRF { k: 60 }
}
pub fn relative_score(dense_weight: f32, sparse_weight: f32) -> Result<Self, FusionError> {
validate_non_negative(&[dense_weight, sparse_weight])?;
validate_weight_sum(dense_weight + sparse_weight)?;
Ok(Self::RelativeScore {
dense_weight,
sparse_weight,
})
}
pub fn weighted(
avg_weight: f32,
max_weight: f32,
hit_weight: f32,
) -> Result<Self, FusionError> {
validate_non_negative(&[avg_weight, max_weight, hit_weight])?;
validate_weight_sum(avg_weight + max_weight + hit_weight)?;
Ok(Self::Weighted {
avg_weight,
max_weight,
hit_weight,
})
}
pub fn fuse(&self, results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
if results.is_empty() {
return Ok(Vec::new());
}
let non_empty_count = results.iter().filter(|r| !r.is_empty()).count();
if non_empty_count == 0 {
return Ok(Vec::new());
}
let total_queries = results.len();
match self {
Self::Average => Self::fuse_average(results),
Self::Maximum => Self::fuse_maximum(results),
Self::RRF { k } => Self::fuse_rrf(results, *k),
Self::Weighted {
avg_weight,
max_weight,
hit_weight,
} => Ok(Self::fuse_weighted(
results,
*avg_weight,
*max_weight,
*hit_weight,
total_queries,
)),
Self::RelativeScore {
dense_weight,
sparse_weight,
} => Self::fuse_relative_score(&results, *dense_weight, *sparse_weight),
}
}
fn collect_doc_scores(results: Vec<Vec<(u64, f32)>>) -> HashMap<u64, Vec<f32>> {
let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
for query_results in results {
let mut query_best: HashMap<u64, f32> = HashMap::new();
for (id, score) in query_results {
query_best
.entry(id)
.and_modify(|s| *s = s.max(score))
.or_insert(score);
}
for (id, score) in query_best {
doc_scores.entry(id).or_default().push(score);
}
}
doc_scores
}
fn sort_descending(fused: &mut [(u64, f32)]) {
fused.sort_by(|a, b| b.1.total_cmp(&a.1));
}
#[allow(clippy::cast_precision_loss)]
fn fuse_average(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
let mut fused: Vec<(u64, f32)> = Self::collect_doc_scores(results)
.into_iter()
.map(|(id, scores)| {
let avg = scores.iter().sum::<f32>() / scores.len() as f32;
(id, avg)
})
.collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
fn fuse_maximum(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
let mut doc_max: HashMap<u64, f32> = HashMap::new();
for query_results in results {
for (id, score) in query_results {
doc_max
.entry(id)
.and_modify(|s| *s = s.max(score))
.or_insert(score);
}
}
let mut fused: Vec<(u64, f32)> = doc_max.into_iter().collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
#[allow(clippy::cast_precision_loss)]
fn fuse_rrf(results: Vec<Vec<(u64, f32)>>, k: u32) -> Result<Vec<(u64, f32)>, FusionError> {
let mut doc_rrf: HashMap<u64, f32> = HashMap::new();
let k_f32 = k as f32;
for query_results in results {
let mut seen: HashMap<u64, usize> = HashMap::new();
for (rank, (id, _score)) in query_results.into_iter().enumerate() {
seen.entry(id).or_insert(rank);
}
for (id, rank) in seen {
let rrf_score = 1.0 / (k_f32 + (rank + 1) as f32);
*doc_rrf.entry(id).or_insert(0.0) += rrf_score;
}
}
let mut fused: Vec<(u64, f32)> = doc_rrf.into_iter().collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
#[allow(clippy::cast_precision_loss)]
fn fuse_weighted(
results: Vec<Vec<(u64, f32)>>,
avg_weight: f32,
max_weight: f32,
hit_weight: f32,
total_queries: usize,
) -> Vec<(u64, f32)> {
let total_q = total_queries as f32;
let mut fused: Vec<(u64, f32)> = Self::collect_doc_scores(results)
.into_iter()
.map(|(id, scores)| {
let avg = scores.iter().sum::<f32>() / scores.len() as f32;
let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let hit_ratio = scores.len() as f32 / total_q;
let combined = avg_weight * avg + max_weight * max + hit_weight * hit_ratio;
(id, combined)
})
.collect();
Self::sort_descending(&mut fused);
fused
}
fn fuse_relative_score(
results: &[Vec<(u64, f32)>],
dense_weight: f32,
sparse_weight: f32,
) -> Result<Vec<(u64, f32)>, FusionError> {
if results.len() > 2 {
tracing::warn!(
branch_count = results.len(),
"RelativeScore fusion received {} branches but only supports 2 (dense + sparse). \
Branches beyond index 1 are ignored.",
results.len(),
);
}
let dense = results.first().map_or(&[][..], |v| v.as_slice());
let sparse = results.get(1).map_or(&[][..], |v| v.as_slice());
let norm_dense = min_max_normalize(dense);
let norm_sparse = min_max_normalize(sparse);
let mut all_ids: HashMap<u64, f32> = HashMap::new();
for (&id, &nd) in &norm_dense {
let ns = norm_sparse.get(&id).copied().unwrap_or(0.0);
all_ids.insert(id, dense_weight * nd + sparse_weight * ns);
}
for (&id, &ns) in &norm_sparse {
all_ids.entry(id).or_insert_with(|| {
let nd = norm_dense.get(&id).copied().unwrap_or(0.0);
dense_weight * nd + sparse_weight * ns
});
}
let mut fused: Vec<(u64, f32)> = all_ids.into_iter().collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
}
impl Default for FusionStrategy {
fn default() -> Self {
Self::RRF { k: 60 }
}
}
fn validate_non_negative(weights: &[f32]) -> Result<(), FusionError> {
for &w in weights {
if w < 0.0 {
return Err(FusionError::NegativeWeight { weight: w });
}
}
Ok(())
}
fn validate_weight_sum(sum: f32) -> Result<(), FusionError> {
if (sum - 1.0).abs() > 0.001 {
return Err(FusionError::InvalidWeightSum { sum });
}
Ok(())
}
fn min_max_normalize(branch: &[(u64, f32)]) -> HashMap<u64, f32> {
if branch.is_empty() {
return HashMap::new();
}
let min = branch.iter().map(|&(_, s)| s).fold(f32::INFINITY, f32::min);
let max = branch
.iter()
.map(|&(_, s)| s)
.fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
branch
.iter()
.map(|&(id, s)| {
let norm = if range < f32::EPSILON {
0.5
} else {
(s - min) / range
};
(id, norm)
})
.collect()
}