rag_plusplus_core/retrieval/
rerank.rs

1//! Reranking Module
2//!
3//! Provides reranking algorithms for retrieved results.
4//!
5//! # Reranking Strategies
6//!
7//! - **Outcome-weighted**: Rerank based on historical outcome statistics
8//! - **Recency**: Boost more recent records
9//! - **MMR (Maximal Marginal Relevance)**: Diversify results
10//! - **Composite**: Combine multiple strategies
11
12use crate::retrieval::engine::RetrievedRecord;
13use ordered_float::OrderedFloat;
14
15/// Type of reranking algorithm.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum RerankerType {
18    /// No reranking (keep original order)
19    #[default]
20    None,
21    /// Rerank by outcome statistics (mean + confidence)
22    OutcomeWeighted,
23    /// Boost recent records
24    Recency,
25    /// Maximal Marginal Relevance for diversity
26    MMR,
27    /// Combine outcome and recency
28    Composite,
29}
30
31/// Reranker configuration.
32#[derive(Debug, Clone)]
33pub struct RerankerConfig {
34    /// Reranking strategy
35    pub strategy: RerankerType,
36    /// Weight for original score (0-1)
37    pub original_weight: f32,
38    /// Weight for outcome score (0-1)
39    pub outcome_weight: f32,
40    /// Weight for recency (0-1)
41    pub recency_weight: f32,
42    /// Recency decay half-life in seconds
43    pub recency_half_life: f64,
44    /// MMR lambda (0 = pure diversity, 1 = pure relevance)
45    pub mmr_lambda: f32,
46    /// Minimum sample count for outcome weighting
47    pub min_samples: u64,
48}
49
50impl Default for RerankerConfig {
51    fn default() -> Self {
52        Self {
53            strategy: RerankerType::OutcomeWeighted,
54            original_weight: 0.5,
55            outcome_weight: 0.3,
56            recency_weight: 0.2,
57            recency_half_life: 86400.0 * 7.0, // 7 days
58            mmr_lambda: 0.7,
59            min_samples: 3,
60        }
61    }
62}
63
64impl RerankerConfig {
65    /// Create new config with defaults.
66    #[must_use]
67    pub fn new() -> Self {
68        Self::default()
69    }
70
71    /// Set strategy.
72    #[must_use]
73    pub const fn with_strategy(mut self, strategy: RerankerType) -> Self {
74        self.strategy = strategy;
75        self
76    }
77
78    /// Set outcome weight.
79    #[must_use]
80    pub const fn with_outcome_weight(mut self, weight: f32) -> Self {
81        self.outcome_weight = weight;
82        self
83    }
84
85    /// Set MMR lambda.
86    #[must_use]
87    pub const fn with_mmr_lambda(mut self, lambda: f32) -> Self {
88        self.mmr_lambda = lambda;
89        self
90    }
91}
92
93/// Reranker for improving result ordering.
94pub struct Reranker {
95    config: RerankerConfig,
96}
97
98impl Reranker {
99    /// Create a new reranker.
100    #[must_use]
101    pub fn new(config: RerankerConfig) -> Self {
102        Self { config }
103    }
104
105    /// Rerank results according to configured strategy.
106    #[must_use]
107    pub fn rerank(&self, results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
108        match self.config.strategy {
109            RerankerType::None => results,
110            RerankerType::OutcomeWeighted => self.rerank_by_outcome(results),
111            RerankerType::Recency => self.rerank_by_recency(results),
112            RerankerType::MMR => self.rerank_mmr(results),
113            RerankerType::Composite => self.rerank_composite(results),
114        }
115    }
116
117    /// Rerank by outcome statistics.
118    fn rerank_by_outcome(&self, mut results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
119        for result in &mut results {
120            let outcome_score = self.compute_outcome_score(&result.record);
121            result.score = self.config.original_weight * result.score
122                + self.config.outcome_weight * outcome_score;
123        }
124
125        results.sort_by(|a, b| {
126            OrderedFloat(b.score).cmp(&OrderedFloat(a.score))
127        });
128
129        results
130    }
131
132    /// Compute outcome score for a record.
133    fn compute_outcome_score(&self, record: &crate::types::MemoryRecord) -> f32 {
134        let stats = &record.stats;
135
136        if stats.count() < self.config.min_samples {
137            // Not enough data, use initial outcome
138            return record.outcome as f32;
139        }
140
141        // Use lower bound of confidence interval for conservative estimate
142        // This implements "optimistic pessimism" - we're optimistic about
143        // exploring but pessimistic in our estimates
144        if let Some((lower, _upper)) = stats.confidence_interval(0.90) {
145            // Return first dimension's lower bound
146            lower.first().copied().unwrap_or(record.outcome as f32)
147        } else {
148            record.outcome as f32
149        }
150    }
151
152    /// Rerank by recency.
153    fn rerank_by_recency(&self, mut results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
154        let now = current_time_secs();
155
156        for result in &mut results {
157            let age_secs = (now - result.record.created_at) as f64;
158            let recency_score = self.compute_recency_score(age_secs);
159
160            result.score = self.config.original_weight * result.score
161                + self.config.recency_weight * recency_score;
162        }
163
164        results.sort_by(|a, b| {
165            OrderedFloat(b.score).cmp(&OrderedFloat(a.score))
166        });
167
168        results
169    }
170
171    /// Compute recency score with exponential decay.
172    fn compute_recency_score(&self, age_secs: f64) -> f32 {
173        // Exponential decay: score = exp(-age / half_life * ln(2))
174        let decay = (-age_secs / self.config.recency_half_life * std::f64::consts::LN_2).exp();
175        decay as f32
176    }
177
178    /// Rerank using MMR for diversity.
179    fn rerank_mmr(&self, results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
180        if results.len() <= 1 {
181            return results;
182        }
183
184        let lambda = self.config.mmr_lambda;
185        let mut reranked = Vec::with_capacity(results.len());
186        let mut remaining: Vec<_> = results.into_iter().collect();
187
188        // Select first by pure relevance
189        remaining.sort_by(|a, b| OrderedFloat(b.score).cmp(&OrderedFloat(a.score)));
190        reranked.push(remaining.remove(0));
191
192        // Select remaining by MMR
193        while !remaining.is_empty() {
194            let mut best_idx = 0;
195            let mut best_mmr = f32::NEG_INFINITY;
196
197            for (i, candidate) in remaining.iter().enumerate() {
198                // Relevance term
199                let relevance = candidate.score;
200
201                // Diversity term (max similarity to already selected)
202                let max_sim = reranked
203                    .iter()
204                    .map(|r| self.embedding_similarity(&candidate.record.embedding, &r.record.embedding))
205                    .fold(0.0f32, f32::max);
206
207                // MMR score
208                let mmr = lambda * relevance - (1.0 - lambda) * max_sim;
209
210                if mmr > best_mmr {
211                    best_mmr = mmr;
212                    best_idx = i;
213                }
214            }
215
216            reranked.push(remaining.remove(best_idx));
217        }
218
219        reranked
220    }
221
222    /// Compute cosine similarity between embeddings.
223    fn embedding_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
224        if a.len() != b.len() {
225            return 0.0;
226        }
227
228        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
229        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
230        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
231
232        if norm_a > 0.0 && norm_b > 0.0 {
233            dot / (norm_a * norm_b)
234        } else {
235            0.0
236        }
237    }
238
239    /// Composite reranking (outcome + recency).
240    fn rerank_composite(&self, mut results: Vec<RetrievedRecord>) -> Vec<RetrievedRecord> {
241        let now = current_time_secs();
242
243        for result in &mut results {
244            let outcome_score = self.compute_outcome_score(&result.record);
245            let age_secs = (now - result.record.created_at) as f64;
246            let recency_score = self.compute_recency_score(age_secs);
247
248            result.score = self.config.original_weight * result.score
249                + self.config.outcome_weight * outcome_score
250                + self.config.recency_weight * recency_score;
251        }
252
253        results.sort_by(|a, b| {
254            OrderedFloat(b.score).cmp(&OrderedFloat(a.score))
255        });
256
257        results
258    }
259}
260
261/// Get current time in seconds (Unix epoch).
262fn current_time_secs() -> u64 {
263    use std::time::{SystemTime, UNIX_EPOCH};
264
265    SystemTime::now()
266        .duration_since(UNIX_EPOCH)
267        .map(|d| d.as_secs())
268        .unwrap_or(0)
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::stats::OutcomeStats;
275    use crate::types::{MemoryRecord, RecordStatus};
276
277    fn create_test_result(id: &str, score: f32, outcome: f64, age_secs: u64) -> RetrievedRecord {
278        let now = current_time_secs();
279        let created_at = now.saturating_sub(age_secs);
280
281        RetrievedRecord {
282            record: MemoryRecord {
283                id: id.into(),
284                embedding: vec![1.0, 0.0, 0.0],
285                context: format!("Context {id}"),
286                outcome,
287                metadata: Default::default(),
288                created_at,
289                status: RecordStatus::Active,
290                stats: OutcomeStats::new(1),
291            },
292            score,
293            rank: 0,
294            source_index: "test".into(),
295        }
296    }
297
298    fn create_result_with_stats(id: &str, score: f32, outcomes: &[f64]) -> RetrievedRecord {
299        let mut stats = OutcomeStats::new(1);
300        for &o in outcomes {
301            stats.update_scalar(o);
302        }
303
304        RetrievedRecord {
305            record: MemoryRecord {
306                id: id.into(),
307                embedding: vec![1.0, 0.0, 0.0],
308                context: format!("Context {id}"),
309                outcome: outcomes.first().copied().unwrap_or(0.5),
310                metadata: Default::default(),
311                created_at: current_time_secs(),
312                status: RecordStatus::Active,
313                stats,
314            },
315            score,
316            rank: 0,
317            source_index: "test".into(),
318        }
319    }
320
321    #[test]
322    fn test_no_reranking() {
323        let reranker = Reranker::new(RerankerConfig::new().with_strategy(RerankerType::None));
324
325        let results = vec![
326            create_test_result("a", 0.9, 0.5, 0),
327            create_test_result("b", 0.8, 0.9, 0),
328        ];
329
330        let reranked = reranker.rerank(results);
331
332        assert_eq!(reranked[0].record.id.as_str(), "a");
333        assert_eq!(reranked[1].record.id.as_str(), "b");
334    }
335
336    #[test]
337    fn test_outcome_reranking() {
338        let reranker = Reranker::new(
339            RerankerConfig::new()
340                .with_strategy(RerankerType::OutcomeWeighted)
341                .with_outcome_weight(0.8),
342        );
343
344        // b has better outcome stats
345        let results = vec![
346            create_result_with_stats("a", 0.9, &[0.3, 0.4, 0.3, 0.4]),
347            create_result_with_stats("b", 0.8, &[0.9, 0.8, 0.9, 0.85]),
348        ];
349
350        let reranked = reranker.rerank(results);
351
352        // b should be ranked higher due to better outcomes
353        assert_eq!(reranked[0].record.id.as_str(), "b");
354    }
355
356    #[test]
357    fn test_recency_reranking() {
358        let reranker = Reranker::new(
359            RerankerConfig::new()
360                .with_strategy(RerankerType::Recency),
361        );
362
363        let results = vec![
364            create_test_result("old", 0.9, 0.5, 86400 * 30), // 30 days old
365            create_test_result("new", 0.8, 0.5, 3600),        // 1 hour old
366        ];
367
368        let reranked = reranker.rerank(results);
369
370        // new should be ranked higher due to recency
371        assert_eq!(reranked[0].record.id.as_str(), "new");
372    }
373
374    #[test]
375    fn test_mmr_diversity() {
376        let reranker = Reranker::new(
377            RerankerConfig::new()
378                .with_strategy(RerankerType::MMR)
379                .with_mmr_lambda(0.5),
380        );
381
382        // Create results with similar embeddings
383        let mut results = vec![
384            RetrievedRecord {
385                record: MemoryRecord {
386                    id: "a".into(),
387                    embedding: vec![1.0, 0.0, 0.0],
388                    context: "a".into(),
389                    outcome: 0.5,
390                    metadata: Default::default(),
391                    created_at: 0,
392                    status: RecordStatus::Active,
393                    stats: OutcomeStats::new(1),
394                },
395                score: 0.95,
396                rank: 0,
397                source_index: "test".into(),
398            },
399            RetrievedRecord {
400                record: MemoryRecord {
401                    id: "b".into(),
402                    embedding: vec![0.99, 0.01, 0.0], // Very similar to a
403                    context: "b".into(),
404                    outcome: 0.5,
405                    metadata: Default::default(),
406                    created_at: 0,
407                    status: RecordStatus::Active,
408                    stats: OutcomeStats::new(1),
409                },
410                score: 0.9,
411                rank: 0,
412                source_index: "test".into(),
413            },
414            RetrievedRecord {
415                record: MemoryRecord {
416                    id: "c".into(),
417                    embedding: vec![0.0, 1.0, 0.0], // Different from a
418                    context: "c".into(),
419                    outcome: 0.5,
420                    metadata: Default::default(),
421                    created_at: 0,
422                    status: RecordStatus::Active,
423                    stats: OutcomeStats::new(1),
424                },
425                score: 0.85,
426                rank: 0,
427                source_index: "test".into(),
428            },
429        ];
430
431        let reranked = reranker.rerank(results);
432
433        // First should still be "a" (highest score)
434        assert_eq!(reranked[0].record.id.as_str(), "a");
435
436        // Second should be "c" (diverse) despite lower score
437        // because "b" is too similar to "a"
438        assert_eq!(reranked[1].record.id.as_str(), "c");
439    }
440
441    #[test]
442    fn test_composite_reranking() {
443        let reranker = Reranker::new(
444            RerankerConfig::new().with_strategy(RerankerType::Composite),
445        );
446
447        let results = vec![
448            create_test_result("a", 0.9, 0.5, 86400 * 30),
449            create_test_result("b", 0.7, 0.9, 3600),
450        ];
451
452        let reranked = reranker.rerank(results);
453
454        // Results should be reranked based on combined factors
455        assert_eq!(reranked.len(), 2);
456    }
457
458    #[test]
459    fn test_empty_results() {
460        let reranker = Reranker::new(RerankerConfig::new());
461        let results = Vec::new();
462        let reranked = reranker.rerank(results);
463        assert!(reranked.is_empty());
464    }
465
466    #[test]
467    fn test_single_result() {
468        let reranker = Reranker::new(RerankerConfig::new());
469        let results = vec![create_test_result("a", 0.9, 0.5, 0)];
470        let reranked = reranker.rerank(results);
471        assert_eq!(reranked.len(), 1);
472    }
473}