rrag/reranking/
cross_encoder.rs

1//! # Cross-Encoder Reranking
2//!
3//! Cross-encoder models that jointly encode query and document pairs
4//! to produce more accurate relevance scores than bi-encoder approaches.
5
6use crate::{RragResult, SearchResult};
7use std::collections::HashMap;
8
9/// Cross-encoder reranker for query-document relevance scoring
10pub struct CrossEncoderReranker {
11    /// Configuration
12    config: CrossEncoderConfig,
13
14    /// Model interface
15    model: Box<dyn CrossEncoderModel>,
16
17    /// Scoring cache
18    score_cache: HashMap<String, f32>,
19}
20
21/// Configuration for cross-encoder reranking
22#[derive(Debug, Clone)]
23pub struct CrossEncoderConfig {
24    /// Model type to use
25    pub model_type: CrossEncoderModelType,
26
27    /// Maximum sequence length for input
28    pub max_sequence_length: usize,
29
30    /// Batch size for processing
31    pub batch_size: usize,
32
33    /// Score aggregation method
34    pub score_aggregation: ScoreAggregation,
35
36    /// Reranking strategy
37    pub strategy: RerankingStrategy,
38
39    /// Confidence threshold
40    pub confidence_threshold: f32,
41
42    /// Enable caching
43    pub enable_caching: bool,
44
45    /// Temperature for score calibration
46    pub temperature: f32,
47}
48
49impl Default for CrossEncoderConfig {
50    fn default() -> Self {
51        Self {
52            model_type: CrossEncoderModelType::SimulatedBert,
53            max_sequence_length: 512,
54            batch_size: 16,
55            score_aggregation: ScoreAggregation::Mean,
56            strategy: RerankingStrategy::TopK(50),
57            confidence_threshold: 0.5,
58            enable_caching: true,
59            temperature: 1.0,
60        }
61    }
62}
63
64/// Types of cross-encoder models
65#[derive(Debug, Clone, PartialEq)]
66pub enum CrossEncoderModelType {
67    /// BERT-based cross-encoder
68    Bert,
69    /// RoBERTa-based cross-encoder
70    RoBERTa,
71    /// DistilBERT for faster inference
72    DistilBert,
73    /// Custom model
74    Custom(String),
75    /// Simulated model for demonstration
76    SimulatedBert,
77}
78
79/// Score aggregation methods
80#[derive(Debug, Clone, PartialEq)]
81pub enum ScoreAggregation {
82    /// Average all scores
83    Mean,
84    /// Maximum score
85    Max,
86    /// Minimum score
87    Min,
88    /// Weighted average
89    Weighted(Vec<f32>),
90    /// Median score
91    Median,
92}
93
94/// Reranking strategies
95#[derive(Debug, Clone, PartialEq)]
96pub enum RerankingStrategy {
97    /// Rerank top-k candidates
98    TopK(usize),
99    /// Rerank all candidates above threshold
100    Threshold(f32),
101    /// Adaptive reranking based on score distribution
102    Adaptive,
103    /// Stage-wise reranking
104    Staged(Vec<usize>),
105}
106
107/// Result from cross-encoder reranking
108#[derive(Debug, Clone)]
109pub struct RerankedResult {
110    /// Document identifier
111    pub document_id: String,
112
113    /// Cross-encoder relevance score
114    pub cross_encoder_score: f32,
115
116    /// Original retrieval score
117    pub original_score: f32,
118
119    /// Combined score
120    pub combined_score: f32,
121
122    /// Confidence in the score
123    pub confidence: f32,
124
125    /// Token-level attention scores (if available)
126    pub attention_scores: Option<Vec<f32>>,
127
128    /// Processing metadata
129    pub metadata: CrossEncoderMetadata,
130}
131
132/// Metadata from cross-encoder processing
133#[derive(Debug, Clone)]
134pub struct CrossEncoderMetadata {
135    /// Model used
136    pub model_type: String,
137
138    /// Input sequence length
139    pub sequence_length: usize,
140
141    /// Processing time in milliseconds
142    pub processing_time_ms: u64,
143
144    /// Number of tokens processed
145    pub num_tokens: usize,
146
147    /// Whether result was cached
148    pub from_cache: bool,
149}
150
151/// Trait for cross-encoder models
152pub trait CrossEncoderModel: Send + Sync {
153    /// Score a single query-document pair
154    fn score(&self, query: &str, document: &str) -> RragResult<f32>;
155
156    /// Score multiple query-document pairs in batch
157    fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>>;
158
159    /// Get model information
160    fn model_info(&self) -> ModelInfo;
161
162    /// Get attention scores if supported
163    fn get_attention_scores(&self, query: &str, document: &str) -> RragResult<Option<Vec<f32>>> {
164        let _ = (query, document);
165        Ok(None)
166    }
167}
168
169/// Model information
170#[derive(Debug, Clone)]
171pub struct ModelInfo {
172    /// Model name
173    pub name: String,
174
175    /// Model version
176    pub version: String,
177
178    /// Maximum sequence length
179    pub max_sequence_length: usize,
180
181    /// Model size in parameters
182    pub parameters: Option<usize>,
183
184    /// Whether model supports attention extraction
185    pub supports_attention: bool,
186}
187
188impl CrossEncoderReranker {
189    /// Create a new cross-encoder reranker
190    pub fn new(config: CrossEncoderConfig) -> Self {
191        let model = Self::create_model(&config.model_type);
192
193        Self {
194            config,
195            model,
196            score_cache: HashMap::new(),
197        }
198    }
199
200    /// Create model based on configuration
201    fn create_model(model_type: &CrossEncoderModelType) -> Box<dyn CrossEncoderModel> {
202        match model_type {
203            CrossEncoderModelType::SimulatedBert => Box::new(SimulatedBertCrossEncoder::new()),
204            CrossEncoderModelType::Bert => Box::new(SimulatedBertCrossEncoder::new()), // Would be real BERT
205            CrossEncoderModelType::RoBERTa => Box::new(SimulatedRobertaCrossEncoder::new()),
206            CrossEncoderModelType::DistilBert => Box::new(SimulatedDistilBertCrossEncoder::new()),
207            CrossEncoderModelType::Custom(name) => Box::new(CustomCrossEncoder::new(name.clone())),
208        }
209    }
210
211    /// Rerank search results using cross-encoder
212    pub async fn rerank(
213        &self,
214        query: &str,
215        results: &[SearchResult],
216    ) -> RragResult<HashMap<usize, f32>> {
217        let _start_time = std::time::Instant::now();
218
219        // Apply reranking strategy to select candidates
220        let candidates = self.select_candidates(results)?;
221
222        // Prepare query-document pairs
223        let pairs: Vec<(String, String)> = candidates
224            .iter()
225            .map(|&idx| (query.to_string(), results[idx].content.clone()))
226            .collect();
227
228        // Score the pairs
229        let scores = if self.config.batch_size > 1 && pairs.len() > 1 {
230            self.score_batch(&pairs).await?
231        } else {
232            self.score_sequential(&pairs).await?
233        };
234
235        // Create result mapping
236        let mut score_map = HashMap::new();
237        for (i, &candidate_idx) in candidates.iter().enumerate() {
238            if let Some(&score) = scores.get(i) {
239                score_map.insert(candidate_idx, score);
240            }
241        }
242
243        Ok(score_map)
244    }
245
246    /// Select candidates for reranking based on strategy
247    fn select_candidates(&self, results: &[SearchResult]) -> RragResult<Vec<usize>> {
248        match &self.config.strategy {
249            RerankingStrategy::TopK(k) => Ok((0..results.len().min(*k)).collect()),
250            RerankingStrategy::Threshold(threshold) => Ok(results
251                .iter()
252                .enumerate()
253                .filter(|(_, result)| result.score >= *threshold)
254                .map(|(idx, _)| idx)
255                .collect()),
256            RerankingStrategy::Adaptive => {
257                // Adaptive strategy based on score distribution
258                let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
259                let mean = scores.iter().sum::<f32>() / scores.len() as f32;
260                let std_dev = {
261                    let variance = scores
262                        .iter()
263                        .map(|score| (score - mean).powi(2))
264                        .sum::<f32>()
265                        / scores.len() as f32;
266                    variance.sqrt()
267                };
268
269                let adaptive_threshold = mean - std_dev * 0.5;
270                Ok(results
271                    .iter()
272                    .enumerate()
273                    .filter(|(_, result)| result.score >= adaptive_threshold)
274                    .map(|(idx, _)| idx)
275                    .take(self.config.batch_size * 3) // Reasonable upper limit
276                    .collect())
277            }
278            RerankingStrategy::Staged(stages) => {
279                // Take the first stage size for now
280                let stage_size = stages.first().copied().unwrap_or(10);
281                Ok((0..results.len().min(stage_size)).collect())
282            }
283        }
284    }
285
286    /// Score pairs sequentially
287    async fn score_sequential(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
288        let mut scores = Vec::new();
289
290        for (query, document) in pairs {
291            let cache_key = format!("{}|{}", query, document);
292
293            let score = if self.config.enable_caching && self.score_cache.contains_key(&cache_key) {
294                *self.score_cache.get(&cache_key).unwrap()
295            } else {
296                let score = self.model.score(query, document)?;
297                if self.config.enable_caching {
298                    // Note: In a real implementation, we'd need mutable access or use interior mutability
299                }
300                score
301            };
302
303            scores.push(score);
304        }
305
306        Ok(scores)
307    }
308
309    /// Score pairs in batches
310    async fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
311        let mut all_scores = Vec::new();
312
313        for chunk in pairs.chunks(self.config.batch_size) {
314            let batch_scores = self.model.score_batch(chunk)?;
315            all_scores.extend(batch_scores);
316        }
317
318        Ok(all_scores)
319    }
320
321    /// Apply temperature scaling to scores
322    fn apply_temperature(&self, score: f32) -> f32 {
323        if self.config.temperature == 1.0 {
324            score
325        } else {
326            score / self.config.temperature
327        }
328    }
329
330    /// Get model information
331    pub fn get_model_info(&self) -> ModelInfo {
332        self.model.model_info()
333    }
334}
335
336/// Simulated BERT cross-encoder for demonstration
337struct SimulatedBertCrossEncoder;
338
339impl SimulatedBertCrossEncoder {
340    fn new() -> Self {
341        Self
342    }
343}
344
345impl CrossEncoderModel for SimulatedBertCrossEncoder {
346    fn score(&self, query: &str, document: &str) -> RragResult<f32> {
347        // Simulate BERT cross-encoder scoring
348        let query_tokens: Vec<&str> = query.split_whitespace().collect();
349        let doc_tokens: Vec<&str> = document.split_whitespace().collect();
350
351        // Simulate attention-based scoring
352        let mut score = 0.0;
353        let mut matches = 0;
354
355        for q_token in &query_tokens {
356            for d_token in &doc_tokens {
357                let similarity = self.token_similarity(q_token, d_token);
358                if similarity > 0.3 {
359                    score += similarity;
360                    matches += 1;
361                }
362            }
363        }
364
365        // Normalize by document length and add position bias
366        let length_penalty = 1.0 / (1.0 + (doc_tokens.len() as f32 / 100.0));
367        let coverage_bonus = if matches as f32 / query_tokens.len() as f32 > 0.5 {
368            0.2
369        } else {
370            0.0
371        };
372
373        let final_score = ((score / query_tokens.len() as f32) * length_penalty + coverage_bonus)
374            .max(0.0)
375            .min(1.0);
376
377        Ok(final_score)
378    }
379
380    fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
381        pairs
382            .iter()
383            .map(|(query, document)| self.score(query, document))
384            .collect()
385    }
386
387    fn model_info(&self) -> ModelInfo {
388        ModelInfo {
389            name: "SimulatedBERT-CrossEncoder".to_string(),
390            version: "1.0".to_string(),
391            max_sequence_length: 512,
392            parameters: Some(110_000_000),
393            supports_attention: true,
394        }
395    }
396
397    fn get_attention_scores(&self, query: &str, document: &str) -> RragResult<Option<Vec<f32>>> {
398        // Simulate attention scores
399        let query_tokens: Vec<&str> = query.split_whitespace().collect();
400        let doc_tokens: Vec<&str> = document.split_whitespace().collect();
401
402        let mut attention_scores = Vec::new();
403        for d_token in &doc_tokens {
404            let max_attention = query_tokens
405                .iter()
406                .map(|q_token| self.token_similarity(q_token, d_token))
407                .fold(0.0f32, |a, b| a.max(b));
408            attention_scores.push(max_attention);
409        }
410
411        Ok(Some(attention_scores))
412    }
413}
414
415impl SimulatedBertCrossEncoder {
416    /// Simulate token-level similarity (would be learned embeddings in real model)
417    fn token_similarity(&self, token1: &str, token2: &str) -> f32 {
418        let t1_lower = token1.to_lowercase();
419        let t2_lower = token2.to_lowercase();
420
421        // Exact match
422        if t1_lower == t2_lower {
423            return 1.0;
424        }
425
426        // Partial matches
427        if t1_lower.contains(&t2_lower) || t2_lower.contains(&t1_lower) {
428            return 0.7;
429        }
430
431        // Character-level similarity (simplified Jaccard)
432        let chars1: std::collections::HashSet<char> = t1_lower.chars().collect();
433        let chars2: std::collections::HashSet<char> = t2_lower.chars().collect();
434
435        let intersection = chars1.intersection(&chars2).count();
436        let union = chars1.union(&chars2).count();
437
438        if union == 0 {
439            0.0
440        } else {
441            (intersection as f32 / union as f32) * 0.5
442        }
443    }
444}
445
446/// Simulated RoBERTa cross-encoder
447struct SimulatedRobertaCrossEncoder;
448
449impl SimulatedRobertaCrossEncoder {
450    fn new() -> Self {
451        Self
452    }
453}
454
455impl CrossEncoderModel for SimulatedRobertaCrossEncoder {
456    fn score(&self, query: &str, document: &str) -> RragResult<f32> {
457        // Simulate RoBERTa with slightly different scoring
458        let bert_encoder = SimulatedBertCrossEncoder::new();
459        let base_score = bert_encoder.score(query, document)?;
460
461        // RoBERTa might have different biases
462        let roberta_adjustment = 0.05 * (document.len() as f32).log10().sin().abs();
463        Ok((base_score + roberta_adjustment).min(1.0))
464    }
465
466    fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
467        pairs
468            .iter()
469            .map(|(query, document)| self.score(query, document))
470            .collect()
471    }
472
473    fn model_info(&self) -> ModelInfo {
474        ModelInfo {
475            name: "SimulatedRoBERTa-CrossEncoder".to_string(),
476            version: "1.0".to_string(),
477            max_sequence_length: 512,
478            parameters: Some(125_000_000),
479            supports_attention: true,
480        }
481    }
482}
483
484/// Simulated DistilBERT cross-encoder (faster, smaller)
485struct SimulatedDistilBertCrossEncoder;
486
487impl SimulatedDistilBertCrossEncoder {
488    fn new() -> Self {
489        Self
490    }
491}
492
493impl CrossEncoderModel for SimulatedDistilBertCrossEncoder {
494    fn score(&self, query: &str, document: &str) -> RragResult<f32> {
495        // Simulate DistilBERT with faster but slightly less accurate scoring
496        let bert_encoder = SimulatedBertCrossEncoder::new();
497        let base_score = bert_encoder.score(query, document)?;
498
499        // DistilBERT might be slightly less accurate
500        let distillation_noise = 0.02 * (query.len() as f32 % 7.0) / 7.0;
501        Ok((base_score - distillation_noise).max(0.0))
502    }
503
504    fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
505        pairs
506            .iter()
507            .map(|(query, document)| self.score(query, document))
508            .collect()
509    }
510
511    fn model_info(&self) -> ModelInfo {
512        ModelInfo {
513            name: "SimulatedDistilBERT-CrossEncoder".to_string(),
514            version: "1.0".to_string(),
515            max_sequence_length: 512,
516            parameters: Some(66_000_000),
517            supports_attention: false, // Simplified model
518        }
519    }
520}
521
522/// Custom cross-encoder model
523struct CustomCrossEncoder {
524    name: String,
525}
526
527impl CustomCrossEncoder {
528    fn new(name: String) -> Self {
529        Self { name }
530    }
531}
532
533impl CrossEncoderModel for CustomCrossEncoder {
534    fn score(&self, query: &str, document: &str) -> RragResult<f32> {
535        // Placeholder for custom model
536        let _ = (query, document);
537        Ok(0.5) // Neutral score
538    }
539
540    fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
541        Ok(vec![0.5; pairs.len()])
542    }
543
544    fn model_info(&self) -> ModelInfo {
545        ModelInfo {
546            name: self.name.clone(),
547            version: "custom".to_string(),
548            max_sequence_length: 512,
549            parameters: None,
550            supports_attention: false,
551        }
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use crate::SearchResult;
559
560    #[tokio::test]
561    async fn test_cross_encoder_reranking() {
562        let config = CrossEncoderConfig::default();
563        let reranker = CrossEncoderReranker::new(config);
564
565        let results = vec![
566            SearchResult {
567                id: "doc1".to_string(),
568                content: "Machine learning is a subset of artificial intelligence".to_string(),
569                score: 0.8,
570                rank: 0,
571                metadata: Default::default(),
572                embedding: None,
573            },
574            SearchResult {
575                id: "doc2".to_string(),
576                content: "Deep learning uses neural networks with multiple layers".to_string(),
577                score: 0.6,
578                rank: 1,
579                metadata: Default::default(),
580                embedding: None,
581            },
582        ];
583
584        let query = "What is machine learning?";
585        let reranked_scores = reranker.rerank(query, &results).await.unwrap();
586
587        assert!(!reranked_scores.is_empty());
588        assert!(reranked_scores.contains_key(&0));
589    }
590
591    #[test]
592    fn test_simulated_bert_scoring() {
593        let model = SimulatedBertCrossEncoder::new();
594
595        let score = model
596            .score(
597                "machine learning",
598                "artificial intelligence and machine learning",
599            )
600            .unwrap();
601        assert!(score > 0.0);
602        assert!(score <= 1.0);
603
604        // Should score higher for better matches
605        let high_score = model
606            .score("rust programming", "rust is a programming language")
607            .unwrap();
608        let low_score = model
609            .score("rust programming", "cooking recipes for dinner")
610            .unwrap();
611        assert!(high_score > low_score);
612    }
613
614    #[test]
615    fn test_batch_scoring() {
616        let model = SimulatedBertCrossEncoder::new();
617
618        let pairs = vec![
619            ("query1".to_string(), "relevant document".to_string()),
620            ("query2".to_string(), "another document".to_string()),
621        ];
622
623        let scores = model.score_batch(&pairs).unwrap();
624        assert_eq!(scores.len(), 2);
625        assert!(scores.iter().all(|&s| s >= 0.0 && s <= 1.0));
626    }
627
628    #[test]
629    fn test_attention_scores() {
630        let model = SimulatedBertCrossEncoder::new();
631
632        let attention = model
633            .get_attention_scores("machine learning", "artificial intelligence")
634            .unwrap();
635        assert!(attention.is_some());
636
637        let scores = attention.unwrap();
638        assert_eq!(scores.len(), 2); // "artificial" and "intelligence"
639        assert!(scores.iter().all(|&s| s >= 0.0 && s <= 1.0));
640    }
641}