use std::collections::BTreeMap;
use super::constants::{RRF_K, RRF_WEIGHTS};
use super::types::{HybridScoreData, RawSearchResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum RrfList {
Bm25,
ExactAlias,
Fuzzy,
Semantic,
}
impl RrfList {
const fn weight(self) -> f64 {
match self {
Self::Bm25 => RRF_WEIGHTS.bm25,
Self::ExactAlias => RRF_WEIGHTS.exact_alias,
Self::Fuzzy => RRF_WEIGHTS.fuzzy,
Self::Semantic => RRF_WEIGHTS.semantic,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct RrfScoreAccumulator {
pub scores: BTreeMap<String, BTreeMap<RrfList, f64>>,
pub max_score_by_list: BTreeMap<RrfList, f64>,
}
impl RrfScoreAccumulator {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn accumulate(&mut self, results: &[RawSearchResult], list: RrfList) {
let weight = list.weight();
let mut rank_f = 0.0_f64;
for result in results {
let rrf = weight / (RRF_K + rank_f + 1.0);
self.scores
.entry(result.path.clone())
.or_default()
.insert(list, rrf);
rank_f += 1.0;
}
let max = if results.is_empty() {
0.0
} else {
weight / (RRF_K + 1.0)
};
self.max_score_by_list.insert(list, max);
}
}
fn sum_list_caps(max_by_list: &BTreeMap<RrfList, f64>) -> f64 {
max_by_list.values().copied().sum()
}
fn first_hit_by_path<'a>(
lists: &'a [&'a [RawSearchResult]],
) -> BTreeMap<String, &'a RawSearchResult> {
let mut out = BTreeMap::new();
for list in lists {
for r in *list {
out.entry(r.path.clone()).or_insert(r);
}
}
out
}
#[derive(Debug, Clone, Default)]
pub struct RrfInputs<'a> {
pub semantic: &'a [RawSearchResult],
pub bm25: &'a [RawSearchResult],
pub exact_alias: &'a [RawSearchResult],
pub fuzzy: &'a [RawSearchResult],
}
#[must_use]
pub fn normalize_and_merge_rrf_results(
acc: &RrfScoreAccumulator,
inputs: &RrfInputs<'_>,
limit: usize,
) -> Vec<HybridScoreData> {
let max_possible = sum_list_caps(&acc.max_score_by_list);
let priority: [&[RawSearchResult]; 4] = [
inputs.semantic,
inputs.bm25,
inputs.exact_alias,
inputs.fuzzy,
];
let result_map = first_hit_by_path(&priority);
let mut hybrid: Vec<HybridScoreData> = acc
.scores
.iter()
.filter_map(|(path, path_scores)| {
let base = result_map.get(path)?;
let raw = path_scores.get(&RrfList::Semantic).copied().unwrap_or(0.0)
+ path_scores.get(&RrfList::Bm25).copied().unwrap_or(0.0)
+ path_scores
.get(&RrfList::ExactAlias)
.copied()
.unwrap_or(0.0)
+ path_scores.get(&RrfList::Fuzzy).copied().unwrap_or(0.0);
let hybrid_before_norm = if max_possible == 0.0 {
0.0
} else {
(raw / max_possible).clamp(0.0, 1.0)
};
Some(HybridScoreData {
path: path.clone(),
title: base.title.clone(),
tags: base.tags.clone(),
aliases: base.aliases.clone(),
snippet: base.snippet.clone(),
bm25: path_scores.get(&RrfList::Bm25).copied(),
fuzzy_title: path_scores.get(&RrfList::Fuzzy).copied(),
semantic: path_scores.get(&RrfList::Semantic).copied(),
hybrid_before_norm: Some(hybrid_before_norm),
semantic_heading: base.semantic_heading.clone(),
semantic_char_start: base.semantic_char_start,
semantic_char_end: base.semantic_char_end,
})
})
.collect();
hybrid.sort_by(|a, b| {
b.hybrid_before_norm
.unwrap_or(0.0)
.partial_cmp(&a.hybrid_before_norm.unwrap_or(0.0))
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.path.cmp(&b.path))
});
hybrid.truncate(limit);
hybrid
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::float_cmp)]
mod tests {
use super::*;
use crate::search::types::SearchScores;
fn r(path: &str, score: f64) -> RawSearchResult {
RawSearchResult {
path: path.into(),
title: format!("Title {path}"),
tags: vec![],
aliases: vec![],
snippet: format!("snip {path}"),
score,
scores: SearchScores::default(),
semantic_heading: None,
semantic_char_start: None,
semantic_char_end: None,
}
}
#[test]
fn empty_accumulator_yields_no_results() {
let acc = RrfScoreAccumulator::new();
let inputs = RrfInputs::default();
assert_eq!(normalize_and_merge_rrf_results(&acc, &inputs, 10).len(), 0);
}
#[test]
fn single_list_normalizes_top_hit_to_one() {
let bm25 = vec![r("a.md", 0.0), r("b.md", 0.0)];
let mut acc = RrfScoreAccumulator::new();
acc.accumulate(&bm25, RrfList::Bm25);
let inputs = RrfInputs {
bm25: &bm25,
..Default::default()
};
let results = normalize_and_merge_rrf_results(&acc, &inputs, 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].path, "a.md");
assert!((results[0].hybrid_before_norm.unwrap() - 1.0).abs() < 1e-9);
assert!(results[1].hybrid_before_norm.unwrap() < 1.0);
}
#[test]
fn multiple_lists_sum_contributions() {
let bm25 = vec![r("a.md", 0.0), r("b.md", 0.0)];
let semantic = vec![r("a.md", 0.0)];
let mut acc = RrfScoreAccumulator::new();
acc.accumulate(&bm25, RrfList::Bm25);
acc.accumulate(&semantic, RrfList::Semantic);
let inputs = RrfInputs {
bm25: &bm25,
semantic: &semantic,
..Default::default()
};
let results = normalize_and_merge_rrf_results(&acc, &inputs, 10);
assert_eq!(results[0].path, "a.md");
assert!(results[0].hybrid_before_norm.unwrap() > results[1].hybrid_before_norm.unwrap());
assert!((results[0].hybrid_before_norm.unwrap() - 1.0).abs() < 1e-9);
}
#[test]
fn weights_favor_bm25_and_alias_over_fuzzy() {
let bm25 = vec![r("bm.md", 0.0)];
let fuzzy = vec![r("fz.md", 0.0)];
let mut acc = RrfScoreAccumulator::new();
acc.accumulate(&bm25, RrfList::Bm25);
acc.accumulate(&fuzzy, RrfList::Fuzzy);
let inputs = RrfInputs {
bm25: &bm25,
fuzzy: &fuzzy,
..Default::default()
};
let results = normalize_and_merge_rrf_results(&acc, &inputs, 10);
let bm = results.iter().find(|r| r.path == "bm.md").unwrap();
let fz = results.iter().find(|r| r.path == "fz.md").unwrap();
assert!(
(bm.hybrid_before_norm.unwrap() - 6.0 / 7.0).abs() < 1e-9,
"bm.md expected {}, got {}",
6.0 / 7.0,
bm.hybrid_before_norm.unwrap()
);
assert!(
(fz.hybrid_before_norm.unwrap() - 1.0 / 7.0).abs() < 1e-9,
"fz.md expected {}, got {}",
1.0 / 7.0,
fz.hybrid_before_norm.unwrap()
);
}
#[test]
fn first_hit_wins_for_metadata_priority() {
let semantic = vec![RawSearchResult {
snippet: "from semantic".into(),
..r("a.md", 0.0)
}];
let bm25 = vec![RawSearchResult {
snippet: "from bm25".into(),
..r("a.md", 0.0)
}];
let mut acc = RrfScoreAccumulator::new();
acc.accumulate(&semantic, RrfList::Semantic);
acc.accumulate(&bm25, RrfList::Bm25);
let inputs = RrfInputs {
semantic: &semantic,
bm25: &bm25,
..Default::default()
};
let results = normalize_and_merge_rrf_results(&acc, &inputs, 10);
assert_eq!(results[0].snippet, "from semantic");
}
#[test]
fn limit_truncates_output() {
let bm25: Vec<_> = (0..5).map(|i| r(&format!("p{i}.md"), 0.0)).collect();
let mut acc = RrfScoreAccumulator::new();
acc.accumulate(&bm25, RrfList::Bm25);
let inputs = RrfInputs {
bm25: &bm25,
..Default::default()
};
assert_eq!(normalize_and_merge_rrf_results(&acc, &inputs, 3).len(), 3);
}
#[test]
fn equal_rrf_scores_sort_by_path() {
let bm25 = vec![r("z.md", 0.0), r("a.md", 0.0)];
let semantic = vec![r("a.md", 0.0), r("z.md", 0.0)];
let mut acc = RrfScoreAccumulator::new();
acc.accumulate(&bm25, RrfList::Bm25);
acc.accumulate(&semantic, RrfList::Semantic);
let inputs = RrfInputs {
semantic: &semantic,
bm25: &bm25,
..Default::default()
};
let out = normalize_and_merge_rrf_results(&acc, &inputs, 10);
assert_eq!(out[0].path, "a.md");
assert_eq!(out[1].path, "z.md");
}
#[test]
fn rrf_weights_match_ohs_benchmark_values() {
assert!((RRF_WEIGHTS.bm25 - 1.5).abs() < f64::EPSILON);
assert!((RRF_WEIGHTS.exact_alias - 2.0).abs() < f64::EPSILON);
assert!((RRF_WEIGHTS.fuzzy - 0.25).abs() < f64::EPSILON);
assert!((RRF_WEIGHTS.semantic - 1.5).abs() < f64::EPSILON);
}
}