use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionConfig {
pub k: usize,
pub weights: Vec<f32>,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
k: 60,
weights: vec![1.0],
}
}
}
pub fn rrf_fusion(ranked_lists: &[Vec<(usize, f32)>], config: &FusionConfig) -> Vec<(usize, f32)> {
use std::collections::HashMap;
if ranked_lists.is_empty() {
return vec![];
}
if ranked_lists.len() == 1 {
return ranked_lists[0].clone();
}
let mut scores: HashMap<usize, f32> = HashMap::new();
for (list_idx, ranked_list) in ranked_lists.iter().enumerate() {
let weight = config.weights.get(list_idx).copied().unwrap_or(1.0);
for (rank, &(doc_idx, _score)) in ranked_list.iter().enumerate() {
*scores.entry(doc_idx).or_insert(0.0) += weight / (config.k as f32 + rank as f32 + 1.0);
}
}
let mut results: Vec<(usize, f32)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrf_two_lists() {
let list1 = vec![(0, 0.9), (1, 0.8), (2, 0.7)];
let list2 = vec![(1, 0.95), (2, 0.85), (3, 0.75)];
let config = FusionConfig {
k: 60,
weights: vec![1.0, 1.0],
};
let fused = rrf_fusion(&[list1, list2], &config);
assert!(!fused.is_empty());
assert_eq!(fused.len(), 4);
assert_eq!(fused[0].0, 1);
let expected_doc1 = 1.0 / 62.0 + 1.0 / 61.0;
assert!((fused[0].1 - expected_doc1 as f32).abs() < 1e-6);
for w in fused.windows(2) {
assert!(w[0].1 >= w[1].1);
}
}
#[test]
fn test_rrf_three_lists() {
let list1 = vec![(0, 0.9), (1, 0.8)];
let list2 = vec![(1, 0.95), (2, 0.85)];
let list3 = vec![(1, 0.99), (0, 0.5)];
let config = FusionConfig {
k: 60,
weights: vec![1.0, 1.0, 1.0],
};
let fused = rrf_fusion(&[list1, list2, list3], &config);
assert_eq!(fused[0].0, 1);
assert_eq!(fused.len(), 3); }
#[test]
fn test_rrf_single_list_bypass() {
let list = vec![(0, 0.9), (1, 0.8), (2, 0.7)];
let config = FusionConfig::default();
let fused = rrf_fusion(&[list.clone()], &config);
assert_eq!(fused, list);
}
#[test]
fn test_rrf_empty_input() {
let config = FusionConfig::default();
let fused = rrf_fusion(&[], &config);
assert!(fused.is_empty());
}
#[test]
fn test_rrf_weighted() {
let bm25 = vec![(0, 10.0), (1, 8.0)];
let neural = vec![(1, 0.95), (0, 0.3)];
let config = FusionConfig {
k: 60,
weights: vec![0.3, 0.7], };
let fused = rrf_fusion(&[bm25, neural], &config);
assert_eq!(fused[0].0, 1);
}
#[test]
fn test_rrf_disjoint_lists() {
let list1 = vec![(0, 0.9), (1, 0.8)];
let list2 = vec![(2, 0.95), (3, 0.85)];
let config = FusionConfig {
k: 60,
weights: vec![1.0, 1.0],
};
let fused = rrf_fusion(&[list1, list2], &config);
assert_eq!(
fused.len(),
4,
"Disjoint lists should produce union of all docs"
);
}
#[test]
fn test_rrf_single_doc_per_list() {
let list1 = vec![(0, 1.0)];
let list2 = vec![(1, 1.0)];
let config = FusionConfig {
k: 60,
weights: vec![1.0, 1.0],
};
let fused = rrf_fusion(&[list1, list2], &config);
assert_eq!(fused.len(), 2);
assert!(
(fused[0].1 - fused[1].1).abs() < 1e-6,
"Equal-ranked single docs should tie"
);
}
#[test]
fn test_rrf_large_k() {
let list1 = vec![(0, 0.9), (1, 0.8)];
let list2 = vec![(1, 0.95), (0, 0.3)];
let config = FusionConfig {
k: 10_000,
weights: vec![1.0, 1.0],
};
let fused = rrf_fusion(&[list1, list2], &config);
assert_eq!(fused.len(), 2);
let score_diff = (fused[0].1 - fused[1].1).abs();
assert!(
score_diff < 1e-6,
"Large k should minimize rank difference effects"
);
}
#[test]
fn test_rrf_k_zero() {
let list1 = vec![(0, 0.9), (1, 0.8)];
let list2 = vec![(1, 0.95), (0, 0.3)];
let config = FusionConfig {
k: 0,
weights: vec![1.0, 1.0],
};
let fused = rrf_fusion(&[list1, list2], &config);
assert!(!fused.is_empty(), "k=0 should still produce results");
}
#[test]
fn test_rrf_zero_weight() {
let list1 = vec![(0, 0.9), (1, 0.8)];
let list2 = vec![(2, 0.95), (3, 0.85)];
let config = FusionConfig {
k: 60,
weights: vec![0.0, 1.0], };
let fused = rrf_fusion(&[list1, list2], &config);
let doc0_score = fused.iter().find(|d| d.0 == 0).unwrap().1;
let doc2_score = fused.iter().find(|d| d.0 == 2).unwrap().1;
assert_eq!(doc0_score, 0.0, "Zero-weight list should contribute 0");
assert!(
doc2_score > 0.0,
"Non-zero weight list should contribute positive"
);
}
#[test]
fn test_rrf_missing_weight_defaults_to_one() {
let list1 = vec![(0, 0.9)];
let list2 = vec![(1, 0.95)];
let config = FusionConfig {
k: 60,
weights: vec![1.0], };
let fused = rrf_fusion(&[list1, list2], &config);
assert_eq!(fused.len(), 2);
let doc1_score = fused.iter().find(|d| d.0 == 1).unwrap().1;
let expected = 1.0 / 61.0_f32;
assert!(
(doc1_score - expected).abs() < 1e-6,
"Missing weight should default to 1.0"
);
}
#[test]
fn test_rrf_duplicate_doc_ids_in_same_list() {
let list1 = vec![(0, 0.9), (0, 0.8)];
let config = FusionConfig {
k: 60,
weights: vec![1.0, 1.0],
};
let list2 = vec![(1, 0.5)];
let fused = rrf_fusion(&[list1, list2], &config);
let doc0 = fused.iter().find(|d| d.0 == 0).unwrap();
let expected = 1.0 / 61.0 + 1.0 / 62.0; assert!((doc0.1 - expected).abs() < 1e-6);
}
#[test]
fn test_rrf_output_sorted_descending() {
let list1 = vec![(0, 0.9), (1, 0.8), (2, 0.7)];
let list2 = vec![(3, 0.95), (4, 0.85), (5, 0.75)];
let config = FusionConfig {
k: 60,
weights: vec![1.0, 1.0],
};
let fused = rrf_fusion(&[list1, list2], &config);
for w in fused.windows(2) {
assert!(w[0].1 >= w[1].1, "Output should be sorted descending");
}
}
}