Skip to main content

nodedb_query/
fusion.rs

1// SPDX-License-Identifier: Apache-2.0
2
3/// Reciprocal Rank Fusion (RRF) for combining ranked results from multiple engines.
4///
5/// RRF is used when a query hits multiple engines (e.g., vector similarity +
6/// metadata filter + BM25 text search). Each engine returns a ranked list;
7/// RRF combines them into a single ranked list.
8///
9/// Formula: RRF_score(d) = Σ 1 / (k + rank_i(d))
10/// where k is a smoothing constant (default 60).
11/// RRF smoothing constant. Standard value from Cormack et al. (2009).
12pub const DEFAULT_RRF_K: f64 = 60.0;
13
14/// A scored result from a single engine.
15#[derive(Debug, Clone)]
16pub struct RankedResult {
17    /// Document identifier (engine-specific).
18    pub document_id: String,
19    /// Rank within the engine's result list (0-based).
20    pub rank: usize,
21    /// Original score from the engine (for diagnostics).
22    pub score: f32,
23    /// Source engine identifier.
24    pub source: &'static str,
25}
26
27/// A fused result after RRF combination.
28#[derive(Debug, Clone)]
29pub struct FusedResult {
30    pub document_id: String,
31    pub rrf_score: f64,
32    /// Per-engine contributions for explainability.
33    pub contributions: Vec<(&'static str, f64)>,
34}
35
36/// Fuse multiple ranked result lists using Reciprocal Rank Fusion.
37///
38/// Each inner Vec is a ranked list from one engine (ordered by relevance).
39/// Returns the top_k fused results sorted by RRF score (descending).
40pub fn reciprocal_rank_fusion(
41    ranked_lists: &[Vec<RankedResult>],
42    k: Option<f64>,
43    top_k: usize,
44) -> Vec<FusedResult> {
45    let k = k.unwrap_or(DEFAULT_RRF_K);
46
47    let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
48        std::collections::HashMap::new();
49
50    for list in ranked_lists {
51        for result in list {
52            let contribution = 1.0 / (k + result.rank as f64 + 1.0);
53            scores
54                .entry(result.document_id.clone())
55                .or_default()
56                .push((result.source, contribution));
57        }
58    }
59
60    let mut fused: Vec<FusedResult> = scores
61        .into_iter()
62        .map(|(doc_id, contributions)| {
63            let rrf_score = contributions.iter().map(|(_, s)| s).sum();
64            FusedResult {
65                document_id: doc_id,
66                rrf_score,
67                contributions,
68            }
69        })
70        .collect();
71
72    fused.sort_unstable_by(|a, b| {
73        b.rrf_score
74            .partial_cmp(&a.rrf_score)
75            .unwrap_or(std::cmp::Ordering::Equal)
76    });
77    fused.truncate(top_k);
78    fused
79}
80
81/// Fuse ranked lists with per-list k-constants for weighted influence.
82///
83/// Each list gets its own k value: lower k → steeper rank discount → more
84/// influence. Typical usage: `k_i = base_k / weight_i`.
85///
86/// # Panics
87///
88/// Panics if `k_per_list.len() != ranked_lists.len()`.
89pub fn reciprocal_rank_fusion_weighted(
90    ranked_lists: &[Vec<RankedResult>],
91    k_per_list: &[f64],
92    top_k: usize,
93) -> Vec<FusedResult> {
94    assert_eq!(
95        ranked_lists.len(),
96        k_per_list.len(),
97        "k_per_list length must match ranked_lists length"
98    );
99
100    let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
101        std::collections::HashMap::new();
102
103    for (list_idx, list) in ranked_lists.iter().enumerate() {
104        let k = k_per_list[list_idx];
105        for result in list {
106            let contribution = 1.0 / (k + result.rank as f64 + 1.0);
107            scores
108                .entry(result.document_id.clone())
109                .or_default()
110                .push((result.source, contribution));
111        }
112    }
113
114    let mut fused: Vec<FusedResult> = scores
115        .into_iter()
116        .map(|(doc_id, contributions)| {
117            let rrf_score = contributions.iter().map(|(_, s)| s).sum();
118            FusedResult {
119                document_id: doc_id,
120                rrf_score,
121                contributions,
122            }
123        })
124        .collect();
125
126    fused.sort_unstable_by(|a, b| {
127        b.rrf_score
128            .partial_cmp(&a.rrf_score)
129            .unwrap_or(std::cmp::Ordering::Equal)
130    });
131    fused.truncate(top_k);
132    fused
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    fn make_ranked(doc_ids: &[&str], source: &'static str) -> Vec<RankedResult> {
140        doc_ids
141            .iter()
142            .enumerate()
143            .map(|(rank, &id)| RankedResult {
144                document_id: id.to_string(),
145                rank,
146                score: 1.0 - (rank as f32 * 0.1),
147                source,
148            })
149            .collect()
150    }
151
152    #[test]
153    fn single_list_preserves_order() {
154        let list = make_ranked(&["d1", "d2", "d3"], "vector");
155        let fused = reciprocal_rank_fusion(&[list], None, 10);
156        assert_eq!(fused.len(), 3);
157        assert_eq!(fused[0].document_id, "d1");
158    }
159
160    #[test]
161    fn overlapping_lists_boost_common_docs() {
162        let vector = make_ranked(&["d1", "d2", "d3"], "vector");
163        let sparse = make_ranked(&["d2", "d1", "d4"], "sparse");
164        let fused = reciprocal_rank_fusion(&[vector, sparse], None, 10);
165        let top2_ids: Vec<&str> = fused[..2].iter().map(|f| f.document_id.as_str()).collect();
166        assert!(top2_ids.contains(&"d1"));
167        assert!(top2_ids.contains(&"d2"));
168    }
169
170    #[test]
171    fn weighted_rrf() {
172        let list_a = make_ranked(&["a1", "a2"], "vector");
173        let list_b = make_ranked(&["b1", "a1"], "text");
174        let fused = reciprocal_rank_fusion_weighted(&[list_a, list_b], &[30.0, 120.0], 10);
175        let a1 = fused.iter().find(|f| f.document_id == "a1").unwrap();
176        assert_eq!(a1.contributions.len(), 2);
177    }
178
179    #[test]
180    fn empty() {
181        assert!(reciprocal_rank_fusion(&[], None, 10).is_empty());
182        assert!(reciprocal_rank_fusion_weighted(&[], &[], 10).is_empty());
183    }
184}