oxirs_vec/reranking/
reranker.rs

1//! Main cross-encoder re-ranker implementation
2
3use crate::reranking::{
4    cache::RerankingCache,
5    config::{RerankingConfig, RerankingMode},
6    cross_encoder::CrossEncoder,
7    diversity::DiversityReranker,
8    fusion::ScoreFusion,
9    types::{RerankingError, RerankingResult, ScoredCandidate},
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Instant;
14
15/// Statistics for a re-ranking operation
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct RerankingStats {
18    /// Number of candidates processed
19    pub num_candidates: usize,
20
21    /// Number of candidates actually re-ranked
22    pub num_reranked: usize,
23
24    /// Number of cache hits
25    pub cache_hits: usize,
26
27    /// Total time (milliseconds)
28    pub total_time_ms: f64,
29
30    /// Model inference time (milliseconds)
31    pub inference_time_ms: f64,
32
33    /// Score fusion time (milliseconds)
34    pub fusion_time_ms: f64,
35
36    /// Average score change
37    pub avg_score_change: f32,
38
39    /// Rank correlation (Kendall's tau)
40    pub rank_correlation: Option<f32>,
41}
42
43/// Output of a re-ranking operation
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct RerankingOutput {
46    /// Re-ranked candidates
47    pub candidates: Vec<ScoredCandidate>,
48
49    /// Statistics
50    pub stats: RerankingStats,
51}
52
53/// Cross-encoder re-ranker
54pub struct CrossEncoderReranker {
55    /// Configuration
56    config: RerankingConfig,
57
58    /// Cross-encoder model
59    encoder: Arc<CrossEncoder>,
60
61    /// Score fusion
62    fusion: Arc<ScoreFusion>,
63
64    /// Diversity re-ranker
65    diversity: Option<Arc<DiversityReranker>>,
66
67    /// Cache
68    cache: Option<Arc<RerankingCache>>,
69}
70
71impl CrossEncoderReranker {
72    /// Create new re-ranker
73    pub fn new(config: RerankingConfig) -> RerankingResult<Self> {
74        config
75            .validate()
76            .map_err(|e| RerankingError::InvalidConfiguration { message: e })?;
77
78        let encoder = Arc::new(CrossEncoder::new(
79            &config.model_name,
80            &config.model_backend,
81        )?);
82        let fusion = Arc::new(ScoreFusion::new(
83            config.fusion_strategy,
84            config.retrieval_weight,
85        ));
86
87        let diversity = if config.enable_diversity {
88            Some(Arc::new(DiversityReranker::new(config.diversity_weight)))
89        } else {
90            None
91        };
92
93        let cache = if config.enable_caching {
94            Some(Arc::new(RerankingCache::new(config.cache_size)))
95        } else {
96            None
97        };
98
99        Ok(Self {
100            config,
101            encoder,
102            fusion,
103            diversity,
104            cache,
105        })
106    }
107
108    /// Re-rank candidates
109    pub fn rerank(
110        &self,
111        query: &str,
112        candidates: &[ScoredCandidate],
113    ) -> RerankingResult<RerankingOutput> {
114        let start = Instant::now();
115
116        // Filter candidates based on mode
117        let candidates_to_rerank = self.select_candidates_for_reranking(candidates);
118
119        let mut stats = RerankingStats {
120            num_candidates: candidates.len(),
121            num_reranked: candidates_to_rerank.len(),
122            ..Default::default()
123        };
124
125        // Check mode
126        if self.config.mode == RerankingMode::Disabled {
127            return Ok(RerankingOutput {
128                candidates: candidates.to_vec(),
129                stats,
130            });
131        }
132
133        // Re-rank with cross-encoder
134        let inference_start = Instant::now();
135        let mut reranked = self.apply_cross_encoder(query, candidates_to_rerank, &mut stats)?;
136        stats.inference_time_ms = inference_start.elapsed().as_secs_f64() * 1000.0;
137
138        // Fuse scores
139        let fusion_start = Instant::now();
140        for candidate in &mut reranked {
141            if let Some(reranking_score) = candidate.reranking_score {
142                candidate.final_score =
143                    self.fusion.fuse(candidate.retrieval_score, reranking_score);
144            }
145        }
146        stats.fusion_time_ms = fusion_start.elapsed().as_secs_f64() * 1000.0;
147
148        // Apply diversity if enabled
149        if let Some(ref diversity) = self.diversity {
150            reranked = diversity.apply_diversity(&reranked)?;
151        }
152
153        // Sort by final score
154        reranked.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
155
156        // Take top-k
157        reranked.truncate(self.config.top_k);
158
159        // Calculate statistics
160        self.calculate_stats(&mut stats, candidates, &reranked);
161        stats.total_time_ms = start.elapsed().as_secs_f64() * 1000.0;
162
163        Ok(RerankingOutput {
164            candidates: reranked,
165            stats,
166        })
167    }
168
169    /// Select candidates for re-ranking based on mode
170    fn select_candidates_for_reranking(
171        &self,
172        candidates: &[ScoredCandidate],
173    ) -> Vec<ScoredCandidate> {
174        let max_candidates = self.config.max_candidates.min(candidates.len());
175
176        match self.config.mode {
177            RerankingMode::Full => candidates.to_vec(),
178            RerankingMode::TopK => candidates[..max_candidates].to_vec(),
179            RerankingMode::Adaptive => {
180                // Use score threshold for adaptive selection
181                let threshold = self.calculate_adaptive_threshold(candidates);
182                candidates
183                    .iter()
184                    .filter(|c| c.retrieval_score >= threshold)
185                    .take(max_candidates)
186                    .cloned()
187                    .collect()
188            }
189            RerankingMode::Disabled => Vec::new(),
190        }
191    }
192
193    /// Calculate adaptive threshold based on score distribution
194    fn calculate_adaptive_threshold(&self, candidates: &[ScoredCandidate]) -> f32 {
195        if candidates.is_empty() {
196            return 0.0;
197        }
198
199        // Use mean - 0.5 * std as threshold
200        let scores: Vec<f32> = candidates.iter().map(|c| c.retrieval_score).collect();
201        let mean = scores.iter().sum::<f32>() / scores.len() as f32;
202        let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
203        let std = variance.sqrt();
204
205        (mean - 0.5 * std).max(0.0)
206    }
207
208    /// Apply cross-encoder to candidates
209    fn apply_cross_encoder(
210        &self,
211        query: &str,
212        candidates: Vec<ScoredCandidate>,
213        stats: &mut RerankingStats,
214    ) -> RerankingResult<Vec<ScoredCandidate>> {
215        let mut reranked = Vec::new();
216
217        // Process in batches
218        for batch in candidates.chunks(self.config.batch_size) {
219            let mut batch_results = Vec::new();
220
221            for candidate in batch {
222                // Check cache first
223                let cache_key = format!("{}:{}", query, candidate.id);
224                let score = if let Some(ref cache) = self.cache {
225                    if let Some(cached_score) = cache.get(&cache_key) {
226                        stats.cache_hits += 1;
227                        cached_score
228                    } else {
229                        let score = self
230                            .encoder
231                            .score(query, candidate.content.as_deref().unwrap_or(""))?;
232                        cache.put(cache_key, score);
233                        score
234                    }
235                } else {
236                    self.encoder
237                        .score(query, candidate.content.as_deref().unwrap_or(""))?
238                };
239
240                let mut updated = candidate.clone();
241                updated.reranking_score = Some(score);
242                batch_results.push(updated);
243            }
244
245            reranked.extend(batch_results);
246        }
247
248        Ok(reranked)
249    }
250
251    /// Calculate additional statistics
252    fn calculate_stats(
253        &self,
254        stats: &mut RerankingStats,
255        original: &[ScoredCandidate],
256        reranked: &[ScoredCandidate],
257    ) {
258        // Calculate average score change
259        let score_changes: Vec<f32> = reranked
260            .iter()
261            .filter_map(|c| c.reranking_score.map(|r| (r - c.retrieval_score).abs()))
262            .collect();
263
264        if !score_changes.is_empty() {
265            stats.avg_score_change = score_changes.iter().sum::<f32>() / score_changes.len() as f32;
266        }
267
268        // Calculate rank correlation (simplified - just check if order changed)
269        if original.len() == reranked.len() && !original.is_empty() {
270            let original_ids: Vec<&String> = original.iter().map(|c| &c.id).collect();
271            let reranked_ids: Vec<&String> = reranked.iter().map(|c| &c.id).collect();
272            let same_order = original_ids == reranked_ids;
273            stats.rank_correlation = Some(if same_order { 1.0 } else { 0.5 });
274        }
275    }
276
277    /// Get configuration
278    pub fn config(&self) -> &RerankingConfig {
279        &self.config
280    }
281
282    /// Clear cache
283    pub fn clear_cache(&self) {
284        if let Some(ref cache) = self.cache {
285            cache.clear();
286        }
287    }
288
289    /// Get cache statistics
290    pub fn cache_stats(&self) -> Option<(usize, usize)> {
291        self.cache.as_ref().map(|c| c.stats())
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::reranking::config::FusionStrategy;
299
300    #[test]
301    fn test_reranking_stats_default() {
302        let stats = RerankingStats::default();
303        assert_eq!(stats.num_candidates, 0);
304        assert_eq!(stats.num_reranked, 0);
305        assert_eq!(stats.cache_hits, 0);
306    }
307
308    #[test]
309    fn test_select_candidates_topk() {
310        let config = RerankingConfig {
311            mode: RerankingMode::TopK,
312            max_candidates: 5,
313            ..RerankingConfig::default_config()
314        };
315
316        let encoder = CrossEncoder::new("dummy", "local").unwrap();
317        let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
318
319        let reranker = CrossEncoderReranker {
320            config,
321            encoder: Arc::new(encoder),
322            fusion: Arc::new(fusion),
323            diversity: None,
324            cache: None,
325        };
326
327        let candidates: Vec<ScoredCandidate> = (0..10)
328            .map(|i| ScoredCandidate::new(format!("doc{}", i), 0.9 - i as f32 * 0.05, i))
329            .collect();
330
331        let selected = reranker.select_candidates_for_reranking(&candidates);
332        assert_eq!(selected.len(), 5);
333    }
334
335    #[test]
336    fn test_adaptive_threshold() {
337        let config = RerankingConfig::default_config();
338        let encoder = CrossEncoder::new("dummy", "local").unwrap();
339        let fusion = ScoreFusion::new(FusionStrategy::Linear, 0.3);
340
341        let reranker = CrossEncoderReranker {
342            config,
343            encoder: Arc::new(encoder),
344            fusion: Arc::new(fusion),
345            diversity: None,
346            cache: None,
347        };
348
349        let candidates = vec![
350            ScoredCandidate::new("doc1", 0.9, 0),
351            ScoredCandidate::new("doc2", 0.8, 1),
352            ScoredCandidate::new("doc3", 0.7, 2),
353            ScoredCandidate::new("doc4", 0.3, 3),
354            ScoredCandidate::new("doc5", 0.2, 4),
355        ];
356
357        let threshold = reranker.calculate_adaptive_threshold(&candidates);
358        assert!(threshold > 0.0);
359        assert!(threshold < 0.9);
360    }
361}