1pub const DEFAULT_RRF_K: f64 = 60.0;
13
14#[derive(Debug, Clone)]
16pub struct RankedResult {
17 pub document_id: String,
19 pub rank: usize,
21 pub score: f32,
23 pub source: &'static str,
25}
26
27#[derive(Debug, Clone)]
29pub struct FusedResult {
30 pub document_id: String,
31 pub rrf_score: f64,
32 pub contributions: Vec<(&'static str, f64)>,
34}
35
36pub fn reciprocal_rank_fusion(
41 ranked_lists: &[Vec<RankedResult>],
42 k: Option<f64>,
43 top_k: usize,
44) -> Vec<FusedResult> {
45 let k = k.unwrap_or(DEFAULT_RRF_K);
46
47 let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
48 std::collections::HashMap::new();
49
50 for list in ranked_lists {
51 for result in list {
52 let contribution = 1.0 / (k + result.rank as f64 + 1.0);
53 scores
54 .entry(result.document_id.clone())
55 .or_default()
56 .push((result.source, contribution));
57 }
58 }
59
60 let mut fused: Vec<FusedResult> = scores
61 .into_iter()
62 .map(|(doc_id, contributions)| {
63 let rrf_score = contributions.iter().map(|(_, s)| s).sum();
64 FusedResult {
65 document_id: doc_id,
66 rrf_score,
67 contributions,
68 }
69 })
70 .collect();
71
72 fused.sort_unstable_by(|a, b| {
73 b.rrf_score
74 .partial_cmp(&a.rrf_score)
75 .unwrap_or(std::cmp::Ordering::Equal)
76 });
77 fused.truncate(top_k);
78 fused
79}
80
81pub fn reciprocal_rank_fusion_weighted(
90 ranked_lists: &[Vec<RankedResult>],
91 k_per_list: &[f64],
92 top_k: usize,
93) -> Vec<FusedResult> {
94 assert_eq!(
95 ranked_lists.len(),
96 k_per_list.len(),
97 "k_per_list length must match ranked_lists length"
98 );
99
100 let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
101 std::collections::HashMap::new();
102
103 for (list_idx, list) in ranked_lists.iter().enumerate() {
104 let k = k_per_list[list_idx];
105 for result in list {
106 let contribution = 1.0 / (k + result.rank as f64 + 1.0);
107 scores
108 .entry(result.document_id.clone())
109 .or_default()
110 .push((result.source, contribution));
111 }
112 }
113
114 let mut fused: Vec<FusedResult> = scores
115 .into_iter()
116 .map(|(doc_id, contributions)| {
117 let rrf_score = contributions.iter().map(|(_, s)| s).sum();
118 FusedResult {
119 document_id: doc_id,
120 rrf_score,
121 contributions,
122 }
123 })
124 .collect();
125
126 fused.sort_unstable_by(|a, b| {
127 b.rrf_score
128 .partial_cmp(&a.rrf_score)
129 .unwrap_or(std::cmp::Ordering::Equal)
130 });
131 fused.truncate(top_k);
132 fused
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 fn make_ranked(doc_ids: &[&str], source: &'static str) -> Vec<RankedResult> {
140 doc_ids
141 .iter()
142 .enumerate()
143 .map(|(rank, &id)| RankedResult {
144 document_id: id.to_string(),
145 rank,
146 score: 1.0 - (rank as f32 * 0.1),
147 source,
148 })
149 .collect()
150 }
151
152 #[test]
153 fn single_list_preserves_order() {
154 let list = make_ranked(&["d1", "d2", "d3"], "vector");
155 let fused = reciprocal_rank_fusion(&[list], None, 10);
156 assert_eq!(fused.len(), 3);
157 assert_eq!(fused[0].document_id, "d1");
158 }
159
160 #[test]
161 fn overlapping_lists_boost_common_docs() {
162 let vector = make_ranked(&["d1", "d2", "d3"], "vector");
163 let sparse = make_ranked(&["d2", "d1", "d4"], "sparse");
164 let fused = reciprocal_rank_fusion(&[vector, sparse], None, 10);
165 let top2_ids: Vec<&str> = fused[..2].iter().map(|f| f.document_id.as_str()).collect();
166 assert!(top2_ids.contains(&"d1"));
167 assert!(top2_ids.contains(&"d2"));
168 }
169
170 #[test]
171 fn weighted_rrf() {
172 let list_a = make_ranked(&["a1", "a2"], "vector");
173 let list_b = make_ranked(&["b1", "a1"], "text");
174 let fused = reciprocal_rank_fusion_weighted(&[list_a, list_b], &[30.0, 120.0], 10);
175 let a1 = fused.iter().find(|f| f.document_id == "a1").unwrap();
176 assert_eq!(a1.contributions.len(), 2);
177 }
178
179 #[test]
180 fn empty() {
181 assert!(reciprocal_rank_fusion(&[], None, 10).is_empty());
182 assert!(reciprocal_rank_fusion_weighted(&[], &[], 10).is_empty());
183 }
184}