1pub const DEFAULT_RRF_K: f64 = 60.0;
13
14#[derive(Debug, Clone)]
16pub struct RankedResult {
17 pub document_id: String,
19 pub rank: usize,
21 pub score: f32,
23 pub source: &'static str,
25}
26
27#[derive(Debug, Clone)]
29pub struct FusedResult {
30 pub document_id: String,
31 pub rrf_score: f64,
32 pub contributions: Vec<(&'static str, f64)>,
34}
35
36pub fn reciprocal_rank_fusion(
41 ranked_lists: &[Vec<RankedResult>],
42 k: Option<f64>,
43 top_k: usize,
44) -> Vec<FusedResult> {
45 let k = k.unwrap_or(DEFAULT_RRF_K);
46
47 let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
48 std::collections::HashMap::new();
49
50 for list in ranked_lists {
51 for result in list {
52 let contribution = 1.0 / (k + result.rank as f64 + 1.0);
53 scores
54 .entry(result.document_id.clone())
55 .or_default()
56 .push((result.source, contribution));
57 }
58 }
59
60 let mut fused: Vec<FusedResult> = scores
61 .into_iter()
62 .map(|(doc_id, contributions)| {
63 let rrf_score = contributions.iter().map(|(_, s)| s).sum();
64 FusedResult {
65 document_id: doc_id,
66 rrf_score,
67 contributions,
68 }
69 })
70 .collect();
71
72 fused.sort_unstable_by(|a, b| {
73 b.rrf_score
74 .partial_cmp(&a.rrf_score)
75 .unwrap_or(std::cmp::Ordering::Equal)
76 .then_with(|| a.document_id.cmp(&b.document_id))
81 });
82 fused.truncate(top_k);
83 fused
84}
85
86pub fn reciprocal_rank_fusion_linear(
101 ranked_lists: &[Vec<RankedResult>],
102 k: Option<f64>,
103 weights: &[f64],
104 top_k: usize,
105) -> Vec<FusedResult> {
106 assert_eq!(
107 ranked_lists.len(),
108 weights.len(),
109 "weights length must match ranked_lists length"
110 );
111 let k = k.unwrap_or(DEFAULT_RRF_K);
112
113 let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
114 std::collections::HashMap::new();
115
116 for (list_idx, list) in ranked_lists.iter().enumerate() {
117 let w = weights[list_idx];
118 for result in list {
119 let contribution = w / (k + result.rank as f64 + 1.0);
120 scores
121 .entry(result.document_id.clone())
122 .or_default()
123 .push((result.source, contribution));
124 }
125 }
126
127 let mut fused: Vec<FusedResult> = scores
128 .into_iter()
129 .map(|(doc_id, contributions)| {
130 let rrf_score = contributions.iter().map(|(_, s)| s).sum();
131 FusedResult {
132 document_id: doc_id,
133 rrf_score,
134 contributions,
135 }
136 })
137 .collect();
138
139 fused.sort_unstable_by(|a, b| {
140 b.rrf_score
141 .partial_cmp(&a.rrf_score)
142 .unwrap_or(std::cmp::Ordering::Equal)
143 .then_with(|| a.document_id.cmp(&b.document_id))
145 });
146 fused.truncate(top_k);
147 fused
148}
149
150pub fn reciprocal_rank_fusion_weighted(
159 ranked_lists: &[Vec<RankedResult>],
160 k_per_list: &[f64],
161 top_k: usize,
162) -> Vec<FusedResult> {
163 assert_eq!(
164 ranked_lists.len(),
165 k_per_list.len(),
166 "k_per_list length must match ranked_lists length"
167 );
168
169 let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
170 std::collections::HashMap::new();
171
172 for (list_idx, list) in ranked_lists.iter().enumerate() {
173 let k = k_per_list[list_idx];
174 for result in list {
175 let contribution = 1.0 / (k + result.rank as f64 + 1.0);
176 scores
177 .entry(result.document_id.clone())
178 .or_default()
179 .push((result.source, contribution));
180 }
181 }
182
183 let mut fused: Vec<FusedResult> = scores
184 .into_iter()
185 .map(|(doc_id, contributions)| {
186 let rrf_score = contributions.iter().map(|(_, s)| s).sum();
187 FusedResult {
188 document_id: doc_id,
189 rrf_score,
190 contributions,
191 }
192 })
193 .collect();
194
195 fused.sort_unstable_by(|a, b| {
196 b.rrf_score
197 .partial_cmp(&a.rrf_score)
198 .unwrap_or(std::cmp::Ordering::Equal)
199 .then_with(|| a.document_id.cmp(&b.document_id))
201 });
202 fused.truncate(top_k);
203 fused
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 fn make_ranked(doc_ids: &[&str], source: &'static str) -> Vec<RankedResult> {
211 doc_ids
212 .iter()
213 .enumerate()
214 .map(|(rank, &id)| RankedResult {
215 document_id: id.to_string(),
216 rank,
217 score: 1.0 - (rank as f32 * 0.1),
218 source,
219 })
220 .collect()
221 }
222
223 #[test]
224 fn single_list_preserves_order() {
225 let list = make_ranked(&["d1", "d2", "d3"], "vector");
226 let fused = reciprocal_rank_fusion(&[list], None, 10);
227 assert_eq!(fused.len(), 3);
228 assert_eq!(fused[0].document_id, "d1");
229 }
230
231 #[test]
232 fn overlapping_lists_boost_common_docs() {
233 let vector = make_ranked(&["d1", "d2", "d3"], "vector");
234 let sparse = make_ranked(&["d2", "d1", "d4"], "sparse");
235 let fused = reciprocal_rank_fusion(&[vector, sparse], None, 10);
236 let top2_ids: Vec<&str> = fused[..2].iter().map(|f| f.document_id.as_str()).collect();
237 assert!(top2_ids.contains(&"d1"));
238 assert!(top2_ids.contains(&"d2"));
239 }
240
241 #[test]
242 fn weighted_rrf() {
243 let list_a = make_ranked(&["a1", "a2"], "vector");
244 let list_b = make_ranked(&["b1", "a1"], "text");
245 let fused = reciprocal_rank_fusion_weighted(&[list_a, list_b], &[30.0, 120.0], 10);
246 let a1 = fused.iter().find(|f| f.document_id == "a1").unwrap();
247 assert_eq!(a1.contributions.len(), 2);
248 }
249
250 #[test]
251 fn linear_weight_lets_strong_source_dominate() {
252 let strong = make_ranked(&["a1", "a2"], "strong");
257 let weak = make_ranked(&["b1", "a1"], "weak");
258 let fused = reciprocal_rank_fusion_linear(&[strong, weak], None, &[4.0, 0.25], 10);
259
260 let a1 = fused.iter().position(|f| f.document_id == "a1").unwrap();
261 let b1 = fused.iter().position(|f| f.document_id == "b1").unwrap();
262 assert!(a1 < b1, "a1 (rank {a1}) should outrank b1 (rank {b1})");
263
264 let a1_res = &fused[a1];
265 assert_eq!(a1_res.contributions.len(), 2);
266 let strong_contrib = a1_res
268 .contributions
269 .iter()
270 .find(|(src, _)| *src == "strong")
271 .map(|(_, s)| *s)
272 .unwrap();
273 let expected = 4.0 / (DEFAULT_RRF_K + 0.0 + 1.0);
274 assert!((strong_contrib - expected).abs() < 1e-12);
275 }
276
277 #[test]
278 #[should_panic(expected = "weights length must match ranked_lists length")]
279 fn linear_mismatched_weights_panics() {
280 let list = make_ranked(&["d1"], "vector");
281 let _ = reciprocal_rank_fusion_linear(&[list], None, &[1.0, 2.0], 10);
282 }
283
284 #[test]
285 fn empty() {
286 assert!(reciprocal_rank_fusion(&[], None, 10).is_empty());
287 assert!(reciprocal_rank_fusion_linear(&[], None, &[], 10).is_empty());
288 assert!(reciprocal_rank_fusion_weighted(&[], &[], 10).is_empty());
289 }
290}