use crate::search::executor_types::RankedResults;
use crate::search::fusion::{FusedResult, FusionWeights, ScoreFusion};
use std::collections::HashMap;
use tracing::{debug, instrument};
pub struct RRFFusion {
k: f32,
}
impl RRFFusion {
pub fn new(k: f32) -> Self {
Self { k }
}
#[inline]
fn rrf_score(&self, rank: usize) -> f32 {
1.0 / (self.k + rank as f32 + 1.0)
}
}
impl Default for RRFFusion {
fn default() -> Self {
Self::new(60.0)
}
}
impl ScoreFusion for RRFFusion {
#[instrument(skip(self, results), fields(
k = self.k,
num_result_sets = results.len(),
limit = limit
))]
fn fuse(
&self,
results: Vec<RankedResults>,
_weights: &FusionWeights, limit: usize,
) -> Vec<FusedResult> {
let mut chunk_scores: HashMap<i64, f32> = HashMap::new();
let num_result_sets = results.len();
let mut chunk_source_scores: HashMap<
i64,
HashMap<crate::search::executor_types::SearchSource, f32>,
> = HashMap::new();
let mut chunk_exact_match: HashMap<i64, f32> = HashMap::new();
for result_set in results {
let source = result_set.source;
for (rank, result) in result_set.results.iter().enumerate() {
let rrf_score = self.rrf_score(rank);
*chunk_scores.entry(result.chunk_id).or_insert(0.0) += rrf_score;
chunk_source_scores
.entry(result.chunk_id)
.or_default()
.insert(source, result.score);
if let Some(mult) = result.exact_match_multiplier {
chunk_exact_match.insert(result.chunk_id, mult);
}
}
}
debug!(
"RRF fusing {} unique chunks from {} result sets with k={}",
chunk_scores.len(),
num_result_sets,
self.k
);
let mut fused_results: Vec<FusedResult> = chunk_scores
.into_iter()
.map(|(chunk_id, score)| {
let source_scores = chunk_source_scores.remove(&chunk_id).unwrap_or_default();
let exact_mult = chunk_exact_match.get(&chunk_id).copied();
FusedResult::with_exact_match(chunk_id, score, source_scores, exact_mult)
})
.collect();
fused_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
fused_results.truncate(limit);
debug!(
"RRF fusion produced {} results (requested: {})",
fused_results.len(),
limit
);
fused_results
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::executor_types::{RankedResult, SearchSource};
#[test]
fn test_rrf_score_calculation() {
let fusion = RRFFusion::new(60.0);
assert!((fusion.rrf_score(0) - (1.0 / 61.0)).abs() < 0.0001);
assert!((fusion.rrf_score(1) - (1.0 / 62.0)).abs() < 0.0001);
assert!((fusion.rrf_score(9) - (1.0 / 70.0)).abs() < 0.0001);
}
#[test]
fn test_rrf_default_k() {
let fusion = RRFFusion::default();
assert_eq!(fusion.k, 60.0);
}
#[test]
fn test_rrf_custom_k() {
let fusion = RRFFusion::new(100.0);
assert_eq!(fusion.k, 100.0);
assert!((fusion.rrf_score(0) - (1.0 / 101.0)).abs() < 0.0001);
}
#[test]
fn test_rrf_fusion_single_source() {
let fusion = RRFFusion::default();
let weights = FusionWeights::default();
let fts_results = RankedResults::new(
vec![
RankedResult::new(1, 0.9, 1),
RankedResult::new(2, 0.8, 2),
RankedResult::new(3, 0.7, 3),
],
SearchSource::FTS,
);
let fused = fusion.fuse(vec![fts_results], &weights, 10);
assert_eq!(fused.len(), 3);
assert_eq!(fused[0].chunk_id, 1);
assert!((fused[0].score - (1.0 / 61.0)).abs() < 0.0001);
assert_eq!(fused[1].chunk_id, 2);
assert!((fused[1].score - (1.0 / 62.0)).abs() < 0.0001);
assert_eq!(fused[2].chunk_id, 3);
assert!((fused[2].score - (1.0 / 63.0)).abs() < 0.0001);
}
#[test]
fn test_rrf_fusion_multiple_sources_same_chunk() {
let fusion = RRFFusion::default();
let weights = FusionWeights::default();
let fts_results = RankedResults::new(vec![RankedResult::new(1, 0.9, 1)], SearchSource::FTS);
let vector_results =
RankedResults::new(vec![RankedResult::new(1, 0.8, 1)], SearchSource::Vector);
let fused = fusion.fuse(vec![fts_results, vector_results], &weights, 10);
assert_eq!(fused.len(), 1);
assert_eq!(fused[0].chunk_id, 1);
let expected_score = 2.0 / 61.0;
assert!((fused[0].score - expected_score).abs() < 0.0001);
assert_eq!(fused[0].source_scores.len(), 2);
assert_eq!(fused[0].source_scores.get(&SearchSource::FTS), Some(&0.9));
assert_eq!(
fused[0].source_scores.get(&SearchSource::Vector),
Some(&0.8)
);
}
#[test]
fn test_rrf_fusion_multiple_sources_different_ranks() {
let fusion = RRFFusion::default();
let weights = FusionWeights::default();
let fts_results = RankedResults::new(
vec![
RankedResult::new(1, 0.9, 1), RankedResult::new(2, 0.7, 2), ],
SearchSource::FTS,
);
let vector_results = RankedResults::new(
vec![
RankedResult::new(2, 0.8, 1), RankedResult::new(3, 0.6, 2), ],
SearchSource::Vector,
);
let fused = fusion.fuse(vec![fts_results, vector_results], &weights, 10);
assert_eq!(fused.len(), 3);
let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
let expected_score_2 = (1.0 / 62.0) + (1.0 / 61.0);
assert!((chunk_2.score - expected_score_2).abs() < 0.0001);
let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
let expected_score_1 = 1.0 / 61.0;
assert!((chunk_1.score - expected_score_1).abs() < 0.0001);
let chunk_3 = fused.iter().find(|r| r.chunk_id == 3).unwrap();
let expected_score_3 = 1.0 / 62.0;
assert!((chunk_3.score - expected_score_3).abs() < 0.0001);
assert_eq!(fused[0].chunk_id, 2);
}
#[test]
fn test_rrf_fusion_all_four_sources() {
let fusion = RRFFusion::default();
let weights = FusionWeights::default();
let fts_results = RankedResults::new(
vec![RankedResult::new(1, 0.9, 1)], SearchSource::FTS,
);
let vector_results = RankedResults::new(
vec![RankedResult::new(1, 0.85, 1)], SearchSource::Vector,
);
let graph_results = RankedResults::new(
vec![RankedResult::new(1, 0.75, 1)], SearchSource::Graph,
);
let signal_results = RankedResults::new(
vec![RankedResult::new(1, 0.65, 1)], SearchSource::Signals,
);
let fused = fusion.fuse(
vec![fts_results, vector_results, graph_results, signal_results],
&weights,
10,
);
assert_eq!(fused.len(), 1);
assert_eq!(fused[0].chunk_id, 1);
let expected_score = 4.0 / 61.0;
assert!((fused[0].score - expected_score).abs() < 0.0001);
assert_eq!(fused[0].source_scores.len(), 4);
}
#[test]
fn test_rrf_fusion_empty_results() {
let fusion = RRFFusion::default();
let weights = FusionWeights::default();
let fused = fusion.fuse(vec![], &weights, 10);
assert_eq!(fused.len(), 0);
}
#[test]
fn test_rrf_fusion_limit() {
let fusion = RRFFusion::default();
let weights = FusionWeights::default();
let fts_results = RankedResults::new(
vec![
RankedResult::new(1, 0.9, 1),
RankedResult::new(2, 0.8, 2),
RankedResult::new(3, 0.7, 3),
RankedResult::new(4, 0.6, 4),
RankedResult::new(5, 0.5, 5),
],
SearchSource::FTS,
);
let fused = fusion.fuse(vec![fts_results], &weights, 3);
assert_eq!(fused.len(), 3);
assert_eq!(fused[0].chunk_id, 1);
assert_eq!(fused[1].chunk_id, 2);
assert_eq!(fused[2].chunk_id, 3);
}
#[test]
fn test_rrf_k_parameter_effect() {
let fusion_low_k = RRFFusion::new(10.0);
let fusion_high_k = RRFFusion::new(100.0);
let weights = FusionWeights::default();
let fts_results = RankedResults::new(
vec![
RankedResult::new(1, 0.9, 1), RankedResult::new(2, 0.8, 2), ],
SearchSource::FTS,
);
let fused_low = fusion_low_k.fuse(vec![fts_results.clone()], &weights, 10);
let fused_high = fusion_high_k.fuse(vec![fts_results], &weights, 10);
assert!(fused_low[0].score > fused_high[0].score);
assert!(fused_low[1].score > fused_high[1].score);
let diff_low = fused_low[0].score - fused_low[1].score;
let diff_high = fused_high[0].score - fused_high[1].score;
assert!(diff_low > diff_high);
}
#[test]
fn test_rrf_sorting_correctness() {
let fusion = RRFFusion::default();
let weights = FusionWeights::default();
let fts_results = RankedResults::new(
vec![
RankedResult::new(1, 0.9, 1), RankedResult::new(3, 0.7, 3), ],
SearchSource::FTS,
);
let vector_results = RankedResults::new(
vec![
RankedResult::new(2, 0.8, 1), RankedResult::new(3, 0.75, 2), ],
SearchSource::Vector,
);
let graph_results = RankedResults::new(
vec![
RankedResult::new(3, 0.85, 1), RankedResult::new(4, 0.6, 2), ],
SearchSource::Graph,
);
let fused = fusion.fuse(
vec![fts_results, vector_results, graph_results],
&weights,
10,
);
assert_eq!(fused[0].chunk_id, 3);
let chunk_3_score = fused[0].score;
let expected_3 = (1.0 / 62.0) + (1.0 / 62.0) + (1.0 / 61.0);
assert!((chunk_3_score - expected_3).abs() < 0.0001);
for result in &fused[1..] {
assert!(result.score < chunk_3_score);
}
}
}