1use crate::{RragResult, SearchResult};
7use std::collections::HashMap;
8
9pub struct CrossEncoderReranker {
11 config: CrossEncoderConfig,
13
14 model: Box<dyn CrossEncoderModel>,
16
17 score_cache: HashMap<String, f32>,
19}
20
21#[derive(Debug, Clone)]
23pub struct CrossEncoderConfig {
24 pub model_type: CrossEncoderModelType,
26
27 pub max_sequence_length: usize,
29
30 pub batch_size: usize,
32
33 pub score_aggregation: ScoreAggregation,
35
36 pub strategy: RerankingStrategy,
38
39 pub confidence_threshold: f32,
41
42 pub enable_caching: bool,
44
45 pub temperature: f32,
47}
48
49impl Default for CrossEncoderConfig {
50 fn default() -> Self {
51 Self {
52 model_type: CrossEncoderModelType::SimulatedBert,
53 max_sequence_length: 512,
54 batch_size: 16,
55 score_aggregation: ScoreAggregation::Mean,
56 strategy: RerankingStrategy::TopK(50),
57 confidence_threshold: 0.5,
58 enable_caching: true,
59 temperature: 1.0,
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq)]
66pub enum CrossEncoderModelType {
67 Bert,
69 RoBERTa,
71 DistilBert,
73 Custom(String),
75 SimulatedBert,
77}
78
79#[derive(Debug, Clone, PartialEq)]
81pub enum ScoreAggregation {
82 Mean,
84 Max,
86 Min,
88 Weighted(Vec<f32>),
90 Median,
92}
93
94#[derive(Debug, Clone, PartialEq)]
96pub enum RerankingStrategy {
97 TopK(usize),
99 Threshold(f32),
101 Adaptive,
103 Staged(Vec<usize>),
105}
106
107#[derive(Debug, Clone)]
109pub struct RerankedResult {
110 pub document_id: String,
112
113 pub cross_encoder_score: f32,
115
116 pub original_score: f32,
118
119 pub combined_score: f32,
121
122 pub confidence: f32,
124
125 pub attention_scores: Option<Vec<f32>>,
127
128 pub metadata: CrossEncoderMetadata,
130}
131
132#[derive(Debug, Clone)]
134pub struct CrossEncoderMetadata {
135 pub model_type: String,
137
138 pub sequence_length: usize,
140
141 pub processing_time_ms: u64,
143
144 pub num_tokens: usize,
146
147 pub from_cache: bool,
149}
150
151pub trait CrossEncoderModel: Send + Sync {
153 fn score(&self, query: &str, document: &str) -> RragResult<f32>;
155
156 fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>>;
158
159 fn model_info(&self) -> ModelInfo;
161
162 fn get_attention_scores(&self, query: &str, document: &str) -> RragResult<Option<Vec<f32>>> {
164 let _ = (query, document);
165 Ok(None)
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct ModelInfo {
172 pub name: String,
174
175 pub version: String,
177
178 pub max_sequence_length: usize,
180
181 pub parameters: Option<usize>,
183
184 pub supports_attention: bool,
186}
187
188impl CrossEncoderReranker {
189 pub fn new(config: CrossEncoderConfig) -> Self {
191 let model = Self::create_model(&config.model_type);
192
193 Self {
194 config,
195 model,
196 score_cache: HashMap::new(),
197 }
198 }
199
200 fn create_model(model_type: &CrossEncoderModelType) -> Box<dyn CrossEncoderModel> {
202 match model_type {
203 CrossEncoderModelType::SimulatedBert => Box::new(SimulatedBertCrossEncoder::new()),
204 CrossEncoderModelType::Bert => Box::new(SimulatedBertCrossEncoder::new()), CrossEncoderModelType::RoBERTa => Box::new(SimulatedRobertaCrossEncoder::new()),
206 CrossEncoderModelType::DistilBert => Box::new(SimulatedDistilBertCrossEncoder::new()),
207 CrossEncoderModelType::Custom(name) => Box::new(CustomCrossEncoder::new(name.clone())),
208 }
209 }
210
211 pub async fn rerank(
213 &self,
214 query: &str,
215 results: &[SearchResult],
216 ) -> RragResult<HashMap<usize, f32>> {
217 let _start_time = std::time::Instant::now();
218
219 let candidates = self.select_candidates(results)?;
221
222 let pairs: Vec<(String, String)> = candidates
224 .iter()
225 .map(|&idx| (query.to_string(), results[idx].content.clone()))
226 .collect();
227
228 let scores = if self.config.batch_size > 1 && pairs.len() > 1 {
230 self.score_batch(&pairs).await?
231 } else {
232 self.score_sequential(&pairs).await?
233 };
234
235 let mut score_map = HashMap::new();
237 for (i, &candidate_idx) in candidates.iter().enumerate() {
238 if let Some(&score) = scores.get(i) {
239 score_map.insert(candidate_idx, score);
240 }
241 }
242
243 Ok(score_map)
244 }
245
246 fn select_candidates(&self, results: &[SearchResult]) -> RragResult<Vec<usize>> {
248 match &self.config.strategy {
249 RerankingStrategy::TopK(k) => Ok((0..results.len().min(*k)).collect()),
250 RerankingStrategy::Threshold(threshold) => Ok(results
251 .iter()
252 .enumerate()
253 .filter(|(_, result)| result.score >= *threshold)
254 .map(|(idx, _)| idx)
255 .collect()),
256 RerankingStrategy::Adaptive => {
257 let scores: Vec<f32> = results.iter().map(|r| r.score).collect();
259 let mean = scores.iter().sum::<f32>() / scores.len() as f32;
260 let std_dev = {
261 let variance = scores
262 .iter()
263 .map(|score| (score - mean).powi(2))
264 .sum::<f32>()
265 / scores.len() as f32;
266 variance.sqrt()
267 };
268
269 let adaptive_threshold = mean - std_dev * 0.5;
270 Ok(results
271 .iter()
272 .enumerate()
273 .filter(|(_, result)| result.score >= adaptive_threshold)
274 .map(|(idx, _)| idx)
275 .take(self.config.batch_size * 3) .collect())
277 }
278 RerankingStrategy::Staged(stages) => {
279 let stage_size = stages.first().copied().unwrap_or(10);
281 Ok((0..results.len().min(stage_size)).collect())
282 }
283 }
284 }
285
286 async fn score_sequential(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
288 let mut scores = Vec::new();
289
290 for (query, document) in pairs {
291 let cache_key = format!("{}|{}", query, document);
292
293 let score = if self.config.enable_caching && self.score_cache.contains_key(&cache_key) {
294 *self.score_cache.get(&cache_key).unwrap()
295 } else {
296 let score = self.model.score(query, document)?;
297 if self.config.enable_caching {
298 }
300 score
301 };
302
303 scores.push(score);
304 }
305
306 Ok(scores)
307 }
308
309 async fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
311 let mut all_scores = Vec::new();
312
313 for chunk in pairs.chunks(self.config.batch_size) {
314 let batch_scores = self.model.score_batch(chunk)?;
315 all_scores.extend(batch_scores);
316 }
317
318 Ok(all_scores)
319 }
320
321 fn apply_temperature(&self, score: f32) -> f32 {
323 if self.config.temperature == 1.0 {
324 score
325 } else {
326 score / self.config.temperature
327 }
328 }
329
330 pub fn get_model_info(&self) -> ModelInfo {
332 self.model.model_info()
333 }
334}
335
336struct SimulatedBertCrossEncoder;
338
339impl SimulatedBertCrossEncoder {
340 fn new() -> Self {
341 Self
342 }
343}
344
345impl CrossEncoderModel for SimulatedBertCrossEncoder {
346 fn score(&self, query: &str, document: &str) -> RragResult<f32> {
347 let query_tokens: Vec<&str> = query.split_whitespace().collect();
349 let doc_tokens: Vec<&str> = document.split_whitespace().collect();
350
351 let mut score = 0.0;
353 let mut matches = 0;
354
355 for q_token in &query_tokens {
356 for d_token in &doc_tokens {
357 let similarity = self.token_similarity(q_token, d_token);
358 if similarity > 0.3 {
359 score += similarity;
360 matches += 1;
361 }
362 }
363 }
364
365 let length_penalty = 1.0 / (1.0 + (doc_tokens.len() as f32 / 100.0));
367 let coverage_bonus = if matches as f32 / query_tokens.len() as f32 > 0.5 {
368 0.2
369 } else {
370 0.0
371 };
372
373 let final_score = ((score / query_tokens.len() as f32) * length_penalty + coverage_bonus)
374 .max(0.0)
375 .min(1.0);
376
377 Ok(final_score)
378 }
379
380 fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
381 pairs
382 .iter()
383 .map(|(query, document)| self.score(query, document))
384 .collect()
385 }
386
387 fn model_info(&self) -> ModelInfo {
388 ModelInfo {
389 name: "SimulatedBERT-CrossEncoder".to_string(),
390 version: "1.0".to_string(),
391 max_sequence_length: 512,
392 parameters: Some(110_000_000),
393 supports_attention: true,
394 }
395 }
396
397 fn get_attention_scores(&self, query: &str, document: &str) -> RragResult<Option<Vec<f32>>> {
398 let query_tokens: Vec<&str> = query.split_whitespace().collect();
400 let doc_tokens: Vec<&str> = document.split_whitespace().collect();
401
402 let mut attention_scores = Vec::new();
403 for d_token in &doc_tokens {
404 let max_attention = query_tokens
405 .iter()
406 .map(|q_token| self.token_similarity(q_token, d_token))
407 .fold(0.0f32, |a, b| a.max(b));
408 attention_scores.push(max_attention);
409 }
410
411 Ok(Some(attention_scores))
412 }
413}
414
415impl SimulatedBertCrossEncoder {
416 fn token_similarity(&self, token1: &str, token2: &str) -> f32 {
418 let t1_lower = token1.to_lowercase();
419 let t2_lower = token2.to_lowercase();
420
421 if t1_lower == t2_lower {
423 return 1.0;
424 }
425
426 if t1_lower.contains(&t2_lower) || t2_lower.contains(&t1_lower) {
428 return 0.7;
429 }
430
431 let chars1: std::collections::HashSet<char> = t1_lower.chars().collect();
433 let chars2: std::collections::HashSet<char> = t2_lower.chars().collect();
434
435 let intersection = chars1.intersection(&chars2).count();
436 let union = chars1.union(&chars2).count();
437
438 if union == 0 {
439 0.0
440 } else {
441 (intersection as f32 / union as f32) * 0.5
442 }
443 }
444}
445
446struct SimulatedRobertaCrossEncoder;
448
449impl SimulatedRobertaCrossEncoder {
450 fn new() -> Self {
451 Self
452 }
453}
454
455impl CrossEncoderModel for SimulatedRobertaCrossEncoder {
456 fn score(&self, query: &str, document: &str) -> RragResult<f32> {
457 let bert_encoder = SimulatedBertCrossEncoder::new();
459 let base_score = bert_encoder.score(query, document)?;
460
461 let roberta_adjustment = 0.05 * (document.len() as f32).log10().sin().abs();
463 Ok((base_score + roberta_adjustment).min(1.0))
464 }
465
466 fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
467 pairs
468 .iter()
469 .map(|(query, document)| self.score(query, document))
470 .collect()
471 }
472
473 fn model_info(&self) -> ModelInfo {
474 ModelInfo {
475 name: "SimulatedRoBERTa-CrossEncoder".to_string(),
476 version: "1.0".to_string(),
477 max_sequence_length: 512,
478 parameters: Some(125_000_000),
479 supports_attention: true,
480 }
481 }
482}
483
484struct SimulatedDistilBertCrossEncoder;
486
487impl SimulatedDistilBertCrossEncoder {
488 fn new() -> Self {
489 Self
490 }
491}
492
493impl CrossEncoderModel for SimulatedDistilBertCrossEncoder {
494 fn score(&self, query: &str, document: &str) -> RragResult<f32> {
495 let bert_encoder = SimulatedBertCrossEncoder::new();
497 let base_score = bert_encoder.score(query, document)?;
498
499 let distillation_noise = 0.02 * (query.len() as f32 % 7.0) / 7.0;
501 Ok((base_score - distillation_noise).max(0.0))
502 }
503
504 fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
505 pairs
506 .iter()
507 .map(|(query, document)| self.score(query, document))
508 .collect()
509 }
510
511 fn model_info(&self) -> ModelInfo {
512 ModelInfo {
513 name: "SimulatedDistilBERT-CrossEncoder".to_string(),
514 version: "1.0".to_string(),
515 max_sequence_length: 512,
516 parameters: Some(66_000_000),
517 supports_attention: false, }
519 }
520}
521
522struct CustomCrossEncoder {
524 name: String,
525}
526
527impl CustomCrossEncoder {
528 fn new(name: String) -> Self {
529 Self { name }
530 }
531}
532
533impl CrossEncoderModel for CustomCrossEncoder {
534 fn score(&self, query: &str, document: &str) -> RragResult<f32> {
535 let _ = (query, document);
537 Ok(0.5) }
539
540 fn score_batch(&self, pairs: &[(String, String)]) -> RragResult<Vec<f32>> {
541 Ok(vec![0.5; pairs.len()])
542 }
543
544 fn model_info(&self) -> ModelInfo {
545 ModelInfo {
546 name: self.name.clone(),
547 version: "custom".to_string(),
548 max_sequence_length: 512,
549 parameters: None,
550 supports_attention: false,
551 }
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use crate::SearchResult;
559
560 #[tokio::test]
561 async fn test_cross_encoder_reranking() {
562 let config = CrossEncoderConfig::default();
563 let reranker = CrossEncoderReranker::new(config);
564
565 let results = vec![
566 SearchResult {
567 id: "doc1".to_string(),
568 content: "Machine learning is a subset of artificial intelligence".to_string(),
569 score: 0.8,
570 rank: 0,
571 metadata: Default::default(),
572 embedding: None,
573 },
574 SearchResult {
575 id: "doc2".to_string(),
576 content: "Deep learning uses neural networks with multiple layers".to_string(),
577 score: 0.6,
578 rank: 1,
579 metadata: Default::default(),
580 embedding: None,
581 },
582 ];
583
584 let query = "What is machine learning?";
585 let reranked_scores = reranker.rerank(query, &results).await.unwrap();
586
587 assert!(!reranked_scores.is_empty());
588 assert!(reranked_scores.contains_key(&0));
589 }
590
591 #[test]
592 fn test_simulated_bert_scoring() {
593 let model = SimulatedBertCrossEncoder::new();
594
595 let score = model
596 .score(
597 "machine learning",
598 "artificial intelligence and machine learning",
599 )
600 .unwrap();
601 assert!(score > 0.0);
602 assert!(score <= 1.0);
603
604 let high_score = model
606 .score("rust programming", "rust is a programming language")
607 .unwrap();
608 let low_score = model
609 .score("rust programming", "cooking recipes for dinner")
610 .unwrap();
611 assert!(high_score > low_score);
612 }
613
614 #[test]
615 fn test_batch_scoring() {
616 let model = SimulatedBertCrossEncoder::new();
617
618 let pairs = vec![
619 ("query1".to_string(), "relevant document".to_string()),
620 ("query2".to_string(), "another document".to_string()),
621 ];
622
623 let scores = model.score_batch(&pairs).unwrap();
624 assert_eq!(scores.len(), 2);
625 assert!(scores.iter().all(|&s| s >= 0.0 && s <= 1.0));
626 }
627
628 #[test]
629 fn test_attention_scores() {
630 let model = SimulatedBertCrossEncoder::new();
631
632 let attention = model
633 .get_attention_scores("machine learning", "artificial intelligence")
634 .unwrap();
635 assert!(attention.is_some());
636
637 let scores = attention.unwrap();
638 assert_eq!(scores.len(), 2); assert!(scores.iter().all(|&s| s >= 0.0 && s <= 1.0));
640 }
641}