1use crate::reranking::{
4 cache::RerankingCache,
5 config::{RerankingConfig, RerankingMode},
6 cross_encoder::CrossEncoder,
7 diversity::DiversityReranker,
8 fusion::ScoreFusion,
9 types::{RerankingError, RerankingResult, ScoredCandidate},
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Instant;
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct RerankingStats {
18 pub num_candidates: usize,
20
21 pub num_reranked: usize,
23
24 pub cache_hits: usize,
26
27 pub total_time_ms: f64,
29
30 pub inference_time_ms: f64,
32
33 pub fusion_time_ms: f64,
35
36 pub avg_score_change: f32,
38
39 pub rank_correlation: Option<f32>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RerankingOutput {
46 pub candidates: Vec<ScoredCandidate>,
48
49 pub stats: RerankingStats,
51}
52
53pub struct CrossEncoderReranker {
55 config: RerankingConfig,
57
58 encoder: Arc<CrossEncoder>,
60
61 fusion: Arc<ScoreFusion>,
63
64 diversity: Option<Arc<DiversityReranker>>,
66
67 cache: Option<Arc<RerankingCache>>,
69}
70
71impl CrossEncoderReranker {
72 pub fn new(config: RerankingConfig) -> RerankingResult<Self> {
74 config
75 .validate()
76 .map_err(|e| RerankingError::InvalidConfiguration { message: e })?;
77
78 let encoder = Arc::new(CrossEncoder::new(
79 &config.model_name,
80 &config.model_backend,
81 )?);
82 let fusion = Arc::new(ScoreFusion::new(
83 config.fusion_strategy,
84 config.retrieval_weight,
85 ));
86
87 let diversity = if config.enable_diversity {
88 Some(Arc::new(DiversityReranker::new(config.diversity_weight)))
89 } else {
90 None
91 };
92
93 let cache = if config.enable_caching {
94 Some(Arc::new(RerankingCache::new(config.cache_size)))
95 } else {
96 None
97 };
98
99 Ok(Self {
100 config,
101 encoder,
102 fusion,
103 diversity,
104 cache,
105 })
106 }
107
108 pub fn rerank(
110 &self,
111 query: &str,
112 candidates: &[ScoredCandidate],
113 ) -> RerankingResult<RerankingOutput> {
114 let start = Instant::now();
115
116 let candidates_to_rerank = self.select_candidates_for_reranking(candidates);
118
119 let mut stats = RerankingStats {
120 num_candidates: candidates.len(),
121 num_reranked: candidates_to_rerank.len(),
122 ..Default::default()
123 };
124
125 if self.config.mode == RerankingMode::Disabled {
127 return Ok(RerankingOutput {
128 candidates: candidates.to_vec(),
129 stats,
130 });
131 }
132
133 let inference_start = Instant::now();
135 let mut reranked = self.apply_cross_encoder(query, candidates_to_rerank, &mut stats)?;
136 stats.inference_time_ms = inference_start.elapsed().as_secs_f64() * 1000.0;
137
138 let fusion_start = Instant::now();
140 for candidate in &mut reranked {
141 if let Some(reranking_score) = candidate.reranking_score {
142 candidate.final_score =
143 self.fusion.fuse(candidate.retrieval_score, reranking_score);
144 }
145 }
146 stats.fusion_time_ms = fusion_start.elapsed().as_secs_f64() * 1000.0;
147
148 if let Some(ref diversity) = self.diversity {
150 reranked = diversity.apply_diversity(&reranked)?;
151 }
152
153 reranked.sort_by(|a, b| {
155 b.final_score
156 .partial_cmp(&a.final_score)
157 .unwrap_or(std::cmp::Ordering::Equal)
158 });
159
160 reranked.truncate(self.config.top_k);
162
163 self.calculate_stats(&mut stats, candidates, &reranked);
165 stats.total_time_ms = start.elapsed().as_secs_f64() * 1000.0;
166
167 Ok(RerankingOutput {
168 candidates: reranked,
169 stats,
170 })
171 }
172
173 fn select_candidates_for_reranking(
175 &self,
176 candidates: &[ScoredCandidate],
177 ) -> Vec<ScoredCandidate> {
178 let max_candidates = self.config.max_candidates.min(candidates.len());
179
180 match self.config.mode {
181 RerankingMode::Full => candidates.to_vec(),
182 RerankingMode::TopK => candidates[..max_candidates].to_vec(),
183 RerankingMode::Adaptive => {
184 let threshold = self.calculate_adaptive_threshold(candidates);
186 candidates
187 .iter()
188 .filter(|c| c.retrieval_score >= threshold)
189 .take(max_candidates)
190 .cloned()
191 .collect()
192 }
193 RerankingMode::Disabled => Vec::new(),
194 }
195 }
196
197 fn calculate_adaptive_threshold(&self, candidates: &[ScoredCandidate]) -> f32 {
199 if candidates.is_empty() {
200 return 0.0;
201 }
202
203 let scores: Vec<f32> = candidates.iter().map(|c| c.retrieval_score).collect();
205 let mean = scores.iter().sum::<f32>() / scores.len() as f32;
206 let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
207 let std = variance.sqrt();
208
209 (mean - 0.5 * std).max(0.0)
210 }
211
212 fn apply_cross_encoder(
214 &self,
215 query: &str,
216 candidates: Vec<ScoredCandidate>,
217 stats: &mut RerankingStats,
218 ) -> RerankingResult<Vec<ScoredCandidate>> {
219 let mut reranked = Vec::new();
220
221 for batch in candidates.chunks(self.config.batch_size) {
223 let mut batch_results = Vec::new();
224
225 for candidate in batch {
226 let cache_key = format!("{}:{}", query, candidate.id);
228 let score = if let Some(ref cache) = self.cache {
229 if let Some(cached_score) = cache.get(&cache_key) {
230 stats.cache_hits += 1;
231 cached_score
232 } else {
233 let score = self
234 .encoder
235 .score(query, candidate.content.as_deref().unwrap_or(""))?;
236 cache.put(cache_key, score);
237 score
238 }
239 } else {
240 self.encoder
241 .score(query, candidate.content.as_deref().unwrap_or(""))?
242 };
243
244 let mut updated = candidate.clone();
245 updated.reranking_score = Some(score);
246 batch_results.push(updated);
247 }
248
249 reranked.extend(batch_results);
250 }
251
252 Ok(reranked)
253 }
254
255 fn calculate_stats(
257 &self,
258 stats: &mut RerankingStats,
259 original: &[ScoredCandidate],
260 reranked: &[ScoredCandidate],
261 ) {
262 let score_changes: Vec<f32> = reranked
264 .iter()
265 .filter_map(|c| c.reranking_score.map(|r| (r - c.retrieval_score).abs()))
266 .collect();
267
268 if !score_changes.is_empty() {
269 stats.avg_score_change = score_changes.iter().sum::<f32>() / score_changes.len() as f32;
270 }
271
272 if original.len() == reranked.len() && !original.is_empty() {
274 let original_ids: Vec<&String> = original.iter().map(|c| &c.id).collect();
275 let reranked_ids: Vec<&String> = reranked.iter().map(|c| &c.id).collect();
276 let same_order = original_ids == reranked_ids;
277 stats.rank_correlation = Some(if same_order { 1.0 } else { 0.5 });
278 }
279 }
280
281 pub fn config(&self) -> &RerankingConfig {
283 &self.config
284 }
285
286 pub fn clear_cache(&self) {
288 if let Some(ref cache) = self.cache {
289 cache.clear();
290 }
291 }
292
293 pub fn cache_stats(&self) -> Option<(usize, usize)> {
295 self.cache.as_ref().map(|c| c.stats())
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::reranking::config::FusionStrategy;
303
304 #[test]
305 fn test_reranking_stats_default() {
306 let stats = RerankingStats::default();
307 assert_eq!(stats.num_candidates, 0);
308 assert_eq!(stats.num_reranked, 0);
309 assert_eq!(stats.cache_hits, 0);
310 }
311
312 #[test]
313 fn test_select_candidates_topk() {
314 let config = RerankingConfig {
315 mode: RerankingMode::TopK,
316 max_candidates: 5,
317 ..RerankingConfig::default_config()
318 };
319
320 let encoder = CrossEncoder::new("dummy", "local").unwrap();
321 let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
322
323 let reranker = CrossEncoderReranker {
324 config,
325 encoder: Arc::new(encoder),
326 fusion: Arc::new(fusion),
327 diversity: None,
328 cache: None,
329 };
330
331 let candidates: Vec<ScoredCandidate> = (0..10)
332 .map(|i| ScoredCandidate::new(format!("doc{}", i), 0.9 - i as f32 * 0.05, i))
333 .collect();
334
335 let selected = reranker.select_candidates_for_reranking(&candidates);
336 assert_eq!(selected.len(), 5);
337 }
338
339 #[test]
340 fn test_adaptive_threshold() {
341 let config = RerankingConfig::default_config();
342 let encoder = CrossEncoder::new("dummy", "local").unwrap();
343 let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
344
345 let reranker = CrossEncoderReranker {
346 config,
347 encoder: Arc::new(encoder),
348 fusion: Arc::new(fusion),
349 diversity: None,
350 cache: None,
351 };
352
353 let candidates = vec![
354 ScoredCandidate::new("doc1", 0.9, 0),
355 ScoredCandidate::new("doc2", 0.8, 1),
356 ScoredCandidate::new("doc3", 0.7, 2),
357 ScoredCandidate::new("doc4", 0.3, 3),
358 ScoredCandidate::new("doc5", 0.2, 4),
359 ];
360
361 let threshold = reranker.calculate_adaptive_threshold(&candidates);
362 assert!(threshold > 0.0);
363 assert!(threshold < 0.9);
364 }
365}