Skip to main content

oxirs_embed/
reranker.rs

1//! Result reranking for retrieval pipelines.
2//!
3//! Provides [`Reranker`]: a flexible reranking engine that supports
4//! cross-encoder scoring simulation, BM25 reranking, reciprocal rank fusion,
5//! score normalisation, top-k selection, threshold filtering, and batch
6//! reranking with rich statistics.
7//!
8//! ## Example
9//!
10//! ```rust
11//! use oxirs_embed::reranker::{Reranker, RerankerConfig, RerankMethod, Document};
12//!
13//! let config = RerankerConfig {
14//!     method: RerankMethod::Bm25,
15//!     top_k: 3,
16//!     score_threshold: Some(0.0),
17//!     normalize_scores: true,
18//! };
19//! let reranker = Reranker::new(config);
20//!
21//! let docs = vec![
22//!     Document { id: "d1".into(), text: "Rust programming language".into(), initial_score: 0.6 },
23//!     Document { id: "d2".into(), text: "Rust toolchain cargo".into(), initial_score: 0.4 },
24//! ];
25//! let results = reranker.rerank("cargo build", &docs);
26//! assert!(!results.is_empty());
27//! ```
28
29use std::collections::HashMap;
30
31// ─────────────────────────────────────────────────────────────────────────────
32// Public types
33// ─────────────────────────────────────────────────────────────────────────────
34
35/// A document candidate to be reranked.
36#[derive(Debug, Clone, PartialEq)]
37pub struct Document {
38    /// Unique identifier.
39    pub id: String,
40    /// Full text of the document.
41    pub text: String,
42    /// Score assigned by the initial retriever (e.g. cosine similarity).
43    pub initial_score: f64,
44}
45
46/// A ranked result produced by the reranker.
47#[derive(Debug, Clone, PartialEq)]
48pub struct RankedResult {
49    /// Document identifier.
50    pub id: String,
51    /// Reranking score (semantic of the value depends on the chosen method).
52    pub score: f64,
53    /// 1-based rank position in the final result set.
54    pub rank: usize,
55    /// Score delta compared with initial retriever score.
56    pub rank_shift: f64,
57}
58
59/// Reranking method to apply.
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum RerankMethod {
62    /// Simulated cross-encoder scoring: score(query, doc) via token overlap
63    /// and length normalisation.
64    CrossEncoder,
65    /// BM25 reranking using TF×IDF approximation.
66    Bm25,
67    /// Reciprocal Rank Fusion — combines the initial retriever rank with a
68    /// secondary BM25 ranking.
69    ReciprocalRankFusion,
70}
71
72/// Configuration for the [`Reranker`].
73#[derive(Debug, Clone)]
74pub struct RerankerConfig {
75    /// Reranking algorithm to use.
76    pub method: RerankMethod,
77    /// Maximum number of results to return after reranking.
78    pub top_k: usize,
79    /// Minimum score to keep a result (applied after normalisation if enabled).
80    /// `None` means no threshold.
81    pub score_threshold: Option<f64>,
82    /// When `true`, final scores are min-max normalised to `[0, 1]`.
83    pub normalize_scores: bool,
84}
85
86impl Default for RerankerConfig {
87    fn default() -> Self {
88        Self {
89            method: RerankMethod::Bm25,
90            top_k: 10,
91            score_threshold: None,
92            normalize_scores: false,
93        }
94    }
95}
96
97/// Score distribution statistics over a reranked set.
98#[derive(Debug, Clone)]
99pub struct RerankStats {
100    /// Number of documents in the reranked list.
101    pub count: usize,
102    /// Minimum score in the result set.
103    pub min_score: f64,
104    /// Maximum score in the result set.
105    pub max_score: f64,
106    /// Arithmetic mean of scores.
107    pub mean_score: f64,
108    /// Standard deviation of scores.
109    pub std_dev: f64,
110    /// Mean absolute rank shift relative to initial order.
111    pub mean_rank_shift: f64,
112}
113
114/// Batch input for reranking multiple queries at once.
115#[derive(Debug, Clone)]
116pub struct BatchRerankInput {
117    /// Query string.
118    pub query: String,
119    /// Candidate documents for this query.
120    pub documents: Vec<Document>,
121}
122
123/// Batch output for a single query's reranked results.
124#[derive(Debug, Clone)]
125pub struct BatchRerankOutput {
126    /// The original query string.
127    pub query: String,
128    /// Ranked results for this query.
129    pub results: Vec<RankedResult>,
130    /// Statistics over the results.
131    pub stats: RerankStats,
132}
133
134// ─────────────────────────────────────────────────────────────────────────────
135// BM25 helpers
136// ─────────────────────────────────────────────────────────────────────────────
137
138/// BM25 tuning parameters.
139const BM25_K1: f64 = 1.5;
140const BM25_B: f64 = 0.75;
141
142/// Tokenise text into lowercase terms.
143fn tokenise(text: &str) -> Vec<String> {
144    text.split_whitespace()
145        .map(|w| {
146            w.chars()
147                .filter(|c| c.is_alphanumeric())
148                .collect::<String>()
149                .to_lowercase()
150        })
151        .filter(|w| !w.is_empty())
152        .collect()
153}
154
155/// Build term-frequency map for a token list.
156fn term_freq(tokens: &[String]) -> HashMap<String, usize> {
157    let mut tf = HashMap::new();
158    for t in tokens {
159        *tf.entry(t.clone()).or_insert(0) += 1;
160    }
161    tf
162}
163
164/// Compute IDF for a term given document frequency and corpus size.
165/// Uses the classic Okapi BM25 IDF formula.
166fn idf(doc_freq: usize, num_docs: usize) -> f64 {
167    let n = num_docs as f64;
168    let df = doc_freq as f64;
169    ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
170}
171
172/// Score a single document against a query using BM25.
173fn bm25_score(
174    query_terms: &[String],
175    doc_tokens: &[String],
176    df_map: &HashMap<String, usize>,
177    num_docs: usize,
178    avg_dl: f64,
179) -> f64 {
180    let tf_map = term_freq(doc_tokens);
181    let dl = doc_tokens.len() as f64;
182    let mut score = 0.0_f64;
183    for term in query_terms {
184        let tf = *tf_map.get(term).unwrap_or(&0) as f64;
185        if tf == 0.0 {
186            continue;
187        }
188        let df = *df_map.get(term).unwrap_or(&0);
189        let idf_val = idf(df, num_docs);
190        let numerator = tf * (BM25_K1 + 1.0);
191        let denominator = tf + BM25_K1 * (1.0 - BM25_B + BM25_B * dl / avg_dl.max(1.0));
192        score += idf_val * numerator / denominator;
193    }
194    score
195}
196
197// ─────────────────────────────────────────────────────────────────────────────
198// Score normalisation helpers
199// ─────────────────────────────────────────────────────────────────────────────
200
201/// Min-max normalise a slice of scores to `[0, 1]`.
202fn min_max_normalize(scores: &[f64]) -> Vec<f64> {
203    let min = scores.iter().cloned().fold(f64::INFINITY, f64::min);
204    let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
205    let range = max - min;
206    if range < f64::EPSILON {
207        return vec![1.0; scores.len()];
208    }
209    scores.iter().map(|s| (s - min) / range).collect()
210}
211
212/// Z-score normalise a slice of scores (mean 0, std 1).
213pub fn z_score_normalize(scores: &[f64]) -> Vec<f64> {
214    if scores.is_empty() {
215        return Vec::new();
216    }
217    let n = scores.len() as f64;
218    let mean = scores.iter().sum::<f64>() / n;
219    let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f64>() / n;
220    let std = variance.sqrt();
221    if std < f64::EPSILON {
222        return vec![0.0; scores.len()];
223    }
224    scores.iter().map(|s| (s - mean) / std).collect()
225}
226
227// ─────────────────────────────────────────────────────────────────────────────
228// Reranker
229// ─────────────────────────────────────────────────────────────────────────────
230
231/// Reranker supporting cross-encoder simulation, BM25, and reciprocal rank
232/// fusion.
233pub struct Reranker {
234    config: RerankerConfig,
235}
236
237impl Reranker {
238    /// Create a new reranker with the given configuration.
239    pub fn new(config: RerankerConfig) -> Self {
240        Self { config }
241    }
242
243    /// Create a reranker with the default configuration.
244    pub fn with_defaults() -> Self {
245        Self::new(RerankerConfig::default())
246    }
247
248    /// Rerank `docs` for `query` and return top-k [`RankedResult`]s.
249    pub fn rerank(&self, query: &str, docs: &[Document]) -> Vec<RankedResult> {
250        if docs.is_empty() {
251            return Vec::new();
252        }
253        let scores = match self.config.method {
254            RerankMethod::CrossEncoder => self.cross_encoder_scores(query, docs),
255            RerankMethod::Bm25 => self.bm25_scores(query, docs),
256            RerankMethod::ReciprocalRankFusion => self.rrf_scores(query, docs),
257        };
258
259        self.finalize(docs, scores)
260    }
261
262    /// Rerank multiple (query, docs) pairs in a single call.
263    pub fn rerank_batch(&self, inputs: &[BatchRerankInput]) -> Vec<BatchRerankOutput> {
264        inputs
265            .iter()
266            .map(|input| {
267                let results = self.rerank(&input.query, &input.documents);
268                let stats = self.compute_stats(&results);
269                BatchRerankOutput {
270                    query: input.query.clone(),
271                    results,
272                    stats,
273                }
274            })
275            .collect()
276    }
277
278    /// Return statistics for the given ranked results.
279    pub fn compute_stats(&self, results: &[RankedResult]) -> RerankStats {
280        if results.is_empty() {
281            return RerankStats {
282                count: 0,
283                min_score: 0.0,
284                max_score: 0.0,
285                mean_score: 0.0,
286                std_dev: 0.0,
287                mean_rank_shift: 0.0,
288            };
289        }
290        let n = results.len() as f64;
291        let scores: Vec<f64> = results.iter().map(|r| r.score).collect();
292        let min_score = scores.iter().cloned().fold(f64::INFINITY, f64::min);
293        let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
294        let mean_score = scores.iter().sum::<f64>() / n;
295        let variance = scores.iter().map(|s| (s - mean_score).powi(2)).sum::<f64>() / n;
296        let std_dev = variance.sqrt();
297        let mean_rank_shift = results.iter().map(|r| r.rank_shift.abs()).sum::<f64>() / n;
298        RerankStats {
299            count: results.len(),
300            min_score,
301            max_score,
302            mean_score,
303            std_dev,
304            mean_rank_shift,
305        }
306    }
307
308    // ── Private scoring methods ──────────────────────────────────────────────
309
310    /// Simulated cross-encoder: token-overlap score normalised by query length.
311    fn cross_encoder_scores(&self, query: &str, docs: &[Document]) -> Vec<f64> {
312        let query_tokens: Vec<String> = tokenise(query);
313        let q_set: std::collections::HashSet<String> = query_tokens.iter().cloned().collect();
314        docs.iter()
315            .map(|doc| {
316                let doc_tokens = tokenise(&doc.text);
317                if q_set.is_empty() || doc_tokens.is_empty() {
318                    return 0.0;
319                }
320                let matches = doc_tokens.iter().filter(|t| q_set.contains(*t)).count();
321                let tf_norm = matches as f64 / doc_tokens.len() as f64;
322                let idf_weight = (matches as f64 + 1.0).ln() / (q_set.len() as f64 + 1.0).ln();
323                // Blend with initial score for cross-encoder flavour
324                0.6 * tf_norm + 0.2 * idf_weight + 0.2 * doc.initial_score
325            })
326            .collect()
327    }
328
329    /// BM25 scores over the document corpus.
330    fn bm25_scores(&self, query: &str, docs: &[Document]) -> Vec<f64> {
331        let query_terms = tokenise(query);
332        let tokenised: Vec<Vec<String>> = docs.iter().map(|d| tokenise(&d.text)).collect();
333        let num_docs = docs.len();
334        let total_len: usize = tokenised.iter().map(|t| t.len()).sum();
335        let avg_dl = total_len as f64 / num_docs as f64;
336
337        // Build document-frequency map
338        let mut df_map: HashMap<String, usize> = HashMap::new();
339        for toks in &tokenised {
340            let unique: std::collections::HashSet<&String> = toks.iter().collect();
341            for t in unique {
342                *df_map.entry(t.clone()).or_insert(0) += 1;
343            }
344        }
345
346        tokenised
347            .iter()
348            .map(|toks| bm25_score(&query_terms, toks, &df_map, num_docs, avg_dl))
349            .collect()
350    }
351
352    /// Reciprocal Rank Fusion: combines initial retriever rank + BM25 rank.
353    ///
354    /// RRF formula: `score = Σ 1 / (k + rank_i)` where `k = 60` (standard).
355    fn rrf_scores(&self, query: &str, docs: &[Document]) -> Vec<f64> {
356        const K: f64 = 60.0;
357
358        // Rank by initial score (descending)
359        let n = docs.len();
360        let mut initial_order: Vec<usize> = (0..n).collect();
361        initial_order.sort_by(|&a, &b| {
362            docs[b]
363                .initial_score
364                .partial_cmp(&docs[a].initial_score)
365                .unwrap_or(std::cmp::Ordering::Equal)
366        });
367        let mut rank_initial = vec![0usize; n];
368        for (rank, &idx) in initial_order.iter().enumerate() {
369            rank_initial[idx] = rank + 1;
370        }
371
372        // Rank by BM25 score (descending)
373        let bm25 = self.bm25_scores(query, docs);
374        let mut bm25_order: Vec<usize> = (0..n).collect();
375        bm25_order.sort_by(|&a, &b| {
376            bm25[b]
377                .partial_cmp(&bm25[a])
378                .unwrap_or(std::cmp::Ordering::Equal)
379        });
380        let mut rank_bm25 = vec![0usize; n];
381        for (rank, &idx) in bm25_order.iter().enumerate() {
382            rank_bm25[idx] = rank + 1;
383        }
384
385        (0..n)
386            .map(|i| 1.0 / (K + rank_initial[i] as f64) + 1.0 / (K + rank_bm25[i] as f64))
387            .collect()
388    }
389
390    /// Apply normalisation, threshold filtering, and top-k selection.
391    fn finalize(&self, docs: &[Document], raw_scores: Vec<f64>) -> Vec<RankedResult> {
392        // Build initial ranks from initial_score ordering
393        let n = docs.len();
394        let mut initial_order: Vec<usize> = (0..n).collect();
395        initial_order.sort_by(|&a, &b| {
396            docs[b]
397                .initial_score
398                .partial_cmp(&docs[a].initial_score)
399                .unwrap_or(std::cmp::Ordering::Equal)
400        });
401        let mut initial_rank = vec![0usize; n];
402        for (rank, &idx) in initial_order.iter().enumerate() {
403            initial_rank[idx] = rank + 1;
404        }
405
406        // Optionally normalise
407        let final_scores = if self.config.normalize_scores {
408            min_max_normalize(&raw_scores)
409        } else {
410            raw_scores.clone()
411        };
412
413        // Sort descending by reranked score
414        let mut order: Vec<usize> = (0..n).collect();
415        order.sort_by(|&a, &b| {
416            final_scores[b]
417                .partial_cmp(&final_scores[a])
418                .unwrap_or(std::cmp::Ordering::Equal)
419        });
420
421        let mut results: Vec<RankedResult> = order
422            .iter()
423            .enumerate()
424            .map(|(new_rank, &idx)| {
425                let rank_shift = initial_rank[idx] as f64 - (new_rank + 1) as f64;
426                RankedResult {
427                    id: docs[idx].id.clone(),
428                    score: final_scores[idx],
429                    rank: new_rank + 1,
430                    rank_shift,
431                }
432            })
433            .collect();
434
435        // Apply score threshold
436        if let Some(threshold) = self.config.score_threshold {
437            results.retain(|r| r.score >= threshold);
438        }
439
440        // Apply top-k
441        results.truncate(self.config.top_k);
442
443        // Re-assign rank after truncation
444        for (i, r) in results.iter_mut().enumerate() {
445            r.rank = i + 1;
446        }
447
448        results
449    }
450}
451
452// ─────────────────────────────────────────────────────────────────────────────
453// Tests
454// ─────────────────────────────────────────────────────────────────────────────
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    fn sample_docs() -> Vec<Document> {
461        vec![
462            Document {
463                id: "d1".into(),
464                text: "Rust is a systems programming language".into(),
465                initial_score: 0.9,
466            },
467            Document {
468                id: "d2".into(),
469                text: "Python is a high-level scripting language".into(),
470                initial_score: 0.7,
471            },
472            Document {
473                id: "d3".into(),
474                text: "Cargo is the Rust package manager and build tool".into(),
475                initial_score: 0.5,
476            },
477            Document {
478                id: "d4".into(),
479                text: "JavaScript runs in the browser".into(),
480                initial_score: 0.3,
481            },
482            Document {
483                id: "d5".into(),
484                text: "Rust ownership model ensures memory safety".into(),
485                initial_score: 0.6,
486            },
487        ]
488    }
489
490    // ── Cross-encoder ────────────────────────────────────────────────────────
491
492    #[test]
493    fn test_cross_encoder_rerank_returns_results() {
494        let config = RerankerConfig {
495            method: RerankMethod::CrossEncoder,
496            top_k: 3,
497            ..Default::default()
498        };
499        let reranker = Reranker::new(config);
500        let results = reranker.rerank("Rust systems language", &sample_docs());
501        assert!(!results.is_empty());
502        assert!(results.len() <= 3);
503    }
504
505    #[test]
506    fn test_cross_encoder_ranks_rust_docs_higher() {
507        let config = RerankerConfig {
508            method: RerankMethod::CrossEncoder,
509            top_k: 5,
510            ..Default::default()
511        };
512        let reranker = Reranker::new(config);
513        let results = reranker.rerank("Rust programming", &sample_docs());
514        // First result should be a Rust-related document
515        assert!(results[0].id == "d1" || results[0].id == "d3" || results[0].id == "d5");
516    }
517
518    #[test]
519    fn test_cross_encoder_rank_order() {
520        let config = RerankerConfig {
521            method: RerankMethod::CrossEncoder,
522            top_k: 5,
523            ..Default::default()
524        };
525        let reranker = Reranker::new(config);
526        let results = reranker.rerank("Rust", &sample_docs());
527        for (i, result) in results.iter().enumerate() {
528            assert_eq!(result.rank, i + 1);
529        }
530    }
531
532    #[test]
533    fn test_cross_encoder_scores_non_negative() {
534        let reranker = Reranker::with_defaults();
535        let results = reranker.rerank("cargo build", &sample_docs());
536        for r in &results {
537            assert!(r.score >= 0.0);
538        }
539    }
540
541    // ── BM25 ─────────────────────────────────────────────────────────────────
542
543    #[test]
544    fn test_bm25_rerank_basic() {
545        let config = RerankerConfig {
546            method: RerankMethod::Bm25,
547            top_k: 5,
548            ..Default::default()
549        };
550        let reranker = Reranker::new(config);
551        let results = reranker.rerank("Rust ownership memory", &sample_docs());
552        assert!(!results.is_empty());
553    }
554
555    #[test]
556    fn test_bm25_top_k_respected() {
557        let config = RerankerConfig {
558            method: RerankMethod::Bm25,
559            top_k: 2,
560            ..Default::default()
561        };
562        let reranker = Reranker::new(config);
563        let results = reranker.rerank("language programming", &sample_docs());
564        assert!(results.len() <= 2);
565    }
566
567    #[test]
568    fn test_bm25_term_frequency_effect() {
569        // A document containing the query term more often should score higher
570        let docs = vec![
571            Document {
572                id: "rare".into(),
573                text: "Rust is a language".into(),
574                initial_score: 0.5,
575            },
576            Document {
577                id: "frequent".into(),
578                text: "Rust Rust Rust Rust Rust performance systems Rust".into(),
579                initial_score: 0.5,
580            },
581        ];
582        let config = RerankerConfig {
583            method: RerankMethod::Bm25,
584            top_k: 2,
585            ..Default::default()
586        };
587        let reranker = Reranker::new(config);
588        let results = reranker.rerank("Rust", &docs);
589        // The "frequent" document should rank first
590        assert_eq!(results[0].id, "frequent");
591    }
592
593    #[test]
594    fn test_bm25_scores_are_non_negative() {
595        let config = RerankerConfig {
596            method: RerankMethod::Bm25,
597            top_k: 5,
598            ..Default::default()
599        };
600        let reranker = Reranker::new(config);
601        let results = reranker.rerank("systems", &sample_docs());
602        for r in &results {
603            assert!(r.score >= 0.0);
604        }
605    }
606
607    // ── Reciprocal Rank Fusion ───────────────────────────────────────────────
608
609    #[test]
610    fn test_rrf_rerank_returns_results() {
611        let config = RerankerConfig {
612            method: RerankMethod::ReciprocalRankFusion,
613            top_k: 5,
614            ..Default::default()
615        };
616        let reranker = Reranker::new(config);
617        let results = reranker.rerank("Rust cargo build", &sample_docs());
618        assert!(!results.is_empty());
619    }
620
621    #[test]
622    fn test_rrf_scores_positive() {
623        let config = RerankerConfig {
624            method: RerankMethod::ReciprocalRankFusion,
625            top_k: 5,
626            ..Default::default()
627        };
628        let reranker = Reranker::new(config);
629        let results = reranker.rerank("language", &sample_docs());
630        for r in &results {
631            assert!(r.score > 0.0);
632        }
633    }
634
635    #[test]
636    fn test_rrf_top_k_applied() {
637        let config = RerankerConfig {
638            method: RerankMethod::ReciprocalRankFusion,
639            top_k: 2,
640            ..Default::default()
641        };
642        let reranker = Reranker::new(config);
643        let results = reranker.rerank("language", &sample_docs());
644        assert!(results.len() <= 2);
645    }
646
647    // ── Score normalisation ──────────────────────────────────────────────────
648
649    #[test]
650    fn test_min_max_normalize_range() {
651        let config = RerankerConfig {
652            method: RerankMethod::Bm25,
653            top_k: 5,
654            normalize_scores: true,
655            ..Default::default()
656        };
657        let reranker = Reranker::new(config);
658        let results = reranker.rerank("Rust", &sample_docs());
659        for r in &results {
660            assert!(r.score >= 0.0 && r.score <= 1.0 + 1e-10);
661        }
662    }
663
664    #[test]
665    fn test_z_score_normalize_identity_for_equal_values() {
666        let scores = vec![5.0, 5.0, 5.0];
667        let normalized = z_score_normalize(&scores);
668        for v in normalized {
669            assert!((v - 0.0).abs() < 1e-10);
670        }
671    }
672
673    #[test]
674    fn test_z_score_normalize_basic() {
675        let scores = vec![1.0, 2.0, 3.0];
676        let normalized = z_score_normalize(&scores);
677        assert_eq!(normalized.len(), 3);
678        // Mean of normalised should be ~0
679        let mean: f64 = normalized.iter().sum::<f64>() / 3.0;
680        assert!(mean.abs() < 1e-10);
681    }
682
683    #[test]
684    fn test_min_max_normalize_empty() {
685        let scores: Vec<f64> = vec![];
686        let normalized = min_max_normalize(&scores);
687        assert!(normalized.is_empty());
688    }
689
690    #[test]
691    fn test_min_max_normalize_single_value() {
692        let scores = vec![3.7];
693        let normalized = min_max_normalize(&scores);
694        assert_eq!(normalized.len(), 1);
695        assert!((normalized[0] - 1.0).abs() < 1e-10);
696    }
697
698    // ── Score threshold ──────────────────────────────────────────────────────
699
700    #[test]
701    fn test_score_threshold_filters_low_scores() {
702        let config = RerankerConfig {
703            method: RerankMethod::Bm25,
704            top_k: 10,
705            normalize_scores: true,
706            score_threshold: Some(0.5),
707        };
708        let reranker = Reranker::new(config);
709        let results = reranker.rerank("Rust", &sample_docs());
710        for r in &results {
711            assert!(r.score >= 0.5);
712        }
713    }
714
715    #[test]
716    fn test_score_threshold_zero_keeps_all_non_negative() {
717        let config = RerankerConfig {
718            method: RerankMethod::Bm25,
719            top_k: 10,
720            score_threshold: Some(0.0),
721            ..Default::default()
722        };
723        let reranker = Reranker::new(config);
724        let results = reranker.rerank("Rust", &sample_docs());
725        for r in &results {
726            assert!(r.score >= 0.0);
727        }
728    }
729
730    // ── Top-k ────────────────────────────────────────────────────────────────
731
732    #[test]
733    fn test_top_k_one() {
734        let config = RerankerConfig {
735            method: RerankMethod::CrossEncoder,
736            top_k: 1,
737            ..Default::default()
738        };
739        let reranker = Reranker::new(config);
740        let results = reranker.rerank("Rust", &sample_docs());
741        assert_eq!(results.len(), 1);
742        assert_eq!(results[0].rank, 1);
743    }
744
745    #[test]
746    fn test_top_k_larger_than_docs() {
747        let config = RerankerConfig {
748            method: RerankMethod::Bm25,
749            top_k: 100,
750            ..Default::default()
751        };
752        let reranker = Reranker::new(config);
753        let results = reranker.rerank("Rust", &sample_docs());
754        assert!(results.len() <= sample_docs().len());
755    }
756
757    // ── Rank shift ───────────────────────────────────────────────────────────
758
759    #[test]
760    fn test_rank_shift_computed() {
761        let config = RerankerConfig {
762            method: RerankMethod::Bm25,
763            top_k: 5,
764            ..Default::default()
765        };
766        let reranker = Reranker::new(config);
767        let results = reranker.rerank("ownership memory", &sample_docs());
768        // rank_shift values must be finite
769        for r in &results {
770            assert!(r.rank_shift.is_finite());
771        }
772    }
773
774    // ── Empty input ──────────────────────────────────────────────────────────
775
776    #[test]
777    fn test_empty_docs_returns_empty() {
778        let reranker = Reranker::with_defaults();
779        let results = reranker.rerank("Rust", &[]);
780        assert!(results.is_empty());
781    }
782
783    #[test]
784    fn test_empty_query_bm25() {
785        let config = RerankerConfig {
786            method: RerankMethod::Bm25,
787            top_k: 5,
788            ..Default::default()
789        };
790        let reranker = Reranker::new(config);
791        let results = reranker.rerank("", &sample_docs());
792        // All BM25 scores will be 0 — we still get (up to) 5 results
793        assert!(results.len() <= 5);
794    }
795
796    // ── Batch reranking ──────────────────────────────────────────────────────
797
798    #[test]
799    fn test_batch_rerank_multiple_queries() {
800        let reranker = Reranker::with_defaults();
801        let inputs = vec![
802            BatchRerankInput {
803                query: "Rust".into(),
804                documents: sample_docs(),
805            },
806            BatchRerankInput {
807                query: "Python".into(),
808                documents: sample_docs(),
809            },
810        ];
811        let outputs = reranker.rerank_batch(&inputs);
812        assert_eq!(outputs.len(), 2);
813        assert_eq!(outputs[0].query, "Rust");
814        assert_eq!(outputs[1].query, "Python");
815    }
816
817    #[test]
818    fn test_batch_rerank_stats_populated() {
819        let reranker = Reranker::with_defaults();
820        let inputs = vec![BatchRerankInput {
821            query: "Rust".into(),
822            documents: sample_docs(),
823        }];
824        let outputs = reranker.rerank_batch(&inputs);
825        let stats = &outputs[0].stats;
826        assert!(stats.count > 0);
827        assert!(stats.max_score >= stats.min_score);
828    }
829
830    #[test]
831    fn test_batch_rerank_empty_inputs() {
832        let reranker = Reranker::with_defaults();
833        let outputs = reranker.rerank_batch(&[]);
834        assert!(outputs.is_empty());
835    }
836
837    // ── Statistics ───────────────────────────────────────────────────────────
838
839    #[test]
840    fn test_compute_stats_empty_results() {
841        let reranker = Reranker::with_defaults();
842        let stats = reranker.compute_stats(&[]);
843        assert_eq!(stats.count, 0);
844        assert_eq!(stats.mean_score, 0.0);
845    }
846
847    #[test]
848    fn test_compute_stats_single_result() {
849        let reranker = Reranker::with_defaults();
850        let results = vec![RankedResult {
851            id: "d1".into(),
852            score: 0.75,
853            rank: 1,
854            rank_shift: 0.0,
855        }];
856        let stats = reranker.compute_stats(&results);
857        assert_eq!(stats.count, 1);
858        assert!((stats.min_score - 0.75).abs() < 1e-10);
859        assert!((stats.max_score - 0.75).abs() < 1e-10);
860        assert!((stats.mean_score - 0.75).abs() < 1e-10);
861    }
862
863    #[test]
864    fn test_compute_stats_std_dev() {
865        let reranker = Reranker::with_defaults();
866        let results = vec![
867            RankedResult {
868                id: "a".into(),
869                score: 1.0,
870                rank: 1,
871                rank_shift: 0.0,
872            },
873            RankedResult {
874                id: "b".into(),
875                score: 3.0,
876                rank: 2,
877                rank_shift: 0.0,
878            },
879        ];
880        let stats = reranker.compute_stats(&results);
881        assert!((stats.std_dev - 1.0).abs() < 1e-10);
882    }
883
884    // ── Config defaults ──────────────────────────────────────────────────────
885
886    #[test]
887    fn test_reranker_config_default() {
888        let cfg = RerankerConfig::default();
889        assert_eq!(cfg.method, RerankMethod::Bm25);
890        assert_eq!(cfg.top_k, 10);
891        assert!(cfg.score_threshold.is_none());
892        assert!(!cfg.normalize_scores);
893    }
894
895    #[test]
896    fn test_tokenise_lowercases_and_strips_punct() {
897        let tokens = tokenise("Hello, World! Rust.");
898        assert!(tokens.contains(&"hello".to_string()));
899        assert!(tokens.contains(&"world".to_string()));
900        assert!(tokens.contains(&"rust".to_string()));
901    }
902
903    #[test]
904    fn test_z_score_normalize_empty() {
905        let result = z_score_normalize(&[]);
906        assert!(result.is_empty());
907    }
908
909    #[test]
910    fn test_idf_formula() {
911        // When df == 1 and num_docs == 10: idf = ln((10-1+0.5)/(1+0.5)+1) = ln(7.33) ≈ 1.99
912        let v = idf(1, 10);
913        assert!(v > 0.0);
914    }
915
916    #[test]
917    fn test_rerank_rank_contiguous() {
918        let config = RerankerConfig {
919            method: RerankMethod::Bm25,
920            top_k: 5,
921            ..Default::default()
922        };
923        let reranker = Reranker::new(config);
924        let results = reranker.rerank("Rust", &sample_docs());
925        for (i, r) in results.iter().enumerate() {
926            assert_eq!(r.rank, i + 1);
927        }
928    }
929
930    #[test]
931    fn test_cross_encoder_single_doc() {
932        let docs = vec![Document {
933            id: "only".into(),
934            text: "Rust is great".into(),
935            initial_score: 0.8,
936        }];
937        let config = RerankerConfig {
938            method: RerankMethod::CrossEncoder,
939            top_k: 1,
940            ..Default::default()
941        };
942        let reranker = Reranker::new(config);
943        let results = reranker.rerank("Rust", &docs);
944        assert_eq!(results.len(), 1);
945        assert_eq!(results[0].id, "only");
946    }
947
948    #[test]
949    fn test_rrf_single_doc() {
950        let docs = vec![Document {
951            id: "only".into(),
952            text: "unique content here".into(),
953            initial_score: 1.0,
954        }];
955        let config = RerankerConfig {
956            method: RerankMethod::ReciprocalRankFusion,
957            top_k: 1,
958            ..Default::default()
959        };
960        let reranker = Reranker::new(config);
961        let results = reranker.rerank("content", &docs);
962        assert_eq!(results.len(), 1);
963    }
964
965    #[test]
966    fn test_term_freq_counts_correctly() {
967        let tokens = tokenise("rust rust cargo");
968        let tf = term_freq(&tokens);
969        assert_eq!(*tf.get("rust").unwrap_or(&0), 2);
970        assert_eq!(*tf.get("cargo").unwrap_or(&0), 1);
971    }
972
973    #[test]
974    fn test_document_clone() {
975        let doc = Document {
976            id: "d1".into(),
977            text: "Rust language".into(),
978            initial_score: 0.9,
979        };
980        let cloned = doc.clone();
981        assert_eq!(cloned.id, "d1");
982        assert!((cloned.initial_score - 0.9).abs() < 1e-10);
983    }
984
985    #[test]
986    fn test_ranked_result_fields() {
987        let r = RankedResult {
988            id: "x".into(),
989            score: 0.5,
990            rank: 2,
991            rank_shift: -1.0,
992        };
993        assert_eq!(r.id, "x");
994        assert_eq!(r.rank, 2);
995        assert!((r.rank_shift + 1.0).abs() < 1e-10);
996    }
997
998    #[test]
999    fn test_batch_rerank_stats_count_matches_results() {
1000        let reranker = Reranker::with_defaults();
1001        let inputs = vec![BatchRerankInput {
1002            query: "language".into(),
1003            documents: sample_docs(),
1004        }];
1005        let outputs = reranker.rerank_batch(&inputs);
1006        assert_eq!(outputs[0].stats.count, outputs[0].results.len());
1007    }
1008
1009    #[test]
1010    fn test_rerank_descending_score_order() {
1011        let config = RerankerConfig {
1012            method: RerankMethod::Bm25,
1013            top_k: 5,
1014            ..Default::default()
1015        };
1016        let reranker = Reranker::new(config);
1017        let results = reranker.rerank("Rust", &sample_docs());
1018        for w in results.windows(2) {
1019            assert!(w[0].score >= w[1].score - 1e-10);
1020        }
1021    }
1022}