use std::collections::HashMap;
pub fn rrf_fuse(lists: &[(f64, &Vec<i64>)], rrf_k: f64) -> HashMap<i64, f64> {
let mut combined: HashMap<i64, f64> = HashMap::new();
for (weight, ids) in lists {
for (rank, &id) in ids.iter().enumerate() {
let contribution = weight * (1.0 / (rrf_k + rank as f64 + 1.0));
*combined.entry(id).or_insert(0.0) += contribution;
}
}
combined
}
pub fn rrf_max_possible(weights: &[f64], rrf_k: f64) -> f64 {
weights.iter().map(|w| w * (1.0 / (rrf_k + 1.0))).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rrf_fuse_single_list_rank_order_preserved() {
let list = vec![10i64, 20, 30];
let scores = rrf_fuse(&[(1.0, &list)], 60.0);
assert!(scores[&10] > scores[&20]);
assert!(scores[&20] > scores[&30]);
}
#[test]
fn rrf_fuse_two_lists_overlap_accumulates() {
let knn = vec![1i64, 2];
let fts = vec![1i64, 3];
let scores = rrf_fuse(&[(1.0, &knn), (1.0, &fts)], 60.0);
assert!(scores[&1] > scores[&2], "overlap item must score higher");
assert!(scores[&1] > scores[&3], "overlap item must score higher");
}
#[test]
fn rrf_fuse_empty_lists_returns_empty() {
let empty: Vec<i64> = vec![];
let scores = rrf_fuse(&[(1.0, &empty)], 60.0);
assert!(scores.is_empty());
}
#[test]
fn rrf_fuse_zero_weight_list_has_no_effect() {
let list_a = vec![1i64, 2];
let list_b = vec![3i64, 4];
let scores_with = rrf_fuse(&[(1.0, &list_a), (0.0, &list_b)], 60.0);
assert_eq!(scores_with.get(&3).copied().unwrap_or(0.0), 0.0);
assert_eq!(scores_with.get(&4).copied().unwrap_or(0.0), 0.0);
}
#[test]
fn rrf_fuse_weights_scale_contribution() {
let list = vec![1i64];
let low = rrf_fuse(&[(0.5, &list)], 60.0);
let high = rrf_fuse(&[(2.0, &list)], 60.0);
assert!(high[&1] > low[&1]);
}
#[test]
fn rrf_max_possible_sums_weights() {
let max = rrf_max_possible(&[1.0], 60.0);
let expected = 1.0 / 61.0;
assert!((max - expected).abs() < 1e-9);
let max2 = rrf_max_possible(&[1.0, 1.0], 60.0);
assert!((max2 - 2.0 / 61.0).abs() < 1e-9);
}
#[test]
fn rrf_fuse_deterministic_for_same_input() {
let list_a = vec![1i64, 2, 3];
let list_b = vec![2i64, 1, 4];
let scores_1 = rrf_fuse(&[(1.0, &list_a), (1.0, &list_b)], 60.0);
let scores_2 = rrf_fuse(&[(1.0, &list_a), (1.0, &list_b)], 60.0);
for id in [1i64, 2, 3, 4] {
assert_eq!(
scores_1.get(&id).copied().unwrap_or(0.0),
scores_2.get(&id).copied().unwrap_or(0.0)
);
}
}
}