Skip to main content

codesearch/rerank/
mod.rs

1//! Reranking and result fusion strategies
2//!
3//! Provides RRF (Reciprocal Rank Fusion) for combining vector and FTS results,
4//! and neural reranking using cross-encoder models for improved accuracy.
5
6mod neural;
7
8use std::collections::HashMap;
9
10use crate::fts::FtsResult;
11use crate::vectordb::SearchResult;
12
13pub use neural::NeuralReranker;
14
15/// Default RRF k parameter (per osgrep reference)
16pub const DEFAULT_RRF_K: f32 = 20.0;
17
18/// RRF k parameter for exact matches (lower = stronger boost)
19pub const EXACT_MATCH_RRF_K: f32 = 5.0;
20
21/// Fused search result combining vector and FTS scores
22#[derive(Debug, Clone)]
23#[allow(dead_code)] // Fields used for debugging/diagnostics
24pub struct FusedResult {
25    /// Chunk ID
26    pub chunk_id: u32,
27    /// Combined RRF score
28    pub rrf_score: f32,
29    /// Original vector similarity score (if present)
30    pub vector_score: Option<f32>,
31    /// Original FTS/BM25 score (if present)
32    pub fts_score: Option<f32>,
33    /// Vector rank (1-indexed, None if not in vector results)
34    pub vector_rank: Option<usize>,
35    /// FTS rank (1-indexed, None if not in FTS results)
36    pub fts_rank: Option<usize>,
37}
38
39/// Reciprocal Rank Fusion (RRF) for combining search results
40///
41/// RRF formula: score = sum(1 / (k + rank)) for each ranking list
42/// where k is a constant (default 20) and rank is 1-indexed position.
43///
44/// This is a proven technique for combining multiple ranking signals
45/// without needing to normalize scores across different systems.
46type ScoreEntry = (f32, Option<f32>, Option<f32>, Option<usize>, Option<usize>);
47
48pub fn rrf_fusion(
49    vector_results: &[SearchResult],
50    fts_results: &[FtsResult],
51    k: f32,
52) -> Vec<FusedResult> {
53    // Maps chunk_id -> (rrf_score, vector_score, fts_score, vector_rank, fts_rank)
54    let mut scores: HashMap<u32, ScoreEntry> = HashMap::new();
55
56    // Process vector results
57    for (rank, result) in vector_results.iter().enumerate() {
58        let chunk_id = result.id;
59        let rrf_score = 1.0 / (k + rank as f32 + 1.0);
60
61        let entry = scores
62            .entry(chunk_id)
63            .or_insert((0.0, None, None, None, None));
64        entry.0 += rrf_score;
65        entry.1 = Some(result.score);
66        entry.3 = Some(rank + 1);
67    }
68
69    // Process FTS results
70    for (rank, result) in fts_results.iter().enumerate() {
71        let chunk_id = result.chunk_id;
72        let rrf_score = 1.0 / (k + rank as f32 + 1.0);
73
74        let entry = scores
75            .entry(chunk_id)
76            .or_insert((0.0, None, None, None, None));
77        entry.0 += rrf_score;
78        entry.2 = Some(result.score);
79        entry.4 = Some(rank + 1);
80    }
81
82    // Convert to FusedResult and sort by RRF score
83    let mut results: Vec<FusedResult> = scores
84        .into_iter()
85        .map(
86            |(chunk_id, (rrf_score, vector_score, fts_score, vector_rank, fts_rank))| FusedResult {
87                chunk_id,
88                rrf_score,
89                vector_score,
90                fts_score,
91                vector_rank,
92                fts_rank,
93            },
94        )
95        .collect();
96
97    // Sort by RRF score descending
98    results.sort_by(|a, b| {
99        b.rrf_score
100            .partial_cmp(&a.rrf_score)
101            .unwrap_or(std::cmp::Ordering::Equal)
102    });
103
104    results
105}
106
107/// Simple vector-only pass-through (no fusion)
108pub fn vector_only(vector_results: &[SearchResult]) -> Vec<FusedResult> {
109    vector_results
110        .iter()
111        .enumerate()
112        .map(|(rank, result)| FusedResult {
113            chunk_id: result.id,
114            rrf_score: result.score,
115            vector_score: Some(result.score),
116            fts_score: None,
117            vector_rank: Some(rank + 1),
118            fts_rank: None,
119        })
120        .collect()
121}
122
123/// Reciprocal Rank Fusion with exact match boosting
124///
125/// Three-way RRF fusion: vector, FTS, and exact matches.
126/// Exact matches get a lower k value (stronger boost) because they're more likely
127/// to be what the user wants when searching for specific identifiers.
128///
129/// # Arguments
130/// * `vector_results` - Vector similarity results
131/// * `fts_results` - BM25 full-text search results
132/// * `exact_results` - Exact identifier match results (from signature field)
133/// * `vector_k` - RRF k for vector (default 20)
134/// * `fts_k` - RRF k for FTS (default 20)
135/// * `exact_k` - RRF k for exact matches (default 5, stronger boost)
136pub fn rrf_fusion_with_exact(
137    vector_results: &[SearchResult],
138    fts_results: &[FtsResult],
139    exact_results: &[FtsResult],
140    vector_k: f32,
141    fts_k: f32,
142    exact_k: f32,
143) -> Vec<FusedResult> {
144    // Maps chunk_id -> (rrf_score, vector_score, fts_score, exact_score, vector_rank, fts_rank, exact_rank)
145    let mut scores: HashMap<
146        u32,
147        (
148            f32,
149            Option<f32>,
150            Option<f32>,
151            Option<f32>,
152            Option<usize>,
153            Option<usize>,
154            Option<usize>,
155        ),
156    > = HashMap::new();
157
158    // Process vector results
159    for (rank, result) in vector_results.iter().enumerate() {
160        let chunk_id = result.id;
161        let rrf_score = 1.0 / (vector_k + rank as f32 + 1.0);
162
163        let entry = scores
164            .entry(chunk_id)
165            .or_insert((0.0, None, None, None, None, None, None));
166        entry.0 += rrf_score;
167        entry.1 = Some(result.score);
168        entry.4 = Some(rank + 1);
169    }
170
171    // Process FTS results
172    for (rank, result) in fts_results.iter().enumerate() {
173        let chunk_id = result.chunk_id;
174        let rrf_score = 1.0 / (fts_k + rank as f32 + 1.0);
175
176        let entry = scores
177            .entry(chunk_id)
178            .or_insert((0.0, None, None, None, None, None, None));
179        entry.0 += rrf_score;
180        entry.2 = Some(result.score);
181        entry.5 = Some(rank + 1);
182    }
183
184    // Process exact results (stronger boost with lower k)
185    for (rank, result) in exact_results.iter().enumerate() {
186        let chunk_id = result.chunk_id;
187        let rrf_score = 1.0 / (exact_k + rank as f32 + 1.0);
188
189        let entry = scores
190            .entry(chunk_id)
191            .or_insert((0.0, None, None, None, None, None, None));
192        entry.0 += rrf_score;
193        entry.3 = Some(result.score);
194        entry.6 = Some(rank + 1);
195    }
196
197    // Convert to FusedResult and sort by RRF score
198    let mut results: Vec<FusedResult> = scores
199        .into_iter()
200        .map(
201            |(
202                chunk_id,
203                (
204                    rrf_score,
205                    vector_score,
206                    fts_score,
207                    exact_score,
208                    vector_rank,
209                    fts_rank,
210                    exact_rank,
211                ),
212            )| {
213                // Combine FTS and exact scores for fts_score field
214                let combined_fts_score = match (fts_score, exact_score) {
215                    (Some(f), Some(e)) => Some((f + e) / 2.0),
216                    (Some(f), None) => Some(f),
217                    (None, Some(e)) => Some(e),
218                    (None, None) => None,
219                };
220
221                FusedResult {
222                    chunk_id,
223                    rrf_score,
224                    vector_score,
225                    fts_score: combined_fts_score,
226                    vector_rank,
227                    fts_rank: fts_rank.or(exact_rank),
228                }
229            },
230        )
231        .collect();
232
233    // Sort by RRF score descending
234    results.sort_by(|a, b| {
235        b.rrf_score
236            .partial_cmp(&a.rrf_score)
237            .unwrap_or(std::cmp::Ordering::Equal)
238    });
239
240    results
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    fn make_vector_result(id: u32, score: f32) -> SearchResult {
248        SearchResult {
249            id,
250            score,
251            path: format!("file_{}.rs", id),
252            content: format!("content {}", id),
253            start_line: 1,
254            end_line: 10,
255            kind: "function".to_string(),
256            signature: None,
257            context_prev: None,
258            context_next: None,
259            distance: 0.0,
260            context: None,
261            docstring: None,
262            hash: String::new(),
263        }
264    }
265
266    fn make_fts_result(id: u32, score: f32) -> FtsResult {
267        FtsResult {
268            chunk_id: id,
269            score,
270        }
271    }
272
273    #[test]
274    fn test_rrf_fusion_basic() {
275        let vector_results = vec![
276            make_vector_result(1, 0.9),
277            make_vector_result(2, 0.8),
278            make_vector_result(3, 0.7),
279        ];
280
281        let fts_results = vec![
282            make_fts_result(2, 10.0), // ID 2 is top in FTS
283            make_fts_result(1, 8.0),
284            make_fts_result(4, 6.0), // ID 4 only in FTS
285        ];
286
287        let fused = rrf_fusion(&vector_results, &fts_results, 20.0);
288
289        // ID 2 should be top (rank 1 in FTS, rank 2 in vector)
290        // ID 1 should be second (rank 1 in vector, rank 2 in FTS)
291        assert!(!fused.is_empty());
292
293        // Find IDs 1 and 2
294        let id1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
295        let id2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
296
297        // Both should have contributions from both sources
298        assert!(id1.vector_rank.is_some());
299        assert!(id1.fts_rank.is_some());
300        assert!(id2.vector_rank.is_some());
301        assert!(id2.fts_rank.is_some());
302
303        // ID 4 should only be in FTS
304        let id4 = fused.iter().find(|r| r.chunk_id == 4).unwrap();
305        assert!(id4.vector_rank.is_none());
306        assert!(id4.fts_rank.is_some());
307    }
308
309    #[test]
310    fn test_rrf_score_calculation() {
311        // With k=20:
312        // Rank 1: 1/(20+1) = 0.0476
313        // Rank 2: 1/(20+2) = 0.0454
314        let vector_results = vec![make_vector_result(1, 0.9)];
315        let fts_results = vec![make_fts_result(1, 10.0)];
316
317        let fused = rrf_fusion(&vector_results, &fts_results, 20.0);
318
319        assert_eq!(fused.len(), 1);
320        let result = &fused[0];
321
322        // Should be sum of both contributions
323        let expected = 1.0 / 21.0 + 1.0 / 21.0;
324        assert!((result.rrf_score - expected).abs() < 0.0001);
325    }
326
327    #[test]
328    fn test_vector_only() {
329        let vector_results = vec![make_vector_result(1, 0.9), make_vector_result(2, 0.8)];
330
331        let results = vector_only(&vector_results);
332
333        assert_eq!(results.len(), 2);
334        assert_eq!(results[0].chunk_id, 1);
335        assert_eq!(results[0].rrf_score, 0.9);
336        assert!(results[0].fts_score.is_none());
337    }
338}