use crate::core::{DocId, SegmentId};
use crate::query::ast::FusionMethod;
use crate::search::results::HitRef;
use crate::search::searcher::ScoringResults;
use std::collections::HashMap;
pub const DEFAULT_RANK_CONSTANT: f32 = 60.0;
pub(crate) fn reciprocal_rank_fusion(
result_lists: &[&ScoringResults],
rank_constant: f32,
weights: Option<&[f32]>,
top_k: usize,
) -> ScoringResults {
let n = result_lists.len();
let default_weights: Vec<f32> = vec![1.0; n];
let w = weights.unwrap_or(&default_weights);
let mut scores: HashMap<(u64, u32), f32> = HashMap::new();
for (list_idx, results) in result_lists.iter().enumerate() {
let weight = w.get(list_idx).copied().unwrap_or(1.0);
for (rank, hit) in results.hits.iter().enumerate() {
let key = (hit.segment_id.as_u64(), hit.doc_id.as_u32());
let rrf_contribution = weight / (rank_constant + (rank + 1) as f32);
*scores.entry(key).or_insert(0.0) += rrf_contribution;
}
}
let total_unique = scores.len() as u64;
let mut ranked: Vec<_> = scores.into_iter().collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
ranked.truncate(top_k);
let hits: Vec<HitRef> = ranked
.into_iter()
.map(|((seg_id, doc_id), score)| HitRef {
doc_id: DocId::new(doc_id),
segment_id: SegmentId::new(seg_id),
score,
sort_values: None,
collapse_key: None,
})
.collect();
ScoringResults {
hits,
total_hits: crate::search::TotalHits::exact(total_unique),
aggregations: HashMap::new(),
}
}
pub(crate) fn score_fusion(
result_lists: &[&ScoringResults],
method: &FusionMethod,
weights: Option<&[f32]>,
top_k: usize,
) -> ScoringResults {
let n = result_lists.len();
let default_weights: Vec<f32> = vec![1.0; n];
let w = weights.unwrap_or(&default_weights);
let mut doc_scores: HashMap<(u64, u32), Vec<(usize, f32)>> = HashMap::new();
for (list_idx, results) in result_lists.iter().enumerate() {
for hit in &results.hits {
let key = (hit.segment_id.as_u64(), hit.doc_id.as_u32());
doc_scores
.entry(key)
.or_default()
.push((list_idx, hit.score));
}
}
let total_unique = doc_scores.len() as u64;
let mut ranked: Vec<((u64, u32), f32)> = doc_scores
.into_iter()
.map(|(key, scores)| {
let fused = match method {
FusionMethod::Sum => scores.iter().map(|&(idx, s)| weight(w, idx) * s).sum(),
FusionMethod::ArithmeticMean => {
let present_weight: f32 = scores.iter().map(|&(idx, _)| weight(w, idx)).sum();
let weighted_sum: f32 = scores.iter().map(|&(idx, s)| weight(w, idx) * s).sum();
if present_weight > 0.0 {
weighted_sum / present_weight
} else {
0.0
}
}
FusionMethod::HarmonicMean => {
let present_weight: f32 = scores.iter().map(|&(idx, _)| weight(w, idx)).sum();
if scores.iter().any(|&(_, s)| s <= 0.0) {
0.0
} else {
let weighted_recip: f32 =
scores.iter().map(|&(idx, s)| weight(w, idx) / s).sum();
if weighted_recip > 0.0 {
present_weight / weighted_recip
} else {
0.0
}
}
}
FusionMethod::GeometricMean => {
if scores.iter().any(|&(_, s)| s <= 0.0) {
0.0
} else {
let present_weight: f32 =
scores.iter().map(|&(idx, _)| weight(w, idx)).sum();
let log_sum: f32 =
scores.iter().map(|&(idx, s)| weight(w, idx) * s.ln()).sum();
if present_weight > 0.0 {
(log_sum / present_weight).exp()
} else {
0.0
}
}
}
FusionMethod::ReciprocalRank => {
unreachable!("RRF handled by reciprocal_rank_fusion")
}
};
(key, fused)
})
.collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
ranked.truncate(top_k);
let hits: Vec<HitRef> = ranked
.into_iter()
.map(|((seg_id, doc_id), score)| HitRef {
doc_id: DocId::new(doc_id),
segment_id: SegmentId::new(seg_id),
score,
sort_values: None,
collapse_key: None,
})
.collect();
ScoringResults {
hits,
total_hits: crate::search::TotalHits::exact(total_unique),
aggregations: HashMap::new(),
}
}
fn weight(weights: &[f32], idx: usize) -> f32 {
weights.get(idx).copied().unwrap_or(1.0)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_results(docs: &[(u32, f32)]) -> ScoringResults {
let hits = docs
.iter()
.map(|&(id, score)| HitRef {
doc_id: DocId::new(id),
segment_id: SegmentId::new(1),
score,
sort_values: None,
collapse_key: None,
})
.collect();
ScoringResults {
hits,
total_hits: crate::search::TotalHits::exact(docs.len() as u64),
aggregations: HashMap::new(),
}
}
#[test]
fn rrf_single_list() {
let r1 = make_results(&[(0, 3.0), (1, 2.0), (2, 1.0)]);
let fused = reciprocal_rank_fusion(&[&r1], 60.0, None, 10);
assert_eq!(fused.hits.len(), 3);
assert!(fused.hits[0].doc_id == DocId::new(0));
}
#[test]
fn rrf_two_lists_overlap() {
let r1 = make_results(&[(0, 3.0), (1, 2.0), (2, 1.0)]);
let r2 = make_results(&[(1, 3.0), (2, 2.0), (0, 1.0)]);
let fused = reciprocal_rank_fusion(&[&r1, &r2], 60.0, None, 10);
assert_eq!(fused.hits.len(), 3);
assert_eq!(fused.hits[0].doc_id, DocId::new(1));
}
#[test]
fn rrf_disjoint_lists() {
let r1 = make_results(&[(0, 1.0), (1, 0.5)]);
let r2 = make_results(&[(2, 1.0), (3, 0.5)]);
let fused = reciprocal_rank_fusion(&[&r1, &r2], 60.0, None, 10);
assert_eq!(fused.hits.len(), 4);
let top_ids: Vec<u32> = fused.hits[..2].iter().map(|h| h.doc_id.as_u32()).collect();
assert!(top_ids.contains(&0) && top_ids.contains(&2));
}
#[test]
fn rrf_top_k() {
let r1 = make_results(&[(0, 1.0), (1, 0.5), (2, 0.3)]);
let fused = reciprocal_rank_fusion(&[&r1], 60.0, None, 2);
assert_eq!(fused.hits.len(), 2);
}
#[test]
fn rrf_empty() {
let r1 = make_results(&[]);
let fused = reciprocal_rank_fusion(&[&r1], 60.0, None, 10);
assert!(fused.hits.is_empty());
}
#[test]
fn rrf_doc_id_preserved() {
let r1 = make_results(&[(42, 1.0)]);
let fused = reciprocal_rank_fusion(&[&r1], 60.0, None, 10);
assert_eq!(fused.hits[0].doc_id, DocId::new(42));
}
#[test]
fn rrf_weighted() {
let r1 = make_results(&[(0, 1.0), (1, 0.5)]);
let r2 = make_results(&[(1, 1.0), (0, 0.5)]);
let weights = vec![3.0, 1.0];
let fused = reciprocal_rank_fusion(&[&r1, &r2], 60.0, Some(&weights), 10);
assert_eq!(fused.hits[0].doc_id, DocId::new(0));
}
#[test]
fn rrf_total_hits_counts_unique() {
let r1 = make_results(&[(0, 1.0), (1, 0.5)]);
let r2 = make_results(&[(1, 1.0), (2, 0.5)]); let fused = reciprocal_rank_fusion(&[&r1, &r2], 60.0, None, 1);
assert_eq!(fused.total_hits.value, 3);
}
#[test]
fn score_fusion_sum() {
let r1 = make_results(&[(0, 2.0), (1, 1.0)]);
let r2 = make_results(&[(0, 3.0), (2, 1.0)]);
let fused = score_fusion(&[&r1, &r2], &FusionMethod::Sum, None, 10);
assert_eq!(fused.hits[0].doc_id, DocId::new(0));
assert!((fused.hits[0].score - 5.0).abs() < 0.01);
}
#[test]
fn score_fusion_arithmetic_mean_present_only() {
let r1 = make_results(&[(0, 4.0)]); let r2 = make_results(&[(1, 2.0)]); let fused = score_fusion(&[&r1, &r2], &FusionMethod::ArithmeticMean, None, 10);
assert_eq!(fused.hits[0].doc_id, DocId::new(0));
assert!((fused.hits[0].score - 4.0).abs() < 0.01);
}
#[test]
fn score_fusion_harmonic_mean_zero_score() {
let r1 = make_results(&[(0, 0.0), (1, 4.0)]);
let r2 = make_results(&[(0, 5.0), (1, 4.0)]);
let fused = score_fusion(&[&r1, &r2], &FusionMethod::HarmonicMean, None, 10);
assert_eq!(fused.hits[0].doc_id, DocId::new(1));
assert!((fused.hits[0].score - 4.0).abs() < 0.01);
assert!((fused.hits[1].score - 0.0).abs() < 0.01);
}
#[test]
fn score_fusion_weight_bounds_safe() {
let r1 = make_results(&[(0, 1.0)]);
let r2 = make_results(&[(0, 2.0)]);
let weights = vec![0.5]; let fused = score_fusion(&[&r1, &r2], &FusionMethod::Sum, Some(&weights), 10);
assert!((fused.hits[0].score - 2.5).abs() < 0.01);
}
}