1use std::collections::HashMap;
8
9#[derive(Debug, Clone, Copy)]
11pub struct RrfConfig {
12 pub k: u32,
16}
17
18impl Default for RrfConfig {
19 fn default() -> Self {
20 Self { k: 60 }
21 }
22}
23
24impl RrfConfig {
25 #[must_use]
27 pub const fn new(k: u32) -> Self {
28 Self { k }
29 }
30}
31
32#[must_use]
64#[allow(clippy::cast_possible_truncation)]
65pub fn reciprocal_rank_fusion(ranked_lists: &[&[i64]], config: &RrfConfig) -> Vec<(i64, f64)> {
66 let mut scores: HashMap<i64, f64> = HashMap::new();
67
68 for list in ranked_lists {
69 for (rank, &item_id) in list.iter().enumerate() {
70 let rrf_score = 1.0 / f64::from(config.k + (rank as u32) + 1);
73 *scores.entry(item_id).or_insert(0.0) += rrf_score;
74 }
75 }
76
77 let mut results: Vec<(i64, f64)> = scores.into_iter().collect();
79 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
80
81 results
82}
83
84#[must_use]
97#[allow(clippy::cast_possible_truncation)]
98pub fn weighted_rrf(ranked_lists: &[(&[i64], f64)], config: &RrfConfig) -> Vec<(i64, f64)> {
99 let mut scores: HashMap<i64, f64> = HashMap::new();
100
101 for (list, weight) in ranked_lists {
102 for (rank, &item_id) in list.iter().enumerate() {
103 let rrf_score = weight / f64::from(config.k + (rank as u32) + 1);
104 *scores.entry(item_id).or_insert(0.0) += rrf_score;
105 }
106 }
107
108 let mut results: Vec<(i64, f64)> = scores.into_iter().collect();
109 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
110
111 results
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_rrf_single_list() {
120 let list = vec![1, 2, 3];
121 let config = RrfConfig::new(60);
122
123 let results = reciprocal_rank_fusion(&[&list], &config);
124
125 assert_eq!(results.len(), 3);
126 assert_eq!(results[0].0, 1);
128 assert!(results[0].1 > results[1].1);
129 assert!(results[1].1 > results[2].1);
130 }
131
132 #[test]
133 fn test_rrf_multiple_lists() {
134 let list1 = vec![1, 2, 3];
135 let list2 = vec![3, 2, 1];
136 let config = RrfConfig::new(60);
137
138 let results = reciprocal_rank_fusion(&[&list1, &list2], &config);
139
140 assert_eq!(results.len(), 3);
141 let ids: std::collections::HashSet<i64> = results.iter().map(|(id, _)| *id).collect();
144 assert!(ids.contains(&1));
145 assert!(ids.contains(&2));
146 assert!(ids.contains(&3));
147 }
148
149 #[test]
150 fn test_rrf_disjoint_lists() {
151 let list1 = vec![1, 2];
152 let list2 = vec![3, 4];
153 let config = RrfConfig::new(60);
154
155 let results = reciprocal_rank_fusion(&[&list1, &list2], &config);
156
157 assert_eq!(results.len(), 4);
158 let score1 = results.iter().find(|(id, _)| *id == 1).unwrap().1;
160 let score3 = results.iter().find(|(id, _)| *id == 3).unwrap().1;
161 assert!((score1 - score3).abs() < f64::EPSILON);
162 }
163
164 #[test]
165 fn test_rrf_empty_lists() {
166 let list1: Vec<i64> = vec![];
167 let config = RrfConfig::new(60);
168
169 let results = reciprocal_rank_fusion(&[&list1], &config);
170 assert!(results.is_empty());
171 }
172
173 #[test]
174 fn test_rrf_k_parameter() {
175 let list = vec![1, 2];
176 let config_low_k = RrfConfig::new(1);
177 let config_high_k = RrfConfig::new(100);
178
179 let results_low = reciprocal_rank_fusion(&[&list], &config_low_k);
180 let results_high = reciprocal_rank_fusion(&[&list], &config_high_k);
181
182 let diff_low = results_low[0].1 - results_low[1].1;
184 let diff_high = results_high[0].1 - results_high[1].1;
185
186 assert!(diff_low > diff_high);
187 }
188
189 #[test]
190 fn test_weighted_rrf() {
191 let list1 = vec![1, 2];
192 let list2 = vec![2, 1];
193 let config = RrfConfig::new(60);
194
195 let results = weighted_rrf(&[(&list1, 2.0), (&list2, 1.0)], &config);
197
198 assert_eq!(results[0].0, 1);
200 }
201
202 #[test]
203 fn test_rrf_score_formula() {
204 let list = vec![1];
205 let config = RrfConfig::new(60);
206
207 let results = reciprocal_rank_fusion(&[&list], &config);
208
209 let expected = 1.0 / 61.0;
211 assert!((results[0].1 - expected).abs() < f64::EPSILON);
212 }
213
214 #[test]
215 fn test_rrf_combined_score() {
216 let list1 = vec![1];
217 let list2 = vec![1];
218 let config = RrfConfig::new(60);
219
220 let results = reciprocal_rank_fusion(&[&list1, &list2], &config);
221
222 let expected = 2.0 / 61.0;
224 assert!((results[0].1 - expected).abs() < f64::EPSILON);
225 }
226
227 #[test]
228 fn test_rrf_config_default() {
229 let config = RrfConfig::default();
231 assert_eq!(config.k, 60);
232 }
233}