use crate::reranking::{
cache::RerankingCache,
config::{RerankingConfig, RerankingMode},
cross_encoder::CrossEncoder,
diversity::DiversityReranker,
fusion::ScoreFusion,
types::{RerankingError, RerankingResult, ScoredCandidate},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RerankingStats {
pub num_candidates: usize,
pub num_reranked: usize,
pub cache_hits: usize,
pub total_time_ms: f64,
pub inference_time_ms: f64,
pub fusion_time_ms: f64,
pub avg_score_change: f32,
pub rank_correlation: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankingOutput {
pub candidates: Vec<ScoredCandidate>,
pub stats: RerankingStats,
}
pub struct CrossEncoderReranker {
config: RerankingConfig,
encoder: Arc<CrossEncoder>,
fusion: Arc<ScoreFusion>,
diversity: Option<Arc<DiversityReranker>>,
cache: Option<Arc<RerankingCache>>,
}
impl CrossEncoderReranker {
pub fn new(config: RerankingConfig) -> RerankingResult<Self> {
config
.validate()
.map_err(|e| RerankingError::InvalidConfiguration { message: e })?;
let encoder = Arc::new(CrossEncoder::new(
&config.model_name,
&config.model_backend,
)?);
let fusion = Arc::new(ScoreFusion::new(
config.fusion_strategy,
config.retrieval_weight,
));
let diversity = if config.enable_diversity {
Some(Arc::new(DiversityReranker::new(config.diversity_weight)))
} else {
None
};
let cache = if config.enable_caching {
Some(Arc::new(RerankingCache::new(config.cache_size)))
} else {
None
};
Ok(Self {
config,
encoder,
fusion,
diversity,
cache,
})
}
pub fn rerank(
&self,
query: &str,
candidates: &[ScoredCandidate],
) -> RerankingResult<RerankingOutput> {
let start = Instant::now();
let candidates_to_rerank = self.select_candidates_for_reranking(candidates);
let mut stats = RerankingStats {
num_candidates: candidates.len(),
num_reranked: candidates_to_rerank.len(),
..Default::default()
};
if self.config.mode == RerankingMode::Disabled {
return Ok(RerankingOutput {
candidates: candidates.to_vec(),
stats,
});
}
let inference_start = Instant::now();
let mut reranked = self.apply_cross_encoder(query, candidates_to_rerank, &mut stats)?;
stats.inference_time_ms = inference_start.elapsed().as_secs_f64() * 1000.0;
let fusion_start = Instant::now();
for candidate in &mut reranked {
if let Some(reranking_score) = candidate.reranking_score {
candidate.final_score =
self.fusion.fuse(candidate.retrieval_score, reranking_score);
}
}
stats.fusion_time_ms = fusion_start.elapsed().as_secs_f64() * 1000.0;
if let Some(ref diversity) = self.diversity {
reranked = diversity.apply_diversity(&reranked)?;
}
reranked.sort_by(|a, b| {
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
reranked.truncate(self.config.top_k);
self.calculate_stats(&mut stats, candidates, &reranked);
stats.total_time_ms = start.elapsed().as_secs_f64() * 1000.0;
Ok(RerankingOutput {
candidates: reranked,
stats,
})
}
fn select_candidates_for_reranking(
&self,
candidates: &[ScoredCandidate],
) -> Vec<ScoredCandidate> {
let max_candidates = self.config.max_candidates.min(candidates.len());
match self.config.mode {
RerankingMode::Full => candidates.to_vec(),
RerankingMode::TopK => candidates[..max_candidates].to_vec(),
RerankingMode::Adaptive => {
let threshold = self.calculate_adaptive_threshold(candidates);
candidates
.iter()
.filter(|c| c.retrieval_score >= threshold)
.take(max_candidates)
.cloned()
.collect()
}
RerankingMode::Disabled => Vec::new(),
}
}
fn calculate_adaptive_threshold(&self, candidates: &[ScoredCandidate]) -> f32 {
if candidates.is_empty() {
return 0.0;
}
let scores: Vec<f32> = candidates.iter().map(|c| c.retrieval_score).collect();
let mean = scores.iter().sum::<f32>() / scores.len() as f32;
let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
let std = variance.sqrt();
(mean - 0.5 * std).max(0.0)
}
fn apply_cross_encoder(
&self,
query: &str,
candidates: Vec<ScoredCandidate>,
stats: &mut RerankingStats,
) -> RerankingResult<Vec<ScoredCandidate>> {
let mut reranked = Vec::new();
for batch in candidates.chunks(self.config.batch_size) {
let mut batch_results = Vec::new();
for candidate in batch {
let cache_key = format!("{}:{}", query, candidate.id);
let score = if let Some(ref cache) = self.cache {
if let Some(cached_score) = cache.get(&cache_key) {
stats.cache_hits += 1;
cached_score
} else {
let score = self
.encoder
.score(query, candidate.content.as_deref().unwrap_or(""))?;
cache.put(cache_key, score);
score
}
} else {
self.encoder
.score(query, candidate.content.as_deref().unwrap_or(""))?
};
let mut updated = candidate.clone();
updated.reranking_score = Some(score);
batch_results.push(updated);
}
reranked.extend(batch_results);
}
Ok(reranked)
}
fn calculate_stats(
&self,
stats: &mut RerankingStats,
original: &[ScoredCandidate],
reranked: &[ScoredCandidate],
) {
let score_changes: Vec<f32> = reranked
.iter()
.filter_map(|c| c.reranking_score.map(|r| (r - c.retrieval_score).abs()))
.collect();
if !score_changes.is_empty() {
stats.avg_score_change = score_changes.iter().sum::<f32>() / score_changes.len() as f32;
}
if original.len() == reranked.len() && !original.is_empty() {
let original_ids: Vec<&String> = original.iter().map(|c| &c.id).collect();
let reranked_ids: Vec<&String> = reranked.iter().map(|c| &c.id).collect();
let same_order = original_ids == reranked_ids;
stats.rank_correlation = Some(if same_order { 1.0 } else { 0.5 });
}
}
pub fn config(&self) -> &RerankingConfig {
&self.config
}
pub fn clear_cache(&self) {
if let Some(ref cache) = self.cache {
cache.clear();
}
}
pub fn cache_stats(&self) -> Option<(usize, usize)> {
self.cache.as_ref().map(|c| c.stats())
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
use crate::reranking::config::FusionStrategy;
#[test]
fn test_reranking_stats_default() {
let stats = RerankingStats::default();
assert_eq!(stats.num_candidates, 0);
assert_eq!(stats.num_reranked, 0);
assert_eq!(stats.cache_hits, 0);
}
#[test]
fn test_select_candidates_topk() -> Result<()> {
let config = RerankingConfig {
mode: RerankingMode::TopK,
max_candidates: 5,
..RerankingConfig::default_config()
};
let encoder = CrossEncoder::new("dummy", "local")?;
let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
let reranker = CrossEncoderReranker {
config,
encoder: Arc::new(encoder),
fusion: Arc::new(fusion),
diversity: None,
cache: None,
};
let candidates: Vec<ScoredCandidate> = (0..10)
.map(|i| ScoredCandidate::new(format!("doc{}", i), 0.9 - i as f32 * 0.05, i))
.collect();
let selected = reranker.select_candidates_for_reranking(&candidates);
assert_eq!(selected.len(), 5);
Ok(())
}
#[test]
fn test_adaptive_threshold() -> Result<()> {
let config = RerankingConfig::default_config();
let encoder = CrossEncoder::new("dummy", "local")?;
let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
let reranker = CrossEncoderReranker {
config,
encoder: Arc::new(encoder),
fusion: Arc::new(fusion),
diversity: None,
cache: None,
};
let candidates = vec![
ScoredCandidate::new("doc1", 0.9, 0),
ScoredCandidate::new("doc2", 0.8, 1),
ScoredCandidate::new("doc3", 0.7, 2),
ScoredCandidate::new("doc4", 0.3, 3),
ScoredCandidate::new("doc5", 0.2, 4),
];
let threshold = reranker.calculate_adaptive_threshold(&candidates);
assert!(threshold > 0.0);
assert!(threshold < 0.9);
Ok(())
}
}