use std::collections::HashMap;
pub type DocId = u64;
#[derive(Debug, Clone)]
pub struct SearchResult {
pub doc_id: DocId,
pub score: f32,
pub component_scores: Option<ComponentScores>,
}
#[derive(Debug, Clone)]
pub struct ComponentScores {
pub vector_score: Option<f32>,
pub vector_rank: Option<usize>,
pub lexical_score: Option<f32>,
pub lexical_rank: Option<usize>,
}
#[derive(Debug, Clone, Copy)]
pub struct RRFConfig {
pub k: f32,
pub vector_weight: f32,
pub lexical_weight: f32,
}
impl Default for RRFConfig {
fn default() -> Self {
Self {
k: 60.0,
vector_weight: 1.0,
lexical_weight: 1.0,
}
}
}
impl RRFConfig {
pub fn with_weights(vector_weight: f32, lexical_weight: f32) -> Self {
Self {
k: 60.0,
vector_weight,
lexical_weight,
}
}
pub fn semantic_focused() -> Self {
Self {
k: 60.0,
vector_weight: 0.7,
lexical_weight: 0.3,
}
}
pub fn keyword_focused() -> Self {
Self {
k: 60.0,
vector_weight: 0.3,
lexical_weight: 0.7,
}
}
pub fn balanced() -> Self {
Self::default()
}
}
pub struct RRFFusion {
config: RRFConfig,
}
impl RRFFusion {
pub fn new(config: RRFConfig) -> Self {
Self { config }
}
pub fn fuse(
&self,
vector_results: &[(DocId, f32)],
lexical_results: &[(DocId, f32)],
limit: usize,
keep_details: bool,
) -> Vec<SearchResult> {
let k = self.config.k;
let mut doc_scores: HashMap<DocId, FusionState> = HashMap::new();
for (rank, &(doc_id, score)) in vector_results.iter().enumerate() {
let rrf_score = self.config.vector_weight / (k + (rank + 1) as f32);
let state = doc_scores.entry(doc_id).or_default();
state.rrf_score += rrf_score;
state.vector_score = Some(score);
state.vector_rank = Some(rank + 1);
}
for (rank, &(doc_id, score)) in lexical_results.iter().enumerate() {
let rrf_score = self.config.lexical_weight / (k + (rank + 1) as f32);
let state = doc_scores.entry(doc_id).or_default();
state.rrf_score += rrf_score;
state.lexical_score = Some(score);
state.lexical_rank = Some(rank + 1);
}
let mut results: Vec<SearchResult> = doc_scores
.into_iter()
.map(|(doc_id, state)| SearchResult {
doc_id,
score: state.rrf_score,
component_scores: if keep_details {
Some(ComponentScores {
vector_score: state.vector_score,
vector_rank: state.vector_rank,
lexical_score: state.lexical_score,
lexical_rank: state.lexical_rank,
})
} else {
None
},
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(limit);
results
}
pub fn fuse_multi(
&self,
result_lists: &[(&[(DocId, f32)], f32)], limit: usize,
) -> Vec<SearchResult> {
let k = self.config.k;
let mut doc_scores: HashMap<DocId, f32> = HashMap::new();
for (results, weight) in result_lists {
for (rank, &(doc_id, _score)) in results.iter().enumerate() {
let rrf_score = *weight / (k + (rank + 1) as f32);
*doc_scores.entry(doc_id).or_default() += rrf_score;
}
}
let mut results: Vec<SearchResult> = doc_scores
.into_iter()
.map(|(doc_id, score)| SearchResult {
doc_id,
score,
component_scores: None,
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(limit);
results
}
}
#[derive(Default)]
struct FusionState {
rrf_score: f32,
vector_score: Option<f32>,
vector_rank: Option<usize>,
lexical_score: Option<f32>,
lexical_rank: Option<usize>,
}
impl Default for RRFFusion {
fn default() -> Self {
Self::new(RRFConfig::default())
}
}
pub struct HybridSearchEngine<V, L> {
vector_search: V,
lexical_search: L,
fusion_config: RRFConfig,
overfetch_factor: f32,
}
pub trait VectorSearchBackend {
fn search(&self, query: &[f32], k: usize) -> Vec<(DocId, f32)>;
}
pub trait LexicalSearchBackend {
fn search(&self, query: &str, k: usize) -> Vec<(DocId, f32)>;
}
impl<V, L> HybridSearchEngine<V, L>
where
V: VectorSearchBackend,
L: LexicalSearchBackend,
{
pub fn new(vector_search: V, lexical_search: L) -> Self {
Self {
vector_search,
lexical_search,
fusion_config: RRFConfig::default(),
overfetch_factor: 2.0,
}
}
pub fn with_fusion_config(mut self, config: RRFConfig) -> Self {
self.fusion_config = config;
self
}
pub fn with_overfetch(mut self, factor: f32) -> Self {
self.overfetch_factor = factor.max(1.0);
self
}
pub fn search(
&self,
vector_query: Option<&[f32]>,
text_query: Option<&str>,
limit: usize,
) -> Vec<SearchResult> {
let fetch_k = (limit as f32 * self.overfetch_factor) as usize;
let vector_results = match vector_query {
Some(q) => self.vector_search.search(q, fetch_k),
None => Vec::new(),
};
let lexical_results = match text_query {
Some(q) => self.lexical_search.search(q, fetch_k),
None => Vec::new(),
};
if vector_results.is_empty() {
return lexical_results
.into_iter()
.take(limit)
.map(|(doc_id, score)| SearchResult {
doc_id,
score,
component_scores: None,
})
.collect();
}
if lexical_results.is_empty() {
return vector_results
.into_iter()
.take(limit)
.map(|(doc_id, score)| SearchResult {
doc_id,
score,
component_scores: None,
})
.collect();
}
let fusion = RRFFusion::new(self.fusion_config);
fusion.fuse(&vector_results, &lexical_results, limit, false)
}
pub fn search_detailed(
&self,
vector_query: Option<&[f32]>,
text_query: Option<&str>,
limit: usize,
) -> Vec<SearchResult> {
let fetch_k = (limit as f32 * self.overfetch_factor) as usize;
let vector_results = vector_query
.map(|q| self.vector_search.search(q, fetch_k))
.unwrap_or_default();
let lexical_results = text_query
.map(|q| self.lexical_search.search(q, fetch_k))
.unwrap_or_default();
let fusion = RRFFusion::new(self.fusion_config);
fusion.fuse(&vector_results, &lexical_results, limit, true)
}
}
pub fn filter_results<F>(
results: Vec<SearchResult>,
predicate: F,
limit: usize,
) -> Vec<SearchResult>
where
F: Fn(DocId) -> bool,
{
results
.into_iter()
.filter(|r| predicate(r.doc_id))
.take(limit)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrf_fusion_basic() {
let fusion = RRFFusion::default();
let vector_results = vec![(1, 0.95), (2, 0.90), (3, 0.85)];
let lexical_results = vec![
(2, 5.0), (4, 4.5),
(3, 4.0), ];
let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
assert!(!results.is_empty());
for r in &results {
assert!(r.score > 0.0);
}
}
#[test]
fn test_rrf_fusion_with_details() {
let fusion = RRFFusion::default();
let vector_results = vec![(1, 0.9), (2, 0.8)];
let lexical_results = vec![(2, 5.0), (3, 4.0)];
let results = fusion.fuse(&vector_results, &lexical_results, 10, true);
let doc2 = results.iter().find(|r| r.doc_id == 2).unwrap();
let scores = doc2.component_scores.as_ref().unwrap();
assert_eq!(scores.vector_rank, Some(2)); assert_eq!(scores.lexical_rank, Some(1)); assert_eq!(scores.vector_score, Some(0.8));
assert_eq!(scores.lexical_score, Some(5.0));
}
#[test]
fn test_rrf_ranking() {
let fusion = RRFFusion::default();
let vector_results = vec![(1, 0.95), (2, 0.90)];
let lexical_results = vec![(2, 5.0)];
let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
assert_eq!(results[0].doc_id, 2); }
#[test]
fn test_rrf_weights() {
let config = RRFConfig::keyword_focused();
let fusion = RRFFusion::new(config);
let vector_results = vec![(1, 0.95)];
let lexical_results = vec![(2, 5.0)];
let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
assert_eq!(results[0].doc_id, 2);
}
#[test]
fn test_fuse_multi() {
let fusion = RRFFusion::default();
let list1: Vec<(DocId, f32)> = vec![(1, 0.9), (2, 0.8)];
let list2: Vec<(DocId, f32)> = vec![(2, 0.9), (3, 0.8)];
let list3: Vec<(DocId, f32)> = vec![(3, 0.9), (1, 0.8)];
let results = fusion.fuse_multi(&[(&list1, 1.0), (&list2, 1.0), (&list3, 1.0)], 10);
let doc_ids: Vec<_> = results.iter().map(|r| r.doc_id).collect();
assert!(doc_ids.contains(&1));
assert!(doc_ids.contains(&2));
assert!(doc_ids.contains(&3));
}
#[test]
fn test_fuse_multi_rrf_formula_golden() {
let k = 60.0_f32;
let fusion = RRFFusion::new(RRFConfig {
k,
vector_weight: 1.0,
lexical_weight: 1.0,
});
let docs: Vec<(DocId, f32)> = vec![(7, 0.9), (8, 0.5)];
let single = fusion.fuse_multi(&[(&docs, 2.0)], 10);
let s7 = single.iter().find(|r| r.doc_id == 7).unwrap().score;
let s8 = single.iter().find(|r| r.doc_id == 8).unwrap().score;
assert!(
(s7 - 2.0 / (k + 1.0)).abs() < 1e-6,
"rank-1 must be 1-indexed weighted"
);
assert!(
(s8 - 2.0 / (k + 2.0)).abs() < 1e-6,
"rank-2 must be 1-indexed weighted"
);
assert!(s7 > s8, "earlier rank must score higher");
let la: Vec<(DocId, f32)> = vec![(1, 0.0)];
let lb: Vec<(DocId, f32)> = vec![(1, 0.0)];
let merged = fusion.fuse_multi(&[(&la, 1.0), (&lb, 3.0)], 10);
let s1 = merged.iter().find(|r| r.doc_id == 1).unwrap().score;
let expected = 1.0 / (k + 1.0) + 3.0 / (k + 1.0);
assert!(
(s1 - expected).abs() < 1e-6,
"weights must sum across lists"
);
}
#[test]
fn test_filter_results() {
let results = vec![
SearchResult {
doc_id: 1,
score: 0.9,
component_scores: None,
},
SearchResult {
doc_id: 2,
score: 0.8,
component_scores: None,
},
SearchResult {
doc_id: 3,
score: 0.7,
component_scores: None,
},
SearchResult {
doc_id: 4,
score: 0.6,
component_scores: None,
},
];
let filtered = filter_results(results, |id| id % 2 == 0, 10);
assert_eq!(filtered.len(), 2);
assert_eq!(filtered[0].doc_id, 2);
assert_eq!(filtered[1].doc_id, 4);
}
}