use std::collections::HashMap;
use std::hash::Hash;
#[derive(Debug, Clone, PartialEq)]
pub struct Candidate<Id> {
pub id: Id,
pub score: f64,
}
#[derive(Debug, Clone)]
pub struct Bucket<Id> {
pub candidates: Vec<Candidate<Id>>,
pub min_score: Option<f64>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FusedItem<Id> {
pub id: Id,
pub rrf_score: f64,
}
pub const RRF_K_DEFAULT: u32 = 60;
pub fn fuse<Id>(buckets: &[Bucket<Id>], k: u32, total_k: usize) -> Vec<FusedItem<Id>>
where
Id: Clone + Eq + Hash + Ord,
{
if total_k == 0 {
return Vec::new();
}
let k_f = f64::from(k);
let mut scores: HashMap<Id, f64> = HashMap::new();
for bucket in buckets {
let mut rank: u32 = 0;
for cand in &bucket.candidates {
if let Some(floor) = bucket.min_score {
if cand.score < floor {
continue;
}
}
rank += 1;
let contribution = 1.0 / (k_f + f64::from(rank));
scores
.entry(cand.id.clone())
.and_modify(|s| *s += contribution)
.or_insert(contribution);
}
}
let mut fused: Vec<FusedItem<Id>> = scores
.into_iter()
.map(|(id, rrf_score)| FusedItem { id, rrf_score })
.collect();
fused.sort_by(|a, b| {
b.rrf_score
.partial_cmp(&a.rrf_score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
fused.truncate(total_k);
fused
}
#[cfg(test)]
mod tests {
use super::*;
fn cand<Id>(id: Id, score: f64) -> Candidate<Id> {
Candidate { id, score }
}
fn bucket_no_floor<Id: Clone>(cs: Vec<Candidate<Id>>) -> Bucket<Id> {
Bucket {
candidates: cs,
min_score: None,
}
}
#[test]
fn rrf_single_list_matches_reference() {
let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5), cand("c", 0.1)]);
let out = fuse(&[bucket], 60, 10);
assert_eq!(out.len(), 3);
assert!((out[0].rrf_score - 1.0 / 61.0).abs() < 1e-12);
assert!((out[1].rrf_score - 1.0 / 62.0).abs() < 1e-12);
assert!((out[2].rrf_score - 1.0 / 63.0).abs() < 1e-12);
assert_eq!(out[0].id, "a");
assert_eq!(out[1].id, "b");
assert_eq!(out[2].id, "c");
}
#[test]
fn rrf_two_lists_sums_contributions() {
let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.9), cand("c", 0.8)]);
let b2 = bucket_no_floor(vec![cand("a", 0.95), cand("b", 0.85), cand("d", 0.7)]);
let out = fuse(&[b1, b2], 60, 10);
let by_id: std::collections::HashMap<_, _> =
out.iter().map(|f| (f.id, f.rrf_score)).collect();
assert!((by_id["a"] - 2.0 / 61.0).abs() < 1e-12);
assert!((by_id["b"] - 2.0 / 62.0).abs() < 1e-12);
assert!((by_id["c"] - 1.0 / 63.0).abs() < 1e-12);
assert!((by_id["d"] - 1.0 / 63.0).abs() < 1e-12);
assert_eq!(out[0].id, "a");
assert_eq!(out[1].id, "b");
assert_eq!(out[2].id, "c");
assert_eq!(out[3].id, "d");
}
#[test]
fn rrf_k_default_is_60() {
assert_eq!(RRF_K_DEFAULT, 60);
}
#[test]
fn alternate_k_changes_scores() {
let bucket = bucket_no_floor(vec![cand("a", 1.0)]);
let out = fuse(&[bucket], 1, 10);
assert!((out[0].rrf_score - 0.5).abs() < 1e-12);
}
#[test]
fn total_k_caps_output() {
let bucket = bucket_no_floor(vec![
cand("a", 1.0),
cand("b", 0.9),
cand("c", 0.8),
cand("d", 0.7),
]);
let out = fuse(&[bucket], 60, 2);
assert_eq!(out.len(), 2);
assert_eq!(out[0].id, "a");
assert_eq!(out[1].id, "b");
}
#[test]
fn total_k_zero_returns_empty() {
let bucket = bucket_no_floor(vec![cand("a", 1.0)]);
let out = fuse(&[bucket], 60, 0);
assert!(out.is_empty());
}
#[test]
fn total_k_larger_than_candidates_returns_all() {
let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
let out = fuse(&[bucket], 60, 100);
assert_eq!(out.len(), 2);
}
#[test]
fn min_score_drops_items_before_ranking() {
let bucket = Bucket {
candidates: vec![cand("a", 0.9), cand("b", 0.4), cand("c", 0.6)],
min_score: Some(0.5),
};
let out = fuse(&[bucket], 60, 10);
assert_eq!(out.len(), 2);
assert_eq!(out[0].id, "a");
assert!((out[0].rrf_score - 1.0 / 61.0).abs() < 1e-12);
assert_eq!(out[1].id, "c");
assert!((out[1].rrf_score - 1.0 / 62.0).abs() < 1e-12);
}
#[test]
fn min_score_independent_per_bucket() {
let bm25 = Bucket {
candidates: vec![cand("x", 0.5), cand("y", 0.3)],
min_score: Some(0.4),
};
let vec_bucket = Bucket {
candidates: vec![cand("x", 0.85), cand("y", 0.6)],
min_score: Some(0.7),
};
let out = fuse(&[bm25, vec_bucket], 60, 10);
assert_eq!(out.len(), 1);
assert_eq!(out[0].id, "x");
assert!((out[0].rrf_score - 2.0 / 61.0).abs() < 1e-12);
}
#[test]
fn min_score_none_keeps_everything() {
let bucket = bucket_no_floor(vec![cand("a", -10.0), cand("b", 0.0)]);
let out = fuse(&[bucket], 60, 10);
assert_eq!(out.len(), 2);
}
#[test]
fn tie_break_is_id_ascending() {
let b1 = bucket_no_floor(vec![cand("zebra", 1.0)]);
let b2 = bucket_no_floor(vec![cand("apple", 1.0)]);
let b3 = bucket_no_floor(vec![cand("mango", 1.0)]);
let out = fuse(&[b1, b2, b3], 60, 10);
assert_eq!(
out.iter().map(|f| f.id).collect::<Vec<_>>(),
vec!["apple", "mango", "zebra"]
);
}
#[test]
fn fuse_is_deterministic_across_calls() {
let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
let b2 = bucket_no_floor(vec![cand("b", 0.9), cand("c", 0.4)]);
let a = fuse(&[b1.clone(), b2.clone()], 60, 10);
let c = fuse(&[b1, b2], 60, 10);
assert_eq!(a, c);
}
#[test]
fn fuse_is_order_independent_across_buckets() {
let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
let b2 = bucket_no_floor(vec![cand("b", 0.9), cand("c", 0.4)]);
let forward = fuse(&[b1.clone(), b2.clone()], 60, 10);
let reverse = fuse(&[b2, b1], 60, 10);
assert_eq!(forward, reverse);
}
#[test]
fn empty_buckets_returns_empty() {
let buckets: Vec<Bucket<&'static str>> = vec![];
let out = fuse(&buckets, 60, 10);
assert!(out.is_empty());
}
#[test]
fn all_empty_buckets_returns_empty() {
let buckets: Vec<Bucket<&'static str>> =
vec![bucket_no_floor(vec![]), bucket_no_floor(vec![])];
let out = fuse(&buckets, 60, 10);
assert!(out.is_empty());
}
#[test]
fn duplicate_id_within_one_bucket_keeps_both_ranks() {
let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("a", 0.5)]);
let out = fuse(&[bucket], 60, 10);
assert_eq!(out.len(), 1);
assert!((out[0].rrf_score - (1.0 / 61.0 + 1.0 / 62.0)).abs() < 1e-12);
}
#[test]
fn integer_ids_supported() {
let b1 = bucket_no_floor(vec![cand(1u64, 1.0), cand(2u64, 0.5)]);
let b2 = bucket_no_floor(vec![cand(2u64, 0.9), cand(3u64, 0.4)]);
let out = fuse(&[b1, b2], 60, 10);
assert_eq!(out[0].id, 2);
}
}