Skip to main content

oxirs_vec/reranking/
reranker.rs

1//! Main cross-encoder re-ranker implementation
2
3use crate::reranking::{
4    cache::RerankingCache,
5    config::{RerankingConfig, RerankingMode},
6    cross_encoder::CrossEncoder,
7    diversity::DiversityReranker,
8    fusion::ScoreFusion,
9    types::{RerankingError, RerankingResult, ScoredCandidate},
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Instant;
14
15/// Statistics for a re-ranking operation
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct RerankingStats {
18    /// Number of candidates processed
19    pub num_candidates: usize,
20
21    /// Number of candidates actually re-ranked
22    pub num_reranked: usize,
23
24    /// Number of cache hits
25    pub cache_hits: usize,
26
27    /// Total time (milliseconds)
28    pub total_time_ms: f64,
29
30    /// Model inference time (milliseconds)
31    pub inference_time_ms: f64,
32
33    /// Score fusion time (milliseconds)
34    pub fusion_time_ms: f64,
35
36    /// Average score change
37    pub avg_score_change: f32,
38
39    /// Rank correlation (Kendall's tau)
40    pub rank_correlation: Option<f32>,
41}
42
43/// Output of a re-ranking operation
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RerankingOutput {
46    /// Re-ranked candidates
47    pub candidates: Vec<ScoredCandidate>,
48
49    /// Statistics
50    pub stats: RerankingStats,
51}
52
53/// Cross-encoder re-ranker
54pub struct CrossEncoderReranker {
55    /// Configuration
56    config: RerankingConfig,
57
58    /// Cross-encoder model
59    encoder: Arc<CrossEncoder>,
60
61    /// Score fusion
62    fusion: Arc<ScoreFusion>,
63
64    /// Diversity re-ranker
65    diversity: Option<Arc<DiversityReranker>>,
66
67    /// Cache
68    cache: Option<Arc<RerankingCache>>,
69}
70
71impl CrossEncoderReranker {
72    /// Create new re-ranker
73    pub fn new(config: RerankingConfig) -> RerankingResult<Self> {
74        config
75            .validate()
76            .map_err(|e| RerankingError::InvalidConfiguration { message: e })?;
77
78        let encoder = Arc::new(CrossEncoder::new(
79            &config.model_name,
80            &config.model_backend,
81        )?);
82        let fusion = Arc::new(ScoreFusion::new(
83            config.fusion_strategy,
84            config.retrieval_weight,
85        ));
86
87        let diversity = if config.enable_diversity {
88            Some(Arc::new(DiversityReranker::new(config.diversity_weight)))
89        } else {
90            None
91        };
92
93        let cache = if config.enable_caching {
94            Some(Arc::new(RerankingCache::new(config.cache_size)))
95        } else {
96            None
97        };
98
99        Ok(Self {
100            config,
101            encoder,
102            fusion,
103            diversity,
104            cache,
105        })
106    }
107
108    /// Re-rank candidates
109    pub fn rerank(
110        &self,
111        query: &str,
112        candidates: &[ScoredCandidate],
113    ) -> RerankingResult<RerankingOutput> {
114        let start = Instant::now();
115
116        // Filter candidates based on mode
117        let candidates_to_rerank = self.select_candidates_for_reranking(candidates);
118
119        let mut stats = RerankingStats {
120            num_candidates: candidates.len(),
121            num_reranked: candidates_to_rerank.len(),
122            ..Default::default()
123        };
124
125        // Check mode
126        if self.config.mode == RerankingMode::Disabled {
127            return Ok(RerankingOutput {
128                candidates: candidates.to_vec(),
129                stats,
130            });
131        }
132
133        // Re-rank with cross-encoder
134        let inference_start = Instant::now();
135        let mut reranked = self.apply_cross_encoder(query, candidates_to_rerank, &mut stats)?;
136        stats.inference_time_ms = inference_start.elapsed().as_secs_f64() * 1000.0;
137
138        // Fuse scores
139        let fusion_start = Instant::now();
140        for candidate in &mut reranked {
141            if let Some(reranking_score) = candidate.reranking_score {
142                candidate.final_score =
143                    self.fusion.fuse(candidate.retrieval_score, reranking_score);
144            }
145        }
146        stats.fusion_time_ms = fusion_start.elapsed().as_secs_f64() * 1000.0;
147
148        // Apply diversity if enabled
149        if let Some(ref diversity) = self.diversity {
150            reranked = diversity.apply_diversity(&reranked)?;
151        }
152
153        // Sort by final score
154        reranked.sort_by(|a, b| {
155            b.final_score
156                .partial_cmp(&a.final_score)
157                .unwrap_or(std::cmp::Ordering::Equal)
158        });
159
160        // Take top-k
161        reranked.truncate(self.config.top_k);
162
163        // Calculate statistics
164        self.calculate_stats(&mut stats, candidates, &reranked);
165        stats.total_time_ms = start.elapsed().as_secs_f64() * 1000.0;
166
167        Ok(RerankingOutput {
168            candidates: reranked,
169            stats,
170        })
171    }
172
173    /// Select candidates for re-ranking based on mode
174    fn select_candidates_for_reranking(
175        &self,
176        candidates: &[ScoredCandidate],
177    ) -> Vec<ScoredCandidate> {
178        let max_candidates = self.config.max_candidates.min(candidates.len());
179
180        match self.config.mode {
181            RerankingMode::Full => candidates.to_vec(),
182            RerankingMode::TopK => candidates[..max_candidates].to_vec(),
183            RerankingMode::Adaptive => {
184                // Use score threshold for adaptive selection
185                let threshold = self.calculate_adaptive_threshold(candidates);
186                candidates
187                    .iter()
188                    .filter(|c| c.retrieval_score >= threshold)
189                    .take(max_candidates)
190                    .cloned()
191                    .collect()
192            }
193            RerankingMode::Disabled => Vec::new(),
194        }
195    }
196
197    /// Calculate adaptive threshold based on score distribution
198    fn calculate_adaptive_threshold(&self, candidates: &[ScoredCandidate]) -> f32 {
199        if candidates.is_empty() {
200            return 0.0;
201        }
202
203        // Use mean - 0.5 * std as threshold
204        let scores: Vec<f32> = candidates.iter().map(|c| c.retrieval_score).collect();
205        let mean = scores.iter().sum::<f32>() / scores.len() as f32;
206        let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
207        let std = variance.sqrt();
208
209        (mean - 0.5 * std).max(0.0)
210    }
211
212    /// Apply cross-encoder to candidates
213    fn apply_cross_encoder(
214        &self,
215        query: &str,
216        candidates: Vec<ScoredCandidate>,
217        stats: &mut RerankingStats,
218    ) -> RerankingResult<Vec<ScoredCandidate>> {
219        let mut reranked = Vec::new();
220
221        // Process in batches
222        for batch in candidates.chunks(self.config.batch_size) {
223            let mut batch_results = Vec::new();
224
225            for candidate in batch {
226                // Check cache first
227                let cache_key = format!("{}:{}", query, candidate.id);
228                let score = if let Some(ref cache) = self.cache {
229                    if let Some(cached_score) = cache.get(&cache_key) {
230                        stats.cache_hits += 1;
231                        cached_score
232                    } else {
233                        let score = self
234                            .encoder
235                            .score(query, candidate.content.as_deref().unwrap_or(""))?;
236                        cache.put(cache_key, score);
237                        score
238                    }
239                } else {
240                    self.encoder
241                        .score(query, candidate.content.as_deref().unwrap_or(""))?
242                };
243
244                let mut updated = candidate.clone();
245                updated.reranking_score = Some(score);
246                batch_results.push(updated);
247            }
248
249            reranked.extend(batch_results);
250        }
251
252        Ok(reranked)
253    }
254
255    /// Calculate additional statistics
256    fn calculate_stats(
257        &self,
258        stats: &mut RerankingStats,
259        original: &[ScoredCandidate],
260        reranked: &[ScoredCandidate],
261    ) {
262        // Calculate average score change
263        let score_changes: Vec<f32> = reranked
264            .iter()
265            .filter_map(|c| c.reranking_score.map(|r| (r - c.retrieval_score).abs()))
266            .collect();
267
268        if !score_changes.is_empty() {
269            stats.avg_score_change = score_changes.iter().sum::<f32>() / score_changes.len() as f32;
270        }
271
272        // Calculate rank correlation (simplified - just check if order changed)
273        if original.len() == reranked.len() && !original.is_empty() {
274            let original_ids: Vec<&String> = original.iter().map(|c| &c.id).collect();
275            let reranked_ids: Vec<&String> = reranked.iter().map(|c| &c.id).collect();
276            let same_order = original_ids == reranked_ids;
277            stats.rank_correlation = Some(if same_order { 1.0 } else { 0.5 });
278        }
279    }
280
281    /// Get configuration
282    pub fn config(&self) -> &RerankingConfig {
283        &self.config
284    }
285
286    /// Clear cache
287    pub fn clear_cache(&self) {
288        if let Some(ref cache) = self.cache {
289            cache.clear();
290        }
291    }
292
293    /// Get cache statistics
294    pub fn cache_stats(&self) -> Option<(usize, usize)> {
295        self.cache.as_ref().map(|c| c.stats())
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::reranking::config::FusionStrategy;
303
304    #[test]
305    fn test_reranking_stats_default() {
306        let stats = RerankingStats::default();
307        assert_eq!(stats.num_candidates, 0);
308        assert_eq!(stats.num_reranked, 0);
309        assert_eq!(stats.cache_hits, 0);
310    }
311
312    #[test]
313    fn test_select_candidates_topk() {
314        let config = RerankingConfig {
315            mode: RerankingMode::TopK,
316            max_candidates: 5,
317            ..RerankingConfig::default_config()
318        };
319
320        let encoder = CrossEncoder::new("dummy", "local").unwrap();
321        let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
322
323        let reranker = CrossEncoderReranker {
324            config,
325            encoder: Arc::new(encoder),
326            fusion: Arc::new(fusion),
327            diversity: None,
328            cache: None,
329        };
330
331        let candidates: Vec<ScoredCandidate> = (0..10)
332            .map(|i| ScoredCandidate::new(format!("doc{}", i), 0.9 - i as f32 * 0.05, i))
333            .collect();
334
335        let selected = reranker.select_candidates_for_reranking(&candidates);
336        assert_eq!(selected.len(), 5);
337    }
338
339    #[test]
340    fn test_adaptive_threshold() {
341        let config = RerankingConfig::default_config();
342        let encoder = CrossEncoder::new("dummy", "local").unwrap();
343        let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
344
345        let reranker = CrossEncoderReranker {
346            config,
347            encoder: Arc::new(encoder),
348            fusion: Arc::new(fusion),
349            diversity: None,
350            cache: None,
351        };
352
353        let candidates = vec![
354            ScoredCandidate::new("doc1", 0.9, 0),
355            ScoredCandidate::new("doc2", 0.8, 1),
356            ScoredCandidate::new("doc3", 0.7, 2),
357            ScoredCandidate::new("doc4", 0.3, 3),
358            ScoredCandidate::new("doc5", 0.2, 4),
359        ];
360
361        let threshold = reranker.calculate_adaptive_threshold(&candidates);
362        assert!(threshold > 0.0);
363        assert!(threshold < 0.9);
364    }
365}