rexis_rag/reranking/
mod.rs

1//! # Advanced Reranking Module
2//!
3//! State-of-the-art reranking algorithms for improving retrieval relevance and precision.
4//!
5//! This module provides multiple reranking strategies to refine initial retrieval results,
6//! ensuring the most relevant documents are ranked at the top. It supports various
7//! approaches from simple score-based reranking to sophisticated neural models.
8//!
9//! ## Features
10//!
11//! - **Cross-Encoder Reranking**: Transformer-based relevance scoring
12//! - **Neural Reranking**: Deep learning models for relevance prediction
13//! - **Learning-to-Rank**: Machine learning ranking with feature engineering
14//! - **Multi-Signal Fusion**: Combine multiple ranking signals
15//! - **Diversity Reranking**: Ensure result diversity while maintaining relevance
16//!
17//! ## Performance
18//!
19//! - Batch processing for efficiency
20//! - Async operations for non-blocking I/O
21//! - Caching of model predictions
22//! - GPU acceleration support (when available)
23//!
24//! ## Examples
25//!
26//! ### Cross-Encoder Reranking
27//! ```rust
28//! use rrag::reranking::{CrossEncoderReranker, CrossEncoderConfig};
29//! use rrag::Document;
30//!
31//! # async fn example() -> rrag::RragResult<()> {
32//! let reranker = CrossEncoderReranker::new(
33//!     CrossEncoderConfig::default()
34//!         .with_model("cross-encoder/ms-marco-MiniLM-L-12-v2")
35//!         .with_batch_size(32)
36//! );
37//!
38//! let query = "What is machine learning?";
39//! let documents = vec![
40//!     Document::new("Machine learning is a subset of AI..."),
41//!     Document::new("Deep learning uses neural networks..."),
42//!     Document::new("Python is a programming language..."),
43//! ];
44//!
45//! let reranked = reranker.rerank(query, documents).await?;
46//! // Most relevant documents are now at the top
47//! # Ok(())
48//! # }
49//! ```
50//!
51//! ### Learning-to-Rank
52//! ```rust
53//! use rrag::reranking::{LearningToRankReranker, LTRConfig};
54//!
55//! # async fn example() -> rrag::RragResult<()> {
56//! let ltr_reranker = LearningToRankReranker::new(
57//!     LTRConfig::default()
58//!         .with_features(vec!["bm25", "tfidf", "semantic_similarity"])
59//!         .with_model("lightgbm")
60//! );
61//!
62//! let reranked = ltr_reranker.rerank_with_features(
63//!     query,
64//!     documents,
65//!     feature_matrix
66//! ).await?;
67//! # Ok(())
68//! # }
69//! ```
70//!
71//! ### Multi-Signal Reranking
72//! ```rust
73//! use rrag::reranking::{MultiSignalReranker, SignalWeight};
74//!
75//! # async fn example() -> rrag::RragResult<()> {
76//! let multi_reranker = MultiSignalReranker::new()
77//!     .add_signal("relevance", 0.6)
78//!     .add_signal("recency", 0.2)
79//!     .add_signal("popularity", 0.2);
80//!
81//! let reranked = multi_reranker.rerank_multi_signal(
82//!     documents,
83//!     signal_scores
84//! ).await?;
85//! # Ok(())
86//! # }
87//! ```
88
89use crate::RragResult;
90
91pub mod cross_encoder;
92pub mod learning_to_rank;
93pub mod multi_signal;
94pub mod neural_reranker;
95
96// Re-exports
97pub use cross_encoder::{
98    CrossEncoderConfig, CrossEncoderModel, CrossEncoderReranker, RerankedResult, RerankingStrategy,
99    ScoreAggregation,
100};
101pub use learning_to_rank::{
102    FeatureExtractor, FeatureType, LTRConfig, LTRFeatures, LTRModel, LearningToRankReranker,
103    RankingFeature,
104};
105pub use multi_signal::{
106    MultiSignalConfig, MultiSignalReranker, RelevanceSignal, SignalAggregation, SignalType,
107    SignalWeight,
108};
109pub use neural_reranker::{
110    AttentionMechanism, BertReranker, NeuralConfig, NeuralReranker, RobertaReranker,
111    TransformerReranker,
112};
113
114/// Main reranking interface that coordinates different reranking strategies
115pub struct AdvancedReranker {
116    /// Cross-encoder for query-document relevance
117    cross_encoder: Option<CrossEncoderReranker>,
118
119    /// Learning-to-rank model
120    ltr_model: Option<LearningToRankReranker>,
121
122    /// Multi-signal aggregation
123    multi_signal: Option<MultiSignalReranker>,
124
125    /// Neural reranking models
126    neural_reranker: Option<NeuralReranker>,
127
128    /// Configuration
129    config: AdvancedRerankingConfig,
130}
131
132/// Configuration for advanced reranking
133#[derive(Debug, Clone)]
134pub struct AdvancedRerankingConfig {
135    /// Enable cross-encoder reranking
136    pub enable_cross_encoder: bool,
137
138    /// Enable learning-to-rank
139    pub enable_ltr: bool,
140
141    /// Enable multi-signal reranking
142    pub enable_multi_signal: bool,
143
144    /// Enable neural reranking
145    pub enable_neural: bool,
146
147    /// Maximum number of candidates to rerank
148    pub max_candidates: usize,
149
150    /// Minimum score threshold
151    pub score_threshold: f32,
152
153    /// Reranking strategy priority order
154    pub strategy_order: Vec<RerankingStrategyType>,
155
156    /// Score combination method
157    pub score_combination: ScoreCombination,
158
159    /// Cache reranking results
160    pub enable_caching: bool,
161
162    /// Batch size for neural models
163    pub batch_size: usize,
164}
165
166impl Default for AdvancedRerankingConfig {
167    fn default() -> Self {
168        Self {
169            enable_cross_encoder: true,
170            enable_ltr: false,
171            enable_multi_signal: true,
172            enable_neural: false,
173            max_candidates: 100,
174            score_threshold: 0.1,
175            strategy_order: vec![
176                RerankingStrategyType::CrossEncoder,
177                RerankingStrategyType::MultiSignal,
178            ],
179            score_combination: ScoreCombination::Weighted(vec![0.7, 0.3]),
180            enable_caching: true,
181            batch_size: 32,
182        }
183    }
184}
185
186/// Types of reranking strategies
187#[derive(Debug, Clone, PartialEq)]
188pub enum RerankingStrategyType {
189    CrossEncoder,
190    LearningToRank,
191    MultiSignal,
192    Neural,
193}
194
195/// Methods for combining scores from multiple rerankers
196#[derive(Debug, Clone)]
197pub enum ScoreCombination {
198    /// Average all scores
199    Average,
200    /// Weighted combination
201    Weighted(Vec<f32>),
202    /// Maximum score
203    Max,
204    /// Minimum score
205    Min,
206    /// Learned combination (requires training)
207    Learned,
208}
209
210/// Result from advanced reranking
211#[derive(Debug, Clone)]
212pub struct AdvancedRerankedResult {
213    /// Document identifier
214    pub document_id: String,
215
216    /// Final combined score
217    pub final_score: f32,
218
219    /// Individual scores from each reranker
220    pub component_scores: std::collections::HashMap<String, f32>,
221
222    /// Original retrieval rank
223    pub original_rank: usize,
224
225    /// New rank after reranking
226    pub new_rank: usize,
227
228    /// Confidence in the reranking decision
229    pub confidence: f32,
230
231    /// Explanation of the reranking decision
232    pub explanation: Option<String>,
233
234    /// Processing metadata
235    pub metadata: RerankingMetadata,
236}
237
238/// Metadata about the reranking process
239#[derive(Debug, Clone)]
240pub struct RerankingMetadata {
241    /// Time taken for reranking
242    pub reranking_time_ms: u64,
243
244    /// Rerankers used
245    pub rerankers_used: Vec<String>,
246
247    /// Features extracted
248    pub features_extracted: usize,
249
250    /// Model versions used
251    pub model_versions: std::collections::HashMap<String, String>,
252
253    /// Warnings or notices
254    pub warnings: Vec<String>,
255}
256
257impl AdvancedReranker {
258    /// Create a new advanced reranker
259    pub fn new(config: AdvancedRerankingConfig) -> Self {
260        Self {
261            cross_encoder: if config.enable_cross_encoder {
262                Some(CrossEncoderReranker::new(CrossEncoderConfig::default()))
263            } else {
264                None
265            },
266            ltr_model: if config.enable_ltr {
267                Some(LearningToRankReranker::new(LTRConfig::default()))
268            } else {
269                None
270            },
271            multi_signal: if config.enable_multi_signal {
272                Some(MultiSignalReranker::new(MultiSignalConfig::default()))
273            } else {
274                None
275            },
276            neural_reranker: if config.enable_neural {
277                Some(NeuralReranker::new(NeuralConfig::default()))
278            } else {
279                None
280            },
281            config,
282        }
283    }
284
285    /// Rerank a list of initial retrieval results
286    pub async fn rerank(
287        &self,
288        query: &str,
289        initial_results: Vec<crate::SearchResult>,
290    ) -> RragResult<Vec<AdvancedRerankedResult>> {
291        let start_time = std::time::Instant::now();
292
293        // Limit candidates if needed
294        let candidates: Vec<_> = initial_results
295            .into_iter()
296            .take(self.config.max_candidates)
297            .enumerate()
298            .collect();
299
300        let mut component_scores = std::collections::HashMap::new();
301        let mut rerankers_used = Vec::new();
302        let mut warnings = Vec::new();
303
304        // Apply reranking strategies in order
305        for strategy in &self.config.strategy_order {
306            match strategy {
307                RerankingStrategyType::CrossEncoder => {
308                    if let Some(ref cross_encoder) = self.cross_encoder {
309                        let candidate_results: Vec<_> = candidates
310                            .iter()
311                            .map(|(_, result)| result.clone())
312                            .collect();
313                        match cross_encoder.rerank(query, &candidate_results).await {
314                            Ok(scores) => {
315                                component_scores.insert("cross_encoder".to_string(), scores);
316                                rerankers_used.push("cross_encoder".to_string());
317                            }
318                            Err(e) => {
319                                warnings.push(format!("Cross-encoder failed: {}", e));
320                            }
321                        }
322                    }
323                }
324                RerankingStrategyType::MultiSignal => {
325                    if let Some(ref multi_signal) = self.multi_signal {
326                        let candidate_results: Vec<_> = candidates
327                            .iter()
328                            .map(|(_, result)| result.clone())
329                            .collect();
330                        match multi_signal.rerank(query, &candidate_results).await {
331                            Ok(scores) => {
332                                component_scores.insert("multi_signal".to_string(), scores);
333                                rerankers_used.push("multi_signal".to_string());
334                            }
335                            Err(e) => {
336                                warnings.push(format!("Multi-signal failed: {}", e));
337                            }
338                        }
339                    }
340                }
341                RerankingStrategyType::LearningToRank => {
342                    if let Some(ref ltr) = self.ltr_model {
343                        let candidate_results: Vec<_> = candidates
344                            .iter()
345                            .map(|(_, result)| result.clone())
346                            .collect();
347                        match ltr.rerank(query, &candidate_results).await {
348                            Ok(scores) => {
349                                component_scores.insert("ltr".to_string(), scores);
350                                rerankers_used.push("ltr".to_string());
351                            }
352                            Err(e) => {
353                                warnings.push(format!("LTR failed: {}", e));
354                            }
355                        }
356                    }
357                }
358                RerankingStrategyType::Neural => {
359                    if let Some(ref neural) = self.neural_reranker {
360                        let candidate_results: Vec<_> = candidates
361                            .iter()
362                            .map(|(_, result)| result.clone())
363                            .collect();
364                        match neural.rerank(query, &candidate_results).await {
365                            Ok(scores) => {
366                                component_scores.insert("neural".to_string(), scores);
367                                rerankers_used.push("neural".to_string());
368                            }
369                            Err(e) => {
370                                warnings.push(format!("Neural reranker failed: {}", e));
371                            }
372                        }
373                    }
374                }
375            }
376        }
377
378        // Combine scores
379        let final_scores = self.combine_scores(&component_scores, candidates.len());
380
381        // Create reranked results
382        let mut reranked_results: Vec<_> = candidates
383            .into_iter()
384            .enumerate()
385            .map(|(idx, (original_rank, result))| AdvancedRerankedResult {
386                document_id: result.id.clone(),
387                final_score: final_scores.get(&idx).copied().unwrap_or(result.score),
388                component_scores: component_scores
389                    .iter()
390                    .map(|(name, scores)| (name.clone(), scores.get(&idx).copied().unwrap_or(0.0)))
391                    .collect(),
392                original_rank,
393                new_rank: 0, // Will be filled after sorting
394                confidence: self.calculate_confidence(&component_scores, idx),
395                explanation: self.generate_explanation(&component_scores, idx),
396                metadata: RerankingMetadata {
397                    reranking_time_ms: start_time.elapsed().as_millis() as u64,
398                    rerankers_used: rerankers_used.clone(),
399                    features_extracted: 0, // Would be set by individual rerankers
400                    model_versions: std::collections::HashMap::new(),
401                    warnings: warnings.clone(),
402                },
403            })
404            .collect();
405
406        // Sort by final score
407        reranked_results.sort_by(|a, b| {
408            b.final_score
409                .partial_cmp(&a.final_score)
410                .unwrap_or(std::cmp::Ordering::Equal)
411        });
412
413        // Update new ranks
414        for (idx, result) in reranked_results.iter_mut().enumerate() {
415            result.new_rank = idx;
416        }
417
418        // Filter by score threshold
419        reranked_results.retain(|result| result.final_score >= self.config.score_threshold);
420
421        Ok(reranked_results)
422    }
423
424    /// Combine scores from different rerankers
425    fn combine_scores(
426        &self,
427        component_scores: &std::collections::HashMap<String, std::collections::HashMap<usize, f32>>,
428        num_candidates: usize,
429    ) -> std::collections::HashMap<usize, f32> {
430        let mut final_scores = std::collections::HashMap::new();
431
432        for idx in 0..num_candidates {
433            let scores: Vec<f32> = component_scores
434                .values()
435                .map(|scores| scores.get(&idx).copied().unwrap_or(0.0))
436                .collect();
437
438            let final_score = match &self.config.score_combination {
439                ScoreCombination::Average => {
440                    if scores.is_empty() {
441                        0.0
442                    } else {
443                        scores.iter().sum::<f32>() / scores.len() as f32
444                    }
445                }
446                ScoreCombination::Weighted(weights) => scores
447                    .iter()
448                    .zip(weights.iter())
449                    .map(|(score, weight)| score * weight)
450                    .sum::<f32>(),
451                ScoreCombination::Max => scores.iter().fold(0.0f32, |a, &b| a.max(b)),
452                ScoreCombination::Min => scores.iter().fold(1.0f32, |a, &b| a.min(b)),
453                ScoreCombination::Learned => {
454                    // Would use a learned combination model
455                    if scores.is_empty() {
456                        0.0
457                    } else {
458                        scores.iter().sum::<f32>() / scores.len() as f32
459                    }
460                }
461            };
462
463            final_scores.insert(idx, final_score);
464        }
465
466        final_scores
467    }
468
469    /// Calculate confidence in the reranking decision
470    fn calculate_confidence(
471        &self,
472        component_scores: &std::collections::HashMap<String, std::collections::HashMap<usize, f32>>,
473        idx: usize,
474    ) -> f32 {
475        // Simple confidence calculation based on score agreement
476        let scores: Vec<f32> = component_scores
477            .values()
478            .map(|scores| scores.get(&idx).copied().unwrap_or(0.0))
479            .collect();
480
481        if scores.len() < 2 {
482            return 0.5; // Low confidence with only one scorer
483        }
484
485        // Calculate standard deviation as inverse confidence
486        let mean = scores.iter().sum::<f32>() / scores.len() as f32;
487        let variance = scores
488            .iter()
489            .map(|score| (score - mean).powi(2))
490            .sum::<f32>()
491            / scores.len() as f32;
492        let std_dev = variance.sqrt();
493
494        // Convert to confidence (lower std_dev = higher confidence)
495        (1.0 - std_dev.min(1.0)).max(0.0)
496    }
497
498    /// Generate explanation for reranking decision
499    fn generate_explanation(
500        &self,
501        component_scores: &std::collections::HashMap<String, std::collections::HashMap<usize, f32>>,
502        idx: usize,
503    ) -> Option<String> {
504        let scores: Vec<(String, f32)> = component_scores
505            .iter()
506            .map(|(name, scores)| (name.clone(), scores.get(&idx).copied().unwrap_or(0.0)))
507            .collect();
508
509        if scores.is_empty() {
510            return None;
511        }
512
513        let mut explanations = Vec::new();
514
515        for (reranker, score) in &scores {
516            match reranker.as_str() {
517                "cross_encoder" => {
518                    explanations.push(format!("Cross-encoder relevance: {:.3}", score));
519                }
520                "multi_signal" => {
521                    explanations.push(format!("Multi-signal analysis: {:.3}", score));
522                }
523                "ltr" => {
524                    explanations.push(format!("Learning-to-rank: {:.3}", score));
525                }
526                "neural" => {
527                    explanations.push(format!("Neural reranker: {:.3}", score));
528                }
529                _ => {
530                    explanations.push(format!("{}: {:.3}", reranker, score));
531                }
532            }
533        }
534
535        Some(explanations.join("; "))
536    }
537
538    /// Update configuration
539    pub fn update_config(&mut self, config: AdvancedRerankingConfig) {
540        self.config = config;
541    }
542
543    /// Get current configuration
544    pub fn get_config(&self) -> &AdvancedRerankingConfig {
545        &self.config
546    }
547}