1use crate::reranking::{
4 cache::RerankingCache,
5 config::{RerankingConfig, RerankingMode},
6 cross_encoder::CrossEncoder,
7 diversity::DiversityReranker,
8 fusion::ScoreFusion,
9 types::{RerankingError, RerankingResult, ScoredCandidate},
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Instant;
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct RerankingStats {
18 pub num_candidates: usize,
20
21 pub num_reranked: usize,
23
24 pub cache_hits: usize,
26
27 pub total_time_ms: f64,
29
30 pub inference_time_ms: f64,
32
33 pub fusion_time_ms: f64,
35
36 pub avg_score_change: f32,
38
39 pub rank_correlation: Option<f32>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RerankingOutput {
46 pub candidates: Vec<ScoredCandidate>,
48
49 pub stats: RerankingStats,
51}
52
53pub struct CrossEncoderReranker {
55 config: RerankingConfig,
57
58 encoder: Arc<CrossEncoder>,
60
61 fusion: Arc<ScoreFusion>,
63
64 diversity: Option<Arc<DiversityReranker>>,
66
67 cache: Option<Arc<RerankingCache>>,
69}
70
71impl CrossEncoderReranker {
72 pub fn new(config: RerankingConfig) -> RerankingResult<Self> {
74 config
75 .validate()
76 .map_err(|e| RerankingError::InvalidConfiguration { message: e })?;
77
78 let encoder = Arc::new(CrossEncoder::new(
79 &config.model_name,
80 &config.model_backend,
81 )?);
82 let fusion = Arc::new(ScoreFusion::new(
83 config.fusion_strategy,
84 config.retrieval_weight,
85 ));
86
87 let diversity = if config.enable_diversity {
88 Some(Arc::new(DiversityReranker::new(config.diversity_weight)))
89 } else {
90 None
91 };
92
93 let cache = if config.enable_caching {
94 Some(Arc::new(RerankingCache::new(config.cache_size)))
95 } else {
96 None
97 };
98
99 Ok(Self {
100 config,
101 encoder,
102 fusion,
103 diversity,
104 cache,
105 })
106 }
107
108 pub fn rerank(
110 &self,
111 query: &str,
112 candidates: &[ScoredCandidate],
113 ) -> RerankingResult<RerankingOutput> {
114 let start = Instant::now();
115
116 let candidates_to_rerank = self.select_candidates_for_reranking(candidates);
118
119 let mut stats = RerankingStats {
120 num_candidates: candidates.len(),
121 num_reranked: candidates_to_rerank.len(),
122 ..Default::default()
123 };
124
125 if self.config.mode == RerankingMode::Disabled {
127 return Ok(RerankingOutput {
128 candidates: candidates.to_vec(),
129 stats,
130 });
131 }
132
133 let inference_start = Instant::now();
135 let mut reranked = self.apply_cross_encoder(query, candidates_to_rerank, &mut stats)?;
136 stats.inference_time_ms = inference_start.elapsed().as_secs_f64() * 1000.0;
137
138 let fusion_start = Instant::now();
140 for candidate in &mut reranked {
141 if let Some(reranking_score) = candidate.reranking_score {
142 candidate.final_score =
143 self.fusion.fuse(candidate.retrieval_score, reranking_score);
144 }
145 }
146 stats.fusion_time_ms = fusion_start.elapsed().as_secs_f64() * 1000.0;
147
148 if let Some(ref diversity) = self.diversity {
150 reranked = diversity.apply_diversity(&reranked)?;
151 }
152
153 reranked.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
155
156 reranked.truncate(self.config.top_k);
158
159 self.calculate_stats(&mut stats, candidates, &reranked);
161 stats.total_time_ms = start.elapsed().as_secs_f64() * 1000.0;
162
163 Ok(RerankingOutput {
164 candidates: reranked,
165 stats,
166 })
167 }
168
169 fn select_candidates_for_reranking(
171 &self,
172 candidates: &[ScoredCandidate],
173 ) -> Vec<ScoredCandidate> {
174 let max_candidates = self.config.max_candidates.min(candidates.len());
175
176 match self.config.mode {
177 RerankingMode::Full => candidates.to_vec(),
178 RerankingMode::TopK => candidates[..max_candidates].to_vec(),
179 RerankingMode::Adaptive => {
180 let threshold = self.calculate_adaptive_threshold(candidates);
182 candidates
183 .iter()
184 .filter(|c| c.retrieval_score >= threshold)
185 .take(max_candidates)
186 .cloned()
187 .collect()
188 }
189 RerankingMode::Disabled => Vec::new(),
190 }
191 }
192
193 fn calculate_adaptive_threshold(&self, candidates: &[ScoredCandidate]) -> f32 {
195 if candidates.is_empty() {
196 return 0.0;
197 }
198
199 let scores: Vec<f32> = candidates.iter().map(|c| c.retrieval_score).collect();
201 let mean = scores.iter().sum::<f32>() / scores.len() as f32;
202 let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
203 let std = variance.sqrt();
204
205 (mean - 0.5 * std).max(0.0)
206 }
207
208 fn apply_cross_encoder(
210 &self,
211 query: &str,
212 candidates: Vec<ScoredCandidate>,
213 stats: &mut RerankingStats,
214 ) -> RerankingResult<Vec<ScoredCandidate>> {
215 let mut reranked = Vec::new();
216
217 for batch in candidates.chunks(self.config.batch_size) {
219 let mut batch_results = Vec::new();
220
221 for candidate in batch {
222 let cache_key = format!("{}:{}", query, candidate.id);
224 let score = if let Some(ref cache) = self.cache {
225 if let Some(cached_score) = cache.get(&cache_key) {
226 stats.cache_hits += 1;
227 cached_score
228 } else {
229 let score = self
230 .encoder
231 .score(query, candidate.content.as_deref().unwrap_or(""))?;
232 cache.put(cache_key, score);
233 score
234 }
235 } else {
236 self.encoder
237 .score(query, candidate.content.as_deref().unwrap_or(""))?
238 };
239
240 let mut updated = candidate.clone();
241 updated.reranking_score = Some(score);
242 batch_results.push(updated);
243 }
244
245 reranked.extend(batch_results);
246 }
247
248 Ok(reranked)
249 }
250
251 fn calculate_stats(
253 &self,
254 stats: &mut RerankingStats,
255 original: &[ScoredCandidate],
256 reranked: &[ScoredCandidate],
257 ) {
258 let score_changes: Vec<f32> = reranked
260 .iter()
261 .filter_map(|c| c.reranking_score.map(|r| (r - c.retrieval_score).abs()))
262 .collect();
263
264 if !score_changes.is_empty() {
265 stats.avg_score_change = score_changes.iter().sum::<f32>() / score_changes.len() as f32;
266 }
267
268 if original.len() == reranked.len() && !original.is_empty() {
270 let original_ids: Vec<&String> = original.iter().map(|c| &c.id).collect();
271 let reranked_ids: Vec<&String> = reranked.iter().map(|c| &c.id).collect();
272 let same_order = original_ids == reranked_ids;
273 stats.rank_correlation = Some(if same_order { 1.0 } else { 0.5 });
274 }
275 }
276
277 pub fn config(&self) -> &RerankingConfig {
279 &self.config
280 }
281
282 pub fn clear_cache(&self) {
284 if let Some(ref cache) = self.cache {
285 cache.clear();
286 }
287 }
288
289 pub fn cache_stats(&self) -> Option<(usize, usize)> {
291 self.cache.as_ref().map(|c| c.stats())
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::reranking::config::FusionStrategy;
299
300 #[test]
301 fn test_reranking_stats_default() {
302 let stats = RerankingStats::default();
303 assert_eq!(stats.num_candidates, 0);
304 assert_eq!(stats.num_reranked, 0);
305 assert_eq!(stats.cache_hits, 0);
306 }
307
308 #[test]
309 fn test_select_candidates_topk() {
310 let config = RerankingConfig {
311 mode: RerankingMode::TopK,
312 max_candidates: 5,
313 ..RerankingConfig::default_config()
314 };
315
316 let encoder = CrossEncoder::new("dummy", "local").unwrap();
317 let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
318
319 let reranker = CrossEncoderReranker {
320 config,
321 encoder: Arc::new(encoder),
322 fusion: Arc::new(fusion),
323 diversity: None,
324 cache: None,
325 };
326
327 let candidates: Vec<ScoredCandidate> = (0..10)
328 .map(|i| ScoredCandidate::new(format!("doc{}", i), 0.9 - i as f32 * 0.05, i))
329 .collect();
330
331 let selected = reranker.select_candidates_for_reranking(&candidates);
332 assert_eq!(selected.len(), 5);
333 }
334
335 #[test]
336 fn test_adaptive_threshold() {
337 let config = RerankingConfig::default_config();
338 let encoder = CrossEncoder::new("dummy", "local").unwrap();
339 let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
340
341 let reranker = CrossEncoderReranker {
342 config,
343 encoder: Arc::new(encoder),
344 fusion: Arc::new(fusion),
345 diversity: None,
346 cache: None,
347 };
348
349 let candidates = vec![
350 ScoredCandidate::new("doc1", 0.9, 0),
351 ScoredCandidate::new("doc2", 0.8, 1),
352 ScoredCandidate::new("doc3", 0.7, 2),
353 ScoredCandidate::new("doc4", 0.3, 3),
354 ScoredCandidate::new("doc5", 0.2, 4),
355 ];
356
357 let threshold = reranker.calculate_adaptive_threshold(&candidates);
358 assert!(threshold > 0.0);
359 assert!(threshold < 0.9);
360 }
361}