rexis_rag/retrieval/
fusion.rs

1//! # Rank Fusion Algorithms
2//!
3//! Advanced algorithms for combining results from multiple retrieval methods.
4//! Implements state-of-the-art fusion techniques for optimal ranking.
5
6use crate::{RragResult, SearchResult};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9
10/// Trait for rank fusion algorithms
11pub trait RankFusion: Send + Sync {
12    /// Fuse multiple result sets into a single ranked list
13    fn fuse(
14        &self,
15        result_sets: Vec<Vec<SearchResult>>,
16        limit: usize,
17    ) -> RragResult<Vec<SearchResult>>;
18}
19
20/// Reciprocal Rank Fusion (RRF)
21///
22/// RRF is a simple yet effective fusion method that combines rankings
23/// by summing reciprocal ranks. It's robust to outliers and doesn't
24/// require score calibration.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ReciprocalRankFusion {
27    /// Constant k to avoid division by zero (typically 60)
28    pub k: f32,
29
30    /// Whether to normalize final scores
31    pub normalize_scores: bool,
32}
33
34impl Default for ReciprocalRankFusion {
35    fn default() -> Self {
36        Self {
37            k: 60.0,
38            normalize_scores: true,
39        }
40    }
41}
42
43impl RankFusion for ReciprocalRankFusion {
44    fn fuse(
45        &self,
46        result_sets: Vec<Vec<SearchResult>>,
47        limit: usize,
48    ) -> RragResult<Vec<SearchResult>> {
49        let mut fusion_scores: HashMap<String, f32> = HashMap::new();
50        let mut doc_contents: HashMap<String, (String, HashMap<String, serde_json::Value>)> =
51            HashMap::new();
52
53        // Calculate RRF scores
54        for results in &result_sets {
55            for (rank, result) in results.iter().enumerate() {
56                // RRF formula: 1 / (k + rank)
57                let rrf_score = 1.0 / (self.k + rank as f32 + 1.0);
58
59                *fusion_scores.entry(result.id.clone()).or_insert(0.0) += rrf_score;
60
61                // Store document content and metadata
62                doc_contents
63                    .entry(result.id.clone())
64                    .or_insert((result.content.clone(), result.metadata.clone()));
65            }
66        }
67
68        // Sort by fusion score
69        let mut sorted_results: Vec<_> = fusion_scores
70            .into_iter()
71            .filter_map(|(id, score)| {
72                doc_contents
73                    .get(&id)
74                    .map(|(content, metadata)| SearchResult {
75                        id: id.clone(),
76                        content: content.clone(),
77                        score,
78                        rank: 0,
79                        metadata: metadata.clone(),
80                        embedding: None,
81                    })
82            })
83            .collect();
84
85        sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
86
87        // Normalize scores if requested
88        if self.normalize_scores && !sorted_results.is_empty() {
89            let max_score = sorted_results[0].score;
90            for result in &mut sorted_results {
91                result.score /= max_score;
92            }
93        }
94
95        // Truncate and update ranks
96        sorted_results.truncate(limit);
97        for (i, result) in sorted_results.iter_mut().enumerate() {
98            result.rank = i;
99        }
100
101        Ok(sorted_results)
102    }
103}
104
105/// Weighted linear combination fusion
106///
107/// Combines scores from different retrievers using weighted linear combination.
108/// Requires score calibration for best results.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct WeightedFusion {
111    /// Weights for each retriever (should sum to 1.0)
112    pub weights: Vec<f32>,
113
114    /// Score normalization method
115    pub normalization: ScoreNormalization,
116}
117
118/// Score normalization methods
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub enum ScoreNormalization {
121    /// Min-max normalization
122    MinMax,
123    /// Z-score normalization
124    ZScore,
125    /// No normalization
126    None,
127}
128
129impl WeightedFusion {
130    pub fn new(weights: Vec<f32>) -> Self {
131        // Normalize weights to sum to 1.0
132        let sum: f32 = weights.iter().sum();
133        let normalized_weights = if sum > 0.0 {
134            weights.iter().map(|w| w / sum).collect()
135        } else {
136            weights
137        };
138
139        Self {
140            weights: normalized_weights,
141            normalization: ScoreNormalization::MinMax,
142        }
143    }
144
145    fn normalize_scores(&self, results: &mut Vec<SearchResult>) {
146        match self.normalization {
147            ScoreNormalization::MinMax => {
148                if results.is_empty() {
149                    return;
150                }
151
152                let min = results
153                    .iter()
154                    .map(|r| r.score)
155                    .fold(f32::INFINITY, f32::min);
156                let max = results
157                    .iter()
158                    .map(|r| r.score)
159                    .fold(f32::NEG_INFINITY, f32::max);
160
161                if max > min {
162                    for result in results {
163                        result.score = (result.score - min) / (max - min);
164                    }
165                }
166            }
167            ScoreNormalization::ZScore => {
168                if results.is_empty() {
169                    return;
170                }
171
172                let mean: f32 = results.iter().map(|r| r.score).sum::<f32>() / results.len() as f32;
173                let variance: f32 = results
174                    .iter()
175                    .map(|r| (r.score - mean).powi(2))
176                    .sum::<f32>()
177                    / results.len() as f32;
178                let std_dev = variance.sqrt();
179
180                if std_dev > 0.0 {
181                    for result in results {
182                        result.score = (result.score - mean) / std_dev;
183                    }
184                }
185            }
186            ScoreNormalization::None => {}
187        }
188    }
189}
190
191impl RankFusion for WeightedFusion {
192    fn fuse(
193        &self,
194        mut result_sets: Vec<Vec<SearchResult>>,
195        limit: usize,
196    ) -> RragResult<Vec<SearchResult>> {
197        // Normalize scores in each result set
198        for results in &mut result_sets {
199            self.normalize_scores(results);
200        }
201
202        let mut fusion_scores: HashMap<String, f32> = HashMap::new();
203        let mut doc_contents: HashMap<String, (String, HashMap<String, serde_json::Value>)> =
204            HashMap::new();
205
206        // Apply weighted combination
207        for (i, results) in result_sets.iter().enumerate() {
208            let weight = self
209                .weights
210                .get(i)
211                .copied()
212                .unwrap_or(1.0 / result_sets.len() as f32);
213
214            for result in results {
215                *fusion_scores.entry(result.id.clone()).or_insert(0.0) += result.score * weight;
216
217                doc_contents
218                    .entry(result.id.clone())
219                    .or_insert((result.content.clone(), result.metadata.clone()));
220            }
221        }
222
223        // Sort by fusion score
224        let mut sorted_results: Vec<_> = fusion_scores
225            .into_iter()
226            .filter_map(|(id, score)| {
227                doc_contents
228                    .get(&id)
229                    .map(|(content, metadata)| SearchResult {
230                        id: id.clone(),
231                        content: content.clone(),
232                        score,
233                        rank: 0,
234                        metadata: metadata.clone(),
235                        embedding: None,
236                    })
237            })
238            .collect();
239
240        sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
241        sorted_results.truncate(limit);
242
243        // Update ranks
244        for (i, result) in sorted_results.iter_mut().enumerate() {
245            result.rank = i;
246        }
247
248        Ok(sorted_results)
249    }
250}
251
252/// Advanced fusion with learning-to-rank capabilities
253#[derive(Debug, Clone)]
254pub struct LearnedFusion {
255    /// Feature weights learned from training data
256    feature_weights: Vec<f32>,
257
258    /// Interaction features between retrievers
259    use_interactions: bool,
260}
261
262impl LearnedFusion {
263    pub fn new(feature_weights: Vec<f32>) -> Self {
264        Self {
265            feature_weights,
266            use_interactions: true,
267        }
268    }
269
270    /// Extract features from result sets for learning
271    pub fn extract_features(&self, result_sets: &[Vec<SearchResult>], doc_id: &str) -> Vec<f32> {
272        let mut features = Vec::new();
273
274        for results in result_sets {
275            // Find document in this result set
276            let doc_result = results.iter().find(|r| r.id == doc_id);
277
278            if let Some(result) = doc_result {
279                // Position features
280                features.push(1.0 / (result.rank as f32 + 1.0)); // Reciprocal rank
281                features.push(result.score); // Raw score
282                features.push((results.len() - result.rank) as f32 / results.len() as f32);
283            // Normalized position
284            } else {
285                // Document not found in this retriever
286                features.push(0.0);
287                features.push(0.0);
288                features.push(0.0);
289            }
290        }
291
292        // Add interaction features if enabled
293        if self.use_interactions && result_sets.len() > 1 {
294            for i in 0..result_sets.len() {
295                for j in i + 1..result_sets.len() {
296                    let score_i = result_sets[i]
297                        .iter()
298                        .find(|r| r.id == doc_id)
299                        .map(|r| r.score)
300                        .unwrap_or(0.0);
301                    let score_j = result_sets[j]
302                        .iter()
303                        .find(|r| r.id == doc_id)
304                        .map(|r| r.score)
305                        .unwrap_or(0.0);
306
307                    // Interaction features
308                    features.push(score_i * score_j); // Product
309                    features.push((score_i - score_j).abs()); // Difference
310                    features.push(score_i.max(score_j)); // Max
311                }
312            }
313        }
314
315        features
316    }
317}
318
319impl RankFusion for LearnedFusion {
320    fn fuse(
321        &self,
322        result_sets: Vec<Vec<SearchResult>>,
323        limit: usize,
324    ) -> RragResult<Vec<SearchResult>> {
325        // Collect all unique document IDs
326        let mut all_docs: HashSet<String> = HashSet::new();
327        let mut doc_contents: HashMap<String, (String, HashMap<String, serde_json::Value>)> =
328            HashMap::new();
329
330        for results in &result_sets {
331            for result in results {
332                all_docs.insert(result.id.clone());
333                doc_contents
334                    .entry(result.id.clone())
335                    .or_insert((result.content.clone(), result.metadata.clone()));
336            }
337        }
338
339        // Score each document using learned weights
340        let mut scored_docs: Vec<(String, f32)> = all_docs
341            .into_iter()
342            .map(|doc_id| {
343                let features = self.extract_features(&result_sets, &doc_id);
344                let score: f32 = features
345                    .iter()
346                    .zip(self.feature_weights.iter())
347                    .map(|(f, w)| f * w)
348                    .sum();
349                (doc_id, score)
350            })
351            .collect();
352
353        // Sort by learned score
354        scored_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
355        scored_docs.truncate(limit);
356
357        // Build final results
358        let results: Vec<SearchResult> = scored_docs
359            .into_iter()
360            .enumerate()
361            .filter_map(|(rank, (doc_id, score))| {
362                doc_contents
363                    .get(&doc_id)
364                    .map(|(content, metadata)| SearchResult {
365                        id: doc_id,
366                        content: content.clone(),
367                        score,
368                        rank,
369                        metadata: metadata.clone(),
370                        embedding: None,
371                    })
372            })
373            .collect();
374
375        Ok(results)
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    fn create_test_results() -> Vec<Vec<SearchResult>> {
384        vec![
385            vec![
386                SearchResult::new("1", "Doc 1", 0.9, 0),
387                SearchResult::new("2", "Doc 2", 0.8, 1),
388                SearchResult::new("3", "Doc 3", 0.7, 2),
389            ],
390            vec![
391                SearchResult::new("2", "Doc 2", 0.95, 0),
392                SearchResult::new("3", "Doc 3", 0.85, 1),
393                SearchResult::new("4", "Doc 4", 0.75, 2),
394            ],
395        ]
396    }
397
398    #[test]
399    fn test_reciprocal_rank_fusion() {
400        let rrf = ReciprocalRankFusion::default();
401        let results = rrf.fuse(create_test_results(), 3).unwrap();
402
403        assert_eq!(results.len(), 3);
404        // Doc 2 should rank highest (appears in both lists at high positions)
405        assert_eq!(results[0].id, "2");
406    }
407
408    #[test]
409    fn test_weighted_fusion() {
410        let fusion = WeightedFusion::new(vec![0.3, 0.7]);
411        let results = fusion.fuse(create_test_results(), 3).unwrap();
412
413        assert_eq!(results.len(), 3);
414        // Results should be weighted towards second retriever
415    }
416}