use crate::types::VectorId;
use std::collections::HashMap;
pub const DEFAULT_RRF_K: f32 = 60.0;
#[derive(Debug, Clone)]
pub struct RankedResult {
pub id: VectorId,
pub rank: usize,
pub original_score: f32,
}
pub fn reciprocal_rank_fusion(
ranked_lists: Vec<Vec<RankedResult>>,
k: f32,
) -> Vec<(VectorId, f32)> {
let mut rrf_scores: HashMap<VectorId, f32> = HashMap::new();
for list in ranked_lists {
for result in list {
let rrf_contribution = 1.0 / (k + result.rank as f32 + 1.0); *rrf_scores.entry(result.id).or_insert(0.0) += rrf_contribution;
}
}
let mut results: Vec<_> = rrf_scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub fn weighted_reciprocal_rank_fusion(
ranked_lists: Vec<(Vec<RankedResult>, f32)>,
k: f32,
) -> Vec<(VectorId, f32)> {
let mut rrf_scores: HashMap<VectorId, f32> = HashMap::new();
for (list, weight) in ranked_lists {
for result in list {
let rrf_contribution = weight * (1.0 / (k + result.rank as f32 + 1.0));
*rrf_scores.entry(result.id).or_insert(0.0) += rrf_contribution;
}
}
let mut results: Vec<_> = rrf_scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrf_basic() {
let list1 = vec![
RankedResult {
id: "a".into(),
rank: 0,
original_score: 1.0,
},
RankedResult {
id: "b".into(),
rank: 1,
original_score: 0.9,
},
RankedResult {
id: "c".into(),
rank: 2,
original_score: 0.8,
},
];
let list2 = vec![
RankedResult {
id: "b".into(),
rank: 0,
original_score: 5.0,
},
RankedResult {
id: "a".into(),
rank: 1,
original_score: 4.0,
},
RankedResult {
id: "d".into(),
rank: 2,
original_score: 3.0,
},
];
let results = reciprocal_rank_fusion(vec![list1, list2], 60.0);
assert!(results.len() >= 3);
let a_pos = results.iter().position(|(id, _)| id == "a").unwrap();
let b_pos = results.iter().position(|(id, _)| id == "b").unwrap();
let d_pos = results.iter().position(|(id, _)| id == "d").unwrap();
assert!(a_pos < d_pos);
assert!(b_pos < d_pos);
}
#[test]
fn test_rrf_single_list() {
let list = vec![
RankedResult {
id: "a".into(),
rank: 0,
original_score: 1.0,
},
RankedResult {
id: "b".into(),
rank: 1,
original_score: 0.5,
},
];
let results = reciprocal_rank_fusion(vec![list], 60.0);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "a");
assert_eq!(results[1].0, "b");
assert!(results[0].1 > results[1].1);
}
#[test]
fn test_weighted_rrf() {
let list1 = vec![RankedResult {
id: "a".into(),
rank: 0,
original_score: 1.0,
}];
let list2 = vec![RankedResult {
id: "b".into(),
rank: 0,
original_score: 1.0,
}];
let results = weighted_reciprocal_rank_fusion(vec![(list1, 0.1), (list2, 0.9)], 60.0);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "b");
assert!(results[0].1 > results[1].1);
}
#[test]
fn test_empty_lists() {
let results = reciprocal_rank_fusion(vec![], 60.0);
assert!(results.is_empty());
let results = reciprocal_rank_fusion(vec![vec![]], 60.0);
assert!(results.is_empty());
}
}