1use crate::RragResult;
90
91pub mod cross_encoder;
92pub mod learning_to_rank;
93pub mod multi_signal;
94pub mod neural_reranker;
95
96pub use cross_encoder::{
98 CrossEncoderConfig, CrossEncoderModel, CrossEncoderReranker, RerankedResult, RerankingStrategy,
99 ScoreAggregation,
100};
101pub use learning_to_rank::{
102 FeatureExtractor, FeatureType, LTRConfig, LTRFeatures, LTRModel, LearningToRankReranker,
103 RankingFeature,
104};
105pub use multi_signal::{
106 MultiSignalConfig, MultiSignalReranker, RelevanceSignal, SignalAggregation, SignalType,
107 SignalWeight,
108};
109pub use neural_reranker::{
110 AttentionMechanism, BertReranker, NeuralConfig, NeuralReranker, RobertaReranker,
111 TransformerReranker,
112};
113
114pub struct AdvancedReranker {
116 cross_encoder: Option<CrossEncoderReranker>,
118
119 ltr_model: Option<LearningToRankReranker>,
121
122 multi_signal: Option<MultiSignalReranker>,
124
125 neural_reranker: Option<NeuralReranker>,
127
128 config: AdvancedRerankingConfig,
130}
131
132#[derive(Debug, Clone)]
134pub struct AdvancedRerankingConfig {
135 pub enable_cross_encoder: bool,
137
138 pub enable_ltr: bool,
140
141 pub enable_multi_signal: bool,
143
144 pub enable_neural: bool,
146
147 pub max_candidates: usize,
149
150 pub score_threshold: f32,
152
153 pub strategy_order: Vec<RerankingStrategyType>,
155
156 pub score_combination: ScoreCombination,
158
159 pub enable_caching: bool,
161
162 pub batch_size: usize,
164}
165
166impl Default for AdvancedRerankingConfig {
167 fn default() -> Self {
168 Self {
169 enable_cross_encoder: true,
170 enable_ltr: false,
171 enable_multi_signal: true,
172 enable_neural: false,
173 max_candidates: 100,
174 score_threshold: 0.1,
175 strategy_order: vec![
176 RerankingStrategyType::CrossEncoder,
177 RerankingStrategyType::MultiSignal,
178 ],
179 score_combination: ScoreCombination::Weighted(vec![0.7, 0.3]),
180 enable_caching: true,
181 batch_size: 32,
182 }
183 }
184}
185
186#[derive(Debug, Clone, PartialEq)]
188pub enum RerankingStrategyType {
189 CrossEncoder,
190 LearningToRank,
191 MultiSignal,
192 Neural,
193}
194
195#[derive(Debug, Clone)]
197pub enum ScoreCombination {
198 Average,
200 Weighted(Vec<f32>),
202 Max,
204 Min,
206 Learned,
208}
209
210#[derive(Debug, Clone)]
212pub struct AdvancedRerankedResult {
213 pub document_id: String,
215
216 pub final_score: f32,
218
219 pub component_scores: std::collections::HashMap<String, f32>,
221
222 pub original_rank: usize,
224
225 pub new_rank: usize,
227
228 pub confidence: f32,
230
231 pub explanation: Option<String>,
233
234 pub metadata: RerankingMetadata,
236}
237
238#[derive(Debug, Clone)]
240pub struct RerankingMetadata {
241 pub reranking_time_ms: u64,
243
244 pub rerankers_used: Vec<String>,
246
247 pub features_extracted: usize,
249
250 pub model_versions: std::collections::HashMap<String, String>,
252
253 pub warnings: Vec<String>,
255}
256
257impl AdvancedReranker {
258 pub fn new(config: AdvancedRerankingConfig) -> Self {
260 Self {
261 cross_encoder: if config.enable_cross_encoder {
262 Some(CrossEncoderReranker::new(CrossEncoderConfig::default()))
263 } else {
264 None
265 },
266 ltr_model: if config.enable_ltr {
267 Some(LearningToRankReranker::new(LTRConfig::default()))
268 } else {
269 None
270 },
271 multi_signal: if config.enable_multi_signal {
272 Some(MultiSignalReranker::new(MultiSignalConfig::default()))
273 } else {
274 None
275 },
276 neural_reranker: if config.enable_neural {
277 Some(NeuralReranker::new(NeuralConfig::default()))
278 } else {
279 None
280 },
281 config,
282 }
283 }
284
285 pub async fn rerank(
287 &self,
288 query: &str,
289 initial_results: Vec<crate::SearchResult>,
290 ) -> RragResult<Vec<AdvancedRerankedResult>> {
291 let start_time = std::time::Instant::now();
292
293 let candidates: Vec<_> = initial_results
295 .into_iter()
296 .take(self.config.max_candidates)
297 .enumerate()
298 .collect();
299
300 let mut component_scores = std::collections::HashMap::new();
301 let mut rerankers_used = Vec::new();
302 let mut warnings = Vec::new();
303
304 for strategy in &self.config.strategy_order {
306 match strategy {
307 RerankingStrategyType::CrossEncoder => {
308 if let Some(ref cross_encoder) = self.cross_encoder {
309 let candidate_results: Vec<_> = candidates
310 .iter()
311 .map(|(_, result)| result.clone())
312 .collect();
313 match cross_encoder.rerank(query, &candidate_results).await {
314 Ok(scores) => {
315 component_scores.insert("cross_encoder".to_string(), scores);
316 rerankers_used.push("cross_encoder".to_string());
317 }
318 Err(e) => {
319 warnings.push(format!("Cross-encoder failed: {}", e));
320 }
321 }
322 }
323 }
324 RerankingStrategyType::MultiSignal => {
325 if let Some(ref multi_signal) = self.multi_signal {
326 let candidate_results: Vec<_> = candidates
327 .iter()
328 .map(|(_, result)| result.clone())
329 .collect();
330 match multi_signal.rerank(query, &candidate_results).await {
331 Ok(scores) => {
332 component_scores.insert("multi_signal".to_string(), scores);
333 rerankers_used.push("multi_signal".to_string());
334 }
335 Err(e) => {
336 warnings.push(format!("Multi-signal failed: {}", e));
337 }
338 }
339 }
340 }
341 RerankingStrategyType::LearningToRank => {
342 if let Some(ref ltr) = self.ltr_model {
343 let candidate_results: Vec<_> = candidates
344 .iter()
345 .map(|(_, result)| result.clone())
346 .collect();
347 match ltr.rerank(query, &candidate_results).await {
348 Ok(scores) => {
349 component_scores.insert("ltr".to_string(), scores);
350 rerankers_used.push("ltr".to_string());
351 }
352 Err(e) => {
353 warnings.push(format!("LTR failed: {}", e));
354 }
355 }
356 }
357 }
358 RerankingStrategyType::Neural => {
359 if let Some(ref neural) = self.neural_reranker {
360 let candidate_results: Vec<_> = candidates
361 .iter()
362 .map(|(_, result)| result.clone())
363 .collect();
364 match neural.rerank(query, &candidate_results).await {
365 Ok(scores) => {
366 component_scores.insert("neural".to_string(), scores);
367 rerankers_used.push("neural".to_string());
368 }
369 Err(e) => {
370 warnings.push(format!("Neural reranker failed: {}", e));
371 }
372 }
373 }
374 }
375 }
376 }
377
378 let final_scores = self.combine_scores(&component_scores, candidates.len());
380
381 let mut reranked_results: Vec<_> = candidates
383 .into_iter()
384 .enumerate()
385 .map(|(idx, (original_rank, result))| AdvancedRerankedResult {
386 document_id: result.id.clone(),
387 final_score: final_scores.get(&idx).copied().unwrap_or(result.score),
388 component_scores: component_scores
389 .iter()
390 .map(|(name, scores)| (name.clone(), scores.get(&idx).copied().unwrap_or(0.0)))
391 .collect(),
392 original_rank,
393 new_rank: 0, confidence: self.calculate_confidence(&component_scores, idx),
395 explanation: self.generate_explanation(&component_scores, idx),
396 metadata: RerankingMetadata {
397 reranking_time_ms: start_time.elapsed().as_millis() as u64,
398 rerankers_used: rerankers_used.clone(),
399 features_extracted: 0, model_versions: std::collections::HashMap::new(),
401 warnings: warnings.clone(),
402 },
403 })
404 .collect();
405
406 reranked_results.sort_by(|a, b| {
408 b.final_score
409 .partial_cmp(&a.final_score)
410 .unwrap_or(std::cmp::Ordering::Equal)
411 });
412
413 for (idx, result) in reranked_results.iter_mut().enumerate() {
415 result.new_rank = idx;
416 }
417
418 reranked_results.retain(|result| result.final_score >= self.config.score_threshold);
420
421 Ok(reranked_results)
422 }
423
424 fn combine_scores(
426 &self,
427 component_scores: &std::collections::HashMap<String, std::collections::HashMap<usize, f32>>,
428 num_candidates: usize,
429 ) -> std::collections::HashMap<usize, f32> {
430 let mut final_scores = std::collections::HashMap::new();
431
432 for idx in 0..num_candidates {
433 let scores: Vec<f32> = component_scores
434 .values()
435 .map(|scores| scores.get(&idx).copied().unwrap_or(0.0))
436 .collect();
437
438 let final_score = match &self.config.score_combination {
439 ScoreCombination::Average => {
440 if scores.is_empty() {
441 0.0
442 } else {
443 scores.iter().sum::<f32>() / scores.len() as f32
444 }
445 }
446 ScoreCombination::Weighted(weights) => scores
447 .iter()
448 .zip(weights.iter())
449 .map(|(score, weight)| score * weight)
450 .sum::<f32>(),
451 ScoreCombination::Max => scores.iter().fold(0.0f32, |a, &b| a.max(b)),
452 ScoreCombination::Min => scores.iter().fold(1.0f32, |a, &b| a.min(b)),
453 ScoreCombination::Learned => {
454 if scores.is_empty() {
456 0.0
457 } else {
458 scores.iter().sum::<f32>() / scores.len() as f32
459 }
460 }
461 };
462
463 final_scores.insert(idx, final_score);
464 }
465
466 final_scores
467 }
468
469 fn calculate_confidence(
471 &self,
472 component_scores: &std::collections::HashMap<String, std::collections::HashMap<usize, f32>>,
473 idx: usize,
474 ) -> f32 {
475 let scores: Vec<f32> = component_scores
477 .values()
478 .map(|scores| scores.get(&idx).copied().unwrap_or(0.0))
479 .collect();
480
481 if scores.len() < 2 {
482 return 0.5; }
484
485 let mean = scores.iter().sum::<f32>() / scores.len() as f32;
487 let variance = scores
488 .iter()
489 .map(|score| (score - mean).powi(2))
490 .sum::<f32>()
491 / scores.len() as f32;
492 let std_dev = variance.sqrt();
493
494 (1.0 - std_dev.min(1.0)).max(0.0)
496 }
497
498 fn generate_explanation(
500 &self,
501 component_scores: &std::collections::HashMap<String, std::collections::HashMap<usize, f32>>,
502 idx: usize,
503 ) -> Option<String> {
504 let scores: Vec<(String, f32)> = component_scores
505 .iter()
506 .map(|(name, scores)| (name.clone(), scores.get(&idx).copied().unwrap_or(0.0)))
507 .collect();
508
509 if scores.is_empty() {
510 return None;
511 }
512
513 let mut explanations = Vec::new();
514
515 for (reranker, score) in &scores {
516 match reranker.as_str() {
517 "cross_encoder" => {
518 explanations.push(format!("Cross-encoder relevance: {:.3}", score));
519 }
520 "multi_signal" => {
521 explanations.push(format!("Multi-signal analysis: {:.3}", score));
522 }
523 "ltr" => {
524 explanations.push(format!("Learning-to-rank: {:.3}", score));
525 }
526 "neural" => {
527 explanations.push(format!("Neural reranker: {:.3}", score));
528 }
529 _ => {
530 explanations.push(format!("{}: {:.3}", reranker, score));
531 }
532 }
533 }
534
535 Some(explanations.join("; "))
536 }
537
538 pub fn update_config(&mut self, config: AdvancedRerankingConfig) {
540 self.config = config;
541 }
542
543 pub fn get_config(&self) -> &AdvancedRerankingConfig {
545 &self.config
546 }
547}