1mod neural;
7
8use std::collections::HashMap;
9
10use crate::fts::FtsResult;
11use crate::vectordb::SearchResult;
12
13pub use neural::NeuralReranker;
14
15pub const DEFAULT_RRF_K: f32 = 20.0;
17
18pub const EXACT_MATCH_RRF_K: f32 = 5.0;
20
21#[derive(Debug, Clone)]
23#[allow(dead_code)] pub struct FusedResult {
25 pub chunk_id: u32,
27 pub rrf_score: f32,
29 pub vector_score: Option<f32>,
31 pub fts_score: Option<f32>,
33 pub vector_rank: Option<usize>,
35 pub fts_rank: Option<usize>,
37}
38
39type ScoreEntry = (f32, Option<f32>, Option<f32>, Option<usize>, Option<usize>);
47
48pub fn rrf_fusion(
49 vector_results: &[SearchResult],
50 fts_results: &[FtsResult],
51 k: f32,
52) -> Vec<FusedResult> {
53 let mut scores: HashMap<u32, ScoreEntry> = HashMap::new();
55
56 for (rank, result) in vector_results.iter().enumerate() {
58 let chunk_id = result.id;
59 let rrf_score = 1.0 / (k + rank as f32 + 1.0);
60
61 let entry = scores
62 .entry(chunk_id)
63 .or_insert((0.0, None, None, None, None));
64 entry.0 += rrf_score;
65 entry.1 = Some(result.score);
66 entry.3 = Some(rank + 1);
67 }
68
69 for (rank, result) in fts_results.iter().enumerate() {
71 let chunk_id = result.chunk_id;
72 let rrf_score = 1.0 / (k + rank as f32 + 1.0);
73
74 let entry = scores
75 .entry(chunk_id)
76 .or_insert((0.0, None, None, None, None));
77 entry.0 += rrf_score;
78 entry.2 = Some(result.score);
79 entry.4 = Some(rank + 1);
80 }
81
82 let mut results: Vec<FusedResult> = scores
84 .into_iter()
85 .map(
86 |(chunk_id, (rrf_score, vector_score, fts_score, vector_rank, fts_rank))| FusedResult {
87 chunk_id,
88 rrf_score,
89 vector_score,
90 fts_score,
91 vector_rank,
92 fts_rank,
93 },
94 )
95 .collect();
96
97 results.sort_by(|a, b| {
99 b.rrf_score
100 .partial_cmp(&a.rrf_score)
101 .unwrap_or(std::cmp::Ordering::Equal)
102 });
103
104 results
105}
106
107pub fn vector_only(vector_results: &[SearchResult]) -> Vec<FusedResult> {
109 vector_results
110 .iter()
111 .enumerate()
112 .map(|(rank, result)| FusedResult {
113 chunk_id: result.id,
114 rrf_score: result.score,
115 vector_score: Some(result.score),
116 fts_score: None,
117 vector_rank: Some(rank + 1),
118 fts_rank: None,
119 })
120 .collect()
121}
122
123pub fn rrf_fusion_with_exact(
137 vector_results: &[SearchResult],
138 fts_results: &[FtsResult],
139 exact_results: &[FtsResult],
140 vector_k: f32,
141 fts_k: f32,
142 exact_k: f32,
143) -> Vec<FusedResult> {
144 let mut scores: HashMap<
146 u32,
147 (
148 f32,
149 Option<f32>,
150 Option<f32>,
151 Option<f32>,
152 Option<usize>,
153 Option<usize>,
154 Option<usize>,
155 ),
156 > = HashMap::new();
157
158 for (rank, result) in vector_results.iter().enumerate() {
160 let chunk_id = result.id;
161 let rrf_score = 1.0 / (vector_k + rank as f32 + 1.0);
162
163 let entry = scores
164 .entry(chunk_id)
165 .or_insert((0.0, None, None, None, None, None, None));
166 entry.0 += rrf_score;
167 entry.1 = Some(result.score);
168 entry.4 = Some(rank + 1);
169 }
170
171 for (rank, result) in fts_results.iter().enumerate() {
173 let chunk_id = result.chunk_id;
174 let rrf_score = 1.0 / (fts_k + rank as f32 + 1.0);
175
176 let entry = scores
177 .entry(chunk_id)
178 .or_insert((0.0, None, None, None, None, None, None));
179 entry.0 += rrf_score;
180 entry.2 = Some(result.score);
181 entry.5 = Some(rank + 1);
182 }
183
184 for (rank, result) in exact_results.iter().enumerate() {
186 let chunk_id = result.chunk_id;
187 let rrf_score = 1.0 / (exact_k + rank as f32 + 1.0);
188
189 let entry = scores
190 .entry(chunk_id)
191 .or_insert((0.0, None, None, None, None, None, None));
192 entry.0 += rrf_score;
193 entry.3 = Some(result.score);
194 entry.6 = Some(rank + 1);
195 }
196
197 let mut results: Vec<FusedResult> = scores
199 .into_iter()
200 .map(
201 |(
202 chunk_id,
203 (
204 rrf_score,
205 vector_score,
206 fts_score,
207 exact_score,
208 vector_rank,
209 fts_rank,
210 exact_rank,
211 ),
212 )| {
213 let combined_fts_score = match (fts_score, exact_score) {
215 (Some(f), Some(e)) => Some((f + e) / 2.0),
216 (Some(f), None) => Some(f),
217 (None, Some(e)) => Some(e),
218 (None, None) => None,
219 };
220
221 FusedResult {
222 chunk_id,
223 rrf_score,
224 vector_score,
225 fts_score: combined_fts_score,
226 vector_rank,
227 fts_rank: fts_rank.or(exact_rank),
228 }
229 },
230 )
231 .collect();
232
233 results.sort_by(|a, b| {
235 b.rrf_score
236 .partial_cmp(&a.rrf_score)
237 .unwrap_or(std::cmp::Ordering::Equal)
238 });
239
240 results
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 fn make_vector_result(id: u32, score: f32) -> SearchResult {
248 SearchResult {
249 id,
250 score,
251 path: format!("file_{}.rs", id),
252 content: format!("content {}", id),
253 start_line: 1,
254 end_line: 10,
255 kind: "function".to_string(),
256 signature: None,
257 context_prev: None,
258 context_next: None,
259 distance: 0.0,
260 context: None,
261 docstring: None,
262 hash: String::new(),
263 }
264 }
265
266 fn make_fts_result(id: u32, score: f32) -> FtsResult {
267 FtsResult {
268 chunk_id: id,
269 score,
270 }
271 }
272
273 #[test]
274 fn test_rrf_fusion_basic() {
275 let vector_results = vec![
276 make_vector_result(1, 0.9),
277 make_vector_result(2, 0.8),
278 make_vector_result(3, 0.7),
279 ];
280
281 let fts_results = vec![
282 make_fts_result(2, 10.0), make_fts_result(1, 8.0),
284 make_fts_result(4, 6.0), ];
286
287 let fused = rrf_fusion(&vector_results, &fts_results, 20.0);
288
289 assert!(!fused.is_empty());
292
293 let id1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
295 let id2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
296
297 assert!(id1.vector_rank.is_some());
299 assert!(id1.fts_rank.is_some());
300 assert!(id2.vector_rank.is_some());
301 assert!(id2.fts_rank.is_some());
302
303 let id4 = fused.iter().find(|r| r.chunk_id == 4).unwrap();
305 assert!(id4.vector_rank.is_none());
306 assert!(id4.fts_rank.is_some());
307 }
308
309 #[test]
310 fn test_rrf_score_calculation() {
311 let vector_results = vec![make_vector_result(1, 0.9)];
315 let fts_results = vec![make_fts_result(1, 10.0)];
316
317 let fused = rrf_fusion(&vector_results, &fts_results, 20.0);
318
319 assert_eq!(fused.len(), 1);
320 let result = &fused[0];
321
322 let expected = 1.0 / 21.0 + 1.0 / 21.0;
324 assert!((result.rrf_score - expected).abs() < 0.0001);
325 }
326
327 #[test]
328 fn test_vector_only() {
329 let vector_results = vec![make_vector_result(1, 0.9), make_vector_result(2, 0.8)];
330
331 let results = vector_only(&vector_results);
332
333 assert_eq!(results.len(), 2);
334 assert_eq!(results[0].chunk_id, 1);
335 assert_eq!(results[0].rrf_score, 0.9);
336 assert!(results[0].fts_score.is_none());
337 }
338}