#[cfg(feature = "embeddings")]
pub mod cross_encoder;
#[cfg(feature = "embeddings")]
pub use cross_encoder::{CrossEncoderModel, CrossEncoderReranker};
pub mod colbert;
pub use colbert::{
ColBERTBatchReranker, ColBERTConfig, ColBERTReranker, SimilarityMetric, TokenEmbeddings,
};
use crate::store::Neighbor;
use anyhow::Result;
pub trait Reranker: Send + Sync {
fn rerank(&self, query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>>;
fn name(&self) -> &str;
}
pub struct MMRReranker {
lambda: f32, }
impl MMRReranker {
pub fn new(lambda: f32) -> Self {
assert!(
(0.0..=1.0).contains(&lambda),
"lambda must be between 0.0 and 1.0"
);
Self { lambda }
}
#[allow(dead_code)]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vectors must have same length");
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
0.0
} else {
dot / (mag_a * mag_b)
}
}
fn mmr_score(&self, candidate: &Neighbor, selected: &[&Neighbor]) -> f32 {
if selected.is_empty() {
return candidate.score; }
let relevance = candidate.score;
let max_similarity = selected
.iter()
.map(|selected_result| {
if candidate.id == selected_result.id {
1.0
} else {
0.0
}
})
.fold(0.0f32, f32::max);
self.lambda * relevance - (1.0 - self.lambda) * max_similarity
}
}
impl Reranker for MMRReranker {
fn rerank(&self, _query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>> {
if results.is_empty() || top_k == 0 {
return Ok(Vec::new());
}
let top_k = top_k.min(results.len());
let mut selected: Vec<Neighbor> = Vec::with_capacity(top_k);
let mut remaining = results;
for _ in 0..top_k {
if remaining.is_empty() {
break;
}
let selected_refs: Vec<&Neighbor> = selected.iter().collect();
let best_idx = remaining
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
let score_a = self.mmr_score(a, &selected_refs);
let score_b = self.mmr_score(b, &selected_refs);
score_a
.partial_cmp(&score_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
.unwrap();
selected.push(remaining.remove(best_idx));
}
Ok(selected)
}
fn name(&self) -> &str {
"MMR (Maximal Marginal Relevance)"
}
}
pub struct ScoreReranker<F>
where
F: Fn(&Neighbor) -> f32 + Send + Sync,
{
score_fn: F,
}
impl<F> ScoreReranker<F>
where
F: Fn(&Neighbor) -> f32 + Send + Sync,
{
pub fn new(score_fn: F) -> Self {
Self { score_fn }
}
}
impl<F> Reranker for ScoreReranker<F>
where
F: Fn(&Neighbor) -> f32 + Send + Sync,
{
fn rerank(&self, _query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>> {
let mut scored: Vec<(f32, Neighbor)> = results
.into_iter()
.map(|neighbor| {
let score = (self.score_fn)(&neighbor);
(score, neighbor)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let reranked = scored
.into_iter()
.take(top_k)
.map(|(_, neighbor)| neighbor)
.collect();
Ok(reranked)
}
fn name(&self) -> &str {
"Score-based Reranker"
}
}
pub struct CrossEncoderFn<F>
where
F: Fn(&str, &str) -> f32 + Send + Sync,
{
score_fn: F,
}
impl<F> CrossEncoderFn<F>
where
F: Fn(&str, &str) -> f32 + Send + Sync,
{
pub fn new(score_fn: F) -> Self {
Self { score_fn }
}
}
impl<F> Reranker for CrossEncoderFn<F>
where
F: Fn(&str, &str) -> f32 + Send + Sync,
{
fn rerank(&self, query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>> {
let mut scored: Vec<(f32, Neighbor)> = results
.into_iter()
.map(|neighbor| {
let doc_text = neighbor
.metadata
.fields
.get("text")
.and_then(|v| v.as_str())
.unwrap_or("");
let score = (self.score_fn)(query, doc_text);
(score, neighbor)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let reranked = scored
.into_iter()
.take(top_k)
.map(|(_, neighbor)| neighbor)
.collect();
Ok(reranked)
}
fn name(&self) -> &str {
"Cross-Encoder Function"
}
}
pub struct IdentityReranker;
impl Reranker for IdentityReranker {
fn rerank(&self, _query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>> {
Ok(results.into_iter().take(top_k).collect())
}
fn name(&self) -> &str {
"Identity (No Reranking)"
}
}
pub struct RRFReranker {
k: f32, }
impl RRFReranker {
pub fn new(k: f32) -> Self {
assert!(k > 0.0, "k must be positive");
Self { k }
}
pub fn fuse_multiple(
&self,
ranked_lists: Vec<Vec<Neighbor>>,
top_k: usize,
) -> Result<Vec<Neighbor>> {
if ranked_lists.is_empty() {
return Ok(Vec::new());
}
let mut doc_scores: std::collections::HashMap<String, (f32, Neighbor)> =
std::collections::HashMap::new();
for ranked_list in ranked_lists {
for (rank, neighbor) in ranked_list.into_iter().enumerate() {
let rrf_score = 1.0 / (self.k + (rank + 1) as f32);
doc_scores
.entry(neighbor.id.clone())
.and_modify(|(score, _)| *score += rrf_score)
.or_insert((rrf_score, neighbor));
}
}
let mut combined: Vec<(f32, Neighbor)> = doc_scores.into_values().collect();
combined.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let reranked = combined
.into_iter()
.take(top_k)
.map(|(rrf_score, mut neighbor)| {
neighbor.score = rrf_score;
neighbor
})
.collect();
Ok(reranked)
}
}
impl Reranker for RRFReranker {
fn rerank(&self, _query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>> {
self.fuse_multiple(vec![results], top_k)
}
fn name(&self) -> &str {
"Reciprocal Rank Fusion (RRF)"
}
}
pub struct EnsembleReranker {
rerankers: Vec<(Box<dyn Reranker>, f32)>, }
impl EnsembleReranker {
pub fn new() -> Self {
Self {
rerankers: Vec::new(),
}
}
pub fn add(mut self, reranker: Box<dyn Reranker>, weight: f32) -> Self {
self.rerankers.push((reranker, weight));
self
}
pub fn add_all(mut self, rerankers: Vec<(Box<dyn Reranker>, f32)>) -> Self {
self.rerankers.extend(rerankers);
self
}
}
impl Default for EnsembleReranker {
fn default() -> Self {
Self::new()
}
}
impl Reranker for EnsembleReranker {
fn rerank(&self, query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>> {
if self.rerankers.is_empty() {
return Ok(results.into_iter().take(top_k).collect());
}
let total_weight: f32 = self.rerankers.iter().map(|(_, w)| w).sum();
if total_weight == 0.0 {
return Ok(results.into_iter().take(top_k).collect());
}
let mut combined_scores: std::collections::HashMap<String, f32> =
std::collections::HashMap::new();
let mut result_map: std::collections::HashMap<String, Neighbor> =
std::collections::HashMap::new();
for neighbor in &results {
result_map.insert(neighbor.id.clone(), neighbor.clone());
}
for (reranker, weight) in &self.rerankers {
let reranked = reranker.rerank(query, results.clone(), results.len())?;
for neighbor in reranked {
let normalized_weight = weight / total_weight;
let weighted_score = neighbor.score * normalized_weight;
combined_scores
.entry(neighbor.id.clone())
.and_modify(|s| *s += weighted_score)
.or_insert(weighted_score);
}
}
let mut final_results: Vec<(f32, String)> = combined_scores
.into_iter()
.map(|(id, score)| (score, id))
.collect();
final_results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let reranked = final_results
.into_iter()
.take(top_k)
.filter_map(|(score, id)| {
result_map.get(&id).map(|n| {
let mut neighbor = n.clone();
neighbor.score = score;
neighbor
})
})
.collect();
Ok(reranked)
}
fn name(&self) -> &str {
"Ensemble Reranker"
}
}
pub struct BordaCountReranker;
impl BordaCountReranker {
pub fn new() -> Self {
Self
}
pub fn combine(&self, ranked_lists: Vec<Vec<Neighbor>>, top_k: usize) -> Result<Vec<Neighbor>> {
if ranked_lists.is_empty() {
return Ok(Vec::new());
}
let mut doc_scores: std::collections::HashMap<String, (f32, Neighbor)> =
std::collections::HashMap::new();
for ranked_list in ranked_lists {
let n = ranked_list.len();
for (rank, neighbor) in ranked_list.into_iter().enumerate() {
let borda_score = (n - rank) as f32;
doc_scores
.entry(neighbor.id.clone())
.and_modify(|(score, _)| *score += borda_score)
.or_insert((borda_score, neighbor));
}
}
let mut combined: Vec<(f32, Neighbor)> = doc_scores.into_values().collect();
combined.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let reranked = combined
.into_iter()
.take(top_k)
.map(|(borda_score, mut neighbor)| {
neighbor.score = borda_score;
neighbor
})
.collect();
Ok(reranked)
}
}
impl Default for BordaCountReranker {
fn default() -> Self {
Self::new()
}
}
impl Reranker for BordaCountReranker {
fn rerank(&self, _query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>> {
self.combine(vec![results], top_k)
}
fn name(&self) -> &str {
"Borda Count"
}
}
pub struct ContextualReranker {
history: Vec<String>,
context_weight: f32, }
impl ContextualReranker {
pub fn new() -> Self {
Self {
history: Vec::new(),
context_weight: 0.2, }
}
pub fn with_history(mut self, history: Vec<String>) -> Self {
self.history = history;
self
}
pub fn add_to_history(&mut self, query: String) {
self.history.push(query);
}
pub fn with_context_weight(mut self, weight: f32) -> Self {
assert!(
(0.0..=1.0).contains(&weight),
"context_weight must be between 0.0 and 1.0"
);
self.context_weight = weight;
self
}
fn context_score(&self, neighbor: &Neighbor) -> f32 {
if self.history.is_empty() {
return 0.0;
}
let doc_text = neighbor
.metadata
.fields
.get("text")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_lowercase();
if doc_text.is_empty() {
return 0.0;
}
let doc_words: std::collections::HashSet<&str> = doc_text.split_whitespace().collect();
let mut overlap_count = 0;
let mut total_history_words = 0;
for hist_query in &self.history {
let hist_lower = hist_query.to_lowercase();
let hist_words: Vec<&str> = hist_lower.split_whitespace().collect();
total_history_words += hist_words.len();
for word in &hist_words {
if doc_words.contains(word) {
overlap_count += 1;
}
}
}
if total_history_words == 0 {
0.0
} else {
overlap_count as f32 / total_history_words as f32
}
}
}
impl Default for ContextualReranker {
fn default() -> Self {
Self::new()
}
}
impl Reranker for ContextualReranker {
fn rerank(&self, _query: &str, results: Vec<Neighbor>, top_k: usize) -> Result<Vec<Neighbor>> {
let mut scored: Vec<(f32, Neighbor)> = results
.into_iter()
.map(|neighbor| {
let original_score = neighbor.score;
let context_score = self.context_score(&neighbor);
let combined_score = (1.0 - self.context_weight) * original_score
+ self.context_weight * context_score;
(combined_score, neighbor)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let reranked = scored
.into_iter()
.take(top_k)
.map(|(score, mut neighbor)| {
neighbor.score = score;
neighbor
})
.collect();
Ok(reranked)
}
fn name(&self) -> &str {
"Contextual Reranker"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Metadata;
use std::collections::HashMap;
fn make_neighbor(id: &str, score: f32) -> Neighbor {
Neighbor {
id: id.to_string(),
score,
metadata: Metadata {
fields: HashMap::new(),
},
}
}
#[test]
fn test_mmr_reranker_basic() {
let reranker = MMRReranker::new(0.7);
let results = vec![
make_neighbor("doc1", 0.9), make_neighbor("doc2", 0.7),
make_neighbor("doc3", 0.5),
make_neighbor("doc4", 0.3),
];
let reranked = reranker.rerank("test query", results, 2).unwrap();
assert_eq!(reranked.len(), 2);
assert_eq!(reranked[0].id, "doc1");
}
#[test]
fn test_mmr_reranker_empty() {
let reranker = MMRReranker::new(0.7);
let results = vec![];
let reranked = reranker.rerank("test", results, 10).unwrap();
assert!(reranked.is_empty());
}
#[test]
fn test_mmr_lambda_extremes() {
let reranker = MMRReranker::new(1.0);
let results = vec![
make_neighbor("doc1", 0.5),
make_neighbor("doc2", 0.9), make_neighbor("doc3", 0.3),
];
let reranked = reranker.rerank("test", results, 1).unwrap();
assert_eq!(reranked[0].id, "doc2");
let reranker = MMRReranker::new(0.0);
let results = vec![make_neighbor("doc1", 0.8), make_neighbor("doc2", 0.9)];
let reranked = reranker.rerank("test", results, 2).unwrap();
assert_eq!(reranked.len(), 2);
}
#[test]
fn test_score_reranker() {
let reranker = ScoreReranker::new(|neighbor| {
neighbor.score
});
let results = vec![
make_neighbor("doc1", 0.5),
make_neighbor("doc2", 0.9), make_neighbor("doc3", 0.7),
];
let reranked = reranker.rerank("test", results, 2).unwrap();
assert_eq!(reranked.len(), 2);
assert_eq!(reranked[0].id, "doc2"); assert_eq!(reranked[1].id, "doc3");
}
#[test]
fn test_identity_reranker() {
let reranker = IdentityReranker;
let results = vec![
make_neighbor("doc1", 0.3),
make_neighbor("doc2", 0.1),
make_neighbor("doc3", 0.2),
];
let reranked = reranker.rerank("test", results.clone(), 2).unwrap();
assert_eq!(reranked.len(), 2);
assert_eq!(reranked[0].id, "doc1"); assert_eq!(reranked[1].id, "doc2");
}
#[test]
fn test_reranker_trait() {
let reranker: Box<dyn Reranker> = Box::new(MMRReranker::new(0.7));
assert_eq!(reranker.name(), "MMR (Maximal Marginal Relevance)");
let results = vec![make_neighbor("doc1", 0.1)];
let reranked = reranker.rerank("test", results, 1).unwrap();
assert_eq!(reranked.len(), 1);
}
#[test]
#[should_panic(expected = "lambda must be between 0.0 and 1.0")]
fn test_mmr_invalid_lambda() {
MMRReranker::new(1.5);
}
#[test]
fn test_cross_encoder_fn() {
let reranker = CrossEncoderFn::new(|query: &str, doc: &str| {
let query_words: Vec<&str> = query.split_whitespace().collect();
let doc_words: Vec<&str> = doc.split_whitespace().collect();
let overlap = query_words.iter().filter(|w| doc_words.contains(w)).count();
overlap as f32
});
let mut meta1 = Metadata {
fields: HashMap::new(),
};
meta1.fields.insert(
"text".to_string(),
serde_json::json!("rust programming language"),
);
let mut meta2 = Metadata {
fields: HashMap::new(),
};
meta2
.fields
.insert("text".to_string(), serde_json::json!("python data science"));
let mut meta3 = Metadata {
fields: HashMap::new(),
};
meta3.fields.insert(
"text".to_string(),
serde_json::json!("rust async programming"),
);
let results = vec![
Neighbor {
id: "doc1".to_string(),
score: 0.5,
metadata: meta1,
},
Neighbor {
id: "doc2".to_string(),
score: 0.9,
metadata: meta2,
},
Neighbor {
id: "doc3".to_string(),
score: 0.7,
metadata: meta3,
},
];
let reranked = reranker.rerank("rust programming", results, 2).unwrap();
assert_eq!(reranked.len(), 2);
assert!(reranked[0].id == "doc1" || reranked[0].id == "doc3");
assert_ne!(reranked[0].id, "doc2");
}
#[test]
fn test_cross_encoder_fn_empty_metadata() {
let reranker = CrossEncoderFn::new(|_query: &str, doc: &str| doc.len() as f32);
let results = vec![make_neighbor("doc1", 0.5)];
let reranked = reranker.rerank("test", results, 1).unwrap();
assert_eq!(reranked.len(), 1);
assert_eq!(reranked[0].id, "doc1");
}
#[test]
fn test_rrf_reranker() {
let reranker = RRFReranker::new(60.0);
let list1 = vec![
make_neighbor("doc1", 0.9),
make_neighbor("doc2", 0.8),
make_neighbor("doc3", 0.7),
];
let list2 = vec![
make_neighbor("doc2", 0.95), make_neighbor("doc3", 0.85),
make_neighbor("doc1", 0.75),
];
let fused = reranker.fuse_multiple(vec![list1, list2], 3).unwrap();
assert_eq!(fused.len(), 3);
assert_eq!(fused[0].id, "doc2");
}
#[test]
fn test_rrf_single_list() {
let reranker = RRFReranker::new(60.0);
let results = vec![make_neighbor("doc1", 0.9), make_neighbor("doc2", 0.8)];
let reranked = reranker.rerank("test", results, 2).unwrap();
assert_eq!(reranked.len(), 2);
}
#[test]
fn test_ensemble_reranker() {
let ensemble = EnsembleReranker::new()
.add(Box::new(MMRReranker::new(0.7)), 0.5)
.add(Box::new(IdentityReranker), 0.5);
let results = vec![
make_neighbor("doc1", 0.9),
make_neighbor("doc2", 0.7),
make_neighbor("doc3", 0.5),
];
let reranked = ensemble.rerank("test", results, 2).unwrap();
assert_eq!(reranked.len(), 2);
}
#[test]
fn test_borda_count() {
let reranker = BordaCountReranker::new();
let list1 = vec![
make_neighbor("doc1", 0.9), make_neighbor("doc2", 0.8), make_neighbor("doc3", 0.7), ];
let list2 = vec![
make_neighbor("doc2", 0.95), make_neighbor("doc1", 0.85), make_neighbor("doc3", 0.75), ];
let combined = reranker.combine(vec![list1, list2], 3).unwrap();
assert_eq!(combined.len(), 3);
assert!(combined[2].id == "doc3");
}
#[test]
fn test_contextual_reranker() {
let mut reranker = ContextualReranker::new()
.with_history(vec![
"rust programming".to_string(),
"memory safety".to_string(),
])
.with_context_weight(0.5);
let mut meta1 = Metadata {
fields: HashMap::new(),
};
meta1.fields.insert(
"text".to_string(),
serde_json::json!("rust is great for memory safety"),
);
let mut meta2 = Metadata {
fields: HashMap::new(),
};
meta2
.fields
.insert("text".to_string(), serde_json::json!("python data science"));
let results = vec![
Neighbor {
id: "doc1".to_string(),
score: 0.5,
metadata: meta1,
},
Neighbor {
id: "doc2".to_string(),
score: 0.9,
metadata: meta2,
},
];
let reranked = reranker.rerank("test", results, 2).unwrap();
assert_eq!(reranked.len(), 2);
reranker.add_to_history("ownership".to_string());
assert_eq!(reranker.history.len(), 3);
}
#[test]
fn test_contextual_reranker_no_history() {
let reranker = ContextualReranker::new();
let results = vec![make_neighbor("doc1", 0.9), make_neighbor("doc2", 0.7)];
let reranked = reranker.rerank("test", results, 2).unwrap();
assert_eq!(reranked.len(), 2);
assert_eq!(reranked[0].id, "doc1");
}
#[test]
#[should_panic(expected = "k must be positive")]
fn test_rrf_invalid_k() {
RRFReranker::new(0.0);
}
#[test]
#[should_panic(expected = "context_weight must be between 0.0 and 1.0")]
fn test_contextual_reranker_invalid_weight() {
ContextualReranker::new().with_context_weight(1.5);
}
}