#![allow(clippy::unnecessary_wraps)]
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum FusionError {
InvalidWeightSum {
sum: f32,
},
NegativeWeight {
weight: f32,
},
WeightCountMismatch {
weights: usize,
branches: usize,
},
}
impl std::fmt::Display for FusionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidWeightSum { sum } => {
write!(f, "Weights must sum to 1.0, got {sum:.4}")
}
Self::NegativeWeight { weight } => {
write!(f, "Weights must be non-negative, got {weight:.4}")
}
Self::WeightCountMismatch { weights, branches } => write!(
f,
"WeightedRRF requires one weight per branch: {weights} weights for {branches} branches",
),
}
}
}
impl std::error::Error for FusionError {}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum FusionStrategy {
Average,
Maximum,
RRF {
k: u32,
},
Weighted {
avg_weight: f32,
max_weight: f32,
hit_weight: f32,
},
RelativeScore {
dense_weight: f32,
sparse_weight: f32,
},
WeightedRRF {
weights: Vec<f32>,
k: f32,
},
}
impl FusionStrategy {
#[must_use]
pub fn rrf_default() -> Self {
Self::RRF { k: 60 }
}
pub fn weighted_rrf(weights: Vec<f32>, k: f32) -> Result<Self, FusionError> {
validate_non_negative(&weights)?;
if k <= 0.0 {
return Err(FusionError::NegativeWeight { weight: k });
}
Ok(Self::WeightedRRF { weights, k })
}
pub fn relative_score(dense_weight: f32, sparse_weight: f32) -> Result<Self, FusionError> {
validate_non_negative(&[dense_weight, sparse_weight])?;
validate_weight_sum(dense_weight + sparse_weight)?;
Ok(Self::RelativeScore {
dense_weight,
sparse_weight,
})
}
pub fn weighted(
avg_weight: f32,
max_weight: f32,
hit_weight: f32,
) -> Result<Self, FusionError> {
validate_non_negative(&[avg_weight, max_weight, hit_weight])?;
validate_weight_sum(avg_weight + max_weight + hit_weight)?;
Ok(Self::Weighted {
avg_weight,
max_weight,
hit_weight,
})
}
pub fn fuse(&self, results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
if results.is_empty() {
return Ok(Vec::new());
}
let non_empty_count = results.iter().filter(|r| !r.is_empty()).count();
if non_empty_count == 0 {
return Ok(Vec::new());
}
let total_queries = results.len();
match self {
Self::Average => Self::fuse_average(results),
Self::Maximum => Self::fuse_maximum(results),
Self::RRF { k } => Self::fuse_rrf(results, *k),
Self::Weighted {
avg_weight,
max_weight,
hit_weight,
} => Self::fuse_weighted(
results,
*avg_weight,
*max_weight,
*hit_weight,
total_queries,
),
Self::RelativeScore {
dense_weight,
sparse_weight,
} => Self::fuse_relative_score(&results, *dense_weight, *sparse_weight),
Self::WeightedRRF { weights, k } => Self::fuse_weighted_rrf(results, weights, *k),
}
}
fn collect_doc_scores(results: Vec<Vec<(u64, f32)>>) -> HashMap<u64, Vec<f32>> {
let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
for query_results in results {
let mut query_best: HashMap<u64, f32> = HashMap::new();
for (id, score) in query_results {
query_best
.entry(id)
.and_modify(|s| *s = s.max(score))
.or_insert(score);
}
for (id, score) in query_best {
doc_scores.entry(id).or_default().push(score);
}
}
doc_scores
}
fn sort_descending(fused: &mut [(u64, f32)]) {
fused.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
}
#[allow(clippy::cast_precision_loss)]
fn fuse_average(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
let mut fused: Vec<(u64, f32)> = Self::collect_doc_scores(results)
.into_iter()
.map(|(id, scores)| {
let avg = scores.iter().sum::<f32>() / scores.len() as f32;
(id, avg)
})
.collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
fn fuse_maximum(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
let mut doc_max: HashMap<u64, f32> = HashMap::new();
for query_results in results {
for (id, score) in query_results {
doc_max
.entry(id)
.and_modify(|s| *s = s.max(score))
.or_insert(score);
}
}
let mut fused: Vec<(u64, f32)> = doc_max.into_iter().collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
#[allow(clippy::cast_precision_loss)]
fn fuse_rrf(results: Vec<Vec<(u64, f32)>>, k: u32) -> Result<Vec<(u64, f32)>, FusionError> {
let mut doc_rrf: HashMap<u64, f32> = HashMap::new();
let k_f32 = k as f32;
for query_results in results {
let mut seen: HashMap<u64, usize> = HashMap::new();
for (rank, (id, _score)) in query_results.into_iter().enumerate() {
seen.entry(id).or_insert(rank);
}
for (id, rank) in seen {
let rrf_score = 1.0 / (k_f32 + (rank + 1) as f32);
*doc_rrf.entry(id).or_insert(0.0) += rrf_score;
}
}
let mut fused: Vec<(u64, f32)> = doc_rrf.into_iter().collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
#[allow(clippy::cast_precision_loss)]
fn fuse_weighted(
results: Vec<Vec<(u64, f32)>>,
avg_weight: f32,
max_weight: f32,
hit_weight: f32,
total_queries: usize,
) -> Result<Vec<(u64, f32)>, FusionError> {
validate_non_negative(&[avg_weight, max_weight, hit_weight])?;
validate_weight_sum(avg_weight + max_weight + hit_weight)?;
let total_q = total_queries as f32;
let mut fused: Vec<(u64, f32)> = Self::collect_doc_scores(results)
.into_iter()
.map(|(id, scores)| {
let avg = scores.iter().sum::<f32>() / scores.len() as f32;
let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let hit_ratio = scores.len() as f32 / total_q;
let combined = avg_weight * avg + max_weight * max + hit_weight * hit_ratio;
(id, combined)
})
.collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
fn fuse_relative_score(
results: &[Vec<(u64, f32)>],
dense_weight: f32,
sparse_weight: f32,
) -> Result<Vec<(u64, f32)>, FusionError> {
validate_non_negative(&[dense_weight, sparse_weight])?;
validate_weight_sum(dense_weight + sparse_weight)?;
if results.len() > 2 {
tracing::warn!(
branch_count = results.len(),
"RelativeScore fusion received {} branches but only supports 2 (dense + sparse). \
Branches beyond index 1 are ignored.",
results.len(),
);
}
let dense = results.first().map_or(&[][..], |v| v.as_slice());
let sparse = results.get(1).map_or(&[][..], |v| v.as_slice());
let norm_dense = min_max_normalize(dense);
let norm_sparse = min_max_normalize(sparse);
let mut all_ids: HashMap<u64, f32> =
HashMap::with_capacity(norm_dense.len() + norm_sparse.len());
for (&id, &nd) in &norm_dense {
let ns = norm_sparse.get(&id).copied().unwrap_or(0.0);
all_ids.insert(id, dense_weight * nd + sparse_weight * ns);
}
for (&id, &ns) in &norm_sparse {
all_ids.entry(id).or_insert(sparse_weight * ns);
}
let mut fused: Vec<(u64, f32)> = all_ids.into_iter().collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
#[allow(clippy::cast_precision_loss)]
fn fuse_weighted_rrf(
branches: Vec<Vec<(u64, f32)>>,
weights: &[f32],
k: f32,
) -> Result<Vec<(u64, f32)>, FusionError> {
validate_non_negative(weights)?;
if k <= 0.0 {
return Err(FusionError::NegativeWeight { weight: k });
}
if weights.len() != branches.len() {
return Err(FusionError::WeightCountMismatch {
weights: weights.len(),
branches: branches.len(),
});
}
let mut doc_scores: HashMap<u64, f32> = HashMap::new();
for (branch, &weight) in branches.into_iter().zip(weights.iter()) {
let mut best_rank: HashMap<u64, usize> = HashMap::new();
for (rank, (id, _)) in branch.into_iter().enumerate() {
best_rank.entry(id).or_insert(rank);
}
for (id, rank) in best_rank {
let contribution = weight / (rank as f32 + k);
*doc_scores.entry(id).or_insert(0.0) += contribution;
}
}
let mut fused: Vec<(u64, f32)> = doc_scores.into_iter().collect();
Self::sort_descending(&mut fused);
Ok(fused)
}
}
impl Default for FusionStrategy {
fn default() -> Self {
Self::RRF { k: 60 }
}
}
fn validate_non_negative(weights: &[f32]) -> Result<(), FusionError> {
for &w in weights {
if w < 0.0 {
return Err(FusionError::NegativeWeight { weight: w });
}
}
Ok(())
}
fn validate_weight_sum(sum: f32) -> Result<(), FusionError> {
if (sum - 1.0).abs() > 0.001 {
return Err(FusionError::InvalidWeightSum { sum });
}
Ok(())
}
fn min_max_normalize(branch: &[(u64, f32)]) -> HashMap<u64, f32> {
if branch.is_empty() {
return HashMap::new();
}
let (min, max) = branch
.iter()
.fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), &(_, s)| {
(lo.min(s), hi.max(s))
});
let range = max - min;
let mut out = HashMap::with_capacity(branch.len());
for &(id, s) in branch {
let norm = if range < f32::EPSILON {
0.5
} else {
(s - min) / range
};
out.insert(id, norm);
}
out
}