use std::collections::HashMap;
use crate::core::types::EntityId;
use crate::search::SearchHit;
const K_RRF: f64 = 60.0;
const VECTOR_WEIGHT: f64 = 0.7;
const BM25_WEIGHT: f64 = 0.3;
pub fn rrf_fuse(
vector_results: Option<&[SearchHit]>,
bm25_results: Option<&[SearchHit]>,
top_k: usize,
) -> Vec<SearchHit> {
if top_k == 0 {
return Vec::new();
}
let mut scores: HashMap<EntityId, f64> = HashMap::new();
if let Some(vr) = vector_results {
for (rank, hit) in vr.iter().enumerate() {
*scores.entry(hit.entity_id).or_default() +=
VECTOR_WEIGHT / (K_RRF + rank as f64 + 1.0);
}
}
if let Some(br) = bm25_results {
for (rank, hit) in br.iter().enumerate() {
*scores.entry(hit.entity_id).or_default() += BM25_WEIGHT / (K_RRF + rank as f64 + 1.0);
}
}
let mut ranked: Vec<(EntityId, f64)> = scores.into_iter().collect();
if top_k < ranked.len() {
ranked.select_nth_unstable_by(top_k, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
ranked.truncate(top_k);
}
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
ranked
.into_iter()
.map(|(entity_id, score)| SearchHit {
entity_id,
score: score as f32,
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn hit(id: u64, score: f32) -> SearchHit {
SearchHit {
entity_id: EntityId(id),
score,
}
}
#[test]
fn test_rrf_both_legs_ranks_higher() {
let vec_results = vec![hit(1, 0.95), hit(2, 0.80)];
let bm25_results = vec![hit(1, 5.0), hit(3, 3.0)];
let results = rrf_fuse(Some(&vec_results), Some(&bm25_results), 10);
assert_eq!(results[0].entity_id, EntityId(1));
assert!(results[0].score > results[1].score);
}
#[test]
fn test_rrf_vector_only() {
let vec_results = vec![hit(1, 0.95), hit(2, 0.80)];
let results = rrf_fuse(Some(&vec_results), None, 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity_id, EntityId(1));
}
#[test]
fn test_rrf_bm25_only() {
let bm25_results = vec![hit(3, 5.0), hit(4, 3.0)];
let results = rrf_fuse(None, Some(&bm25_results), 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].entity_id, EntityId(3));
}
#[test]
fn test_rrf_no_legs() {
let results = rrf_fuse(None, None, 10);
assert!(results.is_empty());
}
#[test]
fn test_rrf_top_k_limit() {
let vec_results: Vec<SearchHit> = (1..=20).map(|i| hit(i, 1.0 - i as f32 * 0.01)).collect();
let results = rrf_fuse(Some(&vec_results), None, 5);
assert_eq!(results.len(), 5);
}
#[test]
fn test_rrf_scores_descending() {
let vec_results = vec![hit(1, 0.9), hit(2, 0.8), hit(3, 0.7)];
let bm25_results = vec![hit(4, 5.0), hit(1, 4.0), hit(5, 3.0)];
let results = rrf_fuse(Some(&vec_results), Some(&bm25_results), 10);
for w in results.windows(2) {
assert!(w[0].score >= w[1].score);
}
}
}