1use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
10pub struct FusedResult {
11 pub id: String,
13 pub score: f32,
15 pub source_scores: HashMap<String, f32>,
17}
18
19#[derive(Debug, Clone, Copy, Default, PartialEq)]
21pub enum FusionMethod {
22 #[default]
24 ReciprocalRank,
25 WeightedSum,
27 MaxScore,
29}
30
31pub fn reciprocal_rank_fusion(
46 ranked_lists: Vec<(&str, Vec<(String, f32)>)>,
47 k: f32,
48 top_k: usize,
49) -> Vec<FusedResult> {
50 let mut scores: HashMap<String, (f32, HashMap<String, f32>)> = HashMap::new();
51
52 for (source_name, rankings) in ranked_lists {
53 for (rank, (id, original_score)) in rankings.into_iter().enumerate() {
54 let rrf_score = 1.0 / (k + (rank + 1) as f32);
55
56 let entry = scores.entry(id).or_insert_with(|| (0.0, HashMap::new()));
57 entry.0 += rrf_score;
58 entry.1.insert(source_name.to_string(), original_score);
59 }
60 }
61
62 let mut results: Vec<FusedResult> = scores
63 .into_iter()
64 .map(|(id, (score, source_scores))| FusedResult {
65 id,
66 score,
67 source_scores,
68 })
69 .collect();
70
71 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
73
74 results.truncate(top_k);
75 results
76}
77
78pub fn weighted_sum_fusion(
89 weighted_lists: Vec<(&str, f32, Vec<(String, f32)>)>,
90 top_k: usize,
91) -> Vec<FusedResult> {
92 let mut scores: HashMap<String, (f32, HashMap<String, f32>)> = HashMap::new();
93
94 for (source_name, weight, rankings) in weighted_lists {
95 let (min_score, max_score) = rankings.iter().fold((f32::MAX, f32::MIN), |(min, max), (_, s)| {
97 (min.min(*s), max.max(*s))
98 });
99
100 let range = max_score - min_score;
101
102 for (id, original_score) in rankings {
103 let normalized = if range > 0.0 {
105 (original_score - min_score) / range
106 } else {
107 1.0 };
109
110 let weighted_score = normalized * weight;
111
112 let entry = scores.entry(id).or_insert_with(|| (0.0, HashMap::new()));
113 entry.0 += weighted_score;
114 entry.1.insert(source_name.to_string(), original_score);
115 }
116 }
117
118 let mut results: Vec<FusedResult> = scores
119 .into_iter()
120 .map(|(id, (score, source_scores))| FusedResult {
121 id,
122 score,
123 source_scores,
124 })
125 .collect();
126
127 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
128
129 results.truncate(top_k);
130 results
131}
132
133pub fn max_score_fusion(
137 ranked_lists: Vec<(&str, Vec<(String, f32)>)>,
138 top_k: usize,
139) -> Vec<FusedResult> {
140 let mut scores: HashMap<String, (f32, HashMap<String, f32>)> = HashMap::new();
141
142 for (source_name, rankings) in ranked_lists {
143 for (id, original_score) in rankings {
144 let entry = scores.entry(id).or_insert_with(|| (f32::MIN, HashMap::new()));
145 if original_score > entry.0 {
146 entry.0 = original_score;
147 }
148 entry.1.insert(source_name.to_string(), original_score);
149 }
150 }
151
152 let mut results: Vec<FusedResult> = scores
153 .into_iter()
154 .map(|(id, (score, source_scores))| FusedResult {
155 id,
156 score,
157 source_scores,
158 })
159 .collect();
160
161 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
162
163 results.truncate(top_k);
164 results
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_rrf_fusion_basic() {
173 let dense_results = vec![
174 ("doc1".to_string(), 0.95),
175 ("doc2".to_string(), 0.85),
176 ("doc3".to_string(), 0.75),
177 ];
178
179 let sparse_results = vec![
180 ("doc2".to_string(), 10.5), ("doc1".to_string(), 8.3),
182 ("doc4".to_string(), 7.1),
183 ];
184
185 let results = reciprocal_rank_fusion(
186 vec![("dense", dense_results), ("sparse", sparse_results)],
187 60.0,
188 5,
189 );
190
191 assert!(results.len() <= 5);
193
194 let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
196 assert!(ids.contains(&"doc1"));
197 assert!(ids.contains(&"doc2"));
198 }
199
200 #[test]
201 fn test_rrf_k_parameter() {
202 let list = vec![
203 ("doc1".to_string(), 1.0),
204 ("doc2".to_string(), 0.9),
205 ];
206
207 let results = reciprocal_rank_fusion(vec![("test", list.clone())], 60.0, 5);
209 assert!((results[0].score - 1.0 / 61.0).abs() < 0.001);
210 assert!((results[1].score - 1.0 / 62.0).abs() < 0.001);
211 }
212
213 #[test]
214 fn test_weighted_sum_fusion() {
215 let dense_results = vec![
216 ("doc1".to_string(), 0.9),
217 ("doc2".to_string(), 0.7),
218 ];
219
220 let sparse_results = vec![
221 ("doc1".to_string(), 5.0),
222 ("doc2".to_string(), 10.0), ];
224
225 let results = weighted_sum_fusion(
226 vec![("dense", 0.7, dense_results), ("sparse", 0.3, sparse_results)],
227 5,
228 );
229
230 assert!(!results.is_empty());
231 let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
233 assert!(ids.contains(&"doc1"));
234 assert!(ids.contains(&"doc2"));
235 }
236
237 #[test]
238 fn test_fusion_with_empty_list() {
239 let results = reciprocal_rank_fusion(vec![], 60.0, 5);
240 assert!(results.is_empty());
241 }
242
243 #[test]
244 fn test_fusion_source_scores_preserved() {
245 let dense_results = vec![("doc1".to_string(), 0.95)];
246 let sparse_results = vec![("doc1".to_string(), 8.5)];
247
248 let results = reciprocal_rank_fusion(
249 vec![("dense", dense_results), ("sparse", sparse_results)],
250 60.0,
251 5,
252 );
253
254 assert_eq!(results[0].id, "doc1");
255 assert_eq!(results[0].source_scores.get("dense"), Some(&0.95));
256 assert_eq!(results[0].source_scores.get("sparse"), Some(&8.5));
257 }
258
259 #[test]
260 fn test_max_score_fusion() {
261 let list1 = vec![
262 ("doc1".to_string(), 0.5),
263 ("doc2".to_string(), 0.8),
264 ];
265
266 let list2 = vec![
267 ("doc1".to_string(), 0.9), ("doc2".to_string(), 0.3),
269 ];
270
271 let results = max_score_fusion(vec![("a", list1), ("b", list2)], 5);
272
273 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
275 let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
276 assert!((doc1.score - 0.9).abs() < 0.001);
277 assert!((doc2.score - 0.8).abs() < 0.001);
278 }
279}