use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, thiserror::Error)]
pub enum RrfError {
#[error("Config error: {0}")]
Config(String),
}
pub fn adaptive_k(candidate_count: usize, configured_k: u32) -> u32 {
let auto_k = (candidate_count / 10).max(1) as u32;
auto_k.min(configured_k) }
pub fn reciprocal_rank_fusion(
rankings: &[Vec<(Uuid, f32)>],
weights: &[f32],
k: u32,
) -> Result<Vec<(Uuid, f32)>, RrfError> {
if rankings.len() != weights.len() {
return Err(RrfError::Config(format!(
"rankings ({}) and weights ({}) must be same length",
rankings.len(),
weights.len()
)));
}
if rankings.is_empty() {
return Ok(Vec::new());
}
let k_f = f64::from(k);
let mut scores: HashMap<Uuid, f64> = HashMap::new();
for (ranking, &weight) in rankings.iter().zip(weights.iter()) {
let w = f64::from(weight);
for (rank_0, (id, _original_score)) in ranking.iter().enumerate() {
let rank_1 = (rank_0 + 1) as f64; let contribution = w / (k_f + rank_1);
*scores.entry(*id).or_insert(0.0) += contribution;
}
}
let mut result: Vec<(Uuid, f32)> = scores
.into_iter()
.map(|(id, score)| (id, score as f32))
.collect();
result.sort_by(|a, b| b.1.total_cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
fn id(n: u8) -> Uuid {
Uuid::from_bytes([n; 16])
}
#[test]
fn test_single_ranking() {
let ranking = vec![(id(1), 0.9_f32), (id(2), 0.5), (id(3), 0.1)];
let result = reciprocal_rank_fusion(&[ranking], &[1.0], 60).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].0, id(1));
assert_eq!(result[1].0, id(2));
assert_eq!(result[2].0, id(3));
}
#[test]
fn test_consensus_wins() {
let list_a = vec![(id(1), 1.0_f32), (id(2), 0.8), (id(3), 0.5)];
let list_b = vec![(id(4), 1.0_f32), (id(2), 0.9), (id(5), 0.3)];
let result = reciprocal_rank_fusion(&[list_a, list_b], &[1.0, 1.0], 60).unwrap();
let pos_id2 = result.iter().position(|(uid, _)| *uid == id(2)).unwrap();
let pos_id1 = result.iter().position(|(uid, _)| *uid == id(1)).unwrap();
assert!(
pos_id2 < pos_id1,
"consensus item (id2) should rank above id1"
);
}
#[test]
fn test_weighted_rrf() {
let list_a = vec![(id(1), 1.0_f32), (id(3), 0.5)];
let list_b = vec![(id(2), 1.0_f32), (id(3), 0.5)];
let result = reciprocal_rank_fusion(&[list_a, list_b], &[2.0, 1.0], 60).unwrap();
let pos_id1 = result.iter().position(|(uid, _)| *uid == id(1)).unwrap();
let pos_id2 = result.iter().position(|(uid, _)| *uid == id(2)).unwrap();
assert!(
pos_id1 < pos_id2,
"weighted list_a's #1 (id1) should beat list_b's #1 (id2)"
);
}
#[test]
fn test_empty_rankings() {
let result = reciprocal_rank_fusion(&[], &[], 60).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_mismatched_lengths() {
let result = reciprocal_rank_fusion(&[vec![]], &[], 60);
assert!(result.is_err());
}
#[test]
fn test_k_parameter_sensitivity() {
let ranking = vec![(id(1), 1.0_f32), (id(2), 0.5)];
let result_low_k = reciprocal_rank_fusion(&[ranking.clone()], &[1.0], 1).unwrap();
let result_high_k = reciprocal_rank_fusion(&[ranking], &[1.0], 1000).unwrap();
let spread_low = result_low_k[0].1 - result_low_k[1].1;
let spread_high = result_high_k[0].1 - result_high_k[1].1;
assert!(
spread_low > spread_high,
"lower k should produce wider score spread: low_k spread={spread_low}, high_k spread={spread_high}"
);
assert_eq!(result_low_k[0].0, id(1), "id1 should be first under low k");
assert_eq!(
result_high_k[0].0,
id(1),
"id1 should be first under high k"
);
}
#[test]
fn test_adaptive_k_small_corpus() {
assert_eq!(adaptive_k(50, 60), 5);
}
#[test]
fn test_adaptive_k_large_corpus() {
assert_eq!(adaptive_k(1000, 60), 60);
}
#[test]
fn test_adaptive_k_tiny_corpus() {
assert_eq!(adaptive_k(5, 60), 1);
}
#[test]
fn test_adaptive_k_preserves_discrimination() {
let ranking: Vec<(Uuid, f32)> = (0..50)
.map(|i| (Uuid::from_bytes([i as u8; 16]), 1.0 - i as f32 / 50.0))
.collect();
let k = adaptive_k(50, 60); let result = reciprocal_rank_fusion(&[ranking], &[1.0], k).unwrap();
let top_score = result[0].1;
let bottom_score = result.last().unwrap().1;
let ratio = top_score / bottom_score;
assert!(
ratio > 5.0,
"Adaptive k should give strong discrimination, ratio={ratio}"
);
}
}