Skip to main content

trueno_rag/
rerank.rs

1//! Reranking module for RAG pipelines
2
3use crate::{retrieve::RetrievalResult, Result};
4use serde::{Deserialize, Serialize};
5
6/// Trait for reranking retrieved results
7pub trait Reranker: Send + Sync {
8    /// Rerank candidates given a query
9    fn rerank(
10        &self,
11        query: &str,
12        candidates: &[RetrievalResult],
13        top_k: usize,
14    ) -> Result<Vec<RetrievalResult>>;
15}
16
17/// Lexical reranker using simple text matching features
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct LexicalReranker {
20    /// Weight for exact query match
21    pub exact_match_weight: f32,
22    /// Weight for query term coverage
23    pub coverage_weight: f32,
24    /// Weight for position bias (earlier terms = better)
25    pub position_weight: f32,
26    /// Whether to lowercase for matching
27    pub case_insensitive: bool,
28}
29
30impl Default for LexicalReranker {
31    fn default() -> Self {
32        Self {
33            exact_match_weight: 0.3,
34            coverage_weight: 0.5,
35            position_weight: 0.2,
36            case_insensitive: true,
37        }
38    }
39}
40
41impl LexicalReranker {
42    /// Create a new lexical reranker
43    #[must_use]
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    /// Set weights
49    #[must_use]
50    pub fn with_weights(mut self, exact_match: f32, coverage: f32, position: f32) -> Self {
51        self.exact_match_weight = exact_match;
52        self.coverage_weight = coverage;
53        self.position_weight = position;
54        self
55    }
56
57    /// Calculate rerank score for a single candidate
58    fn score(&self, query: &str, content: &str) -> f32 {
59        let (query, content) = if self.case_insensitive {
60            (query.to_lowercase(), content.to_lowercase())
61        } else {
62            (query.to_string(), content.to_string())
63        };
64
65        let query_terms: Vec<&str> = query.split_whitespace().collect();
66        if query_terms.is_empty() {
67            return 0.0;
68        }
69
70        // Exact match score
71        let exact_match = if content.contains(&query) { 1.0 } else { 0.0 };
72
73        // Coverage score: what fraction of query terms appear in content
74        let matches = query_terms.iter().filter(|term| content.contains(*term)).count() as f32;
75        let coverage = matches / query_terms.len().max(1) as f32;
76
77        // Position score: how early do query terms appear
78        let position_score = query_terms
79            .iter()
80            .filter_map(|term| content.find(term))
81            .map(|pos| 1.0 / (1.0 + pos as f32 / 100.0))
82            .sum::<f32>()
83            / query_terms.len().max(1) as f32;
84
85        self.exact_match_weight * exact_match
86            + self.coverage_weight * coverage
87            + self.position_weight * position_score
88    }
89}
90
91impl Reranker for LexicalReranker {
92    fn rerank(
93        &self,
94        query: &str,
95        candidates: &[RetrievalResult],
96        top_k: usize,
97    ) -> Result<Vec<RetrievalResult>> {
98        let mut scored: Vec<(RetrievalResult, f32)> = candidates
99            .iter()
100            .map(|c| {
101                let score = self.score(query, &c.chunk.content);
102                (c.clone(), score)
103            })
104            .collect();
105
106        // Sort by rerank score descending
107        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
108
109        // Take top_k and set rerank score
110        Ok(scored
111            .into_iter()
112            .take(top_k)
113            .map(|(mut result, score)| {
114                result.rerank_score = Some(score);
115                result
116            })
117            .collect())
118    }
119}
120
121/// Mock cross-encoder reranker for testing
122#[derive(Debug, Clone)]
123pub struct MockCrossEncoderReranker {
124    /// Model identifier
125    model_id: String,
126}
127
128impl MockCrossEncoderReranker {
129    /// Create a new mock cross-encoder
130    #[must_use]
131    pub fn new(model_id: impl Into<String>) -> Self {
132        Self { model_id: model_id.into() }
133    }
134
135    /// Get the model ID
136    #[must_use]
137    pub fn model_id(&self) -> &str {
138        &self.model_id
139    }
140
141    /// Score a query-document pair (mock implementation)
142    #[allow(clippy::unused_self)]
143    fn score_pair(&self, query: &str, document: &str) -> f32 {
144        // Simple mock: based on term overlap
145        let query_lower = query.to_lowercase();
146        let doc_lower = document.to_lowercase();
147
148        let query_terms: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
149        let doc_terms: std::collections::HashSet<&str> = doc_lower.split_whitespace().collect();
150
151        if query_terms.is_empty() || doc_terms.is_empty() {
152            return 0.0;
153        }
154
155        let overlap = query_terms.intersection(&doc_terms).count();
156        overlap as f32 / query_terms.len() as f32
157    }
158}
159
160impl Reranker for MockCrossEncoderReranker {
161    fn rerank(
162        &self,
163        query: &str,
164        candidates: &[RetrievalResult],
165        top_k: usize,
166    ) -> Result<Vec<RetrievalResult>> {
167        let mut scored: Vec<(RetrievalResult, f32)> = candidates
168            .iter()
169            .map(|c| {
170                let score = self.score_pair(query, &c.chunk.content);
171                (c.clone(), score)
172            })
173            .collect();
174
175        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
176
177        Ok(scored
178            .into_iter()
179            .take(top_k)
180            .map(|(mut result, score)| {
181                result.rerank_score = Some(score);
182                result
183            })
184            .collect())
185    }
186}
187
188/// Composite reranker that combines multiple rerankers
189pub struct CompositeReranker {
190    rerankers: Vec<(Box<dyn Reranker>, f32)>,
191}
192
193impl CompositeReranker {
194    /// Create a new composite reranker
195    #[must_use]
196    pub fn new() -> Self {
197        Self { rerankers: Vec::new() }
198    }
199
200    /// Add a reranker with a weight
201    #[must_use]
202    pub fn with_reranker(mut self, reranker: Box<dyn Reranker>, weight: f32) -> Self {
203        self.rerankers.push((reranker, weight));
204        self
205    }
206}
207
208impl Default for CompositeReranker {
209    fn default() -> Self {
210        Self::new()
211    }
212}
213
214impl Reranker for CompositeReranker {
215    fn rerank(
216        &self,
217        query: &str,
218        candidates: &[RetrievalResult],
219        top_k: usize,
220    ) -> Result<Vec<RetrievalResult>> {
221        if self.rerankers.is_empty() {
222            return Ok(candidates.iter().take(top_k).cloned().collect());
223        }
224
225        // Get scores from each reranker
226        let mut combined_scores: std::collections::HashMap<usize, f32> =
227            std::collections::HashMap::new();
228
229        for (reranker, weight) in &self.rerankers {
230            let reranked = reranker.rerank(query, candidates, candidates.len())?;
231            for result in &reranked {
232                // Find original index
233                for (orig_idx, orig) in candidates.iter().enumerate() {
234                    if orig.chunk.id == result.chunk.id {
235                        let score = result.rerank_score.unwrap_or(0.0);
236                        *combined_scores.entry(orig_idx).or_insert(0.0) += weight * score;
237                        break;
238                    }
239                }
240            }
241        }
242
243        // Sort by combined score
244        let mut indexed: Vec<_> = combined_scores.into_iter().collect();
245        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
246
247        Ok(indexed
248            .into_iter()
249            .take(top_k)
250            .map(|(idx, score)| {
251                let mut result = candidates[idx].clone();
252                result.rerank_score = Some(score);
253                result
254            })
255            .collect())
256    }
257}
258
259/// No-op reranker that just returns candidates in original order
260#[derive(Debug, Clone, Default)]
261pub struct NoOpReranker;
262
263impl NoOpReranker {
264    /// Create a new no-op reranker
265    #[must_use]
266    pub fn new() -> Self {
267        Self
268    }
269}
270
271impl Reranker for NoOpReranker {
272    fn rerank(
273        &self,
274        _query: &str,
275        candidates: &[RetrievalResult],
276        top_k: usize,
277    ) -> Result<Vec<RetrievalResult>> {
278        Ok(candidates.iter().take(top_k).cloned().collect())
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::{Chunk, DocumentId};
286
287    fn create_result(content: &str) -> RetrievalResult {
288        let chunk = Chunk::new(DocumentId::new(), content.to_string(), 0, content.len());
289        RetrievalResult::new(chunk)
290    }
291
292    fn create_result_with_score(content: &str, dense: f32) -> RetrievalResult {
293        create_result(content).with_dense_score(dense)
294    }
295
296    // ============ LexicalReranker Tests ============
297
298    #[test]
299    fn test_lexical_reranker_default() {
300        let reranker = LexicalReranker::default();
301        assert!((reranker.exact_match_weight - 0.3).abs() < 0.01);
302        assert!((reranker.coverage_weight - 0.5).abs() < 0.01);
303        assert!((reranker.position_weight - 0.2).abs() < 0.01);
304        assert!(reranker.case_insensitive);
305    }
306
307    #[test]
308    fn test_lexical_reranker_with_weights() {
309        let reranker = LexicalReranker::new().with_weights(0.5, 0.3, 0.2);
310        assert!((reranker.exact_match_weight - 0.5).abs() < 0.01);
311    }
312
313    #[test]
314    fn test_lexical_reranker_exact_match() {
315        let reranker = LexicalReranker::new();
316        let candidates = vec![
317            create_result("This contains the exact query machine learning"),
318            create_result("This mentions machine and learning separately"),
319        ];
320
321        let results = reranker.rerank("machine learning", &candidates, 2).unwrap();
322
323        // Exact match should score higher
324        assert!(results[0].rerank_score.unwrap() > results[1].rerank_score.unwrap());
325    }
326
327    #[test]
328    fn test_lexical_reranker_coverage() {
329        let reranker = LexicalReranker::new();
330        let candidates =
331            vec![create_result("machine learning algorithms"), create_result("machine only here")];
332
333        let results = reranker.rerank("machine learning neural networks", &candidates, 2).unwrap();
334
335        // First has 2 matches, second has 1
336        assert!(results[0].rerank_score.unwrap() >= results[1].rerank_score.unwrap());
337    }
338
339    #[test]
340    fn test_lexical_reranker_top_k() {
341        let reranker = LexicalReranker::new();
342        let candidates: Vec<_> = (0..10).map(|i| create_result(&format!("doc {i}"))).collect();
343
344        let results = reranker.rerank("doc", &candidates, 3).unwrap();
345        assert_eq!(results.len(), 3);
346    }
347
348    #[test]
349    fn test_lexical_reranker_empty_query() {
350        let reranker = LexicalReranker::new();
351        let candidates = vec![create_result("test content")];
352
353        let results = reranker.rerank("", &candidates, 10).unwrap();
354        assert_eq!(results.len(), 1);
355        assert!((results[0].rerank_score.unwrap() - 0.0).abs() < 0.001);
356    }
357
358    #[test]
359    fn test_lexical_reranker_case_insensitive() {
360        let reranker = LexicalReranker::new();
361        let candidates = vec![create_result("MACHINE LEARNING"), create_result("machine learning")];
362
363        let results = reranker.rerank("Machine Learning", &candidates, 2).unwrap();
364
365        // Both should score the same (case insensitive)
366        let diff = (results[0].rerank_score.unwrap() - results[1].rerank_score.unwrap()).abs();
367        assert!(diff < 0.01);
368    }
369
370    // ============ MockCrossEncoderReranker Tests ============
371
372    #[test]
373    fn test_mock_cross_encoder_new() {
374        let reranker = MockCrossEncoderReranker::new("ms-marco-MiniLM");
375        assert_eq!(reranker.model_id(), "ms-marco-MiniLM");
376    }
377
378    #[test]
379    fn test_mock_cross_encoder_rerank() {
380        let reranker = MockCrossEncoderReranker::new("test-model");
381        let candidates =
382            vec![create_result("machine learning algorithms"), create_result("cooking recipes")];
383
384        let results = reranker.rerank("machine learning", &candidates, 2).unwrap();
385
386        // First should score higher (more term overlap)
387        assert!(results[0].rerank_score.unwrap() > results[1].rerank_score.unwrap());
388    }
389
390    #[test]
391    fn test_mock_cross_encoder_empty() {
392        let reranker = MockCrossEncoderReranker::new("test-model");
393        let results = reranker.rerank("test", &[], 10).unwrap();
394        assert!(results.is_empty());
395    }
396
397    // ============ CompositeReranker Tests ============
398
399    #[test]
400    fn test_composite_reranker_empty() {
401        let reranker = CompositeReranker::new();
402        let candidates = vec![create_result("test")];
403
404        let results = reranker.rerank("test", &candidates, 10).unwrap();
405        assert_eq!(results.len(), 1);
406    }
407
408    #[test]
409    fn test_composite_reranker_single() {
410        let lexical = Box::new(LexicalReranker::new());
411        let reranker = CompositeReranker::new().with_reranker(lexical, 1.0);
412
413        let candidates =
414            vec![create_result("exact match query here"), create_result("no match at all")];
415
416        let results = reranker.rerank("query", &candidates, 2).unwrap();
417        assert_eq!(results.len(), 2);
418        assert!(results[0].rerank_score.is_some());
419    }
420
421    #[test]
422    fn test_composite_reranker_multiple() {
423        let lexical = Box::new(LexicalReranker::new());
424        let cross = Box::new(MockCrossEncoderReranker::new("test"));
425
426        let reranker =
427            CompositeReranker::new().with_reranker(lexical, 0.5).with_reranker(cross, 0.5);
428
429        let candidates =
430            vec![create_result("machine learning test"), create_result("unrelated content")];
431
432        let results = reranker.rerank("machine learning", &candidates, 2).unwrap();
433        assert_eq!(results.len(), 2);
434    }
435
436    // ============ NoOpReranker Tests ============
437
438    #[test]
439    fn test_noop_reranker() {
440        let reranker = NoOpReranker::new();
441        let candidates =
442            vec![create_result_with_score("first", 0.9), create_result_with_score("second", 0.8)];
443
444        let results = reranker.rerank("query", &candidates, 10).unwrap();
445
446        assert_eq!(results.len(), 2);
447        // Order should be preserved
448        assert!(results[0].chunk.content.contains("first"));
449    }
450
451    #[test]
452    fn test_noop_reranker_top_k() {
453        let reranker = NoOpReranker::new();
454        let candidates: Vec<_> = (0..10).map(|i| create_result(&format!("doc {i}"))).collect();
455
456        let results = reranker.rerank("query", &candidates, 3).unwrap();
457        assert_eq!(results.len(), 3);
458    }
459
460    // ============ Property-Based Tests ============
461
462    use proptest::prelude::*;
463
464    proptest! {
465        #[test]
466        fn prop_lexical_rerank_scores_bounded(
467            query in "[a-zA-Z ]{1,20}",
468            content in "[a-zA-Z ]{1,100}"
469        ) {
470            let reranker = LexicalReranker::new();
471            let candidates = vec![create_result(&content)];
472
473            let results = reranker.rerank(&query, &candidates, 1).unwrap();
474            let score = results[0].rerank_score.unwrap();
475
476            prop_assert!(score >= 0.0);
477            prop_assert!(score <= 1.0);
478        }
479
480        #[test]
481        fn prop_rerank_respects_top_k(k in 1usize..10, n in 1usize..20) {
482            let reranker = LexicalReranker::new();
483            let candidates: Vec<_> = (0..n)
484                .map(|i| create_result(&format!("document {i}")))
485                .collect();
486
487            let results = reranker.rerank("document", &candidates, k).unwrap();
488            prop_assert!(results.len() <= k);
489            prop_assert!(results.len() <= n);
490        }
491
492        #[test]
493        fn prop_noop_preserves_order(n in 1usize..10) {
494            let reranker = NoOpReranker::new();
495            let candidates: Vec<_> = (0..n)
496                .map(|i| create_result(&format!("doc {i}")))
497                .collect();
498
499            let results = reranker.rerank("query", &candidates, n).unwrap();
500
501            for (i, result) in results.iter().enumerate() {
502                let expected = format!("doc {i}");
503                prop_assert!(result.chunk.content.contains(&expected));
504            }
505        }
506    }
507}