1use crate::{RragResult, SearchResult};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9
10pub trait RankFusion: Send + Sync {
12 fn fuse(
14 &self,
15 result_sets: Vec<Vec<SearchResult>>,
16 limit: usize,
17 ) -> RragResult<Vec<SearchResult>>;
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ReciprocalRankFusion {
27 pub k: f32,
29
30 pub normalize_scores: bool,
32}
33
34impl Default for ReciprocalRankFusion {
35 fn default() -> Self {
36 Self {
37 k: 60.0,
38 normalize_scores: true,
39 }
40 }
41}
42
43impl RankFusion for ReciprocalRankFusion {
44 fn fuse(
45 &self,
46 result_sets: Vec<Vec<SearchResult>>,
47 limit: usize,
48 ) -> RragResult<Vec<SearchResult>> {
49 let mut fusion_scores: HashMap<String, f32> = HashMap::new();
50 let mut doc_contents: HashMap<String, (String, HashMap<String, serde_json::Value>)> =
51 HashMap::new();
52
53 for results in &result_sets {
55 for (rank, result) in results.iter().enumerate() {
56 let rrf_score = 1.0 / (self.k + rank as f32 + 1.0);
58
59 *fusion_scores.entry(result.id.clone()).or_insert(0.0) += rrf_score;
60
61 doc_contents
63 .entry(result.id.clone())
64 .or_insert((result.content.clone(), result.metadata.clone()));
65 }
66 }
67
68 let mut sorted_results: Vec<_> = fusion_scores
70 .into_iter()
71 .filter_map(|(id, score)| {
72 doc_contents
73 .get(&id)
74 .map(|(content, metadata)| SearchResult {
75 id: id.clone(),
76 content: content.clone(),
77 score,
78 rank: 0,
79 metadata: metadata.clone(),
80 embedding: None,
81 })
82 })
83 .collect();
84
85 sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
86
87 if self.normalize_scores && !sorted_results.is_empty() {
89 let max_score = sorted_results[0].score;
90 for result in &mut sorted_results {
91 result.score /= max_score;
92 }
93 }
94
95 sorted_results.truncate(limit);
97 for (i, result) in sorted_results.iter_mut().enumerate() {
98 result.rank = i;
99 }
100
101 Ok(sorted_results)
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct WeightedFusion {
111 pub weights: Vec<f32>,
113
114 pub normalization: ScoreNormalization,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub enum ScoreNormalization {
121 MinMax,
123 ZScore,
125 None,
127}
128
129impl WeightedFusion {
130 pub fn new(weights: Vec<f32>) -> Self {
131 let sum: f32 = weights.iter().sum();
133 let normalized_weights = if sum > 0.0 {
134 weights.iter().map(|w| w / sum).collect()
135 } else {
136 weights
137 };
138
139 Self {
140 weights: normalized_weights,
141 normalization: ScoreNormalization::MinMax,
142 }
143 }
144
145 fn normalize_scores(&self, results: &mut Vec<SearchResult>) {
146 match self.normalization {
147 ScoreNormalization::MinMax => {
148 if results.is_empty() {
149 return;
150 }
151
152 let min = results
153 .iter()
154 .map(|r| r.score)
155 .fold(f32::INFINITY, f32::min);
156 let max = results
157 .iter()
158 .map(|r| r.score)
159 .fold(f32::NEG_INFINITY, f32::max);
160
161 if max > min {
162 for result in results {
163 result.score = (result.score - min) / (max - min);
164 }
165 }
166 }
167 ScoreNormalization::ZScore => {
168 if results.is_empty() {
169 return;
170 }
171
172 let mean: f32 = results.iter().map(|r| r.score).sum::<f32>() / results.len() as f32;
173 let variance: f32 = results
174 .iter()
175 .map(|r| (r.score - mean).powi(2))
176 .sum::<f32>()
177 / results.len() as f32;
178 let std_dev = variance.sqrt();
179
180 if std_dev > 0.0 {
181 for result in results {
182 result.score = (result.score - mean) / std_dev;
183 }
184 }
185 }
186 ScoreNormalization::None => {}
187 }
188 }
189}
190
191impl RankFusion for WeightedFusion {
192 fn fuse(
193 &self,
194 mut result_sets: Vec<Vec<SearchResult>>,
195 limit: usize,
196 ) -> RragResult<Vec<SearchResult>> {
197 for results in &mut result_sets {
199 self.normalize_scores(results);
200 }
201
202 let mut fusion_scores: HashMap<String, f32> = HashMap::new();
203 let mut doc_contents: HashMap<String, (String, HashMap<String, serde_json::Value>)> =
204 HashMap::new();
205
206 for (i, results) in result_sets.iter().enumerate() {
208 let weight = self
209 .weights
210 .get(i)
211 .copied()
212 .unwrap_or(1.0 / result_sets.len() as f32);
213
214 for result in results {
215 *fusion_scores.entry(result.id.clone()).or_insert(0.0) += result.score * weight;
216
217 doc_contents
218 .entry(result.id.clone())
219 .or_insert((result.content.clone(), result.metadata.clone()));
220 }
221 }
222
223 let mut sorted_results: Vec<_> = fusion_scores
225 .into_iter()
226 .filter_map(|(id, score)| {
227 doc_contents
228 .get(&id)
229 .map(|(content, metadata)| SearchResult {
230 id: id.clone(),
231 content: content.clone(),
232 score,
233 rank: 0,
234 metadata: metadata.clone(),
235 embedding: None,
236 })
237 })
238 .collect();
239
240 sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
241 sorted_results.truncate(limit);
242
243 for (i, result) in sorted_results.iter_mut().enumerate() {
245 result.rank = i;
246 }
247
248 Ok(sorted_results)
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct LearnedFusion {
255 feature_weights: Vec<f32>,
257
258 use_interactions: bool,
260}
261
262impl LearnedFusion {
263 pub fn new(feature_weights: Vec<f32>) -> Self {
264 Self {
265 feature_weights,
266 use_interactions: true,
267 }
268 }
269
270 pub fn extract_features(&self, result_sets: &[Vec<SearchResult>], doc_id: &str) -> Vec<f32> {
272 let mut features = Vec::new();
273
274 for results in result_sets {
275 let doc_result = results.iter().find(|r| r.id == doc_id);
277
278 if let Some(result) = doc_result {
279 features.push(1.0 / (result.rank as f32 + 1.0)); features.push(result.score); features.push((results.len() - result.rank) as f32 / results.len() as f32);
283 } else {
285 features.push(0.0);
287 features.push(0.0);
288 features.push(0.0);
289 }
290 }
291
292 if self.use_interactions && result_sets.len() > 1 {
294 for i in 0..result_sets.len() {
295 for j in i + 1..result_sets.len() {
296 let score_i = result_sets[i]
297 .iter()
298 .find(|r| r.id == doc_id)
299 .map(|r| r.score)
300 .unwrap_or(0.0);
301 let score_j = result_sets[j]
302 .iter()
303 .find(|r| r.id == doc_id)
304 .map(|r| r.score)
305 .unwrap_or(0.0);
306
307 features.push(score_i * score_j); features.push((score_i - score_j).abs()); features.push(score_i.max(score_j)); }
312 }
313 }
314
315 features
316 }
317}
318
319impl RankFusion for LearnedFusion {
320 fn fuse(
321 &self,
322 result_sets: Vec<Vec<SearchResult>>,
323 limit: usize,
324 ) -> RragResult<Vec<SearchResult>> {
325 let mut all_docs: HashSet<String> = HashSet::new();
327 let mut doc_contents: HashMap<String, (String, HashMap<String, serde_json::Value>)> =
328 HashMap::new();
329
330 for results in &result_sets {
331 for result in results {
332 all_docs.insert(result.id.clone());
333 doc_contents
334 .entry(result.id.clone())
335 .or_insert((result.content.clone(), result.metadata.clone()));
336 }
337 }
338
339 let mut scored_docs: Vec<(String, f32)> = all_docs
341 .into_iter()
342 .map(|doc_id| {
343 let features = self.extract_features(&result_sets, &doc_id);
344 let score: f32 = features
345 .iter()
346 .zip(self.feature_weights.iter())
347 .map(|(f, w)| f * w)
348 .sum();
349 (doc_id, score)
350 })
351 .collect();
352
353 scored_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
355 scored_docs.truncate(limit);
356
357 let results: Vec<SearchResult> = scored_docs
359 .into_iter()
360 .enumerate()
361 .filter_map(|(rank, (doc_id, score))| {
362 doc_contents
363 .get(&doc_id)
364 .map(|(content, metadata)| SearchResult {
365 id: doc_id,
366 content: content.clone(),
367 score,
368 rank,
369 metadata: metadata.clone(),
370 embedding: None,
371 })
372 })
373 .collect();
374
375 Ok(results)
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 fn create_test_results() -> Vec<Vec<SearchResult>> {
384 vec![
385 vec![
386 SearchResult::new("1", "Doc 1", 0.9, 0),
387 SearchResult::new("2", "Doc 2", 0.8, 1),
388 SearchResult::new("3", "Doc 3", 0.7, 2),
389 ],
390 vec![
391 SearchResult::new("2", "Doc 2", 0.95, 0),
392 SearchResult::new("3", "Doc 3", 0.85, 1),
393 SearchResult::new("4", "Doc 4", 0.75, 2),
394 ],
395 ]
396 }
397
398 #[test]
399 fn test_reciprocal_rank_fusion() {
400 let rrf = ReciprocalRankFusion::default();
401 let results = rrf.fuse(create_test_results(), 3).unwrap();
402
403 assert_eq!(results.len(), 3);
404 assert_eq!(results[0].id, "2");
406 }
407
408 #[test]
409 fn test_weighted_fusion() {
410 let fusion = WeightedFusion::new(vec![0.3, 0.7]);
411 let results = fusion.fuse(create_test_results(), 3).unwrap();
412
413 assert_eq!(results.len(), 3);
414 }
416}