use super::config::RankFusionStrategy;
use super::types::{DocumentScore, HybridResult, SearchWeights};
use std::collections::{HashMap, HashSet};
pub struct RankFusion {
strategy: RankFusionStrategy,
}
impl RankFusion {
pub fn new(strategy: RankFusionStrategy) -> Self {
Self { strategy }
}
pub fn fuse(
&self,
keyword_results: Vec<DocumentScore>,
semantic_results: Vec<DocumentScore>,
weights: &SearchWeights,
) -> Vec<HybridResult> {
match self.strategy {
RankFusionStrategy::WeightedSum => {
self.weighted_sum(keyword_results, semantic_results, weights)
}
RankFusionStrategy::ReciprocalRankFusion => {
self.reciprocal_rank_fusion(keyword_results, semantic_results)
}
RankFusionStrategy::Cascade => self.cascade(keyword_results, semantic_results, weights),
RankFusionStrategy::Interleave => {
self.interleave(keyword_results, semantic_results, weights)
}
}
}
fn weighted_sum(
&self,
keyword_results: Vec<DocumentScore>,
semantic_results: Vec<DocumentScore>,
weights: &SearchWeights,
) -> Vec<HybridResult> {
let keyword_norm = Self::normalize_scores(&keyword_results);
let semantic_norm = Self::normalize_scores(&semantic_results);
let mut combined: HashMap<String, HybridResult> = HashMap::new();
for doc in keyword_norm {
let result = HybridResult::new(doc.doc_id.clone(), doc.score, 0.0, 0.0, weights);
combined.insert(doc.doc_id, result);
}
for doc in semantic_norm {
combined
.entry(doc.doc_id.clone())
.and_modify(|r| {
r.score_breakdown.semantic_score = doc.score;
r.score = doc.score * weights.semantic_weight
+ r.score_breakdown.keyword_score * weights.keyword_weight;
})
.or_insert_with(|| {
HybridResult::new(doc.doc_id.clone(), 0.0, doc.score, 0.0, weights)
});
}
let mut results: Vec<HybridResult> = combined.into_values().collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
fn reciprocal_rank_fusion(
&self,
keyword_results: Vec<DocumentScore>,
semantic_results: Vec<DocumentScore>,
) -> Vec<HybridResult> {
const K: f32 = 60.0;
let mut rrf_scores: HashMap<String, f32> = HashMap::new();
let mut keyword_ranks: HashMap<String, usize> = HashMap::new();
let mut semantic_ranks: HashMap<String, usize> = HashMap::new();
for (rank, doc) in keyword_results.iter().enumerate() {
let rrf = 1.0 / (K + rank as f32 + 1.0);
*rrf_scores.entry(doc.doc_id.clone()).or_insert(0.0) += rrf;
keyword_ranks.insert(doc.doc_id.clone(), rank);
}
for (rank, doc) in semantic_results.iter().enumerate() {
let rrf = 1.0 / (K + rank as f32 + 1.0);
*rrf_scores.entry(doc.doc_id.clone()).or_insert(0.0) += rrf;
semantic_ranks.insert(doc.doc_id.clone(), rank);
}
let mut results: Vec<HybridResult> = rrf_scores
.into_iter()
.map(|(doc_id, score)| {
let mut result = HybridResult {
doc_id: doc_id.clone(),
score,
score_breakdown: super::types::ScoreBreakdown {
keyword_score: 0.0,
semantic_score: 0.0,
recency_score: 0.0,
keyword_rank: keyword_ranks.get(&doc_id).copied(),
semantic_rank: semantic_ranks.get(&doc_id).copied(),
},
metadata: HashMap::new(),
};
if let Some(kw_doc) = keyword_results.iter().find(|d| d.doc_id == doc_id) {
result.score_breakdown.keyword_score = kw_doc.score;
}
if let Some(sem_doc) = semantic_results.iter().find(|d| d.doc_id == doc_id) {
result.score_breakdown.semantic_score = sem_doc.score;
}
result
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
fn cascade(
&self,
keyword_results: Vec<DocumentScore>,
semantic_results: Vec<DocumentScore>,
weights: &SearchWeights,
) -> Vec<HybridResult> {
let keyword_docs: HashSet<String> =
keyword_results.iter().map(|d| d.doc_id.clone()).collect();
let semantic_map: HashMap<String, f32> = semantic_results
.iter()
.filter(|d| keyword_docs.contains(&d.doc_id))
.map(|d| (d.doc_id.clone(), d.score))
.collect();
let mut results: Vec<HybridResult> = keyword_results
.into_iter()
.map(|doc| {
let semantic_score = semantic_map.get(&doc.doc_id).copied().unwrap_or(0.0);
HybridResult::new(doc.doc_id, doc.score, semantic_score, 0.0, weights)
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
fn interleave(
&self,
keyword_results: Vec<DocumentScore>,
semantic_results: Vec<DocumentScore>,
weights: &SearchWeights,
) -> Vec<HybridResult> {
let mut results = Vec::new();
let mut seen = HashSet::new();
let max_len = keyword_results.len().max(semantic_results.len());
for i in 0..max_len {
if i < keyword_results.len() {
let doc = &keyword_results[i];
if !seen.contains(&doc.doc_id) {
let semantic_score = semantic_results
.iter()
.find(|d| d.doc_id == doc.doc_id)
.map(|d| d.score)
.unwrap_or(0.0);
results.push(HybridResult::new(
doc.doc_id.clone(),
doc.score,
semantic_score,
0.0,
weights,
));
seen.insert(doc.doc_id.clone());
}
}
if i < semantic_results.len() {
let doc = &semantic_results[i];
if !seen.contains(&doc.doc_id) {
let keyword_score = keyword_results
.iter()
.find(|d| d.doc_id == doc.doc_id)
.map(|d| d.score)
.unwrap_or(0.0);
results.push(HybridResult::new(
doc.doc_id.clone(),
keyword_score,
doc.score,
0.0,
weights,
));
seen.insert(doc.doc_id.clone());
}
}
}
results
}
fn normalize_scores(results: &[DocumentScore]) -> Vec<DocumentScore> {
if results.is_empty() {
return Vec::new();
}
let max_score = results
.iter()
.map(|d| d.score)
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(1.0);
let min_score = results
.iter()
.map(|d| d.score)
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let range = (max_score - min_score).max(0.001);
results
.iter()
.map(|d| DocumentScore {
doc_id: d.doc_id.clone(),
score: (d.score - min_score) / range,
rank: d.rank,
})
.collect()
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
fn create_test_results() -> (Vec<DocumentScore>, Vec<DocumentScore>) {
let keyword = vec![
DocumentScore {
doc_id: "doc1".to_string(),
score: 10.0,
rank: 0,
},
DocumentScore {
doc_id: "doc2".to_string(),
score: 8.0,
rank: 1,
},
DocumentScore {
doc_id: "doc3".to_string(),
score: 5.0,
rank: 2,
},
];
let semantic = vec![
DocumentScore {
doc_id: "doc2".to_string(),
score: 0.95,
rank: 0,
},
DocumentScore {
doc_id: "doc4".to_string(),
score: 0.90,
rank: 1,
},
DocumentScore {
doc_id: "doc1".to_string(),
score: 0.85,
rank: 2,
},
];
(keyword, semantic)
}
#[test]
fn test_weighted_sum() {
let (keyword, semantic) = create_test_results();
let fusion = RankFusion::new(RankFusionStrategy::WeightedSum);
let weights = SearchWeights::default();
let results = fusion.fuse(keyword, semantic, &weights);
assert!(!results.is_empty());
assert!(results[0].score > 0.0);
}
#[test]
fn test_reciprocal_rank_fusion() -> Result<()> {
let (keyword, semantic) = create_test_results();
let fusion = RankFusion::new(RankFusionStrategy::ReciprocalRankFusion);
let weights = SearchWeights::default();
let results = fusion.fuse(keyword, semantic, &weights);
assert!(!results.is_empty());
let doc1_score = results
.iter()
.find(|r| r.doc_id == "doc1")
.expect("doc1 should be found")
.score;
let doc4_score = results
.iter()
.find(|r| r.doc_id == "doc4")
.expect("doc4 should be found")
.score;
assert!(doc1_score > doc4_score);
Ok(())
}
#[test]
fn test_cascade() {
let (keyword, semantic) = create_test_results();
let fusion = RankFusion::new(RankFusionStrategy::Cascade);
let weights = SearchWeights::default();
let results = fusion.fuse(keyword, semantic, &weights);
assert!(results.iter().all(|r| r.doc_id != "doc4"));
}
#[test]
fn test_interleave() {
let (keyword, semantic) = create_test_results();
let fusion = RankFusion::new(RankFusionStrategy::Interleave);
let weights = SearchWeights::default();
let results = fusion.fuse(keyword, semantic, &weights);
let doc_ids: HashSet<String> = results.iter().map(|r| r.doc_id.clone()).collect();
assert_eq!(doc_ids.len(), 4);
}
#[test]
fn test_normalize_scores() {
let results = vec![
DocumentScore {
doc_id: "doc1".to_string(),
score: 10.0,
rank: 0,
},
DocumentScore {
doc_id: "doc2".to_string(),
score: 5.0,
rank: 1,
},
DocumentScore {
doc_id: "doc3".to_string(),
score: 0.0,
rank: 2,
},
];
let normalized = RankFusion::normalize_scores(&results);
assert!((normalized[0].score - 1.0).abs() < 0.001);
assert!((normalized[2].score - 0.0).abs() < 0.001);
}
}