Skip to main content

nodedb_query/
fusion.rs

1// SPDX-License-Identifier: Apache-2.0
2
3/// Reciprocal Rank Fusion (RRF) for combining ranked results from multiple engines.
4///
5/// RRF is used when a query hits multiple engines (e.g., vector similarity +
6/// metadata filter + BM25 text search). Each engine returns a ranked list;
7/// RRF combines them into a single ranked list.
8///
9/// Formula: RRF_score(d) = Σ 1 / (k + rank_i(d))
10/// where k is a smoothing constant (default 60).
11/// RRF smoothing constant. Standard value from Cormack et al. (2009).
12pub const DEFAULT_RRF_K: f64 = 60.0;
13
14/// A scored result from a single engine.
15#[derive(Debug, Clone)]
16pub struct RankedResult {
17    /// Document identifier (engine-specific).
18    pub document_id: String,
19    /// Rank within the engine's result list (0-based).
20    pub rank: usize,
21    /// Original score from the engine (for diagnostics).
22    pub score: f32,
23    /// Source engine identifier.
24    pub source: &'static str,
25}
26
27/// A fused result after RRF combination.
28#[derive(Debug, Clone)]
29pub struct FusedResult {
30    pub document_id: String,
31    pub rrf_score: f64,
32    /// Per-engine contributions for explainability.
33    pub contributions: Vec<(&'static str, f64)>,
34}
35
36/// Fuse multiple ranked result lists using Reciprocal Rank Fusion.
37///
38/// Each inner Vec is a ranked list from one engine (ordered by relevance).
39/// Returns the top_k fused results sorted by RRF score (descending).
40pub fn reciprocal_rank_fusion(
41    ranked_lists: &[Vec<RankedResult>],
42    k: Option<f64>,
43    top_k: usize,
44) -> Vec<FusedResult> {
45    let k = k.unwrap_or(DEFAULT_RRF_K);
46
47    let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
48        std::collections::HashMap::new();
49
50    for list in ranked_lists {
51        for result in list {
52            let contribution = 1.0 / (k + result.rank as f64 + 1.0);
53            scores
54                .entry(result.document_id.clone())
55                .or_default()
56                .push((result.source, contribution));
57        }
58    }
59
60    let mut fused: Vec<FusedResult> = scores
61        .into_iter()
62        .map(|(doc_id, contributions)| {
63            let rrf_score = contributions.iter().map(|(_, s)| s).sum();
64            FusedResult {
65                document_id: doc_id,
66                rrf_score,
67                contributions,
68            }
69        })
70        .collect();
71
72    fused.sort_unstable_by(|a, b| {
73        b.rrf_score
74            .partial_cmp(&a.rrf_score)
75            .unwrap_or(std::cmp::Ordering::Equal)
76            // Deterministic tie-break: RRF produces many equal scores, and the
77            // score map iterates in nondeterministic order, so without a stable
78            // secondary key the output ranking varies run-to-run. document_id
79            // is unique, giving a total deterministic order.
80            .then_with(|| a.document_id.cmp(&b.document_id))
81    });
82    fused.truncate(top_k);
83    fused
84}
85
86/// Fuse ranked lists with per-list **linear weights**.
87///
88/// Each list's reciprocal-rank contribution is scaled by its weight, so a
89/// more-trusted source can dominate: `contribution = weight_i / (k + rank + 1)`.
90/// Unlike [`reciprocal_rank_fusion_weighted`] (which varies the `k` decay
91/// constant per list), this scales contribution magnitude directly, which is
92/// the right lever when one source (e.g. BM25) is far more reliable than another
93/// (e.g. a weak dense index) — equal-weight RRF would let the weak source drag
94/// down the strong source's ranking. `weights.len()` must equal
95/// `ranked_lists.len()`.
96///
97/// # Panics
98///
99/// Panics if `weights.len() != ranked_lists.len()`.
100pub fn reciprocal_rank_fusion_linear(
101    ranked_lists: &[Vec<RankedResult>],
102    k: Option<f64>,
103    weights: &[f64],
104    top_k: usize,
105) -> Vec<FusedResult> {
106    assert_eq!(
107        ranked_lists.len(),
108        weights.len(),
109        "weights length must match ranked_lists length"
110    );
111    let k = k.unwrap_or(DEFAULT_RRF_K);
112
113    let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
114        std::collections::HashMap::new();
115
116    for (list_idx, list) in ranked_lists.iter().enumerate() {
117        let w = weights[list_idx];
118        for result in list {
119            let contribution = w / (k + result.rank as f64 + 1.0);
120            scores
121                .entry(result.document_id.clone())
122                .or_default()
123                .push((result.source, contribution));
124        }
125    }
126
127    let mut fused: Vec<FusedResult> = scores
128        .into_iter()
129        .map(|(doc_id, contributions)| {
130            let rrf_score = contributions.iter().map(|(_, s)| s).sum();
131            FusedResult {
132                document_id: doc_id,
133                rrf_score,
134                contributions,
135            }
136        })
137        .collect();
138
139    fused.sort_unstable_by(|a, b| {
140        b.rrf_score
141            .partial_cmp(&a.rrf_score)
142            .unwrap_or(std::cmp::Ordering::Equal)
143            // Deterministic tie-break by unique document_id (see note above).
144            .then_with(|| a.document_id.cmp(&b.document_id))
145    });
146    fused.truncate(top_k);
147    fused
148}
149
150/// Fuse ranked lists with per-list k-constants for weighted influence.
151///
152/// Each list gets its own k value: lower k → steeper rank discount → more
153/// influence. Typical usage: `k_i = base_k / weight_i`.
154///
155/// # Panics
156///
157/// Panics if `k_per_list.len() != ranked_lists.len()`.
158pub fn reciprocal_rank_fusion_weighted(
159    ranked_lists: &[Vec<RankedResult>],
160    k_per_list: &[f64],
161    top_k: usize,
162) -> Vec<FusedResult> {
163    assert_eq!(
164        ranked_lists.len(),
165        k_per_list.len(),
166        "k_per_list length must match ranked_lists length"
167    );
168
169    let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
170        std::collections::HashMap::new();
171
172    for (list_idx, list) in ranked_lists.iter().enumerate() {
173        let k = k_per_list[list_idx];
174        for result in list {
175            let contribution = 1.0 / (k + result.rank as f64 + 1.0);
176            scores
177                .entry(result.document_id.clone())
178                .or_default()
179                .push((result.source, contribution));
180        }
181    }
182
183    let mut fused: Vec<FusedResult> = scores
184        .into_iter()
185        .map(|(doc_id, contributions)| {
186            let rrf_score = contributions.iter().map(|(_, s)| s).sum();
187            FusedResult {
188                document_id: doc_id,
189                rrf_score,
190                contributions,
191            }
192        })
193        .collect();
194
195    fused.sort_unstable_by(|a, b| {
196        b.rrf_score
197            .partial_cmp(&a.rrf_score)
198            .unwrap_or(std::cmp::Ordering::Equal)
199            // Deterministic tie-break by unique document_id (see note above).
200            .then_with(|| a.document_id.cmp(&b.document_id))
201    });
202    fused.truncate(top_k);
203    fused
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    fn make_ranked(doc_ids: &[&str], source: &'static str) -> Vec<RankedResult> {
211        doc_ids
212            .iter()
213            .enumerate()
214            .map(|(rank, &id)| RankedResult {
215                document_id: id.to_string(),
216                rank,
217                score: 1.0 - (rank as f32 * 0.1),
218                source,
219            })
220            .collect()
221    }
222
223    #[test]
224    fn single_list_preserves_order() {
225        let list = make_ranked(&["d1", "d2", "d3"], "vector");
226        let fused = reciprocal_rank_fusion(&[list], None, 10);
227        assert_eq!(fused.len(), 3);
228        assert_eq!(fused[0].document_id, "d1");
229    }
230
231    #[test]
232    fn overlapping_lists_boost_common_docs() {
233        let vector = make_ranked(&["d1", "d2", "d3"], "vector");
234        let sparse = make_ranked(&["d2", "d1", "d4"], "sparse");
235        let fused = reciprocal_rank_fusion(&[vector, sparse], None, 10);
236        let top2_ids: Vec<&str> = fused[..2].iter().map(|f| f.document_id.as_str()).collect();
237        assert!(top2_ids.contains(&"d1"));
238        assert!(top2_ids.contains(&"d2"));
239    }
240
241    #[test]
242    fn weighted_rrf() {
243        let list_a = make_ranked(&["a1", "a2"], "vector");
244        let list_b = make_ranked(&["b1", "a1"], "text");
245        let fused = reciprocal_rank_fusion_weighted(&[list_a, list_b], &[30.0, 120.0], 10);
246        let a1 = fused.iter().find(|f| f.document_id == "a1").unwrap();
247        assert_eq!(a1.contributions.len(), 2);
248    }
249
250    #[test]
251    fn linear_weight_lets_strong_source_dominate() {
252        // `a1` is ranked #1 by the strong source and #2 by the weak source;
253        // `b1` is ranked #1 only by the weak source. With a heavy weight on
254        // the strong source, `a1` must outrank `b1` — equal-weight RRF would
255        // tie/flip them on the weak source's #1.
256        let strong = make_ranked(&["a1", "a2"], "strong");
257        let weak = make_ranked(&["b1", "a1"], "weak");
258        let fused = reciprocal_rank_fusion_linear(&[strong, weak], None, &[4.0, 0.25], 10);
259
260        let a1 = fused.iter().position(|f| f.document_id == "a1").unwrap();
261        let b1 = fused.iter().position(|f| f.document_id == "b1").unwrap();
262        assert!(a1 < b1, "a1 (rank {a1}) should outrank b1 (rank {b1})");
263
264        let a1_res = &fused[a1];
265        assert_eq!(a1_res.contributions.len(), 2);
266        // Contribution magnitude scales linearly with the per-list weight.
267        let strong_contrib = a1_res
268            .contributions
269            .iter()
270            .find(|(src, _)| *src == "strong")
271            .map(|(_, s)| *s)
272            .unwrap();
273        let expected = 4.0 / (DEFAULT_RRF_K + 0.0 + 1.0);
274        assert!((strong_contrib - expected).abs() < 1e-12);
275    }
276
277    #[test]
278    #[should_panic(expected = "weights length must match ranked_lists length")]
279    fn linear_mismatched_weights_panics() {
280        let list = make_ranked(&["d1"], "vector");
281        let _ = reciprocal_rank_fusion_linear(&[list], None, &[1.0, 2.0], 10);
282    }
283
284    #[test]
285    fn empty() {
286        assert!(reciprocal_rank_fusion(&[], None, 10).is_empty());
287        assert!(reciprocal_rank_fusion_linear(&[], None, &[], 10).is_empty());
288        assert!(reciprocal_rank_fusion_weighted(&[], &[], 10).is_empty());
289    }
290}