1use std::collections::HashMap;
30
31#[derive(Debug, Clone, PartialEq)]
37pub struct Document {
38 pub id: String,
40 pub text: String,
42 pub initial_score: f64,
44}
45
46#[derive(Debug, Clone, PartialEq)]
48pub struct RankedResult {
49 pub id: String,
51 pub score: f64,
53 pub rank: usize,
55 pub rank_shift: f64,
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum RerankMethod {
62 CrossEncoder,
65 Bm25,
67 ReciprocalRankFusion,
70}
71
72#[derive(Debug, Clone)]
74pub struct RerankerConfig {
75 pub method: RerankMethod,
77 pub top_k: usize,
79 pub score_threshold: Option<f64>,
82 pub normalize_scores: bool,
84}
85
86impl Default for RerankerConfig {
87 fn default() -> Self {
88 Self {
89 method: RerankMethod::Bm25,
90 top_k: 10,
91 score_threshold: None,
92 normalize_scores: false,
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct RerankStats {
100 pub count: usize,
102 pub min_score: f64,
104 pub max_score: f64,
106 pub mean_score: f64,
108 pub std_dev: f64,
110 pub mean_rank_shift: f64,
112}
113
114#[derive(Debug, Clone)]
116pub struct BatchRerankInput {
117 pub query: String,
119 pub documents: Vec<Document>,
121}
122
123#[derive(Debug, Clone)]
125pub struct BatchRerankOutput {
126 pub query: String,
128 pub results: Vec<RankedResult>,
130 pub stats: RerankStats,
132}
133
134const BM25_K1: f64 = 1.5;
140const BM25_B: f64 = 0.75;
141
142fn tokenise(text: &str) -> Vec<String> {
144 text.split_whitespace()
145 .map(|w| {
146 w.chars()
147 .filter(|c| c.is_alphanumeric())
148 .collect::<String>()
149 .to_lowercase()
150 })
151 .filter(|w| !w.is_empty())
152 .collect()
153}
154
155fn term_freq(tokens: &[String]) -> HashMap<String, usize> {
157 let mut tf = HashMap::new();
158 for t in tokens {
159 *tf.entry(t.clone()).or_insert(0) += 1;
160 }
161 tf
162}
163
164fn idf(doc_freq: usize, num_docs: usize) -> f64 {
167 let n = num_docs as f64;
168 let df = doc_freq as f64;
169 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
170}
171
172fn bm25_score(
174 query_terms: &[String],
175 doc_tokens: &[String],
176 df_map: &HashMap<String, usize>,
177 num_docs: usize,
178 avg_dl: f64,
179) -> f64 {
180 let tf_map = term_freq(doc_tokens);
181 let dl = doc_tokens.len() as f64;
182 let mut score = 0.0_f64;
183 for term in query_terms {
184 let tf = *tf_map.get(term).unwrap_or(&0) as f64;
185 if tf == 0.0 {
186 continue;
187 }
188 let df = *df_map.get(term).unwrap_or(&0);
189 let idf_val = idf(df, num_docs);
190 let numerator = tf * (BM25_K1 + 1.0);
191 let denominator = tf + BM25_K1 * (1.0 - BM25_B + BM25_B * dl / avg_dl.max(1.0));
192 score += idf_val * numerator / denominator;
193 }
194 score
195}
196
197fn min_max_normalize(scores: &[f64]) -> Vec<f64> {
203 let min = scores.iter().cloned().fold(f64::INFINITY, f64::min);
204 let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
205 let range = max - min;
206 if range < f64::EPSILON {
207 return vec![1.0; scores.len()];
208 }
209 scores.iter().map(|s| (s - min) / range).collect()
210}
211
212pub fn z_score_normalize(scores: &[f64]) -> Vec<f64> {
214 if scores.is_empty() {
215 return Vec::new();
216 }
217 let n = scores.len() as f64;
218 let mean = scores.iter().sum::<f64>() / n;
219 let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f64>() / n;
220 let std = variance.sqrt();
221 if std < f64::EPSILON {
222 return vec![0.0; scores.len()];
223 }
224 scores.iter().map(|s| (s - mean) / std).collect()
225}
226
227pub struct Reranker {
234 config: RerankerConfig,
235}
236
237impl Reranker {
238 pub fn new(config: RerankerConfig) -> Self {
240 Self { config }
241 }
242
243 pub fn with_defaults() -> Self {
245 Self::new(RerankerConfig::default())
246 }
247
248 pub fn rerank(&self, query: &str, docs: &[Document]) -> Vec<RankedResult> {
250 if docs.is_empty() {
251 return Vec::new();
252 }
253 let scores = match self.config.method {
254 RerankMethod::CrossEncoder => self.cross_encoder_scores(query, docs),
255 RerankMethod::Bm25 => self.bm25_scores(query, docs),
256 RerankMethod::ReciprocalRankFusion => self.rrf_scores(query, docs),
257 };
258
259 self.finalize(docs, scores)
260 }
261
262 pub fn rerank_batch(&self, inputs: &[BatchRerankInput]) -> Vec<BatchRerankOutput> {
264 inputs
265 .iter()
266 .map(|input| {
267 let results = self.rerank(&input.query, &input.documents);
268 let stats = self.compute_stats(&results);
269 BatchRerankOutput {
270 query: input.query.clone(),
271 results,
272 stats,
273 }
274 })
275 .collect()
276 }
277
278 pub fn compute_stats(&self, results: &[RankedResult]) -> RerankStats {
280 if results.is_empty() {
281 return RerankStats {
282 count: 0,
283 min_score: 0.0,
284 max_score: 0.0,
285 mean_score: 0.0,
286 std_dev: 0.0,
287 mean_rank_shift: 0.0,
288 };
289 }
290 let n = results.len() as f64;
291 let scores: Vec<f64> = results.iter().map(|r| r.score).collect();
292 let min_score = scores.iter().cloned().fold(f64::INFINITY, f64::min);
293 let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
294 let mean_score = scores.iter().sum::<f64>() / n;
295 let variance = scores.iter().map(|s| (s - mean_score).powi(2)).sum::<f64>() / n;
296 let std_dev = variance.sqrt();
297 let mean_rank_shift = results.iter().map(|r| r.rank_shift.abs()).sum::<f64>() / n;
298 RerankStats {
299 count: results.len(),
300 min_score,
301 max_score,
302 mean_score,
303 std_dev,
304 mean_rank_shift,
305 }
306 }
307
308 fn cross_encoder_scores(&self, query: &str, docs: &[Document]) -> Vec<f64> {
312 let query_tokens: Vec<String> = tokenise(query);
313 let q_set: std::collections::HashSet<String> = query_tokens.iter().cloned().collect();
314 docs.iter()
315 .map(|doc| {
316 let doc_tokens = tokenise(&doc.text);
317 if q_set.is_empty() || doc_tokens.is_empty() {
318 return 0.0;
319 }
320 let matches = doc_tokens.iter().filter(|t| q_set.contains(*t)).count();
321 let tf_norm = matches as f64 / doc_tokens.len() as f64;
322 let idf_weight = (matches as f64 + 1.0).ln() / (q_set.len() as f64 + 1.0).ln();
323 0.6 * tf_norm + 0.2 * idf_weight + 0.2 * doc.initial_score
325 })
326 .collect()
327 }
328
329 fn bm25_scores(&self, query: &str, docs: &[Document]) -> Vec<f64> {
331 let query_terms = tokenise(query);
332 let tokenised: Vec<Vec<String>> = docs.iter().map(|d| tokenise(&d.text)).collect();
333 let num_docs = docs.len();
334 let total_len: usize = tokenised.iter().map(|t| t.len()).sum();
335 let avg_dl = total_len as f64 / num_docs as f64;
336
337 let mut df_map: HashMap<String, usize> = HashMap::new();
339 for toks in &tokenised {
340 let unique: std::collections::HashSet<&String> = toks.iter().collect();
341 for t in unique {
342 *df_map.entry(t.clone()).or_insert(0) += 1;
343 }
344 }
345
346 tokenised
347 .iter()
348 .map(|toks| bm25_score(&query_terms, toks, &df_map, num_docs, avg_dl))
349 .collect()
350 }
351
352 fn rrf_scores(&self, query: &str, docs: &[Document]) -> Vec<f64> {
356 const K: f64 = 60.0;
357
358 let n = docs.len();
360 let mut initial_order: Vec<usize> = (0..n).collect();
361 initial_order.sort_by(|&a, &b| {
362 docs[b]
363 .initial_score
364 .partial_cmp(&docs[a].initial_score)
365 .unwrap_or(std::cmp::Ordering::Equal)
366 });
367 let mut rank_initial = vec![0usize; n];
368 for (rank, &idx) in initial_order.iter().enumerate() {
369 rank_initial[idx] = rank + 1;
370 }
371
372 let bm25 = self.bm25_scores(query, docs);
374 let mut bm25_order: Vec<usize> = (0..n).collect();
375 bm25_order.sort_by(|&a, &b| {
376 bm25[b]
377 .partial_cmp(&bm25[a])
378 .unwrap_or(std::cmp::Ordering::Equal)
379 });
380 let mut rank_bm25 = vec![0usize; n];
381 for (rank, &idx) in bm25_order.iter().enumerate() {
382 rank_bm25[idx] = rank + 1;
383 }
384
385 (0..n)
386 .map(|i| 1.0 / (K + rank_initial[i] as f64) + 1.0 / (K + rank_bm25[i] as f64))
387 .collect()
388 }
389
390 fn finalize(&self, docs: &[Document], raw_scores: Vec<f64>) -> Vec<RankedResult> {
392 let n = docs.len();
394 let mut initial_order: Vec<usize> = (0..n).collect();
395 initial_order.sort_by(|&a, &b| {
396 docs[b]
397 .initial_score
398 .partial_cmp(&docs[a].initial_score)
399 .unwrap_or(std::cmp::Ordering::Equal)
400 });
401 let mut initial_rank = vec![0usize; n];
402 for (rank, &idx) in initial_order.iter().enumerate() {
403 initial_rank[idx] = rank + 1;
404 }
405
406 let final_scores = if self.config.normalize_scores {
408 min_max_normalize(&raw_scores)
409 } else {
410 raw_scores.clone()
411 };
412
413 let mut order: Vec<usize> = (0..n).collect();
415 order.sort_by(|&a, &b| {
416 final_scores[b]
417 .partial_cmp(&final_scores[a])
418 .unwrap_or(std::cmp::Ordering::Equal)
419 });
420
421 let mut results: Vec<RankedResult> = order
422 .iter()
423 .enumerate()
424 .map(|(new_rank, &idx)| {
425 let rank_shift = initial_rank[idx] as f64 - (new_rank + 1) as f64;
426 RankedResult {
427 id: docs[idx].id.clone(),
428 score: final_scores[idx],
429 rank: new_rank + 1,
430 rank_shift,
431 }
432 })
433 .collect();
434
435 if let Some(threshold) = self.config.score_threshold {
437 results.retain(|r| r.score >= threshold);
438 }
439
440 results.truncate(self.config.top_k);
442
443 for (i, r) in results.iter_mut().enumerate() {
445 r.rank = i + 1;
446 }
447
448 results
449 }
450}
451
452#[cfg(test)]
457mod tests {
458 use super::*;
459
460 fn sample_docs() -> Vec<Document> {
461 vec![
462 Document {
463 id: "d1".into(),
464 text: "Rust is a systems programming language".into(),
465 initial_score: 0.9,
466 },
467 Document {
468 id: "d2".into(),
469 text: "Python is a high-level scripting language".into(),
470 initial_score: 0.7,
471 },
472 Document {
473 id: "d3".into(),
474 text: "Cargo is the Rust package manager and build tool".into(),
475 initial_score: 0.5,
476 },
477 Document {
478 id: "d4".into(),
479 text: "JavaScript runs in the browser".into(),
480 initial_score: 0.3,
481 },
482 Document {
483 id: "d5".into(),
484 text: "Rust ownership model ensures memory safety".into(),
485 initial_score: 0.6,
486 },
487 ]
488 }
489
490 #[test]
493 fn test_cross_encoder_rerank_returns_results() {
494 let config = RerankerConfig {
495 method: RerankMethod::CrossEncoder,
496 top_k: 3,
497 ..Default::default()
498 };
499 let reranker = Reranker::new(config);
500 let results = reranker.rerank("Rust systems language", &sample_docs());
501 assert!(!results.is_empty());
502 assert!(results.len() <= 3);
503 }
504
505 #[test]
506 fn test_cross_encoder_ranks_rust_docs_higher() {
507 let config = RerankerConfig {
508 method: RerankMethod::CrossEncoder,
509 top_k: 5,
510 ..Default::default()
511 };
512 let reranker = Reranker::new(config);
513 let results = reranker.rerank("Rust programming", &sample_docs());
514 assert!(results[0].id == "d1" || results[0].id == "d3" || results[0].id == "d5");
516 }
517
518 #[test]
519 fn test_cross_encoder_rank_order() {
520 let config = RerankerConfig {
521 method: RerankMethod::CrossEncoder,
522 top_k: 5,
523 ..Default::default()
524 };
525 let reranker = Reranker::new(config);
526 let results = reranker.rerank("Rust", &sample_docs());
527 for (i, result) in results.iter().enumerate() {
528 assert_eq!(result.rank, i + 1);
529 }
530 }
531
532 #[test]
533 fn test_cross_encoder_scores_non_negative() {
534 let reranker = Reranker::with_defaults();
535 let results = reranker.rerank("cargo build", &sample_docs());
536 for r in &results {
537 assert!(r.score >= 0.0);
538 }
539 }
540
541 #[test]
544 fn test_bm25_rerank_basic() {
545 let config = RerankerConfig {
546 method: RerankMethod::Bm25,
547 top_k: 5,
548 ..Default::default()
549 };
550 let reranker = Reranker::new(config);
551 let results = reranker.rerank("Rust ownership memory", &sample_docs());
552 assert!(!results.is_empty());
553 }
554
555 #[test]
556 fn test_bm25_top_k_respected() {
557 let config = RerankerConfig {
558 method: RerankMethod::Bm25,
559 top_k: 2,
560 ..Default::default()
561 };
562 let reranker = Reranker::new(config);
563 let results = reranker.rerank("language programming", &sample_docs());
564 assert!(results.len() <= 2);
565 }
566
567 #[test]
568 fn test_bm25_term_frequency_effect() {
569 let docs = vec![
571 Document {
572 id: "rare".into(),
573 text: "Rust is a language".into(),
574 initial_score: 0.5,
575 },
576 Document {
577 id: "frequent".into(),
578 text: "Rust Rust Rust Rust Rust performance systems Rust".into(),
579 initial_score: 0.5,
580 },
581 ];
582 let config = RerankerConfig {
583 method: RerankMethod::Bm25,
584 top_k: 2,
585 ..Default::default()
586 };
587 let reranker = Reranker::new(config);
588 let results = reranker.rerank("Rust", &docs);
589 assert_eq!(results[0].id, "frequent");
591 }
592
593 #[test]
594 fn test_bm25_scores_are_non_negative() {
595 let config = RerankerConfig {
596 method: RerankMethod::Bm25,
597 top_k: 5,
598 ..Default::default()
599 };
600 let reranker = Reranker::new(config);
601 let results = reranker.rerank("systems", &sample_docs());
602 for r in &results {
603 assert!(r.score >= 0.0);
604 }
605 }
606
607 #[test]
610 fn test_rrf_rerank_returns_results() {
611 let config = RerankerConfig {
612 method: RerankMethod::ReciprocalRankFusion,
613 top_k: 5,
614 ..Default::default()
615 };
616 let reranker = Reranker::new(config);
617 let results = reranker.rerank("Rust cargo build", &sample_docs());
618 assert!(!results.is_empty());
619 }
620
621 #[test]
622 fn test_rrf_scores_positive() {
623 let config = RerankerConfig {
624 method: RerankMethod::ReciprocalRankFusion,
625 top_k: 5,
626 ..Default::default()
627 };
628 let reranker = Reranker::new(config);
629 let results = reranker.rerank("language", &sample_docs());
630 for r in &results {
631 assert!(r.score > 0.0);
632 }
633 }
634
635 #[test]
636 fn test_rrf_top_k_applied() {
637 let config = RerankerConfig {
638 method: RerankMethod::ReciprocalRankFusion,
639 top_k: 2,
640 ..Default::default()
641 };
642 let reranker = Reranker::new(config);
643 let results = reranker.rerank("language", &sample_docs());
644 assert!(results.len() <= 2);
645 }
646
647 #[test]
650 fn test_min_max_normalize_range() {
651 let config = RerankerConfig {
652 method: RerankMethod::Bm25,
653 top_k: 5,
654 normalize_scores: true,
655 ..Default::default()
656 };
657 let reranker = Reranker::new(config);
658 let results = reranker.rerank("Rust", &sample_docs());
659 for r in &results {
660 assert!(r.score >= 0.0 && r.score <= 1.0 + 1e-10);
661 }
662 }
663
664 #[test]
665 fn test_z_score_normalize_identity_for_equal_values() {
666 let scores = vec![5.0, 5.0, 5.0];
667 let normalized = z_score_normalize(&scores);
668 for v in normalized {
669 assert!((v - 0.0).abs() < 1e-10);
670 }
671 }
672
673 #[test]
674 fn test_z_score_normalize_basic() {
675 let scores = vec![1.0, 2.0, 3.0];
676 let normalized = z_score_normalize(&scores);
677 assert_eq!(normalized.len(), 3);
678 let mean: f64 = normalized.iter().sum::<f64>() / 3.0;
680 assert!(mean.abs() < 1e-10);
681 }
682
683 #[test]
684 fn test_min_max_normalize_empty() {
685 let scores: Vec<f64> = vec![];
686 let normalized = min_max_normalize(&scores);
687 assert!(normalized.is_empty());
688 }
689
690 #[test]
691 fn test_min_max_normalize_single_value() {
692 let scores = vec![3.7];
693 let normalized = min_max_normalize(&scores);
694 assert_eq!(normalized.len(), 1);
695 assert!((normalized[0] - 1.0).abs() < 1e-10);
696 }
697
698 #[test]
701 fn test_score_threshold_filters_low_scores() {
702 let config = RerankerConfig {
703 method: RerankMethod::Bm25,
704 top_k: 10,
705 normalize_scores: true,
706 score_threshold: Some(0.5),
707 };
708 let reranker = Reranker::new(config);
709 let results = reranker.rerank("Rust", &sample_docs());
710 for r in &results {
711 assert!(r.score >= 0.5);
712 }
713 }
714
715 #[test]
716 fn test_score_threshold_zero_keeps_all_non_negative() {
717 let config = RerankerConfig {
718 method: RerankMethod::Bm25,
719 top_k: 10,
720 score_threshold: Some(0.0),
721 ..Default::default()
722 };
723 let reranker = Reranker::new(config);
724 let results = reranker.rerank("Rust", &sample_docs());
725 for r in &results {
726 assert!(r.score >= 0.0);
727 }
728 }
729
730 #[test]
733 fn test_top_k_one() {
734 let config = RerankerConfig {
735 method: RerankMethod::CrossEncoder,
736 top_k: 1,
737 ..Default::default()
738 };
739 let reranker = Reranker::new(config);
740 let results = reranker.rerank("Rust", &sample_docs());
741 assert_eq!(results.len(), 1);
742 assert_eq!(results[0].rank, 1);
743 }
744
745 #[test]
746 fn test_top_k_larger_than_docs() {
747 let config = RerankerConfig {
748 method: RerankMethod::Bm25,
749 top_k: 100,
750 ..Default::default()
751 };
752 let reranker = Reranker::new(config);
753 let results = reranker.rerank("Rust", &sample_docs());
754 assert!(results.len() <= sample_docs().len());
755 }
756
757 #[test]
760 fn test_rank_shift_computed() {
761 let config = RerankerConfig {
762 method: RerankMethod::Bm25,
763 top_k: 5,
764 ..Default::default()
765 };
766 let reranker = Reranker::new(config);
767 let results = reranker.rerank("ownership memory", &sample_docs());
768 for r in &results {
770 assert!(r.rank_shift.is_finite());
771 }
772 }
773
774 #[test]
777 fn test_empty_docs_returns_empty() {
778 let reranker = Reranker::with_defaults();
779 let results = reranker.rerank("Rust", &[]);
780 assert!(results.is_empty());
781 }
782
783 #[test]
784 fn test_empty_query_bm25() {
785 let config = RerankerConfig {
786 method: RerankMethod::Bm25,
787 top_k: 5,
788 ..Default::default()
789 };
790 let reranker = Reranker::new(config);
791 let results = reranker.rerank("", &sample_docs());
792 assert!(results.len() <= 5);
794 }
795
796 #[test]
799 fn test_batch_rerank_multiple_queries() {
800 let reranker = Reranker::with_defaults();
801 let inputs = vec![
802 BatchRerankInput {
803 query: "Rust".into(),
804 documents: sample_docs(),
805 },
806 BatchRerankInput {
807 query: "Python".into(),
808 documents: sample_docs(),
809 },
810 ];
811 let outputs = reranker.rerank_batch(&inputs);
812 assert_eq!(outputs.len(), 2);
813 assert_eq!(outputs[0].query, "Rust");
814 assert_eq!(outputs[1].query, "Python");
815 }
816
817 #[test]
818 fn test_batch_rerank_stats_populated() {
819 let reranker = Reranker::with_defaults();
820 let inputs = vec![BatchRerankInput {
821 query: "Rust".into(),
822 documents: sample_docs(),
823 }];
824 let outputs = reranker.rerank_batch(&inputs);
825 let stats = &outputs[0].stats;
826 assert!(stats.count > 0);
827 assert!(stats.max_score >= stats.min_score);
828 }
829
830 #[test]
831 fn test_batch_rerank_empty_inputs() {
832 let reranker = Reranker::with_defaults();
833 let outputs = reranker.rerank_batch(&[]);
834 assert!(outputs.is_empty());
835 }
836
837 #[test]
840 fn test_compute_stats_empty_results() {
841 let reranker = Reranker::with_defaults();
842 let stats = reranker.compute_stats(&[]);
843 assert_eq!(stats.count, 0);
844 assert_eq!(stats.mean_score, 0.0);
845 }
846
847 #[test]
848 fn test_compute_stats_single_result() {
849 let reranker = Reranker::with_defaults();
850 let results = vec![RankedResult {
851 id: "d1".into(),
852 score: 0.75,
853 rank: 1,
854 rank_shift: 0.0,
855 }];
856 let stats = reranker.compute_stats(&results);
857 assert_eq!(stats.count, 1);
858 assert!((stats.min_score - 0.75).abs() < 1e-10);
859 assert!((stats.max_score - 0.75).abs() < 1e-10);
860 assert!((stats.mean_score - 0.75).abs() < 1e-10);
861 }
862
863 #[test]
864 fn test_compute_stats_std_dev() {
865 let reranker = Reranker::with_defaults();
866 let results = vec![
867 RankedResult {
868 id: "a".into(),
869 score: 1.0,
870 rank: 1,
871 rank_shift: 0.0,
872 },
873 RankedResult {
874 id: "b".into(),
875 score: 3.0,
876 rank: 2,
877 rank_shift: 0.0,
878 },
879 ];
880 let stats = reranker.compute_stats(&results);
881 assert!((stats.std_dev - 1.0).abs() < 1e-10);
882 }
883
884 #[test]
887 fn test_reranker_config_default() {
888 let cfg = RerankerConfig::default();
889 assert_eq!(cfg.method, RerankMethod::Bm25);
890 assert_eq!(cfg.top_k, 10);
891 assert!(cfg.score_threshold.is_none());
892 assert!(!cfg.normalize_scores);
893 }
894
895 #[test]
896 fn test_tokenise_lowercases_and_strips_punct() {
897 let tokens = tokenise("Hello, World! Rust.");
898 assert!(tokens.contains(&"hello".to_string()));
899 assert!(tokens.contains(&"world".to_string()));
900 assert!(tokens.contains(&"rust".to_string()));
901 }
902
903 #[test]
904 fn test_z_score_normalize_empty() {
905 let result = z_score_normalize(&[]);
906 assert!(result.is_empty());
907 }
908
909 #[test]
910 fn test_idf_formula() {
911 let v = idf(1, 10);
913 assert!(v > 0.0);
914 }
915
916 #[test]
917 fn test_rerank_rank_contiguous() {
918 let config = RerankerConfig {
919 method: RerankMethod::Bm25,
920 top_k: 5,
921 ..Default::default()
922 };
923 let reranker = Reranker::new(config);
924 let results = reranker.rerank("Rust", &sample_docs());
925 for (i, r) in results.iter().enumerate() {
926 assert_eq!(r.rank, i + 1);
927 }
928 }
929
930 #[test]
931 fn test_cross_encoder_single_doc() {
932 let docs = vec![Document {
933 id: "only".into(),
934 text: "Rust is great".into(),
935 initial_score: 0.8,
936 }];
937 let config = RerankerConfig {
938 method: RerankMethod::CrossEncoder,
939 top_k: 1,
940 ..Default::default()
941 };
942 let reranker = Reranker::new(config);
943 let results = reranker.rerank("Rust", &docs);
944 assert_eq!(results.len(), 1);
945 assert_eq!(results[0].id, "only");
946 }
947
948 #[test]
949 fn test_rrf_single_doc() {
950 let docs = vec![Document {
951 id: "only".into(),
952 text: "unique content here".into(),
953 initial_score: 1.0,
954 }];
955 let config = RerankerConfig {
956 method: RerankMethod::ReciprocalRankFusion,
957 top_k: 1,
958 ..Default::default()
959 };
960 let reranker = Reranker::new(config);
961 let results = reranker.rerank("content", &docs);
962 assert_eq!(results.len(), 1);
963 }
964
965 #[test]
966 fn test_term_freq_counts_correctly() {
967 let tokens = tokenise("rust rust cargo");
968 let tf = term_freq(&tokens);
969 assert_eq!(*tf.get("rust").unwrap_or(&0), 2);
970 assert_eq!(*tf.get("cargo").unwrap_or(&0), 1);
971 }
972
973 #[test]
974 fn test_document_clone() {
975 let doc = Document {
976 id: "d1".into(),
977 text: "Rust language".into(),
978 initial_score: 0.9,
979 };
980 let cloned = doc.clone();
981 assert_eq!(cloned.id, "d1");
982 assert!((cloned.initial_score - 0.9).abs() < 1e-10);
983 }
984
985 #[test]
986 fn test_ranked_result_fields() {
987 let r = RankedResult {
988 id: "x".into(),
989 score: 0.5,
990 rank: 2,
991 rank_shift: -1.0,
992 };
993 assert_eq!(r.id, "x");
994 assert_eq!(r.rank, 2);
995 assert!((r.rank_shift + 1.0).abs() < 1e-10);
996 }
997
998 #[test]
999 fn test_batch_rerank_stats_count_matches_results() {
1000 let reranker = Reranker::with_defaults();
1001 let inputs = vec![BatchRerankInput {
1002 query: "language".into(),
1003 documents: sample_docs(),
1004 }];
1005 let outputs = reranker.rerank_batch(&inputs);
1006 assert_eq!(outputs[0].stats.count, outputs[0].results.len());
1007 }
1008
1009 #[test]
1010 fn test_rerank_descending_score_order() {
1011 let config = RerankerConfig {
1012 method: RerankMethod::Bm25,
1013 top_k: 5,
1014 ..Default::default()
1015 };
1016 let reranker = Reranker::new(config);
1017 let results = reranker.rerank("Rust", &sample_docs());
1018 for w in results.windows(2) {
1019 assert!(w[0].score >= w[1].score - 1e-10);
1020 }
1021 }
1022}