nodedb_vector/hybrid/
rrf.rs1use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct RankedResult {
14 pub id: u32,
16 pub rank: u32,
18 pub raw_score: f32,
21}
22
23pub struct RrfOptions {
25 pub k: f32,
28}
29
30impl Default for RrfOptions {
31 fn default() -> Self {
32 Self { k: 60.0 }
33 }
34}
35
36pub fn rrf_fuse(
48 lists: &[Vec<RankedResult>],
49 options: &RrfOptions,
50 top_k: usize,
51) -> Vec<(u32, f32)> {
52 let mut scores: HashMap<u32, f32> = HashMap::new();
54
55 for list in lists {
56 for entry in list {
57 let contribution = 1.0 / (options.k + entry.rank as f32);
58 *scores.entry(entry.id).or_insert(0.0) += contribution;
59 }
60 }
61
62 let mut ranked: Vec<(u32, f32)> = scores.into_iter().collect();
64 ranked.sort_unstable_by(|a, b| {
65 b.1.partial_cmp(&a.1)
66 .unwrap_or(std::cmp::Ordering::Equal)
67 .then_with(|| a.0.cmp(&b.0))
68 });
69
70 if top_k > 0 {
71 ranked.truncate(top_k);
72 }
73
74 ranked
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
82 (a - b).abs() < eps
83 }
84
85 #[test]
86 fn two_lists_top_ranked_in_both() {
87 let list_a = vec![
89 RankedResult {
90 id: 5,
91 rank: 1,
92 raw_score: 0.9,
93 },
94 RankedResult {
95 id: 3,
96 rank: 2,
97 raw_score: 0.7,
98 },
99 ];
100 let list_b = vec![
101 RankedResult {
102 id: 5,
103 rank: 1,
104 raw_score: 0.85,
105 },
106 RankedResult {
107 id: 7,
108 rank: 2,
109 raw_score: 0.6,
110 },
111 ];
112
113 let opts = RrfOptions::default();
114 let result = rrf_fuse(&[list_a, list_b], &opts, 10);
115
116 assert_eq!(result[0].0, 5);
118 let expected = 2.0 / 61.0;
119 assert!(
120 approx_eq(result[0].1, expected, 1e-5),
121 "expected ≈{expected:.6}, got {:.6}",
122 result[0].1
123 );
124 }
125
126 #[test]
127 fn top_k_limits_output() {
128 let list: Vec<RankedResult> = (1..=20)
129 .map(|i| RankedResult {
130 id: i,
131 rank: i,
132 raw_score: 1.0 / i as f32,
133 })
134 .collect();
135
136 let opts = RrfOptions::default();
137 let result = rrf_fuse(&[list], &opts, 5);
138
139 assert_eq!(result.len(), 5);
140 }
141
142 #[test]
143 fn empty_input_lists_returns_empty() {
144 let opts = RrfOptions::default();
145 let result = rrf_fuse(&[], &opts, 10);
146 assert!(result.is_empty());
147 }
148
149 #[test]
150 fn empty_individual_lists_ignored() {
151 let list_a: Vec<RankedResult> = vec![];
152 let list_b = vec![RankedResult {
153 id: 1,
154 rank: 1,
155 raw_score: 1.0,
156 }];
157
158 let opts = RrfOptions::default();
159 let result = rrf_fuse(&[list_a, list_b], &opts, 10);
160
161 assert_eq!(result.len(), 1);
162 assert_eq!(result[0].0, 1);
163 }
164}