use std::cmp::Ordering;
use std::collections::HashMap;
pub const RRF_DEFAULT_K: u32 = 60;
#[derive(Clone, Debug, PartialEq)]
pub struct FusionResult {
pub id: u64,
pub score: f32,
pub dense_rank: Option<usize>,
pub sparse_rank: Option<usize>,
}
impl FusionResult {
#[inline]
#[must_use]
pub fn new(id: u64, score: f32) -> Self {
Self {
id,
score,
dense_rank: None,
sparse_rank: None,
}
}
#[inline]
#[must_use]
pub fn with_ranks(
id: u64,
score: f32,
dense_rank: Option<usize>,
sparse_rank: Option<usize>,
) -> Self {
Self {
id,
score,
dense_rank,
sparse_rank,
}
}
}
#[derive(Clone, Debug)]
pub enum FusionMethod {
Rrf {
k: u32,
},
Linear {
alpha: f32,
},
}
impl Default for FusionMethod {
fn default() -> Self {
FusionMethod::Rrf { k: RRF_DEFAULT_K }
}
}
impl FusionMethod {
#[inline]
#[must_use]
pub fn rrf() -> Self {
FusionMethod::Rrf { k: RRF_DEFAULT_K }
}
#[inline]
#[must_use]
pub fn rrf_with_k(k: u32) -> Self {
FusionMethod::Rrf { k }
}
#[inline]
pub fn linear(alpha: f32) -> Result<Self, String> {
if !(0.0..=1.0).contains(&alpha) {
return Err(format!("Alpha must be in range [0.0, 1.0], got {alpha}"));
}
Ok(FusionMethod::Linear { alpha })
}
#[inline]
#[must_use]
pub fn linear_balanced() -> Self {
FusionMethod::Linear { alpha: 0.5 }
}
}
#[must_use]
pub fn rrf_fusion(
dense_results: &[(u64, f32)],
sparse_results: &[(u64, f32)],
k: u32,
top_n: usize,
) -> Vec<FusionResult> {
struct DocInfo {
score: f32,
dense_rank: Option<usize>,
sparse_rank: Option<usize>,
}
if top_n == 0 {
return Vec::new();
}
let k_f64 = f64::from(k);
let mut doc_map: HashMap<u64, DocInfo> = HashMap::new();
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
for (rank_0, (id, _score)) in dense_results.iter().enumerate() {
let rank = rank_0 + 1; let rrf_contribution = (1.0 / (k_f64 + rank as f64)) as f32;
doc_map
.entry(*id)
.and_modify(|info| {
info.score += rrf_contribution;
info.dense_rank = Some(rank);
})
.or_insert(DocInfo {
score: rrf_contribution,
dense_rank: Some(rank),
sparse_rank: None,
});
}
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
for (rank_0, (id, _score)) in sparse_results.iter().enumerate() {
let rank = rank_0 + 1; let rrf_contribution = (1.0 / (k_f64 + rank as f64)) as f32;
doc_map
.entry(*id)
.and_modify(|info| {
info.score += rrf_contribution;
info.sparse_rank = Some(rank);
})
.or_insert(DocInfo {
score: rrf_contribution,
dense_rank: None,
sparse_rank: Some(rank),
});
}
let mut results: Vec<FusionResult> = doc_map
.into_iter()
.map(|(id, info)| {
FusionResult::with_ranks(id, info.score, info.dense_rank, info.sparse_rank)
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results.truncate(top_n);
results
}
fn normalize_scores(results: &[(u64, f32)]) -> HashMap<u64, f32> {
if results.is_empty() {
return HashMap::new();
}
let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
let min = scores.iter().copied().fold(f32::INFINITY, f32::min);
let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
results
.iter()
.map(|(id, score)| {
let normalized = if range == 0.0 {
1.0 } else {
(score - min) / range
};
(*id, normalized)
})
.collect()
}
#[must_use]
pub fn linear_fusion(
dense_results: &[(u64, f32)],
sparse_results: &[(u64, f32)],
alpha: f32,
top_n: usize,
) -> Vec<FusionResult> {
if top_n == 0 {
return Vec::new();
}
let alpha = alpha.clamp(0.0, 1.0);
debug_assert!(
dense_results.iter().all(|(_, s)| *s >= 0.0),
"Dense scores must be non-negative for linear fusion normalization"
);
debug_assert!(
sparse_results.iter().all(|(_, s)| *s >= 0.0),
"Sparse scores must be non-negative for linear fusion normalization"
);
let dense_norm = normalize_scores(dense_results);
let sparse_norm = normalize_scores(sparse_results);
let mut doc_map: HashMap<u64, (f32, Option<usize>, Option<usize>)> = HashMap::new();
for (rank_0, (id, _)) in dense_results.iter().enumerate() {
let norm_score = dense_norm.get(id).copied().unwrap_or(0.0);
doc_map
.entry(*id)
.and_modify(|(s, dr, _)| {
*s += alpha * norm_score;
*dr = Some(rank_0 + 1);
})
.or_insert((alpha * norm_score, Some(rank_0 + 1), None));
}
for (rank_0, (id, _)) in sparse_results.iter().enumerate() {
let norm_score = sparse_norm.get(id).copied().unwrap_or(0.0);
doc_map
.entry(*id)
.and_modify(|(s, _, sr)| {
*s += (1.0 - alpha) * norm_score;
*sr = Some(rank_0 + 1);
})
.or_insert(((1.0 - alpha) * norm_score, None, Some(rank_0 + 1)));
}
let mut results: Vec<FusionResult> = doc_map
.into_iter()
.map(|(id, (score, dense_rank, sparse_rank))| {
FusionResult::with_ranks(id, score, dense_rank, sparse_rank)
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results.truncate(top_n);
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrf_identical_lists() {
let dense = vec![(1, 0.9), (2, 0.8), (3, 0.7)];
let sparse = vec![(1, 5.0), (2, 4.0), (3, 3.0)];
let results = rrf_fusion(&dense, &sparse, 60, 10);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, 1);
assert_eq!(results[1].id, 2);
assert_eq!(results[2].id, 3);
let expected_score = 2.0 / 61.0; assert!((results[0].score - expected_score).abs() < 1e-6);
}
#[test]
fn test_rrf_disjoint_lists() {
let dense = vec![(1, 0.9), (2, 0.8)];
let sparse = vec![(3, 5.0), (4, 4.0)];
let results = rrf_fusion(&dense, &sparse, 60, 10);
assert_eq!(results.len(), 4);
for result in &results {
assert!(result.dense_rank.is_some() != result.sparse_rank.is_some());
}
}
#[test]
fn test_rrf_partial_overlap() {
let dense = vec![(1, 0.9), (2, 0.8), (3, 0.7)];
let sparse = vec![(2, 5.0), (4, 4.0), (5, 3.0)];
let results = rrf_fusion(&dense, &sparse, 60, 10);
assert_eq!(results[0].id, 2);
assert_eq!(results[0].dense_rank, Some(2));
assert_eq!(results[0].sparse_rank, Some(1));
let expected = 1.0 / 62.0 + 1.0 / 61.0;
assert!((results[0].score - expected).abs() < 1e-6);
}
#[test]
fn test_rrf_k_parameter_effect() {
let dense = vec![(1, 0.9), (2, 0.8)];
let sparse = vec![(1, 5.0), (2, 4.0)];
let results_k60 = rrf_fusion(&dense, &sparse, 60, 10);
let results_k1 = rrf_fusion(&dense, &sparse, 1, 10);
assert!(results_k60[0].score < results_k1[0].score);
assert_eq!(results_k60[0].id, results_k1[0].id);
}
#[test]
fn test_rrf_empty_dense() {
let dense: Vec<(u64, f32)> = vec![];
let sparse = vec![(1, 5.0), (2, 4.0)];
let results = rrf_fusion(&dense, &sparse, 60, 10);
assert_eq!(results.len(), 2);
assert!(results[0].dense_rank.is_none());
assert!(results[0].sparse_rank.is_some());
}
#[test]
fn test_rrf_empty_sparse() {
let dense = vec![(1, 0.9), (2, 0.8)];
let sparse: Vec<(u64, f32)> = vec![];
let results = rrf_fusion(&dense, &sparse, 60, 10);
assert_eq!(results.len(), 2);
assert!(results[0].dense_rank.is_some());
assert!(results[0].sparse_rank.is_none());
}
#[test]
fn test_rrf_both_empty() {
let dense: Vec<(u64, f32)> = vec![];
let sparse: Vec<(u64, f32)> = vec![];
let results = rrf_fusion(&dense, &sparse, 60, 10);
assert!(results.is_empty());
}
#[test]
fn test_rrf_top_n_zero() {
let dense = vec![(1, 0.9), (2, 0.8)];
let sparse = vec![(1, 5.0), (2, 4.0)];
let results = rrf_fusion(&dense, &sparse, 60, 0);
assert!(results.is_empty());
}
#[test]
fn test_rrf_top_n_truncation() {
let dense = vec![(1, 0.9), (2, 0.8), (3, 0.7)];
let sparse = vec![(4, 5.0), (5, 4.0), (6, 3.0)];
let results = rrf_fusion(&dense, &sparse, 60, 3);
assert_eq!(results.len(), 3);
}
#[test]
fn test_rrf_score_ordering() {
let dense = vec![(1, 0.9), (2, 0.8), (3, 0.7), (4, 0.6)];
let sparse = vec![(4, 5.0), (3, 4.0), (2, 3.0), (1, 2.0)];
let results = rrf_fusion(&dense, &sparse, 60, 10);
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"Results not sorted: {} < {} at positions {} and {}",
results[i - 1].score,
results[i].score,
i - 1,
i
);
}
}
#[test]
fn test_rrf_rank_tracking() {
let dense = vec![(1, 0.9), (2, 0.8)];
let sparse = vec![(2, 5.0), (3, 4.0)];
let results = rrf_fusion(&dense, &sparse, 60, 10);
let r1 = results.iter().find(|r| r.id == 1).unwrap();
assert_eq!(r1.dense_rank, Some(1));
assert_eq!(r1.sparse_rank, None);
let r2 = results.iter().find(|r| r.id == 2).unwrap();
assert_eq!(r2.dense_rank, Some(2));
assert_eq!(r2.sparse_rank, Some(1));
let r3 = results.iter().find(|r| r.id == 3).unwrap();
assert_eq!(r3.dense_rank, None);
assert_eq!(r3.sparse_rank, Some(2));
}
#[test]
fn test_rrf_large_lists() {
let dense: Vec<(u64, f32)> = (0..1000u64)
.map(|i| (i, 1.0 - (i as f32 / 1000.0)))
.collect();
let sparse: Vec<(u64, f32)> = (500..1500u64)
.map(|i| (i, 1.0 - ((i - 500) as f32 / 1000.0)))
.collect();
let results = rrf_fusion(&dense, &sparse, 60, 100);
assert_eq!(results.len(), 100);
let overlap_count = results
.iter()
.filter(|r| r.dense_rank.is_some() && r.sparse_rank.is_some())
.count();
assert!(
overlap_count > 50,
"Expected most results from overlap, got {}",
overlap_count
);
}
#[test]
fn test_normalize_empty() {
let results: Vec<(u64, f32)> = vec![];
let normalized = normalize_scores(&results);
assert!(normalized.is_empty());
}
#[test]
fn test_normalize_single() {
let results = vec![(1, 0.5)];
let normalized = normalize_scores(&results);
assert_eq!(normalized.len(), 1);
assert!((normalized.get(&1).copied().unwrap() - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_normalize_all_same() {
let results = vec![(1, 0.5), (2, 0.5), (3, 0.5)];
let normalized = normalize_scores(&results);
assert_eq!(normalized.len(), 3);
for id in 1..=3 {
assert!((normalized.get(&id).copied().unwrap() - 1.0).abs() < f32::EPSILON);
}
}
#[test]
fn test_normalize_range() {
let results = vec![(1, 0.0), (2, 0.5), (3, 1.0)];
let normalized = normalize_scores(&results);
assert!((normalized.get(&1).copied().unwrap() - 0.0).abs() < 1e-6);
assert!((normalized.get(&2).copied().unwrap() - 0.5).abs() < 1e-6);
assert!((normalized.get(&3).copied().unwrap() - 1.0).abs() < 1e-6);
}
#[test]
fn test_normalize_large_range() {
let results = vec![(1, 0.0), (2, 100.0)];
let normalized = normalize_scores(&results);
assert!((normalized.get(&1).copied().unwrap() - 0.0).abs() < 1e-6);
assert!((normalized.get(&2).copied().unwrap() - 1.0).abs() < 1e-6);
}
#[test]
fn test_linear_identical_lists() {
let dense = vec![(1, 0.9), (2, 0.8), (3, 0.7)];
let sparse = vec![(1, 5.0), (2, 4.0), (3, 3.0)];
let results = linear_fusion(&dense, &sparse, 0.5, 10);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, 1);
}
#[test]
fn test_linear_alpha_weighting() {
let dense = vec![(1, 1.0), (2, 0.0)];
let sparse = vec![(2, 1.0), (1, 0.0)];
let results_dense = linear_fusion(&dense, &sparse, 1.0, 10);
assert_eq!(results_dense[0].id, 1);
let results_sparse = linear_fusion(&dense, &sparse, 0.0, 10);
assert_eq!(results_sparse[0].id, 2);
let results_balanced = linear_fusion(&dense, &sparse, 0.5, 10);
assert!((results_balanced[0].score - results_balanced[1].score).abs() < 1e-6);
}
#[test]
fn test_linear_empty_dense() {
let dense: Vec<(u64, f32)> = vec![];
let sparse = vec![(1, 5.0), (2, 4.0)];
let results = linear_fusion(&dense, &sparse, 0.5, 10);
assert_eq!(results.len(), 2);
assert!(results[0].score <= 0.5 + 1e-6);
}
#[test]
fn test_linear_empty_sparse() {
let dense = vec![(1, 0.9), (2, 0.8)];
let sparse: Vec<(u64, f32)> = vec![];
let results = linear_fusion(&dense, &sparse, 0.5, 10);
assert_eq!(results.len(), 2);
}
#[test]
fn test_linear_both_empty() {
let dense: Vec<(u64, f32)> = vec![];
let sparse: Vec<(u64, f32)> = vec![];
let results = linear_fusion(&dense, &sparse, 0.5, 10);
assert!(results.is_empty());
}
#[test]
fn test_linear_top_n_zero() {
let dense = vec![(1, 0.9)];
let sparse = vec![(1, 5.0)];
let results = linear_fusion(&dense, &sparse, 0.5, 0);
assert!(results.is_empty());
}
#[test]
fn test_linear_single_score_normalization() {
let dense = vec![(1, 0.5)];
let sparse = vec![(1, 3.0)];
let results = linear_fusion(&dense, &sparse, 0.5, 10);
assert!((results[0].score - 1.0).abs() < 1e-6);
}
#[test]
fn test_linear_alpha_zero() {
let dense = vec![(1, 1.0), (2, 0.5), (3, 0.0)];
let sparse = vec![(1, 0.0), (2, 0.5), (3, 1.0)];
let results = linear_fusion(&dense, &sparse, 0.0, 10);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, 3);
assert!((results[0].score - 1.0).abs() < 1e-6);
assert_eq!(results[1].id, 2);
assert!((results[1].score - 0.5).abs() < 1e-6);
assert_eq!(results[2].id, 1);
assert!(results[2].score.abs() < 1e-6);
}
#[test]
fn test_linear_alpha_one() {
let dense = vec![(1, 1.0), (2, 0.5), (3, 0.0)];
let sparse = vec![(1, 0.0), (2, 0.5), (3, 1.0)];
let results = linear_fusion(&dense, &sparse, 1.0, 10);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, 1);
assert!((results[0].score - 1.0).abs() < 1e-6);
assert_eq!(results[1].id, 2);
assert!((results[1].score - 0.5).abs() < 1e-6);
assert_eq!(results[2].id, 3);
assert!(results[2].score.abs() < 1e-6);
}
#[test]
fn test_linear_disjoint_lists() {
let dense = vec![(1, 0.9), (2, 0.8)];
let sparse = vec![(3, 5.0), (4, 4.0)];
let results = linear_fusion(&dense, &sparse, 0.5, 10);
assert_eq!(results.len(), 4);
for result in &results {
let has_dense = result.dense_rank.is_some();
let has_sparse = result.sparse_rank.is_some();
assert!(
has_dense ^ has_sparse,
"Disjoint items should have exactly one rank"
);
}
let ids: Vec<u64> = results.iter().map(|r| r.id).collect();
assert!(ids.contains(&1));
assert!(ids.contains(&2));
assert!(ids.contains(&3));
assert!(ids.contains(&4));
}
#[test]
fn test_linear_preserves_ranks() {
let dense = vec![(1, 0.9), (2, 0.7), (3, 0.5)];
let sparse = vec![(2, 5.0), (3, 4.0), (4, 3.0)];
let results = linear_fusion(&dense, &sparse, 0.5, 10);
let item_1 = results.iter().find(|r| r.id == 1).unwrap();
assert_eq!(item_1.dense_rank, Some(1)); assert_eq!(item_1.sparse_rank, None);
let item_2 = results.iter().find(|r| r.id == 2).unwrap();
assert_eq!(item_2.dense_rank, Some(2)); assert_eq!(item_2.sparse_rank, Some(1));
let item_3 = results.iter().find(|r| r.id == 3).unwrap();
assert_eq!(item_3.dense_rank, Some(3)); assert_eq!(item_3.sparse_rank, Some(2));
let item_4 = results.iter().find(|r| r.id == 4).unwrap();
assert_eq!(item_4.dense_rank, None); assert_eq!(item_4.sparse_rank, Some(3)); }
#[test]
fn test_linear_score_ordering() {
let dense = vec![(1, 0.9), (2, 0.8), (3, 0.7)];
let sparse = vec![(3, 5.0), (4, 4.0), (5, 3.0)];
let results = linear_fusion(&dense, &sparse, 0.5, 10);
for window in results.windows(2) {
assert!(
window[0].score >= window[1].score,
"Results should be sorted in descending order by score"
);
}
assert!(!results.is_empty());
}
#[test]
fn test_linear_alpha_boundary_low() {
let dense = vec![(1, 0.9)];
let sparse = vec![(2, 5.0)];
let results = linear_fusion(&dense, &sparse, 0.01, 10);
assert_eq!(results[0].id, 2);
}
#[test]
fn test_linear_alpha_boundary_high() {
let dense = vec![(1, 0.9)];
let sparse = vec![(2, 5.0)];
let results = linear_fusion(&dense, &sparse, 0.99, 10);
assert_eq!(results[0].id, 1);
}
#[test]
fn test_linear_alpha_clamping() {
let dense = vec![(1, 0.9)];
let sparse = vec![(2, 5.0)];
let results_high = linear_fusion(&dense, &sparse, 1.5, 10);
assert!(!results_high.is_empty());
assert_eq!(results_high[0].id, 1);
let results_low = linear_fusion(&dense, &sparse, -0.5, 10);
assert!(!results_low.is_empty());
assert_eq!(results_low[0].id, 2);
}
#[test]
fn test_fusion_result_new() {
let result = FusionResult::new(42, 0.5);
assert_eq!(result.id, 42);
assert!((result.score - 0.5).abs() < f32::EPSILON);
assert_eq!(result.dense_rank, None);
assert_eq!(result.sparse_rank, None);
}
#[test]
fn test_fusion_result_with_ranks() {
let result = FusionResult::with_ranks(42, 0.5, Some(1), Some(2));
assert_eq!(result.id, 42);
assert!((result.score - 0.5).abs() < f32::EPSILON);
assert_eq!(result.dense_rank, Some(1));
assert_eq!(result.sparse_rank, Some(2));
}
#[test]
fn test_fusion_result_partial_eq() {
let r1 = FusionResult::with_ranks(42, 0.5, Some(1), Some(2));
let r2 = FusionResult::with_ranks(42, 0.5, Some(1), Some(2));
let r3 = FusionResult::with_ranks(43, 0.5, Some(1), Some(2));
assert_eq!(r1, r2);
assert_ne!(r1, r3);
}
#[test]
fn test_fusion_method_default() {
let method = FusionMethod::default();
match method {
FusionMethod::Rrf { k } => assert_eq!(k, 60),
FusionMethod::Linear { .. } => panic!("Expected RRF"),
}
}
#[test]
fn test_fusion_method_rrf() {
let method = FusionMethod::rrf();
match method {
FusionMethod::Rrf { k } => assert_eq!(k, RRF_DEFAULT_K),
FusionMethod::Linear { .. } => panic!("Expected RRF"),
}
}
#[test]
fn test_fusion_method_rrf_with_k() {
let method = FusionMethod::rrf_with_k(100);
match method {
FusionMethod::Rrf { k } => assert_eq!(k, 100),
FusionMethod::Linear { .. } => panic!("Expected RRF"),
}
}
#[test]
fn test_fusion_method_linear() {
let method = FusionMethod::linear(0.7).unwrap();
match method {
FusionMethod::Linear { alpha } => assert!((alpha - 0.7).abs() < f32::EPSILON),
FusionMethod::Rrf { .. } => panic!("Expected Linear"),
}
}
#[test]
fn test_fusion_method_linear_balanced() {
let method = FusionMethod::linear_balanced();
match method {
FusionMethod::Linear { alpha } => assert!((alpha - 0.5).abs() < f32::EPSILON),
FusionMethod::Rrf { .. } => panic!("Expected Linear"),
}
}
#[test]
fn test_fusion_method_linear_invalid_high() {
let result = FusionMethod::linear(1.5);
assert!(result.is_err());
assert!(
result.unwrap_err().contains("Alpha must be in range"),
"Error message should mention valid range"
);
}
#[test]
fn test_fusion_method_linear_invalid_low() {
let result = FusionMethod::linear(-0.1);
assert!(result.is_err());
assert!(
result.unwrap_err().contains("Alpha must be in range"),
"Error message should mention valid range"
);
}
}