Skip to main content

reddb_server/runtime/ai/
rrf_fuser.rs

1//! `RrfFuser` — pure Reciprocal Rank Fusion for ASK hybrid retrieval.
2//!
3//! Issue #398 (PRD #391): the ASK pipeline retrieves candidates from
4//! multiple buckets — BM25 text search, vector similarity, graph
5//! traversal — and needs a single ranked list to feed the prompt
6//! assembler. RRF is the standard, parameter-light way to combine
7//! ranked lists from heterogeneous scorers whose raw scores are not
8//! directly comparable.
9//!
10//! Deep module: no I/O, no transport, no global state. The caller
11//! hands in pre-ranked per-bucket lists, optional per-bucket
12//! `min_score` filters, the RRF constant `k`, and the final cap.
13//! The fuser returns the fused, deterministically-tied list capped
14//! to `total_k`.
15//!
16//! ## Formula
17//!
18//! For each item `d` and each ranker `r`:
19//!
20//! ```text
21//! rrf_score(d) = Σ_r 1 / (k + rank_r(d))
22//! ```
23//!
24//! where `rank_r(d)` is 1-indexed (best item in list r is rank 1).
25//! Items absent from a list contribute nothing. The convention `k=60`
26//! comes from Cormack, Clarke & Büttcher 2009 and is the value used
27//! by published RRF baselines (Elasticsearch, Weaviate, Qdrant) — we
28//! keep it as the default but expose it for tests.
29//!
30//! ## Per-bucket filtering
31//!
32//! Each bucket carries native scores (BM25 score, cosine similarity,
33//! graph traversal weight). `min_score` is applied per bucket *before*
34//! fusion, because the natural threshold differs by ranker (cosine 0.7
35//! ≠ BM25 0.7). Filtered items are dropped entirely; they do not
36//! contribute to any other bucket's ranks.
37//!
38//! ## Tie-break
39//!
40//! When two items share an RRF score, the fused order is determined
41//! by the item id (lexicographic for strings, natural for ints). This
42//! makes the fuser a pure function: identical inputs produce
43//! byte-identical outputs, which the ASK determinism contract (#400)
44//! relies on.
45
46use std::collections::HashMap;
47use std::hash::Hash;
48
49/// One candidate inside a per-bucket ranked list, with its native
50/// score. The score is only used for `min_score` filtering — RRF
51/// itself looks only at rank, not at score magnitude.
52#[derive(Debug, Clone, PartialEq)]
53pub struct Candidate<Id> {
54    pub id: Id,
55    pub score: f64,
56}
57
58/// A ranked list from one retriever (a "bucket"). Order matters —
59/// position 0 is best, position 1 is second, etc.
60#[derive(Debug, Clone)]
61pub struct Bucket<Id> {
62    pub candidates: Vec<Candidate<Id>>,
63    /// Per-bucket score floor. `None` means "no filter". Applied
64    /// before fusion. Use this so that BM25 0.4 and cosine 0.7 can
65    /// coexist sensibly.
66    pub min_score: Option<f64>,
67}
68
69/// Output of fusion: one row per surviving id, sorted by `rrf_score`
70/// descending, with deterministic tie-break by id.
71#[derive(Debug, Clone, PartialEq)]
72pub struct FusedItem<Id> {
73    pub id: Id,
74    pub rrf_score: f64,
75}
76
77/// The canonical RRF constant from Cormack et al. 2009. Exposed for
78/// tests; production code should pass this through unchanged.
79pub const RRF_K_DEFAULT: u32 = 60;
80
81/// Fuse per-bucket ranked lists into a single ranked list capped at
82/// `total_k`. Pure function — no I/O, no clock.
83///
84/// `k` is the RRF constant (use [`RRF_K_DEFAULT`] = 60 in production).
85/// `total_k` is the maximum number of items to return; if zero, the
86/// result is empty.
87pub fn fuse<Id>(buckets: &[Bucket<Id>], k: u32, total_k: usize) -> Vec<FusedItem<Id>>
88where
89    Id: Clone + Eq + Hash + Ord,
90{
91    if total_k == 0 {
92        return Vec::new();
93    }
94
95    let k_f = f64::from(k);
96    let mut scores: HashMap<Id, f64> = HashMap::new();
97
98    for bucket in buckets {
99        let mut rank: u32 = 0;
100        for cand in &bucket.candidates {
101            if let Some(floor) = bucket.min_score {
102                if cand.score < floor {
103                    continue;
104                }
105            }
106            rank += 1;
107            let contribution = 1.0 / (k_f + f64::from(rank));
108            scores
109                .entry(cand.id.clone())
110                .and_modify(|s| *s += contribution)
111                .or_insert(contribution);
112        }
113    }
114
115    let mut fused: Vec<FusedItem<Id>> = scores
116        .into_iter()
117        .map(|(id, rrf_score)| FusedItem { id, rrf_score })
118        .collect();
119
120    // Descending by score; ties broken by id ascending. partial_cmp
121    // is safe here because rrf_score is always a finite positive sum
122    // of positive reciprocals — no NaN possible.
123    fused.sort_by(|a, b| {
124        b.rrf_score
125            .partial_cmp(&a.rrf_score)
126            .unwrap_or(std::cmp::Ordering::Equal)
127            .then_with(|| a.id.cmp(&b.id))
128    });
129
130    fused.truncate(total_k);
131    fused
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    fn cand<Id>(id: Id, score: f64) -> Candidate<Id> {
139        Candidate { id, score }
140    }
141
142    fn bucket_no_floor<Id: Clone>(cs: Vec<Candidate<Id>>) -> Bucket<Id> {
143        Bucket {
144            candidates: cs,
145            min_score: None,
146        }
147    }
148
149    // ---- Reference values ---------------------------------------------
150
151    #[test]
152    fn rrf_single_list_matches_reference() {
153        // Single list, k=60. Expected scores by rank: 1/61, 1/62, 1/63.
154        let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5), cand("c", 0.1)]);
155        let out = fuse(&[bucket], 60, 10);
156        assert_eq!(out.len(), 3);
157        assert!((out[0].rrf_score - 1.0 / 61.0).abs() < 1e-12);
158        assert!((out[1].rrf_score - 1.0 / 62.0).abs() < 1e-12);
159        assert!((out[2].rrf_score - 1.0 / 63.0).abs() < 1e-12);
160        assert_eq!(out[0].id, "a");
161        assert_eq!(out[1].id, "b");
162        assert_eq!(out[2].id, "c");
163    }
164
165    #[test]
166    fn rrf_two_lists_sums_contributions() {
167        // 'a' is rank 1 in both → 2/61.
168        // 'b' is rank 2 in both → 2/62.
169        // 'c' only in list 1 at rank 3 → 1/63.
170        // 'd' only in list 2 at rank 3 → 1/63.
171        let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.9), cand("c", 0.8)]);
172        let b2 = bucket_no_floor(vec![cand("a", 0.95), cand("b", 0.85), cand("d", 0.7)]);
173        let out = fuse(&[b1, b2], 60, 10);
174        let by_id: std::collections::HashMap<_, _> =
175            out.iter().map(|f| (f.id, f.rrf_score)).collect();
176        assert!((by_id["a"] - 2.0 / 61.0).abs() < 1e-12);
177        assert!((by_id["b"] - 2.0 / 62.0).abs() < 1e-12);
178        assert!((by_id["c"] - 1.0 / 63.0).abs() < 1e-12);
179        assert!((by_id["d"] - 1.0 / 63.0).abs() < 1e-12);
180        // Order: a > b > (c,d) with c first by id tie-break.
181        assert_eq!(out[0].id, "a");
182        assert_eq!(out[1].id, "b");
183        assert_eq!(out[2].id, "c");
184        assert_eq!(out[3].id, "d");
185    }
186
187    #[test]
188    fn rrf_k_default_is_60() {
189        assert_eq!(RRF_K_DEFAULT, 60);
190    }
191
192    #[test]
193    fn alternate_k_changes_scores() {
194        // k=1 → rank 1 contribution is 1/2. Sanity check the constant
195        // is actually wired in.
196        let bucket = bucket_no_floor(vec![cand("a", 1.0)]);
197        let out = fuse(&[bucket], 1, 10);
198        assert!((out[0].rrf_score - 0.5).abs() < 1e-12);
199    }
200
201    // ---- LIMIT total_k ------------------------------------------------
202
203    #[test]
204    fn total_k_caps_output() {
205        let bucket = bucket_no_floor(vec![
206            cand("a", 1.0),
207            cand("b", 0.9),
208            cand("c", 0.8),
209            cand("d", 0.7),
210        ]);
211        let out = fuse(&[bucket], 60, 2);
212        assert_eq!(out.len(), 2);
213        assert_eq!(out[0].id, "a");
214        assert_eq!(out[1].id, "b");
215    }
216
217    #[test]
218    fn total_k_zero_returns_empty() {
219        let bucket = bucket_no_floor(vec![cand("a", 1.0)]);
220        let out = fuse(&[bucket], 60, 0);
221        assert!(out.is_empty());
222    }
223
224    #[test]
225    fn total_k_larger_than_candidates_returns_all() {
226        let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
227        let out = fuse(&[bucket], 60, 100);
228        assert_eq!(out.len(), 2);
229    }
230
231    // ---- MIN_SCORE per-bucket -----------------------------------------
232
233    #[test]
234    fn min_score_drops_items_before_ranking() {
235        // 'b' fails the 0.5 floor — must be dropped, AND 'c' must
236        // then be promoted to rank 2 (not rank 3).
237        let bucket = Bucket {
238            candidates: vec![cand("a", 0.9), cand("b", 0.4), cand("c", 0.6)],
239            min_score: Some(0.5),
240        };
241        let out = fuse(&[bucket], 60, 10);
242        assert_eq!(out.len(), 2);
243        assert_eq!(out[0].id, "a");
244        assert!((out[0].rrf_score - 1.0 / 61.0).abs() < 1e-12);
245        assert_eq!(out[1].id, "c");
246        // c is rank 2 after filter, not rank 3.
247        assert!((out[1].rrf_score - 1.0 / 62.0).abs() < 1e-12);
248    }
249
250    #[test]
251    fn min_score_independent_per_bucket() {
252        // bm25-bucket uses min_score 0.4, vector-bucket uses 0.7.
253        let bm25 = Bucket {
254            candidates: vec![cand("x", 0.5), cand("y", 0.3)],
255            min_score: Some(0.4),
256        };
257        let vec_bucket = Bucket {
258            candidates: vec![cand("x", 0.85), cand("y", 0.6)],
259            min_score: Some(0.7),
260        };
261        let out = fuse(&[bm25, vec_bucket], 60, 10);
262        // y is dropped from vector bucket (0.6 < 0.7) and from bm25
263        // bucket (0.3 < 0.4). x survives both → 2/61.
264        assert_eq!(out.len(), 1);
265        assert_eq!(out[0].id, "x");
266        assert!((out[0].rrf_score - 2.0 / 61.0).abs() < 1e-12);
267    }
268
269    #[test]
270    fn min_score_none_keeps_everything() {
271        let bucket = bucket_no_floor(vec![cand("a", -10.0), cand("b", 0.0)]);
272        let out = fuse(&[bucket], 60, 10);
273        assert_eq!(out.len(), 2);
274    }
275
276    // ---- Tie-break determinism ----------------------------------------
277
278    #[test]
279    fn tie_break_is_id_ascending() {
280        // Three items each appearing once at rank 1 across three
281        // separate buckets — all share the same rrf_score 1/61.
282        let b1 = bucket_no_floor(vec![cand("zebra", 1.0)]);
283        let b2 = bucket_no_floor(vec![cand("apple", 1.0)]);
284        let b3 = bucket_no_floor(vec![cand("mango", 1.0)]);
285        let out = fuse(&[b1, b2, b3], 60, 10);
286        assert_eq!(
287            out.iter().map(|f| f.id).collect::<Vec<_>>(),
288            vec!["apple", "mango", "zebra"]
289        );
290    }
291
292    #[test]
293    fn fuse_is_deterministic_across_calls() {
294        // Same inputs → byte-equal outputs. Required by the ASK
295        // determinism contract (#400).
296        let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
297        let b2 = bucket_no_floor(vec![cand("b", 0.9), cand("c", 0.4)]);
298        let a = fuse(&[b1.clone(), b2.clone()], 60, 10);
299        let c = fuse(&[b1, b2], 60, 10);
300        assert_eq!(a, c);
301    }
302
303    #[test]
304    fn fuse_is_order_independent_across_buckets() {
305        // The bucket order on input should not affect the fused
306        // output — RRF is commutative across rankers.
307        let b1 = bucket_no_floor(vec![cand("a", 1.0), cand("b", 0.5)]);
308        let b2 = bucket_no_floor(vec![cand("b", 0.9), cand("c", 0.4)]);
309        let forward = fuse(&[b1.clone(), b2.clone()], 60, 10);
310        let reverse = fuse(&[b2, b1], 60, 10);
311        assert_eq!(forward, reverse);
312    }
313
314    // ---- Edge cases ---------------------------------------------------
315
316    #[test]
317    fn empty_buckets_returns_empty() {
318        let buckets: Vec<Bucket<&'static str>> = vec![];
319        let out = fuse(&buckets, 60, 10);
320        assert!(out.is_empty());
321    }
322
323    #[test]
324    fn all_empty_buckets_returns_empty() {
325        let buckets: Vec<Bucket<&'static str>> =
326            vec![bucket_no_floor(vec![]), bucket_no_floor(vec![])];
327        let out = fuse(&buckets, 60, 10);
328        assert!(out.is_empty());
329    }
330
331    #[test]
332    fn duplicate_id_within_one_bucket_keeps_both_ranks() {
333        // Realistically a retriever should not emit the same id twice,
334        // but if it does, the later occurrence keeps a lower rank.
335        // Document the behavior so it isn't a future surprise: both
336        // contributions accumulate.
337        let bucket = bucket_no_floor(vec![cand("a", 1.0), cand("a", 0.5)]);
338        let out = fuse(&[bucket], 60, 10);
339        assert_eq!(out.len(), 1);
340        assert!((out[0].rrf_score - (1.0 / 61.0 + 1.0 / 62.0)).abs() < 1e-12);
341    }
342
343    #[test]
344    fn integer_ids_supported() {
345        // The fuser is generic over id type; ints are valid.
346        let b1 = bucket_no_floor(vec![cand(1u64, 1.0), cand(2u64, 0.5)]);
347        let b2 = bucket_no_floor(vec![cand(2u64, 0.9), cand(3u64, 0.4)]);
348        let out = fuse(&[b1, b2], 60, 10);
349        assert_eq!(out[0].id, 2);
350    }
351}