use super::bm25::{Bm25Index, Bm25Score};
use super::reranker::{mmr_rerank, rrf_fuse, RrfParams};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionMethod {
Rrf,
Mmr,
Linear,
}
#[derive(Debug, Clone)]
pub struct ScoredHit {
pub doc_id: u64,
pub score: f64,
pub vector: Option<Vec<f32>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct HybridHit {
pub doc_id: u64,
pub score: f64,
}
#[must_use]
pub fn hybrid_search(
bm25_hits: &[ScoredHit],
vector_hits: &[ScoredHit],
method: FusionMethod,
lambda: f64,
limit: usize,
) -> Vec<HybridHit> {
match method {
FusionMethod::Rrf => {
let lists = vec![
bm25_hits.iter().map(|h| h.doc_id).collect::<Vec<_>>(),
vector_hits.iter().map(|h| h.doc_id).collect::<Vec<_>>(),
];
let fused = rrf_fuse(&lists, RrfParams::default());
fused
.into_iter()
.take(limit)
.map(|(doc_id, score)| HybridHit { doc_id, score })
.collect()
}
FusionMethod::Mmr => {
use std::collections::HashMap;
let mut by_id: HashMap<u64, ScoredHit> = HashMap::new();
for h in vector_hits.iter().chain(bm25_hits.iter()) {
by_id.entry(h.doc_id).or_insert_with(|| h.clone());
}
let candidates: Vec<_> = by_id
.into_values()
.filter_map(|h| h.vector.map(|v| (h.doc_id, h.score, v)))
.collect();
let reranked = mmr_rerank(candidates, lambda, limit);
reranked
.into_iter()
.map(|(doc_id, score)| HybridHit { doc_id, score })
.collect()
}
FusionMethod::Linear => {
let bm25_norm = normalised_scores(bm25_hits);
let vec_norm = normalised_scores(vector_hits);
use std::collections::HashMap;
let mut combined: HashMap<u64, f64> = HashMap::new();
for (id, score) in &bm25_norm {
*combined.entry(*id).or_insert(0.0) += (1.0 - lambda) * score;
}
for (id, score) in &vec_norm {
*combined.entry(*id).or_insert(0.0) += lambda * score;
}
let mut out: Vec<HybridHit> = combined
.into_iter()
.map(|(doc_id, score)| HybridHit { doc_id, score })
.collect();
out.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.doc_id.cmp(&b.doc_id))
});
out.truncate(limit);
out
}
}
}
fn normalised_scores(hits: &[ScoredHit]) -> Vec<(u64, f64)> {
if hits.is_empty() {
return Vec::new();
}
let mut max = f64::MIN;
let mut min = f64::MAX;
for h in hits {
if h.score > max {
max = h.score;
}
if h.score < min {
min = h.score;
}
}
let range = max - min;
if range <= f64::EPSILON {
return hits.iter().map(|h| (h.doc_id, 1.0)).collect();
}
hits.iter().map(|h| (h.doc_id, (h.score - min) / range)).collect()
}
#[must_use]
pub fn bm25_hits(index: &Bm25Index, query: &str, limit: Option<usize>) -> Vec<ScoredHit> {
index
.score(query, limit)
.into_iter()
.map(|Bm25Score { doc_id, score }| ScoredHit {
doc_id,
score,
vector: None,
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn hits(ids: &[u64], scores: &[f64]) -> Vec<ScoredHit> {
ids.iter()
.zip(scores.iter())
.map(|(id, s)| ScoredHit {
doc_id: *id,
score: *s,
vector: None,
})
.collect()
}
#[test]
fn rrf_combines_bm25_and_vec_lists() {
let bm25 = hits(&[1, 2, 3], &[3.0, 2.0, 1.0]);
let vec = hits(&[3, 1, 4], &[0.9, 0.8, 0.7]);
let res = hybrid_search(&bm25, &vec, FusionMethod::Rrf, 0.5, 10);
let ids: Vec<_> = res.iter().map(|h| h.doc_id).collect();
assert_eq!(ids[0], 1);
assert!(ids.contains(&3));
assert!(ids.contains(&4));
}
#[test]
fn linear_fusion_respects_lambda() {
let bm25 = hits(&[1, 2], &[10.0, 1.0]);
let vec = hits(&[2, 1], &[10.0, 1.0]);
let r1 = hybrid_search(&bm25, &vec, FusionMethod::Linear, 1.0, 10);
assert_eq!(r1[0].doc_id, 2);
let r0 = hybrid_search(&bm25, &vec, FusionMethod::Linear, 0.0, 10);
assert_eq!(r0[0].doc_id, 1);
}
#[test]
fn mmr_uses_vectors_when_provided() {
let bm25 = hits(&[1, 2, 3], &[1.0, 1.0, 1.0]);
let vec = vec![
ScoredHit {
doc_id: 1,
score: 0.9,
vector: Some(vec![1.0, 0.0]),
},
ScoredHit {
doc_id: 2,
score: 0.85,
vector: Some(vec![1.0, 0.0]),
},
ScoredHit {
doc_id: 3,
score: 0.7,
vector: Some(vec![0.0, 1.0]),
},
];
let res = hybrid_search(&bm25, &vec, FusionMethod::Mmr, 0.3, 3);
assert_eq!(res.len(), 3);
assert_eq!(res[0].doc_id, 1);
assert_eq!(res[1].doc_id, 3);
}
#[test]
fn limit_caps_output() {
let bm25 = hits(&[1, 2, 3, 4, 5], &[5.0, 4.0, 3.0, 2.0, 1.0]);
let vec = hits(&[5, 4, 3, 2, 1], &[0.9, 0.8, 0.7, 0.6, 0.5]);
let res = hybrid_search(&bm25, &vec, FusionMethod::Rrf, 0.5, 2);
assert_eq!(res.len(), 2);
}
#[test]
fn empty_inputs_safe() {
let res = hybrid_search(&[], &[], FusionMethod::Rrf, 0.5, 5);
assert!(res.is_empty());
let res2 = hybrid_search(&[], &[], FusionMethod::Linear, 0.5, 5);
assert!(res2.is_empty());
}
#[test]
fn bm25_hits_helper_lifts_index() {
let idx = Bm25Index::new();
idx.add_document(1, "alpha beta");
idx.add_document(2, "alpha gamma");
let h = bm25_hits(&idx, "alpha", None);
assert_eq!(h.len(), 2);
assert!(h.iter().all(|x| x.vector.is_none()));
}
}