Skip to main content

nodedb_query/
fusion.rs

1/// Reciprocal Rank Fusion (RRF) for combining ranked results from multiple engines.
2///
3/// RRF is used when a query hits multiple engines (e.g., vector similarity +
4/// metadata filter + BM25 text search). Each engine returns a ranked list;
5/// RRF combines them into a single ranked list.
6///
7/// Formula: RRF_score(d) = Σ 1 / (k + rank_i(d))
8/// where k is a smoothing constant (default 60).
9/// RRF smoothing constant. Standard value from Cormack et al. (2009).
10pub const DEFAULT_RRF_K: f64 = 60.0;
11
12/// A scored result from a single engine.
13#[derive(Debug, Clone)]
14pub struct RankedResult {
15    /// Document identifier (engine-specific).
16    pub document_id: String,
17    /// Rank within the engine's result list (0-based).
18    pub rank: usize,
19    /// Original score from the engine (for diagnostics).
20    pub score: f32,
21    /// Source engine identifier.
22    pub source: &'static str,
23}
24
25/// A fused result after RRF combination.
26#[derive(Debug, Clone)]
27pub struct FusedResult {
28    pub document_id: String,
29    pub rrf_score: f64,
30    /// Per-engine contributions for explainability.
31    pub contributions: Vec<(&'static str, f64)>,
32}
33
34/// Fuse multiple ranked result lists using Reciprocal Rank Fusion.
35///
36/// Each inner Vec is a ranked list from one engine (ordered by relevance).
37/// Returns the top_k fused results sorted by RRF score (descending).
38pub fn reciprocal_rank_fusion(
39    ranked_lists: &[Vec<RankedResult>],
40    k: Option<f64>,
41    top_k: usize,
42) -> Vec<FusedResult> {
43    let k = k.unwrap_or(DEFAULT_RRF_K);
44
45    let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
46        std::collections::HashMap::new();
47
48    for list in ranked_lists {
49        for result in list {
50            let contribution = 1.0 / (k + result.rank as f64 + 1.0);
51            scores
52                .entry(result.document_id.clone())
53                .or_default()
54                .push((result.source, contribution));
55        }
56    }
57
58    let mut fused: Vec<FusedResult> = scores
59        .into_iter()
60        .map(|(doc_id, contributions)| {
61            let rrf_score = contributions.iter().map(|(_, s)| s).sum();
62            FusedResult {
63                document_id: doc_id,
64                rrf_score,
65                contributions,
66            }
67        })
68        .collect();
69
70    fused.sort_unstable_by(|a, b| {
71        b.rrf_score
72            .partial_cmp(&a.rrf_score)
73            .unwrap_or(std::cmp::Ordering::Equal)
74    });
75    fused.truncate(top_k);
76    fused
77}
78
79/// Fuse ranked lists with per-list k-constants for weighted influence.
80///
81/// Each list gets its own k value: lower k → steeper rank discount → more
82/// influence. Typical usage: `k_i = base_k / weight_i`.
83///
84/// # Panics
85///
86/// Panics if `k_per_list.len() != ranked_lists.len()`.
87pub fn reciprocal_rank_fusion_weighted(
88    ranked_lists: &[Vec<RankedResult>],
89    k_per_list: &[f64],
90    top_k: usize,
91) -> Vec<FusedResult> {
92    assert_eq!(
93        ranked_lists.len(),
94        k_per_list.len(),
95        "k_per_list length must match ranked_lists length"
96    );
97
98    let mut scores: std::collections::HashMap<String, Vec<(&'static str, f64)>> =
99        std::collections::HashMap::new();
100
101    for (list_idx, list) in ranked_lists.iter().enumerate() {
102        let k = k_per_list[list_idx];
103        for result in list {
104            let contribution = 1.0 / (k + result.rank as f64 + 1.0);
105            scores
106                .entry(result.document_id.clone())
107                .or_default()
108                .push((result.source, contribution));
109        }
110    }
111
112    let mut fused: Vec<FusedResult> = scores
113        .into_iter()
114        .map(|(doc_id, contributions)| {
115            let rrf_score = contributions.iter().map(|(_, s)| s).sum();
116            FusedResult {
117                document_id: doc_id,
118                rrf_score,
119                contributions,
120            }
121        })
122        .collect();
123
124    fused.sort_unstable_by(|a, b| {
125        b.rrf_score
126            .partial_cmp(&a.rrf_score)
127            .unwrap_or(std::cmp::Ordering::Equal)
128    });
129    fused.truncate(top_k);
130    fused
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    fn make_ranked(doc_ids: &[&str], source: &'static str) -> Vec<RankedResult> {
138        doc_ids
139            .iter()
140            .enumerate()
141            .map(|(rank, &id)| RankedResult {
142                document_id: id.to_string(),
143                rank,
144                score: 1.0 - (rank as f32 * 0.1),
145                source,
146            })
147            .collect()
148    }
149
150    #[test]
151    fn single_list_preserves_order() {
152        let list = make_ranked(&["d1", "d2", "d3"], "vector");
153        let fused = reciprocal_rank_fusion(&[list], None, 10);
154        assert_eq!(fused.len(), 3);
155        assert_eq!(fused[0].document_id, "d1");
156    }
157
158    #[test]
159    fn overlapping_lists_boost_common_docs() {
160        let vector = make_ranked(&["d1", "d2", "d3"], "vector");
161        let sparse = make_ranked(&["d2", "d1", "d4"], "sparse");
162        let fused = reciprocal_rank_fusion(&[vector, sparse], None, 10);
163        let top2_ids: Vec<&str> = fused[..2].iter().map(|f| f.document_id.as_str()).collect();
164        assert!(top2_ids.contains(&"d1"));
165        assert!(top2_ids.contains(&"d2"));
166    }
167
168    #[test]
169    fn weighted_rrf() {
170        let list_a = make_ranked(&["a1", "a2"], "vector");
171        let list_b = make_ranked(&["b1", "a1"], "text");
172        let fused = reciprocal_rank_fusion_weighted(&[list_a, list_b], &[30.0, 120.0], 10);
173        let a1 = fused.iter().find(|f| f.document_id == "a1").unwrap();
174        assert_eq!(a1.contributions.len(), 2);
175    }
176
177    #[test]
178    fn empty() {
179        assert!(reciprocal_rank_fusion(&[], None, 10).is_empty());
180        assert!(reciprocal_rank_fusion_weighted(&[], &[], 10).is_empty());
181    }
182}