use std::collections::HashMap;
use crate::search::types::SearchResult;
pub fn rrf_fuse(result_sets: &[Vec<SearchResult>], k: u32) -> Vec<SearchResult> {
let mut scores: HashMap<String, f32> = HashMap::new();
let mut texts: HashMap<String, String> = HashMap::new();
for results in result_sets {
for (rank, result) in results.iter().enumerate() {
let entry = scores.entry(result.id.clone()).or_insert(0.0);
*entry += 1.0 / (k as f32 + rank as f32 + 1.0);
texts
.entry(result.id.clone())
.or_insert_with(|| result.text.clone());
}
}
let mut fused: Vec<SearchResult> = scores
.into_iter()
.map(|(id, score)| SearchResult {
text: texts.remove(&id).unwrap_or_default(),
id,
score,
})
.collect();
fused.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
fused
}
#[cfg(test)]
mod tests {
use super::*;
fn make_results(ids: &[(&str, f32)]) -> Vec<SearchResult> {
ids.iter()
.map(|(id, score)| SearchResult {
id: id.to_string(),
score: *score,
text: format!("text for {id}"),
})
.collect()
}
#[test]
fn empty_inputs() {
let result = rrf_fuse(&[], 60);
assert!(result.is_empty());
}
#[test]
fn single_list() {
let list = make_results(&[("a", 0.9), ("b", 0.5)]);
let fused = rrf_fuse(&[list], 60);
assert_eq!(fused.len(), 2);
assert_eq!(fused[0].id, "a"); }
#[test]
fn two_lists_overlap() {
let bm25 = make_results(&[("a", 0.9), ("b", 0.7)]);
let vector = make_results(&[("b", 0.95), ("c", 0.6)]);
let fused = rrf_fuse(&[bm25, vector], 60);
assert_eq!(fused.len(), 3);
assert_eq!(fused[0].id, "b");
}
#[test]
fn rrf_score_formula() {
let list = make_results(&[("x", 1.0)]);
let fused = rrf_fuse(&[list], 60);
let expected = 1.0 / 61.0;
assert!((fused[0].score - expected).abs() < 1e-6);
}
#[test]
fn three_lists() {
let a = make_results(&[("doc1", 1.0), ("doc2", 0.8)]);
let b = make_results(&[("doc2", 0.9), ("doc3", 0.7)]);
let c = make_results(&[("doc1", 0.8), ("doc3", 0.9)]);
let fused = rrf_fuse(&[a, b, c], 60);
assert_eq!(fused.len(), 3);
let doc3_score = fused.iter().find(|r| r.id == "doc3").unwrap().score;
let doc1_score = fused.iter().find(|r| r.id == "doc1").unwrap().score;
assert!(doc1_score > doc3_score);
}
}