1pub const DEFAULT_RRF_K: f64 = 60.0;
11
12#[derive(Debug, Clone)]
14pub struct RankedResult {
15 pub document_id: String,
17 pub rank: usize,
19 pub score: f32,
21 pub source: &'static str,
23}
24
25#[derive(Debug, Clone)]
27pub struct FusedResult {
28 pub document_id: String,
29 pub rrf_score: f64,
30 pub contributions: Vec<(&'static str, f64)>,
32}
33
34pub fn reciprocal_rank_fusion(
39 ranked_lists: &[Vec<RankedResult>],
40 k: Option<f64>,
41 top_k: usize,
42) -> Vec<FusedResult> {
43 let k = k.unwrap_or(DEFAULT_RRF_K);
44
45 let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
46 std::collections::HashMap::new();
47
48 for list in ranked_lists {
49 for result in list {
50 let contribution = 1.0 / (k + result.rank as f64 + 1.0);
51 scores
52 .entry(result.document_id.clone())
53 .or_default()
54 .push((result.source, contribution));
55 }
56 }
57
58 let mut fused: Vec<FusedResult> = scores
59 .into_iter()
60 .map(|(doc_id, contributions)| {
61 let rrf_score = contributions.iter().map(|(_, s)| s).sum();
62 FusedResult {
63 document_id: doc_id,
64 rrf_score,
65 contributions,
66 }
67 })
68 .collect();
69
70 fused.sort_unstable_by(|a, b| {
71 b.rrf_score
72 .partial_cmp(&a.rrf_score)
73 .unwrap_or(std::cmp::Ordering::Equal)
74 });
75 fused.truncate(top_k);
76 fused
77}
78
79pub fn reciprocal_rank_fusion_weighted(
88 ranked_lists: &[Vec<RankedResult>],
89 k_per_list: &[f64],
90 top_k: usize,
91) -> Vec<FusedResult> {
92 assert_eq!(
93 ranked_lists.len(),
94 k_per_list.len(),
95 "k_per_list length must match ranked_lists length"
96 );
97
98 let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
99 std::collections::HashMap::new();
100
101 for (list_idx, list) in ranked_lists.iter().enumerate() {
102 let k = k_per_list[list_idx];
103 for result in list {
104 let contribution = 1.0 / (k + result.rank as f64 + 1.0);
105 scores
106 .entry(result.document_id.clone())
107 .or_default()
108 .push((result.source, contribution));
109 }
110 }
111
112 let mut fused: Vec<FusedResult> = scores
113 .into_iter()
114 .map(|(doc_id, contributions)| {
115 let rrf_score = contributions.iter().map(|(_, s)| s).sum();
116 FusedResult {
117 document_id: doc_id,
118 rrf_score,
119 contributions,
120 }
121 })
122 .collect();
123
124 fused.sort_unstable_by(|a, b| {
125 b.rrf_score
126 .partial_cmp(&a.rrf_score)
127 .unwrap_or(std::cmp::Ordering::Equal)
128 });
129 fused.truncate(top_k);
130 fused
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 fn make_ranked(doc_ids: &[&str], source: &'static str) -> Vec<RankedResult> {
138 doc_ids
139 .iter()
140 .enumerate()
141 .map(|(rank, &id)| RankedResult {
142 document_id: id.to_string(),
143 rank,
144 score: 1.0 - (rank as f32 * 0.1),
145 source,
146 })
147 .collect()
148 }
149
150 #[test]
151 fn single_list_preserves_order() {
152 let list = make_ranked(&["d1", "d2", "d3"], "vector");
153 let fused = reciprocal_rank_fusion(&[list], None, 10);
154 assert_eq!(fused.len(), 3);
155 assert_eq!(fused[0].document_id, "d1");
156 }
157
158 #[test]
159 fn overlapping_lists_boost_common_docs() {
160 let vector = make_ranked(&["d1", "d2", "d3"], "vector");
161 let sparse = make_ranked(&["d2", "d1", "d4"], "sparse");
162 let fused = reciprocal_rank_fusion(&[vector, sparse], None, 10);
163 let top2_ids: Vec<&str> = fused[..2].iter().map(|f| f.document_id.as_str()).collect();
164 assert!(top2_ids.contains(&"d1"));
165 assert!(top2_ids.contains(&"d2"));
166 }
167
168 #[test]
169 fn weighted_rrf() {
170 let list_a = make_ranked(&["a1", "a2"], "vector");
171 let list_b = make_ranked(&["b1", "a1"], "text");
172 let fused = reciprocal_rank_fusion_weighted(&[list_a, list_b], &[30.0, 120.0], 10);
173 let a1 = fused.iter().find(|f| f.document_id == "a1").unwrap();
174 assert_eq!(a1.contributions.len(), 2);
175 }
176
177 #[test]
178 fn empty() {
179 assert!(reciprocal_rank_fusion(&[], None, 10).is_empty());
180 assert!(reciprocal_rank_fusion_weighted(&[], &[], 10).is_empty());
181 }
182}