1use std::collections::HashMap;
35use std::hash::Hash;
36
37#[derive(Debug, Clone, PartialEq)]
43pub enum FusionError {
44 ZeroWeights,
46 InvalidConfig(String),
48}
49
50impl std::fmt::Display for FusionError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 Self::ZeroWeights => write!(f, "weights sum to zero"),
54 Self::InvalidConfig(msg) => write!(f, "invalid config: {msg}"),
55 }
56 }
57}
58
59impl std::error::Error for FusionError {}
60
61pub type Result<T> = std::result::Result<T, FusionError>;
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub struct RrfConfig {
81 pub k: u32,
83 pub top_k: Option<usize>,
85}
86
87impl Default for RrfConfig {
88 fn default() -> Self {
89 Self { k: 60, top_k: None }
90 }
91}
92
93impl RrfConfig {
94 #[must_use]
96 pub const fn new(k: u32) -> Self {
97 Self { k, top_k: None }
98 }
99
100 #[must_use]
106 pub const fn with_k(mut self, k: u32) -> Self {
107 self.k = k;
108 self
109 }
110
111 #[must_use]
113 pub const fn with_top_k(mut self, top_k: usize) -> Self {
114 self.top_k = Some(top_k);
115 self
116 }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq)]
132pub struct WeightedConfig {
133 pub weight_a: f32,
135 pub weight_b: f32,
137 pub normalize: bool,
139 pub top_k: Option<usize>,
141}
142
143impl Default for WeightedConfig {
144 fn default() -> Self {
145 Self {
146 weight_a: 0.5,
147 weight_b: 0.5,
148 normalize: true,
149 top_k: None,
150 }
151 }
152}
153
154impl WeightedConfig {
155 #[must_use]
157 pub const fn new(weight_a: f32, weight_b: f32) -> Self {
158 Self {
159 weight_a,
160 weight_b,
161 normalize: true,
162 top_k: None,
163 }
164 }
165
166 #[must_use]
168 pub const fn with_weights(mut self, weight_a: f32, weight_b: f32) -> Self {
169 self.weight_a = weight_a;
170 self.weight_b = weight_b;
171 self
172 }
173
174 #[must_use]
176 pub const fn with_normalize(mut self, normalize: bool) -> Self {
177 self.normalize = normalize;
178 self
179 }
180
181 #[must_use]
183 pub const fn with_top_k(mut self, top_k: usize) -> Self {
184 self.top_k = Some(top_k);
185 self
186 }
187}
188
189#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
191pub struct FusionConfig {
192 pub top_k: Option<usize>,
194}
195
196impl FusionConfig {
197 #[must_use]
199 pub const fn with_top_k(mut self, top_k: usize) -> Self {
200 self.top_k = Some(top_k);
201 self
202 }
203}
204
205pub mod prelude {
215 pub use crate::{
216 borda, combmnz, combsum, dbsf, isr, isr_with_config, rrf, rrf_with_config, weighted,
217 };
218 pub use crate::{FusionConfig, FusionError, FusionMethod, Result, RrfConfig, WeightedConfig};
219}
220
221#[derive(Debug, Clone, Copy, PartialEq)]
244pub enum FusionMethod {
245 Rrf {
247 k: u32,
249 },
250 Isr {
252 k: u32,
254 },
255 CombSum,
257 CombMnz,
259 Borda,
261 Weighted {
263 weight_a: f32,
265 weight_b: f32,
267 normalize: bool,
269 },
270 Dbsf,
272}
273
274impl Default for FusionMethod {
275 fn default() -> Self {
276 Self::Rrf { k: 60 }
277 }
278}
279
280impl FusionMethod {
281 #[must_use]
283 pub const fn rrf() -> Self {
284 Self::Rrf { k: 60 }
285 }
286
287 #[must_use]
289 pub const fn rrf_with_k(k: u32) -> Self {
290 Self::Rrf { k }
291 }
292
293 #[must_use]
295 pub const fn isr() -> Self {
296 Self::Isr { k: 1 }
297 }
298
299 #[must_use]
301 pub const fn isr_with_k(k: u32) -> Self {
302 Self::Isr { k }
303 }
304
305 #[must_use]
307 pub const fn weighted(weight_a: f32, weight_b: f32) -> Self {
308 Self::Weighted {
309 weight_a,
310 weight_b,
311 normalize: true,
312 }
313 }
314
315 #[must_use]
324 pub fn fuse<I: Clone + Eq + Hash>(&self, a: &[(I, f32)], b: &[(I, f32)]) -> Vec<(I, f32)> {
325 match self {
326 Self::Rrf { k } => crate::rrf_multi(&[a, b], RrfConfig::new(*k)),
327 Self::Isr { k } => crate::isr_multi(&[a, b], RrfConfig::new(*k)),
328 Self::CombSum => crate::combsum(a, b),
329 Self::CombMnz => crate::combmnz(a, b),
330 Self::Borda => crate::borda(a, b),
331 Self::Weighted {
332 weight_a,
333 weight_b,
334 normalize,
335 } => crate::weighted(
336 a,
337 b,
338 WeightedConfig::new(*weight_a, *weight_b).with_normalize(*normalize),
339 ),
340 Self::Dbsf => crate::dbsf(a, b),
341 }
342 }
343
344 #[must_use]
352 pub fn fuse_multi<I, L>(&self, lists: &[L]) -> Vec<(I, f32)>
353 where
354 I: Clone + Eq + Hash,
355 L: AsRef<[(I, f32)]>,
356 {
357 match self {
358 Self::Rrf { k } => crate::rrf_multi(lists, RrfConfig::new(*k)),
359 Self::Isr { k } => crate::isr_multi(lists, RrfConfig::new(*k)),
360 Self::CombSum => crate::combsum_multi(lists, FusionConfig::default()),
361 Self::CombMnz => crate::combmnz_multi(lists, FusionConfig::default()),
362 Self::Borda => crate::borda_multi(lists, FusionConfig::default()),
363 Self::Weighted { .. } => {
364 if lists.len() == 2 {
367 self.fuse(lists[0].as_ref(), lists[1].as_ref())
368 } else {
369 crate::combsum_multi(lists, FusionConfig::default())
371 }
372 }
373 Self::Dbsf => crate::dbsf_multi(lists, FusionConfig::default()),
374 }
375 }
376}
377
378#[must_use]
414pub fn rrf<I: Clone + Eq + Hash>(results_a: &[(I, f32)], results_b: &[(I, f32)]) -> Vec<(I, f32)> {
415 rrf_with_config(results_a, results_b, RrfConfig::default())
416}
417
418#[must_use]
439#[allow(clippy::cast_precision_loss)]
440pub fn rrf_with_config<I: Clone + Eq + Hash>(
441 results_a: &[(I, f32)],
442 results_b: &[(I, f32)],
443 config: RrfConfig,
444) -> Vec<(I, f32)> {
445 let k = config.k as f32;
446 let mut scores: HashMap<I, f32> = HashMap::new();
447
448 for (rank, (id, _)) in results_a.iter().enumerate() {
449 *scores.entry(id.clone()).or_default() += 1.0 / (k + rank as f32);
450 }
451 for (rank, (id, _)) in results_b.iter().enumerate() {
452 *scores.entry(id.clone()).or_default() += 1.0 / (k + rank as f32);
453 }
454
455 finalize(scores, config.top_k)
456}
457
458#[allow(clippy::cast_precision_loss)]
460pub fn rrf_into<I: Clone + Eq + Hash>(
461 results_a: &[(I, f32)],
462 results_b: &[(I, f32)],
463 config: RrfConfig,
464 output: &mut Vec<(I, f32)>,
465) {
466 output.clear();
467 let k = config.k as f32;
468 let mut scores: HashMap<I, f32> = HashMap::with_capacity(results_a.len() + results_b.len());
469
470 for (rank, (id, _)) in results_a.iter().enumerate() {
471 *scores.entry(id.clone()).or_default() += 1.0 / (k + rank as f32);
472 }
473 for (rank, (id, _)) in results_b.iter().enumerate() {
474 *scores.entry(id.clone()).or_default() += 1.0 / (k + rank as f32);
475 }
476
477 output.extend(scores);
478 sort_scored_desc(output);
479 if let Some(top_k) = config.top_k {
480 output.truncate(top_k);
481 }
482}
483
484#[must_use]
486#[allow(clippy::cast_precision_loss)]
487pub fn rrf_multi<I, L>(lists: &[L], config: RrfConfig) -> Vec<(I, f32)>
488where
489 I: Clone + Eq + Hash,
490 L: AsRef<[(I, f32)]>,
491{
492 let k = config.k as f32;
493 let mut scores: HashMap<I, f32> = HashMap::new();
494
495 for list in lists {
496 for (rank, (id, _)) in list.as_ref().iter().enumerate() {
497 *scores.entry(id.clone()).or_default() += 1.0 / (k + rank as f32);
498 }
499 }
500
501 finalize(scores, config.top_k)
502}
503
504#[allow(clippy::cast_precision_loss)]
529pub fn rrf_weighted<I, L>(lists: &[L], weights: &[f32], config: RrfConfig) -> Result<Vec<(I, f32)>>
530where
531 I: Clone + Eq + Hash,
532 L: AsRef<[(I, f32)]>,
533{
534 let weight_sum: f32 = weights.iter().sum();
535 if weight_sum.abs() < 1e-9 {
536 return Err(FusionError::ZeroWeights);
537 }
538
539 let k = config.k as f32;
540 let mut scores: HashMap<I, f32> = HashMap::new();
541
542 for (list, &weight) in lists.iter().zip(weights.iter()) {
543 let normalized_weight = weight / weight_sum;
544 for (rank, (id, _)) in list.as_ref().iter().enumerate() {
545 *scores.entry(id.clone()).or_default() += normalized_weight / (k + rank as f32);
546 }
547 }
548
549 Ok(finalize(scores, config.top_k))
550}
551
552#[must_use]
582pub fn isr<I: Clone + Eq + Hash>(results_a: &[(I, f32)], results_b: &[(I, f32)]) -> Vec<(I, f32)> {
583 isr_with_config(results_a, results_b, RrfConfig::new(1))
584}
585
586#[must_use]
603#[allow(clippy::cast_precision_loss)]
604pub fn isr_with_config<I: Clone + Eq + Hash>(
605 results_a: &[(I, f32)],
606 results_b: &[(I, f32)],
607 config: RrfConfig,
608) -> Vec<(I, f32)> {
609 let k = config.k as f32;
610 let mut scores: HashMap<I, f32> = HashMap::new();
611
612 for (rank, (id, _)) in results_a.iter().enumerate() {
613 *scores.entry(id.clone()).or_default() += 1.0 / (k + rank as f32).sqrt();
614 }
615 for (rank, (id, _)) in results_b.iter().enumerate() {
616 *scores.entry(id.clone()).or_default() += 1.0 / (k + rank as f32).sqrt();
617 }
618
619 finalize(scores, config.top_k)
620}
621
622#[must_use]
624#[allow(clippy::cast_precision_loss)]
625pub fn isr_multi<I, L>(lists: &[L], config: RrfConfig) -> Vec<(I, f32)>
626where
627 I: Clone + Eq + Hash,
628 L: AsRef<[(I, f32)]>,
629{
630 let k = config.k as f32;
631 let mut scores: HashMap<I, f32> = HashMap::new();
632
633 for list in lists {
634 for (rank, (id, _)) in list.as_ref().iter().enumerate() {
635 *scores.entry(id.clone()).or_default() += 1.0 / (k + rank as f32).sqrt();
636 }
637 }
638
639 finalize(scores, config.top_k)
640}
641
642#[must_use]
657pub fn weighted<I: Clone + Eq + Hash>(
658 results_a: &[(I, f32)],
659 results_b: &[(I, f32)],
660 config: WeightedConfig,
661) -> Vec<(I, f32)> {
662 weighted_impl(
663 &[(results_a, config.weight_a), (results_b, config.weight_b)],
664 config.normalize,
665 config.top_k,
666 )
667}
668
669pub fn weighted_multi<I, L>(
677 lists: &[(L, f32)],
678 normalize: bool,
679 top_k: Option<usize>,
680) -> Result<Vec<(I, f32)>>
681where
682 I: Clone + Eq + Hash,
683 L: AsRef<[(I, f32)]>,
684{
685 let total_weight: f32 = lists.iter().map(|(_, w)| w).sum();
686 if total_weight.abs() < 1e-9 {
687 return Err(FusionError::ZeroWeights);
688 }
689
690 let mut scores: HashMap<I, f32> = HashMap::new();
691
692 for (list, weight) in lists {
693 let items = list.as_ref();
694 let w = weight / total_weight;
695 let (norm, off) = if normalize {
696 min_max_params(items)
697 } else {
698 (1.0, 0.0)
699 };
700 for (id, s) in items {
701 *scores.entry(id.clone()).or_default() += w * (s - off) * norm;
702 }
703 }
704
705 Ok(finalize(scores, top_k))
706}
707
708fn weighted_impl<I, L>(lists: &[(L, f32)], normalize: bool, top_k: Option<usize>) -> Vec<(I, f32)>
710where
711 I: Clone + Eq + Hash,
712 L: AsRef<[(I, f32)]>,
713{
714 let total_weight: f32 = lists.iter().map(|(_, w)| w).sum();
715 if total_weight.abs() < 1e-9 {
716 return Vec::new();
717 }
718
719 let mut scores: HashMap<I, f32> = HashMap::new();
720
721 for (list, weight) in lists {
722 let items = list.as_ref();
723 let w = weight / total_weight;
724 let (norm, off) = if normalize {
725 min_max_params(items)
726 } else {
727 (1.0, 0.0)
728 };
729 for (id, s) in items {
730 *scores.entry(id.clone()).or_default() += w * (s - off) * norm;
731 }
732 }
733
734 finalize(scores, top_k)
735}
736
737#[must_use]
747pub fn combsum<I: Clone + Eq + Hash>(
748 results_a: &[(I, f32)],
749 results_b: &[(I, f32)],
750) -> Vec<(I, f32)> {
751 combsum_with_config(results_a, results_b, FusionConfig::default())
752}
753
754#[must_use]
756pub fn combsum_with_config<I: Clone + Eq + Hash>(
757 results_a: &[(I, f32)],
758 results_b: &[(I, f32)],
759 config: FusionConfig,
760) -> Vec<(I, f32)> {
761 combsum_multi(&[results_a, results_b], config)
762}
763
764#[must_use]
766pub fn combsum_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
767where
768 I: Clone + Eq + Hash,
769 L: AsRef<[(I, f32)]>,
770{
771 let mut scores: HashMap<I, f32> = HashMap::new();
772
773 for list in lists {
774 let items = list.as_ref();
775 let (norm, off) = min_max_params(items);
776 for (id, s) in items {
777 *scores.entry(id.clone()).or_default() += (s - off) * norm;
778 }
779 }
780
781 finalize(scores, config.top_k)
782}
783
784#[must_use]
795pub fn combmnz<I: Clone + Eq + Hash>(
796 results_a: &[(I, f32)],
797 results_b: &[(I, f32)],
798) -> Vec<(I, f32)> {
799 combmnz_with_config(results_a, results_b, FusionConfig::default())
800}
801
802#[must_use]
804pub fn combmnz_with_config<I: Clone + Eq + Hash>(
805 results_a: &[(I, f32)],
806 results_b: &[(I, f32)],
807 config: FusionConfig,
808) -> Vec<(I, f32)> {
809 combmnz_multi(&[results_a, results_b], config)
810}
811
812#[must_use]
814#[allow(clippy::cast_precision_loss)]
815pub fn combmnz_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
816where
817 I: Clone + Eq + Hash,
818 L: AsRef<[(I, f32)]>,
819{
820 let mut scores: HashMap<I, (f32, u32)> = HashMap::new();
821
822 for list in lists {
823 let items = list.as_ref();
824 let (norm, off) = min_max_params(items);
825 for (id, s) in items {
826 let e = scores.entry(id.clone()).or_default();
827 e.0 += (s - off) * norm;
828 e.1 += 1;
829 }
830 }
831
832 let mut results: Vec<_> = scores
833 .into_iter()
834 .map(|(id, (sum, n))| (id, sum * n as f32))
835 .collect();
836 sort_scored_desc(&mut results);
837 if let Some(top_k) = config.top_k {
838 results.truncate(top_k);
839 }
840 results
841}
842
843#[must_use]
858pub fn borda<I: Clone + Eq + Hash>(
859 results_a: &[(I, f32)],
860 results_b: &[(I, f32)],
861) -> Vec<(I, f32)> {
862 borda_with_config(results_a, results_b, FusionConfig::default())
863}
864
865#[must_use]
867pub fn borda_with_config<I: Clone + Eq + Hash>(
868 results_a: &[(I, f32)],
869 results_b: &[(I, f32)],
870 config: FusionConfig,
871) -> Vec<(I, f32)> {
872 borda_multi(&[results_a, results_b], config)
873}
874
875#[must_use]
877#[allow(clippy::cast_precision_loss)]
878pub fn borda_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
879where
880 I: Clone + Eq + Hash,
881 L: AsRef<[(I, f32)]>,
882{
883 let mut scores: HashMap<I, f32> = HashMap::new();
884
885 for list in lists {
886 let items = list.as_ref();
887 let n = items.len() as f32;
888 for (rank, (id, _)) in items.iter().enumerate() {
889 *scores.entry(id.clone()).or_default() += n - rank as f32;
890 }
891 }
892
893 finalize(scores, config.top_k)
894}
895
896#[must_use]
922pub fn dbsf<I: Clone + Eq + Hash>(results_a: &[(I, f32)], results_b: &[(I, f32)]) -> Vec<(I, f32)> {
923 dbsf_with_config(results_a, results_b, FusionConfig::default())
924}
925
926#[must_use]
928pub fn dbsf_with_config<I: Clone + Eq + Hash>(
929 results_a: &[(I, f32)],
930 results_b: &[(I, f32)],
931 config: FusionConfig,
932) -> Vec<(I, f32)> {
933 dbsf_multi(&[results_a, results_b], config)
934}
935
936#[must_use]
938pub fn dbsf_multi<I, L>(lists: &[L], config: FusionConfig) -> Vec<(I, f32)>
939where
940 I: Clone + Eq + Hash,
941 L: AsRef<[(I, f32)]>,
942{
943 let mut scores: HashMap<I, f32> = HashMap::new();
944
945 for list in lists {
946 let items = list.as_ref();
947 let (mean, std) = zscore_params(items);
948
949 for (id, s) in items {
950 let z = if std > 1e-9 {
952 ((s - mean) / std).clamp(-3.0, 3.0)
953 } else {
954 0.0 };
956 *scores.entry(id.clone()).or_default() += z;
957 }
958 }
959
960 finalize(scores, config.top_k)
961}
962
963#[inline]
965fn zscore_params<I>(results: &[(I, f32)]) -> (f32, f32) {
966 if results.is_empty() {
967 return (0.0, 1.0);
968 }
969
970 let n = results.len() as f32;
971 let mean = results.iter().map(|(_, s)| s).sum::<f32>() / n;
972 let variance = results.iter().map(|(_, s)| (s - mean).powi(2)).sum::<f32>() / n;
973 let std = variance.sqrt();
974
975 (mean, std)
976}
977
978#[inline]
986fn finalize<I>(scores: HashMap<I, f32>, top_k: Option<usize>) -> Vec<(I, f32)> {
987 let mut results: Vec<_> = scores.into_iter().collect();
988 sort_scored_desc(&mut results);
989 if let Some(k) = top_k {
990 results.truncate(k);
991 }
992 results
993}
994
995#[inline]
999fn sort_scored_desc<I>(results: &mut [(I, f32)]) {
1000 results.sort_by(|a, b| b.1.total_cmp(&a.1));
1001}
1002
1003#[inline]
1010fn min_max_params<I>(results: &[(I, f32)]) -> (f32, f32) {
1011 if results.is_empty() {
1012 return (1.0, 0.0);
1013 }
1014 let (min, max) = results
1015 .iter()
1016 .fold((f32::INFINITY, f32::NEG_INFINITY), |(lo, hi), (_, s)| {
1017 (lo.min(*s), hi.max(*s))
1018 });
1019 let range = max - min;
1020 if range < 1e-9 {
1021 (1.0, 0.0)
1023 } else {
1024 (1.0 / range, min)
1025 }
1026}
1027
1028#[cfg(test)]
1033mod tests {
1034 use super::*;
1035
1036 fn ranked<'a>(ids: &[&'a str]) -> Vec<(&'a str, f32)> {
1037 ids.iter()
1038 .enumerate()
1039 .map(|(i, &id)| (id, 1.0 - i as f32 * 0.1))
1040 .collect()
1041 }
1042
1043 #[test]
1044 fn rrf_basic() {
1045 let a = ranked(&["d1", "d2", "d3"]);
1046 let b = ranked(&["d2", "d3", "d4"]);
1047 let f = rrf(&a, &b);
1048
1049 assert!(f.iter().position(|(id, _)| *id == "d2").unwrap() < 2);
1050 }
1051
1052 #[test]
1053 fn rrf_with_top_k() {
1054 let a = ranked(&["d1", "d2", "d3"]);
1055 let b = ranked(&["d2", "d3", "d4"]);
1056 let f = rrf_with_config(&a, &b, RrfConfig::default().with_top_k(2));
1057
1058 assert_eq!(f.len(), 2);
1059 }
1060
1061 #[test]
1062 fn rrf_into_works() {
1063 let a = ranked(&["d1", "d2"]);
1064 let b = ranked(&["d2", "d3"]);
1065 let mut out = Vec::new();
1066
1067 rrf_into(&a, &b, RrfConfig::default(), &mut out);
1068
1069 assert_eq!(out.len(), 3);
1070 assert_eq!(out[0].0, "d2");
1071 }
1072
1073 #[test]
1074 fn rrf_score_formula() {
1075 let a = vec![("d1", 1.0)];
1076 let b: Vec<(&str, f32)> = vec![];
1077 let f = rrf_with_config(&a, &b, RrfConfig::new(60));
1078
1079 let expected = 1.0 / 60.0;
1080 assert!((f[0].1 - expected).abs() < 1e-6);
1081 }
1082
1083 #[test]
1085 fn rrf_exact_score_computation() {
1086 let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
1089 let b = vec![("d4", 0.9), ("d5", 0.8), ("d1", 0.7)];
1090
1091 let f = rrf_with_config(&a, &b, RrfConfig::new(60));
1092
1093 let d1_score = f.iter().find(|(id, _)| *id == "d1").unwrap().1;
1095 let expected = 1.0 / 60.0 + 1.0 / 62.0; assert!(
1098 (d1_score - expected).abs() < 1e-6,
1099 "d1 score {} != expected {}",
1100 d1_score,
1101 expected
1102 );
1103 }
1104
1105 #[test]
1107 fn isr_exact_score_computation() {
1108 let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
1111 let b = vec![("d4", 0.9), ("d5", 0.8), ("d1", 0.7)];
1112
1113 let f = isr_with_config(&a, &b, RrfConfig::new(1));
1114
1115 let d1_score = f.iter().find(|(id, _)| *id == "d1").unwrap().1;
1116 let expected = 1.0 / 1.0_f32.sqrt() + 1.0 / 3.0_f32.sqrt();
1117
1118 assert!(
1119 (d1_score - expected).abs() < 1e-6,
1120 "d1 score {} != expected {}",
1121 d1_score,
1122 expected
1123 );
1124 }
1125
1126 #[test]
1128 fn borda_exact_score_computation() {
1129 let a = vec![("d1", 0.9), ("d2", 0.8), ("d3", 0.7)];
1133 let b = vec![("d4", 0.9), ("d5", 0.8), ("d1", 0.7), ("d6", 0.6)];
1134
1135 let f = borda(&a, &b);
1136
1137 let d1_score = f.iter().find(|(id, _)| *id == "d1").unwrap().1;
1138 let expected = 3.0 + 2.0; assert!(
1141 (d1_score - expected).abs() < 1e-6,
1142 "d1 score {} != expected {}",
1143 d1_score,
1144 expected
1145 );
1146 }
1147
1148 #[test]
1149 fn rrf_weighted_applies_weights() {
1150 let list_a = vec![("d1", 0.0)];
1152 let list_b = vec![("d2", 0.0)];
1153
1154 let weights = [0.25, 0.75];
1156 let f = rrf_weighted(&[&list_a[..], &list_b[..]], &weights, RrfConfig::new(60)).unwrap();
1157
1158 assert_eq!(f[0].0, "d2", "weighted RRF should favor higher-weight list");
1160
1161 let d1_score = f.iter().find(|(id, _)| *id == "d1").unwrap().1;
1165 let d2_score = f.iter().find(|(id, _)| *id == "d2").unwrap().1;
1166 assert!(
1167 d2_score > d1_score * 2.0,
1168 "d2 should score ~3x higher than d1"
1169 );
1170 }
1171
1172 #[test]
1173 fn rrf_weighted_zero_weights_error() {
1174 let list_a = vec![("d1", 0.0)];
1175 let list_b = vec![("d2", 0.0)];
1176 let weights = [0.0, 0.0];
1177
1178 let result = rrf_weighted(&[&list_a[..], &list_b[..]], &weights, RrfConfig::default());
1179 assert!(matches!(result, Err(FusionError::ZeroWeights)));
1180 }
1181
1182 #[test]
1187 fn isr_basic() {
1188 let a = ranked(&["d1", "d2", "d3"]);
1189 let b = ranked(&["d2", "d3", "d4"]);
1190 let f = isr(&a, &b);
1191
1192 assert!(f.iter().position(|(id, _)| *id == "d2").unwrap() < 2);
1194 }
1195
1196 #[test]
1197 fn isr_score_formula() {
1198 let a = vec![("d1", 1.0)];
1200 let b: Vec<(&str, f32)> = vec![];
1201 let f = isr_with_config(&a, &b, RrfConfig::new(1));
1202
1203 let expected = 1.0 / 1.0_f32.sqrt(); assert!((f[0].1 - expected).abs() < 1e-6);
1205 }
1206
1207 #[test]
1208 fn isr_gentler_decay_than_rrf() {
1209 let a = vec![("d1", 1.0), ("d2", 0.9), ("d3", 0.8), ("d4", 0.7)];
1214 let b: Vec<(&str, f32)> = vec![];
1215
1216 let rrf_result = rrf_with_config(&a, &b, RrfConfig::new(1));
1217 let isr_result = isr_with_config(&a, &b, RrfConfig::new(1));
1218
1219 let rrf_ratio = rrf_result[0].1 / rrf_result[3].1;
1221 let isr_ratio = isr_result[0].1 / isr_result[3].1;
1222
1223 assert!(
1225 isr_ratio < rrf_ratio,
1226 "ISR should have gentler decay: ISR ratio={}, RRF ratio={}",
1227 isr_ratio,
1228 rrf_ratio
1229 );
1230 }
1231
1232 #[test]
1233 fn isr_multi_works() {
1234 let a = ranked(&["d1", "d2"]);
1235 let b = ranked(&["d2", "d3"]);
1236 let c = ranked(&["d3", "d4"]);
1237 let f = isr_multi(&[&a, &b, &c], RrfConfig::new(1));
1238
1239 assert_eq!(f.len(), 4);
1241 let top_2: Vec<_> = f.iter().take(2).map(|(id, _)| *id).collect();
1246 assert!(top_2.contains(&"d2") && top_2.contains(&"d3"));
1247 }
1248
1249 #[test]
1250 fn isr_with_top_k() {
1251 let a = ranked(&["d1", "d2", "d3"]);
1252 let b = ranked(&["d2", "d3", "d4"]);
1253 let f = isr_with_config(&a, &b, RrfConfig::new(1).with_top_k(2));
1254
1255 assert_eq!(f.len(), 2);
1256 }
1257
1258 #[test]
1259 fn isr_empty_lists() {
1260 let empty: Vec<(&str, f32)> = vec![];
1261 let non_empty = ranked(&["d1"]);
1262
1263 assert_eq!(isr(&empty, &non_empty).len(), 1);
1264 assert_eq!(isr(&non_empty, &empty).len(), 1);
1265 assert_eq!(isr(&empty, &empty).len(), 0);
1266 }
1267
1268 #[test]
1269 fn fusion_method_isr() {
1270 let a = ranked(&["d1", "d2"]);
1271 let b = ranked(&["d2", "d3"]);
1272
1273 let f = FusionMethod::isr().fuse(&a, &b);
1274 assert_eq!(f[0].0, "d2");
1275
1276 let f = FusionMethod::isr_with_k(10).fuse(&a, &b);
1278 assert_eq!(f[0].0, "d2");
1279 }
1280
1281 #[test]
1282 fn fusion_method_isr_multi() {
1283 let a = ranked(&["d1", "d2"]);
1284 let b = ranked(&["d2", "d3"]);
1285 let c = ranked(&["d3", "d4"]);
1286 let lists = [&a[..], &b[..], &c[..]];
1287
1288 let f = FusionMethod::isr().fuse_multi(&lists);
1289 assert!(!f.is_empty());
1290 }
1291
1292 #[test]
1293 fn combmnz_rewards_overlap() {
1294 let a = ranked(&["d1", "d2"]);
1295 let b = ranked(&["d2", "d3"]);
1296 let f = combmnz(&a, &b);
1297
1298 assert_eq!(f[0].0, "d2");
1299 }
1300
1301 #[test]
1302 fn combsum_basic() {
1303 let a = vec![("d1", 0.5), ("d2", 1.0)];
1304 let b = vec![("d2", 1.0), ("d3", 0.5)];
1305 let f = combsum(&a, &b);
1306
1307 assert_eq!(f[0].0, "d2");
1308 }
1309
1310 #[test]
1311 fn weighted_skewed() {
1312 let a = vec![("d1", 1.0)];
1313 let b = vec![("d2", 1.0)];
1314
1315 let f = weighted(
1316 &a,
1317 &b,
1318 WeightedConfig::default()
1319 .with_weights(0.9, 0.1)
1320 .with_normalize(false),
1321 );
1322 assert_eq!(f[0].0, "d1");
1323
1324 let f = weighted(
1325 &a,
1326 &b,
1327 WeightedConfig::default()
1328 .with_weights(0.1, 0.9)
1329 .with_normalize(false),
1330 );
1331 assert_eq!(f[0].0, "d2");
1332 }
1333
1334 #[test]
1335 fn borda_symmetric() {
1336 let a = ranked(&["d1", "d2", "d3"]);
1337 let b = ranked(&["d3", "d2", "d1"]);
1338 let f = borda(&a, &b);
1339
1340 let scores: Vec<f32> = f.iter().map(|(_, s)| *s).collect();
1341 assert!((scores[0] - scores[1]).abs() < 0.01);
1342 assert!((scores[1] - scores[2]).abs() < 0.01);
1343 }
1344
1345 #[test]
1346 fn rrf_multi_works() {
1347 let lists: Vec<Vec<(&str, f32)>> = vec![
1348 ranked(&["d1", "d2"]),
1349 ranked(&["d2", "d3"]),
1350 ranked(&["d1", "d3"]),
1351 ];
1352 let f = rrf_multi(&lists, RrfConfig::default());
1353
1354 assert_eq!(f.len(), 3);
1355 }
1356
1357 #[test]
1358 fn borda_multi_works() {
1359 let lists: Vec<Vec<(&str, f32)>> = vec![
1360 ranked(&["d1", "d2"]),
1361 ranked(&["d2", "d3"]),
1362 ranked(&["d1", "d3"]),
1363 ];
1364 let f = borda_multi(&lists, FusionConfig::default());
1365 assert_eq!(f.len(), 3);
1366 assert_eq!(f[0].0, "d1");
1367 }
1368
1369 #[test]
1370 fn combsum_multi_works() {
1371 let lists: Vec<Vec<(&str, f32)>> = vec![
1372 vec![("d1", 1.0), ("d2", 0.5)],
1373 vec![("d2", 1.0), ("d3", 0.5)],
1374 vec![("d1", 1.0), ("d3", 0.5)],
1375 ];
1376 let f = combsum_multi(&lists, FusionConfig::default());
1377 assert_eq!(f.len(), 3);
1378 }
1379
1380 #[test]
1381 fn combmnz_multi_works() {
1382 let lists: Vec<Vec<(&str, f32)>> = vec![
1383 vec![("d1", 1.0)],
1384 vec![("d1", 1.0), ("d2", 0.5)],
1385 vec![("d1", 1.0), ("d2", 0.5)],
1386 ];
1387 let f = combmnz_multi(&lists, FusionConfig::default());
1388 assert_eq!(f[0].0, "d1");
1389 }
1390
1391 #[test]
1392 fn weighted_multi_works() {
1393 let a = vec![("d1", 1.0)];
1394 let b = vec![("d2", 1.0)];
1395 let c = vec![("d3", 1.0)];
1396
1397 let f = weighted_multi(&[(&a, 1.0), (&b, 1.0), (&c, 1.0)], false, None).unwrap();
1398 assert_eq!(f.len(), 3);
1399
1400 let f = weighted_multi(&[(&a, 10.0), (&b, 1.0), (&c, 1.0)], false, None).unwrap();
1401 assert_eq!(f[0].0, "d1");
1402 }
1403
1404 #[test]
1405 fn weighted_multi_zero_weights() {
1406 let a = vec![("d1", 1.0)];
1407 let result = weighted_multi(&[(&a, 0.0)], false, None);
1408 assert!(matches!(result, Err(FusionError::ZeroWeights)));
1409 }
1410
1411 #[test]
1412 fn empty_inputs() {
1413 let empty: Vec<(&str, f32)> = vec![];
1414 let non_empty = ranked(&["d1"]);
1415
1416 assert_eq!(rrf(&empty, &non_empty).len(), 1);
1417 assert_eq!(rrf(&non_empty, &empty).len(), 1);
1418 }
1419
1420 #[test]
1421 fn both_empty() {
1422 let empty: Vec<(&str, f32)> = vec![];
1423 assert_eq!(rrf(&empty, &empty).len(), 0);
1424 assert_eq!(combsum(&empty, &empty).len(), 0);
1425 assert_eq!(borda(&empty, &empty).len(), 0);
1426 }
1427
1428 #[test]
1429 fn duplicate_ids_in_same_list() {
1430 let a = vec![("d1", 1.0), ("d1", 0.5)];
1431 let b: Vec<(&str, f32)> = vec![];
1432 let f = rrf_with_config(&a, &b, RrfConfig::new(60));
1433
1434 assert_eq!(f.len(), 1);
1435 let expected = 1.0 / 60.0 + 1.0 / 61.0;
1436 assert!((f[0].1 - expected).abs() < 1e-6);
1437 }
1438
1439 #[test]
1440 fn builder_pattern() {
1441 let config = RrfConfig::default().with_k(30).with_top_k(5);
1442 assert_eq!(config.k, 30);
1443 assert_eq!(config.top_k, Some(5));
1444
1445 let config = WeightedConfig::default()
1446 .with_weights(0.8, 0.2)
1447 .with_normalize(false)
1448 .with_top_k(10);
1449 assert_eq!(config.weight_a, 0.8);
1450 assert!(!config.normalize);
1451 assert_eq!(config.top_k, Some(10));
1452 }
1453
1454 #[test]
1459 fn nan_scores_handled() {
1460 let a = vec![("d1", f32::NAN), ("d2", 0.5)];
1461 let b = vec![("d2", 0.9), ("d3", 0.1)];
1462
1463 let _ = rrf(&a, &b);
1465 let _ = combsum(&a, &b);
1466 let _ = combmnz(&a, &b);
1467 let _ = borda(&a, &b);
1468 }
1469
1470 #[test]
1471 fn inf_scores_handled() {
1472 let a = vec![("d1", f32::INFINITY), ("d2", 0.5)];
1473 let b = vec![("d2", f32::NEG_INFINITY), ("d3", 0.1)];
1474
1475 let _ = rrf(&a, &b);
1477 let _ = combsum(&a, &b);
1478 }
1479
1480 #[test]
1481 fn zero_scores() {
1482 let a = vec![("d1", 0.0), ("d2", 0.0)];
1483 let b = vec![("d2", 0.0), ("d3", 0.0)];
1484
1485 let f = combsum(&a, &b);
1486 assert_eq!(f.len(), 3);
1487 }
1488
1489 #[test]
1490 fn negative_scores() {
1491 let a = vec![("d1", -1.0), ("d2", -0.5)];
1492 let b = vec![("d2", -0.9), ("d3", -0.1)];
1493
1494 let f = combsum(&a, &b);
1495 assert_eq!(f.len(), 3);
1496 }
1498
1499 #[test]
1500 fn large_k_value() {
1501 let a = ranked(&["d1", "d2"]);
1502 let b = ranked(&["d2", "d3"]);
1503
1504 let f = rrf_with_config(&a, &b, RrfConfig::new(u32::MAX));
1506 assert!(!f.is_empty());
1507 }
1508
1509 #[test]
1510 fn single_item_lists() {
1511 let a = vec![("d1", 1.0)];
1512 let b = vec![("d1", 1.0)];
1513
1514 let f = rrf(&a, &b);
1515 assert_eq!(f.len(), 1);
1516
1517 let f = combsum(&a, &b);
1518 assert_eq!(f.len(), 1);
1519
1520 let f = borda(&a, &b);
1521 assert_eq!(f.len(), 1);
1522 }
1523
1524 #[test]
1525 fn disjoint_lists() {
1526 let a = vec![("d1", 1.0), ("d2", 0.9)];
1527 let b = vec![("d3", 1.0), ("d4", 0.9)];
1528
1529 let f = rrf(&a, &b);
1530 assert_eq!(f.len(), 4);
1531
1532 let f = combmnz(&a, &b);
1533 assert_eq!(f.len(), 4);
1534 }
1536
1537 #[test]
1538 fn identical_lists() {
1539 let a = ranked(&["d1", "d2", "d3"]);
1540 let b = ranked(&["d1", "d2", "d3"]);
1541
1542 let f = rrf(&a, &b);
1543 assert_eq!(f[0].0, "d1");
1545 assert_eq!(f[1].0, "d2");
1546 assert_eq!(f[2].0, "d3");
1547 }
1548
1549 #[test]
1550 fn reversed_lists() {
1551 let a = ranked(&["d1", "d2", "d3"]);
1552 let b = ranked(&["d3", "d2", "d1"]);
1553
1554 let f = rrf(&a, &b);
1555 assert_eq!(f.len(), 3);
1562 }
1563
1564 #[test]
1565 fn top_k_larger_than_result() {
1566 let a = ranked(&["d1"]);
1567 let b = ranked(&["d2"]);
1568
1569 let f = rrf_with_config(&a, &b, RrfConfig::default().with_top_k(100));
1570 assert_eq!(f.len(), 2);
1571 }
1572
1573 #[test]
1574 fn top_k_zero() {
1575 let a = ranked(&["d1", "d2"]);
1576 let b = ranked(&["d2", "d3"]);
1577
1578 let f = rrf_with_config(&a, &b, RrfConfig::default().with_top_k(0));
1579 assert_eq!(f.len(), 0);
1580 }
1581
1582 #[test]
1587 fn fusion_method_rrf() {
1588 let a = ranked(&["d1", "d2"]);
1589 let b = ranked(&["d2", "d3"]);
1590
1591 let f = FusionMethod::rrf().fuse(&a, &b);
1592 assert_eq!(f[0].0, "d2"); }
1594
1595 #[test]
1596 fn fusion_method_combsum() {
1597 let a = vec![("d1", 1.0_f32), ("d2", 0.6), ("d4", 0.2)];
1603 let b = vec![("d2", 1.0_f32), ("d3", 0.5)];
1604 let f = FusionMethod::CombSum.fuse(&a, &b);
1609 assert_eq!(f[0].0, "d2");
1611 }
1612
1613 #[test]
1614 fn fusion_method_combmnz() {
1615 let a = ranked(&["d1", "d2"]);
1616 let b = ranked(&["d2", "d3"]);
1617
1618 let f = FusionMethod::CombMnz.fuse(&a, &b);
1619 assert_eq!(f[0].0, "d2"); }
1621
1622 #[test]
1623 fn fusion_method_borda() {
1624 let a = ranked(&["d1", "d2"]);
1625 let b = ranked(&["d2", "d3"]);
1626
1627 let f = FusionMethod::Borda.fuse(&a, &b);
1628 assert_eq!(f[0].0, "d2");
1629 }
1630
1631 #[test]
1632 fn fusion_method_weighted() {
1633 let a = vec![("d1", 1.0f32)];
1634 let b = vec![("d2", 1.0f32)];
1635
1636 let f = FusionMethod::weighted(0.9, 0.1).fuse(&a, &b);
1638 assert_eq!(f[0].0, "d1");
1639
1640 let f = FusionMethod::weighted(0.1, 0.9).fuse(&a, &b);
1642 assert_eq!(f[0].0, "d2");
1643 }
1644
1645 #[test]
1646 fn fusion_method_multi() {
1647 let lists: Vec<Vec<(&str, f32)>> = vec![
1648 ranked(&["d1", "d2"]),
1649 ranked(&["d2", "d3"]),
1650 ranked(&["d1", "d3"]),
1651 ];
1652
1653 let f = FusionMethod::rrf().fuse_multi(&lists);
1654 assert_eq!(f.len(), 3);
1655 }
1657
1658 #[test]
1659 fn fusion_method_default_is_rrf() {
1660 let method = FusionMethod::default();
1661 assert!(matches!(method, FusionMethod::Rrf { k: 60 }));
1662 }
1663}
1664
1665#[cfg(test)]
1670mod proptests {
1671 use super::*;
1672 use proptest::prelude::*;
1673
1674 fn arb_results(max_len: usize) -> impl Strategy<Value = Vec<(u32, f32)>> {
1675 proptest::collection::vec((0u32..100, 0.0f32..1.0), 0..max_len)
1676 }
1677
1678 proptest! {
1679 #[test]
1680 fn rrf_output_bounded(a in arb_results(50), b in arb_results(50)) {
1681 let result = rrf(&a, &b);
1682 prop_assert!(result.len() <= a.len() + b.len());
1683 }
1684
1685 #[test]
1686 fn rrf_scores_positive(a in arb_results(50), b in arb_results(50)) {
1687 let result = rrf(&a, &b);
1688 for (_, score) in &result {
1689 prop_assert!(*score > 0.0);
1690 }
1691 }
1692
1693 #[test]
1694 fn rrf_commutative(a in arb_results(20), b in arb_results(20)) {
1695 let ab = rrf(&a, &b);
1696 let ba = rrf(&b, &a);
1697
1698 prop_assert_eq!(ab.len(), ba.len());
1699
1700 let ab_map: HashMap<_, _> = ab.into_iter().collect();
1701 let ba_map: HashMap<_, _> = ba.into_iter().collect();
1702
1703 for (id, score_ab) in &ab_map {
1704 let score_ba = ba_map.get(id).expect("same keys");
1705 prop_assert!((score_ab - score_ba).abs() < 1e-6);
1706 }
1707 }
1708
1709 #[test]
1710 fn rrf_sorted_descending(a in arb_results(50), b in arb_results(50)) {
1711 let result = rrf(&a, &b);
1712 for window in result.windows(2) {
1713 prop_assert!(window[0].1 >= window[1].1);
1714 }
1715 }
1716
1717 #[test]
1718 fn rrf_top_k_respected(a in arb_results(50), b in arb_results(50), k in 1usize..20) {
1719 let result = rrf_with_config(&a, &b, RrfConfig::default().with_top_k(k));
1720 prop_assert!(result.len() <= k);
1721 }
1722
1723 #[test]
1724 fn borda_commutative(a in arb_results(20), b in arb_results(20)) {
1725 let ab = borda(&a, &b);
1726 let ba = borda(&b, &a);
1727
1728 let ab_map: HashMap<_, _> = ab.into_iter().collect();
1729 let ba_map: HashMap<_, _> = ba.into_iter().collect();
1730 prop_assert_eq!(ab_map, ba_map);
1731 }
1732
1733 #[test]
1734 fn combsum_commutative(a in arb_results(20), b in arb_results(20)) {
1735 let ab = combsum(&a, &b);
1736 let ba = combsum(&b, &a);
1737
1738 let ab_map: HashMap<_, _> = ab.into_iter().collect();
1739 let ba_map: HashMap<_, _> = ba.into_iter().collect();
1740
1741 prop_assert_eq!(ab_map.len(), ba_map.len());
1742 for (id, score_ab) in &ab_map {
1743 let score_ba = ba_map.get(id).unwrap();
1744 prop_assert!((score_ab - score_ba).abs() < 1e-5);
1745 }
1746 }
1747
1748 #[test]
1750 fn combmnz_commutative(a in arb_results(20), b in arb_results(20)) {
1751 let ab = combmnz(&a, &b);
1752 let ba = combmnz(&b, &a);
1753
1754 let ab_map: HashMap<_, _> = ab.into_iter().collect();
1755 let ba_map: HashMap<_, _> = ba.into_iter().collect();
1756
1757 prop_assert_eq!(ab_map.len(), ba_map.len());
1758 for (id, score_ab) in &ab_map {
1759 let score_ba = ba_map.get(id).expect("same keys");
1760 prop_assert!((score_ab - score_ba).abs() < 1e-5,
1761 "CombMNZ not commutative for id {:?}: {} vs {}", id, score_ab, score_ba);
1762 }
1763 }
1764
1765 #[test]
1767 fn dbsf_commutative(a in arb_results(20), b in arb_results(20)) {
1768 let ab = dbsf(&a, &b);
1769 let ba = dbsf(&b, &a);
1770
1771 let ab_map: HashMap<_, _> = ab.into_iter().collect();
1772 let ba_map: HashMap<_, _> = ba.into_iter().collect();
1773
1774 prop_assert_eq!(ab_map.len(), ba_map.len());
1775 for (id, score_ab) in &ab_map {
1776 let score_ba = ba_map.get(id).expect("same keys");
1777 prop_assert!((score_ab - score_ba).abs() < 1e-5,
1778 "DBSF not commutative for id {:?}: {} vs {}", id, score_ab, score_ba);
1779 }
1780 }
1781
1782 #[test]
1783 fn rrf_k_uniformity(a in arb_results(10).prop_filter("need items", |v| v.len() >= 2)) {
1784 let b: Vec<(u32, f32)> = vec![];
1785
1786 let low_k = rrf_with_config(&a, &b, RrfConfig::new(1));
1787 let high_k = rrf_with_config(&a, &b, RrfConfig::new(1000));
1788
1789 if low_k.len() >= 2 && high_k.len() >= 2 {
1790 let low_k_range = low_k[0].1 - low_k[low_k.len()-1].1;
1791 let high_k_range = high_k[0].1 - high_k[high_k.len()-1].1;
1792 prop_assert!(high_k_range <= low_k_range);
1793 }
1794 }
1795
1796 #[test]
1798 fn combmnz_overlap_bonus(id in 0u32..100, score in 0.1f32..1.0) {
1799 let a = vec![(id, score)];
1800 let b = vec![(id, score)];
1801 let c = vec![(id + 1, score)]; let overlapped = combmnz(&a, &b);
1804 let disjoint = combmnz(&a, &c);
1805
1806 let overlap_score = overlapped.iter().find(|(i, _)| *i == id).map(|(_, s)| *s).unwrap_or(0.0);
1808 let disjoint_score = disjoint.iter().find(|(i, _)| *i == id).map(|(_, s)| *s).unwrap_or(0.0);
1809 prop_assert!(overlap_score >= disjoint_score);
1810 }
1811
1812 #[test]
1814 fn weighted_extreme_weights(a_id in 0u32..50, b_id in 50u32..100) {
1815 let a = vec![(a_id, 1.0f32)];
1816 let b = vec![(b_id, 1.0f32)];
1817
1818 let high_a = weighted(&a, &b, WeightedConfig::new(0.99, 0.01).with_normalize(false));
1819 let high_b = weighted(&a, &b, WeightedConfig::new(0.01, 0.99).with_normalize(false));
1820
1821 let a_score_in_high_a = high_a.iter().find(|(id, _)| *id == a_id).map(|(_, s)| *s).unwrap_or(0.0);
1824 let b_score_in_high_a = high_a.iter().find(|(id, _)| *id == b_id).map(|(_, s)| *s).unwrap_or(0.0);
1825 prop_assert!(a_score_in_high_a > b_score_in_high_a);
1826
1827 let a_score_in_high_b = high_b.iter().find(|(id, _)| *id == a_id).map(|(_, s)| *s).unwrap_or(0.0);
1828 let b_score_in_high_b = high_b.iter().find(|(id, _)| *id == b_id).map(|(_, s)| *s).unwrap_or(0.0);
1829 prop_assert!(b_score_in_high_b > a_score_in_high_b);
1830 }
1831
1832 #[test]
1834 fn nonempty_output(a in arb_results(5).prop_filter("need items", |v| !v.is_empty())) {
1835 let b: Vec<(u32, f32)> = vec![];
1836
1837 prop_assert!(!rrf(&a, &b).is_empty());
1838 prop_assert!(!isr(&a, &b).is_empty());
1839 prop_assert!(!combsum(&a, &b).is_empty());
1840 prop_assert!(!combmnz(&a, &b).is_empty());
1841 prop_assert!(!borda(&a, &b).is_empty());
1842 }
1843
1844 #[test]
1850 fn isr_output_bounded(a in arb_results(50), b in arb_results(50)) {
1851 let result = isr(&a, &b);
1852 prop_assert!(result.len() <= a.len() + b.len());
1853 }
1854
1855 #[test]
1857 fn isr_scores_positive(a in arb_results(50), b in arb_results(50)) {
1858 let result = isr(&a, &b);
1859 for (_, score) in &result {
1860 prop_assert!(*score > 0.0);
1861 }
1862 }
1863
1864 #[test]
1866 fn isr_commutative(a in arb_results(20), b in arb_results(20)) {
1867 let ab = isr(&a, &b);
1868 let ba = isr(&b, &a);
1869
1870 prop_assert_eq!(ab.len(), ba.len());
1871
1872 let ab_map: HashMap<_, _> = ab.into_iter().collect();
1873 let ba_map: HashMap<_, _> = ba.into_iter().collect();
1874
1875 for (id, score_ab) in &ab_map {
1876 let score_ba = ba_map.get(id).expect("same keys");
1877 prop_assert!((score_ab - score_ba).abs() < 1e-6);
1878 }
1879 }
1880
1881 #[test]
1883 fn isr_sorted_descending(a in arb_results(50), b in arb_results(50)) {
1884 let result = isr(&a, &b);
1885 for window in result.windows(2) {
1886 prop_assert!(window[0].1 >= window[1].1);
1887 }
1888 }
1889
1890 #[test]
1892 fn isr_top_k_respected(a in arb_results(50), b in arb_results(50), k in 1usize..20) {
1893 let result = isr_with_config(&a, &b, RrfConfig::new(1).with_top_k(k));
1894 prop_assert!(result.len() <= k);
1895 }
1896
1897 #[test]
1900 fn isr_gentler_than_rrf(n in 3usize..20) {
1901 let a: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, 1.0)).collect();
1903 let b: Vec<(u32, f32)> = vec![];
1904
1905 let rrf_result = rrf_with_config(&a, &b, RrfConfig::new(1));
1906 let isr_result = isr_with_config(&a, &b, RrfConfig::new(1));
1907
1908 prop_assert_eq!(rrf_result.len(), n);
1910 prop_assert_eq!(isr_result.len(), n);
1911
1912 let rrf_ratio = rrf_result[0].1 / rrf_result.last().unwrap().1;
1914 let isr_ratio = isr_result[0].1 / isr_result.last().unwrap().1;
1915
1916 prop_assert!(isr_ratio < rrf_ratio,
1921 "ISR ratio {} should be < RRF ratio {} for n={}", isr_ratio, rrf_ratio, n);
1922 }
1923
1924 #[test]
1926 fn multi_matches_two_list(a in arb_results(10), b in arb_results(10)) {
1927 let two_list = rrf(&a, &b);
1928 let multi = rrf_multi(&[a.clone(), b.clone()], RrfConfig::default());
1929
1930 prop_assert_eq!(two_list.len(), multi.len());
1931
1932 let two_map: HashMap<_, _> = two_list.into_iter().collect();
1934 let multi_map: HashMap<_, _> = multi.into_iter().collect();
1935
1936 for (id, score) in &two_map {
1937 let multi_score = multi_map.get(id).expect("same ids");
1938 prop_assert!((score - multi_score).abs() < 1e-6, "score mismatch for {:?}", id);
1939 }
1940 }
1941
1942 #[test]
1944 fn borda_top_position_wins(n in 2usize..10) {
1945 let top_id = 999u32;
1946 let a: Vec<(u32, f32)> = std::iter::once((top_id, 1.0))
1947 .chain((0..n as u32 - 1).map(|i| (i, 0.9 - i as f32 * 0.1)))
1948 .collect();
1949 let b: Vec<(u32, f32)> = std::iter::once((top_id, 1.0))
1950 .chain((100..100 + n as u32 - 1).map(|i| (i, 0.9)))
1951 .collect();
1952
1953 let f = borda(&a, &b);
1954 prop_assert_eq!(f[0].0, top_id);
1955 }
1956
1957 #[test]
1963 fn nan_does_not_corrupt_sorting(a in arb_results(10)) {
1964 let mut with_nan = a.clone();
1965 if !with_nan.is_empty() {
1966 with_nan[0].1 = f32::NAN;
1967 }
1968 let b: Vec<(u32, f32)> = vec![];
1969
1970 let result = combsum(&with_nan, &b);
1972 for window in result.windows(2) {
1973 let cmp = window[0].1.total_cmp(&window[1].1);
1975 prop_assert!(cmp != std::cmp::Ordering::Less,
1976 "Not sorted: {:?} < {:?}", window[0].1, window[1].1);
1977 }
1978 }
1979
1980 #[test]
1982 fn infinity_handled_gracefully(id in 0u32..50) {
1983 let a = vec![(id, f32::INFINITY)];
1984 let b = vec![(id + 100, f32::NEG_INFINITY)]; let result = combsum(&a, &b);
1988 prop_assert_eq!(result.len(), 2);
1989 }
1991
1992 #[test]
1994 fn output_always_sorted(a in arb_results(20), b in arb_results(20)) {
1995 for result in [
1996 rrf(&a, &b),
1997 combsum(&a, &b),
1998 combmnz(&a, &b),
1999 borda(&a, &b),
2000 ] {
2001 for window in result.windows(2) {
2002 prop_assert!(
2003 window[0].1.total_cmp(&window[1].1) != std::cmp::Ordering::Less,
2004 "Not sorted: {} < {}", window[0].1, window[1].1
2005 );
2006 }
2007 }
2008 }
2009
2010 #[test]
2012 fn unique_ids_in_output(a in arb_results(20), b in arb_results(20)) {
2013 let result = rrf(&a, &b);
2014 let mut seen = std::collections::HashSet::new();
2015 for (id, _) in &result {
2016 prop_assert!(seen.insert(id), "Duplicate ID in output: {:?}", id);
2017 }
2018 }
2019
2020 #[test]
2022 fn combsum_scores_nonnegative(a in arb_results(10), b in arb_results(10)) {
2023 let result = combsum(&a, &b);
2024 for (_, score) in &result {
2025 if !score.is_nan() {
2026 prop_assert!(*score >= -0.01, "Score {} is negative", score);
2027 }
2028 }
2029 }
2030
2031 #[test]
2033 fn equal_weights_symmetric(a in arb_results(10), b in arb_results(10)) {
2034 let ab = weighted(&a, &b, WeightedConfig::default());
2035 let ba = weighted(&b, &a, WeightedConfig::default());
2036
2037 let ab_map: HashMap<_, _> = ab.into_iter().collect();
2038 let ba_map: HashMap<_, _> = ba.into_iter().collect();
2039
2040 prop_assert_eq!(ab_map.len(), ba_map.len());
2041 for (id, score_ab) in &ab_map {
2042 if let Some(score_ba) = ba_map.get(id) {
2043 prop_assert!((score_ab - score_ba).abs() < 1e-5,
2044 "Symmetric treatment violated for {:?}: {} != {}", id, score_ab, score_ba);
2045 }
2046 }
2047 }
2048
2049 #[test]
2051 fn rrf_score_bounded(k in 1u32..1000) {
2052 let a = vec![(1u32, 1.0)];
2053 let b = vec![(1u32, 1.0)];
2054
2055 let result = rrf_with_config(&a, &b, RrfConfig::new(k));
2056 let max_possible = 2.0 / k as f32; prop_assert!(result[0].1 <= max_possible + 1e-6);
2058 }
2059
2060 #[test]
2062 fn empty_list_preserves_ids(n in 1usize..10) {
2063 let a: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, 1.0 - i as f32 * 0.1)).collect();
2065 let empty: Vec<(u32, f32)> = vec![];
2066
2067 let rrf_result = rrf(&a, &empty);
2068
2069 prop_assert_eq!(rrf_result.len(), n);
2071
2072 for (id, _) in &a {
2074 prop_assert!(rrf_result.iter().any(|(rid, _)| rid == id), "Missing ID {:?}", id);
2075 }
2076 }
2077 }
2078}