1use std::collections::HashMap;
47use std::hash::Hash;
48
49#[derive(Debug, Clone, PartialEq)]
53pub struct Candidate<Id> {
54 pub id: Id,
55 pub score: f64,
56}
57
58#[derive(Debug, Clone)]
61pub struct Bucket<Id> {
62 pub candidates: Vec<Candidate<Id>>,
63 pub min_score: Option<f64>,
67}
68
69#[derive(Debug, Clone, PartialEq)]
72pub struct FusedItem<Id> {
73 pub id: Id,
74 pub rrf_score: f64,
75}
76
77pub const RRF_K_DEFAULT: u32 = 60;
80
81pub fn fuse<Id>(buckets: &[Bucket<Id>], k: u32, total_k: usize) -> Vec<FusedItem<Id>>
88where
89 Id: Clone + Eq + Hash + Ord,
90{
91 if total_k == 0 {
92 return Vec::new();
93 }
94
95 let k_f = f64::from(k);
96 let mut scores: HashMap<Id, f64> = HashMap::new();
97
98 for bucket in buckets {
99 let mut rank: u32 = 0;
100 for cand in &bucket.candidates {
101 if let Some(floor) = bucket.min_score {
102 if cand.score < floor {
103 continue;
104 }
105 }
106 rank += 1;
107 let contribution = 1.0 / (k_f + f64::from(rank));
108 scores
109 .entry(cand.id.clone())
110 .and_modify(|s| *s += contribution)
111 .or_insert(contribution);
112 }
113 }
114
115 let mut fused: Vec<FusedItem<Id>> = scores
116 .into_iter()
117 .map(|(id, rrf_score)| FusedItem { id, rrf_score })
118 .collect();
119
120 fused.sort_by(|a, b| {
124 b.rrf_score
125 .partial_cmp(&a.rrf_score)
126 .unwrap_or(std::cmp::Ordering::Equal)
127 .then_with(|| a.id.cmp(&b.id))
128 });
129
130 fused.truncate(total_k);
131 fused
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 fn cand<Id>(id: Id, score: f64) -> Candidate<Id> {
139 Candidate { id, score }
140 }
141
142 fn bucket_no_floor<Id: Clone>(cs: Vec<Candidate<Id>>) -> Bucket<Id> {
143 Bucket {
144 candidates: cs,
145 min_score: None,
146 }
147 }
148
149 #[test]
152 fn rrf_single_list_matches_reference() {
153 let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5), cand("c", 0.1)]);
155 let out = fuse(&[bucket], 60, 10);
156 assert_eq!(out.len(), 3);
157 assert!((out[0].rrf_score - 1.0 / 61.0).abs() < 1e-12);
158 assert!((out[1].rrf_score - 1.0 / 62.0).abs() < 1e-12);
159 assert!((out[2].rrf_score - 1.0 / 63.0).abs() < 1e-12);
160 assert_eq!(out[0].id, "a");
161 assert_eq!(out[1].id, "b");
162 assert_eq!(out[2].id, "c");
163 }
164
165 #[test]
166 fn rrf_two_lists_sums_contributions() {
167 let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.9), cand("c", 0.8)]);
172 let b2 = bucket_no_floor(vec![cand("a", 0.95), cand("b", 0.85), cand("d", 0.7)]);
173 let out = fuse(&[b1, b2], 60, 10);
174 let by_id: std::collections::HashMap<_, _> =
175 out.iter().map(|f| (f.id, f.rrf_score)).collect();
176 assert!((by_id["a"] - 2.0 / 61.0).abs() < 1e-12);
177 assert!((by_id["b"] - 2.0 / 62.0).abs() < 1e-12);
178 assert!((by_id["c"] - 1.0 / 63.0).abs() < 1e-12);
179 assert!((by_id["d"] - 1.0 / 63.0).abs() < 1e-12);
180 assert_eq!(out[0].id, "a");
182 assert_eq!(out[1].id, "b");
183 assert_eq!(out[2].id, "c");
184 assert_eq!(out[3].id, "d");
185 }
186
187 #[test]
188 fn rrf_k_default_is_60() {
189 assert_eq!(RRF_K_DEFAULT, 60);
190 }
191
192 #[test]
193 fn alternate_k_changes_scores() {
194 let bucket = bucket_no_floor(vec![cand("a", 1.0)]);
197 let out = fuse(&[bucket], 1, 10);
198 assert!((out[0].rrf_score - 0.5).abs() < 1e-12);
199 }
200
201 #[test]
204 fn total_k_caps_output() {
205 let bucket = bucket_no_floor(vec![
206 cand("a", 1.0),
207 cand("b", 0.9),
208 cand("c", 0.8),
209 cand("d", 0.7),
210 ]);
211 let out = fuse(&[bucket], 60, 2);
212 assert_eq!(out.len(), 2);
213 assert_eq!(out[0].id, "a");
214 assert_eq!(out[1].id, "b");
215 }
216
217 #[test]
218 fn total_k_zero_returns_empty() {
219 let bucket = bucket_no_floor(vec![cand("a", 1.0)]);
220 let out = fuse(&[bucket], 60, 0);
221 assert!(out.is_empty());
222 }
223
224 #[test]
225 fn total_k_larger_than_candidates_returns_all() {
226 let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
227 let out = fuse(&[bucket], 60, 100);
228 assert_eq!(out.len(), 2);
229 }
230
231 #[test]
234 fn min_score_drops_items_before_ranking() {
235 let bucket = Bucket {
238 candidates: vec![cand("a", 0.9), cand("b", 0.4), cand("c", 0.6)],
239 min_score: Some(0.5),
240 };
241 let out = fuse(&[bucket], 60, 10);
242 assert_eq!(out.len(), 2);
243 assert_eq!(out[0].id, "a");
244 assert!((out[0].rrf_score - 1.0 / 61.0).abs() < 1e-12);
245 assert_eq!(out[1].id, "c");
246 assert!((out[1].rrf_score - 1.0 / 62.0).abs() < 1e-12);
248 }
249
250 #[test]
251 fn min_score_independent_per_bucket() {
252 let bm25 = Bucket {
254 candidates: vec![cand("x", 0.5), cand("y", 0.3)],
255 min_score: Some(0.4),
256 };
257 let vec_bucket = Bucket {
258 candidates: vec![cand("x", 0.85), cand("y", 0.6)],
259 min_score: Some(0.7),
260 };
261 let out = fuse(&[bm25, vec_bucket], 60, 10);
262 assert_eq!(out.len(), 1);
265 assert_eq!(out[0].id, "x");
266 assert!((out[0].rrf_score - 2.0 / 61.0).abs() < 1e-12);
267 }
268
269 #[test]
270 fn min_score_none_keeps_everything() {
271 let bucket = bucket_no_floor(vec![cand("a", -10.0), cand("b", 0.0)]);
272 let out = fuse(&[bucket], 60, 10);
273 assert_eq!(out.len(), 2);
274 }
275
276 #[test]
279 fn tie_break_is_id_ascending() {
280 let b1 = bucket_no_floor(vec![cand("zebra", 1.0)]);
283 let b2 = bucket_no_floor(vec![cand("apple", 1.0)]);
284 let b3 = bucket_no_floor(vec![cand("mango", 1.0)]);
285 let out = fuse(&[b1, b2, b3], 60, 10);
286 assert_eq!(
287 out.iter().map(|f| f.id).collect::<Vec<_>>(),
288 vec!["apple", "mango", "zebra"]
289 );
290 }
291
292 #[test]
293 fn fuse_is_deterministic_across_calls() {
294 let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
297 let b2 = bucket_no_floor(vec![cand("b", 0.9), cand("c", 0.4)]);
298 let a = fuse(&[b1.clone(), b2.clone()], 60, 10);
299 let c = fuse(&[b1, b2], 60, 10);
300 assert_eq!(a, c);
301 }
302
303 #[test]
304 fn fuse_is_order_independent_across_buckets() {
305 let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
308 let b2 = bucket_no_floor(vec![cand("b", 0.9), cand("c", 0.4)]);
309 let forward = fuse(&[b1.clone(), b2.clone()], 60, 10);
310 let reverse = fuse(&[b2, b1], 60, 10);
311 assert_eq!(forward, reverse);
312 }
313
314 #[test]
317 fn empty_buckets_returns_empty() {
318 let buckets: Vec<Bucket<&'static str>> = vec![];
319 let out = fuse(&buckets, 60, 10);
320 assert!(out.is_empty());
321 }
322
323 #[test]
324 fn all_empty_buckets_returns_empty() {
325 let buckets: Vec<Bucket<&'static str>> =
326 vec![bucket_no_floor(vec![]), bucket_no_floor(vec![])];
327 let out = fuse(&buckets, 60, 10);
328 assert!(out.is_empty());
329 }
330
331 #[test]
332 fn duplicate_id_within_one_bucket_keeps_both_ranks() {
333 let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("a", 0.5)]);
338 let out = fuse(&[bucket], 60, 10);
339 assert_eq!(out.len(), 1);
340 assert!((out[0].rrf_score - (1.0 / 61.0 + 1.0 / 62.0)).abs() < 1e-12);
341 }
342
343 #[test]
344 fn integer_ids_supported() {
345 let b1 = bucket_no_floor(vec![cand(1u64, 1.0), cand(2u64, 0.5)]);
347 let b2 = bucket_no_floor(vec![cand(2u64, 0.9), cand(3u64, 0.4)]);
348 let out = fuse(&[b1, b2], 60, 10);
349 assert_eq!(out[0].id, 2);
350 }
351}