use std::cmp::Ordering;
use std::collections::HashMap;
use async_trait::async_trait;
use chrono::Utc;
use crate::{
error::VectorResult,
types::{RerankerConfig, SearchResult},
};
#[async_trait]
pub trait Reranker {
async fn rerank(
&self,
query: &[f32],
results: Vec<SearchResult>,
) -> VectorResult<Vec<SearchResult>>;
}
pub struct CrossEncoderReranker {
pub service_available: bool,
}
#[async_trait]
impl Reranker for CrossEncoderReranker {
async fn rerank(
&self,
_query: &[f32],
results: Vec<SearchResult>,
) -> VectorResult<Vec<SearchResult>> {
let _ = self.service_available;
Ok(results)
}
}
pub struct DiversityReranker {
pub lambda: f32,
}
#[async_trait]
impl Reranker for DiversityReranker {
async fn rerank(
&self,
query: &[f32],
results: Vec<SearchResult>,
) -> VectorResult<Vec<SearchResult>> {
Ok(mmr_select(query, &results, self.lambda, results.len()))
}
}
pub struct RecencyReranker {
pub recency_weight: f32,
pub half_life_days: f32,
}
#[async_trait]
impl Reranker for RecencyReranker {
async fn rerank(
&self,
_query: &[f32],
mut results: Vec<SearchResult>,
) -> VectorResult<Vec<SearchResult>> {
let now = Utc::now();
let half_life_days = self.half_life_days.max(0.001);
for result in &mut results {
let age_seconds = now
.signed_duration_since(result.created_at)
.num_seconds()
.max(0) as f32;
let age_days = age_seconds / 86_400.0;
let decay_factor = (-age_days / half_life_days).exp();
result.score *= 1.0 + self.recency_weight * decay_factor;
}
sort_results_desc(&mut results);
Ok(results)
}
}
pub struct CompositeReranker(pub Vec<Box<dyn Reranker + Send + Sync>>);
#[async_trait]
impl Reranker for CompositeReranker {
async fn rerank(
&self,
query: &[f32],
mut results: Vec<SearchResult>,
) -> VectorResult<Vec<SearchResult>> {
for reranker in &self.0 {
results = reranker.rerank(query, results).await?;
}
Ok(results)
}
}
pub fn mmr_select(
query: &[f32],
candidates: &[SearchResult],
lambda: f32,
top_k: usize,
) -> Vec<SearchResult> {
if candidates.is_empty() || top_k == 0 {
return Vec::new();
}
let lambda = lambda.clamp(0.0, 1.0);
let mut remaining = candidates.to_vec();
let mut selected = Vec::with_capacity(top_k.min(candidates.len()));
while !remaining.is_empty() && selected.len() < top_k {
let best_index = remaining
.iter()
.enumerate()
.map(|(index, candidate)| {
let relevance = query_relevance(query, candidate);
let max_similarity = selected
.iter()
.map(|selected| candidate_similarity(candidate, selected))
.fold(0.0, f32::max);
let score = lambda * relevance - (1.0 - lambda) * max_similarity;
(index, score)
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
.map(|(index, _)| index)
.unwrap_or(0);
selected.push(remaining.remove(best_index));
}
selected
}
pub async fn apply_reranker_config(
query: &[f32],
results: Vec<SearchResult>,
config: Option<&RerankerConfig>,
) -> VectorResult<Vec<SearchResult>> {
match config {
None | Some(RerankerConfig::None) => Ok(results),
Some(RerankerConfig::Composite(configs)) => {
apply_composite_reranker(query, results, configs).await
}
Some(config) => build_reranker(config).rerank(query, results).await,
}
}
pub fn reranker_needs_vectors(config: Option<&RerankerConfig>) -> bool {
match config {
None | Some(RerankerConfig::None) => false,
Some(RerankerConfig::Diversity { .. }) => true,
Some(RerankerConfig::Recency { .. }) => false,
Some(RerankerConfig::Composite(configs)) => configs
.iter()
.any(|config| reranker_needs_vectors(Some(config))),
}
}
fn build_reranker(config: &RerankerConfig) -> Box<dyn Reranker + Send + Sync> {
match config {
RerankerConfig::None => Box::new(CompositeReranker(Vec::new())),
RerankerConfig::Diversity { lambda, .. } => Box::new(DiversityReranker { lambda: *lambda }),
RerankerConfig::Recency {
boost,
half_life_days,
..
} => Box::new(RecencyReranker {
recency_weight: *boost,
half_life_days: *half_life_days,
}),
RerankerConfig::Composite(configs) => Box::new(CompositeReranker(
configs.iter().map(build_reranker).collect(),
)),
}
}
async fn apply_composite_reranker(
query: &[f32],
results: Vec<SearchResult>,
configs: &[RerankerConfig],
) -> VectorResult<Vec<SearchResult>> {
if configs.is_empty() {
return Ok(results);
}
let mut current = results;
let mut aggregate_scores: HashMap<uuid::Uuid, f32> = HashMap::new();
let mut by_id: HashMap<uuid::Uuid, SearchResult> = HashMap::new();
let mut total_weight = 0.0f32;
for config in configs {
let reranked = build_reranker(config).rerank(query, current).await?;
let normalized = normalize_scores(reranked);
let stage_weight = stage_weight(config);
total_weight += stage_weight;
current = normalized.clone();
for result in normalized {
*aggregate_scores.entry(result.id).or_insert(0.0) += stage_weight * result.score;
by_id.insert(result.id, result);
}
}
let mut final_results = by_id
.into_iter()
.filter_map(|(id, mut result)| {
let denominator = if total_weight <= 0.0 {
1.0
} else {
total_weight
};
let final_score = aggregate_scores.get(&id).copied()? / denominator;
result.score = final_score;
Some(result)
})
.collect::<Vec<_>>();
sort_results_desc(&mut final_results);
Ok(final_results)
}
fn normalize_scores(mut results: Vec<SearchResult>) -> Vec<SearchResult> {
if results.is_empty() {
return results;
}
let min_score = results
.iter()
.map(|result| result.score)
.fold(f32::INFINITY, f32::min);
let max_score = results
.iter()
.map(|result| result.score)
.fold(f32::NEG_INFINITY, f32::max);
let epsilon = 1e-9_f32;
let range = max_score - min_score;
if range.abs() < epsilon {
for result in &mut results {
result.score = 1.0;
}
return results;
}
for result in &mut results {
result.score = ((result.score - min_score) / (range + epsilon)).clamp(0.0, 1.0);
}
results
}
fn stage_weight(config: &RerankerConfig) -> f32 {
match config {
RerankerConfig::None => 0.0,
RerankerConfig::Diversity { weight, .. } => *weight,
RerankerConfig::Recency { weight, .. } => *weight,
RerankerConfig::Composite(_) => 1.0,
}
}
fn query_relevance(query: &[f32], candidate: &SearchResult) -> f32 {
if let Some(vector) = candidate.vector.as_deref() {
cosine_similarity(query, vector)
.map(|similarity| ((similarity + 1.0) / 2.0).clamp(0.0, 1.0))
.unwrap_or(candidate.score)
} else {
candidate.score
}
}
fn candidate_similarity(left: &SearchResult, right: &SearchResult) -> f32 {
match (left.vector.as_deref(), right.vector.as_deref()) {
(Some(left), Some(right)) => cosine_similarity(left, right)
.map(|similarity| ((similarity + 1.0) / 2.0).clamp(0.0, 1.0))
.unwrap_or(0.0),
_ => 0.0,
}
}
fn cosine_similarity(left: &[f32], right: &[f32]) -> Option<f32> {
if left.len() != right.len() || left.is_empty() {
return None;
}
let dot: f32 = left.iter().zip(right.iter()).map(|(a, b)| a * b).sum();
let left_norm = left.iter().map(|value| value * value).sum::<f32>().sqrt();
let right_norm = right.iter().map(|value| value * value).sum::<f32>().sqrt();
if left_norm == 0.0 || right_norm == 0.0 {
None
} else {
Some(dot / (left_norm * right_norm))
}
}
fn sort_results_desc(results: &mut [SearchResult]) {
results.sort_by(|left, right| {
right
.score
.partial_cmp(&left.score)
.unwrap_or(Ordering::Equal)
});
}
#[cfg(test)]
mod tests {
use chrono::{Duration, Utc};
use super::apply_reranker_config;
use crate::types::{RerankerConfig, SearchResult};
fn fixture_result(
id: uuid::Uuid,
vector: Vec<f32>,
created_at: chrono::DateTime<Utc>,
) -> SearchResult {
SearchResult {
id,
score: 0.5,
vector: Some(vector),
metadata: serde_json::json!({}),
text: None,
created_at,
}
}
#[tokio::test]
async fn composite_diversity_then_recency_is_deterministic() {
let now = Utc::now();
let first_id = uuid::Uuid::parse_str("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa").unwrap();
let second_id = uuid::Uuid::parse_str("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb").unwrap();
let third_id = uuid::Uuid::parse_str("cccccccc-cccc-cccc-cccc-cccccccccccc").unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = vec![
fixture_result(first_id, vec![0.95, 0.05, 0.0], now),
fixture_result(second_id, vec![0.5, 0.5, 0.0], now - Duration::days(1)),
fixture_result(third_id, vec![0.2, 0.8, 0.0], now - Duration::days(30)),
];
let config = RerankerConfig::Composite(vec![
RerankerConfig::Diversity {
lambda: 0.8,
weight: 0.6,
},
RerankerConfig::Recency {
boost: 0.5,
half_life_days: 7.0,
weight: 0.4,
},
]);
let reranked = apply_reranker_config(&query, results, Some(&config))
.await
.unwrap();
let ids = reranked.into_iter().map(|item| item.id).collect::<Vec<_>>();
assert_eq!(ids, vec![first_id, second_id, third_id]);
}
}