Skip to main content

codesearch/rerank/
neural.rs

1//! Neural reranking using cross-encoder models
2//!
3//! Provides second-pass reranking using fastembed's TextRerank
4//! with the Jina Reranker v1 Turbo model for improved accuracy.
5
6use crate::info_print;
7use anyhow::Result;
8use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
9
10/// Score blending weights (per osgrep pattern)
11/// 57.5% rerank + 42.5% RRF
12pub const RERANK_WEIGHT: f32 = 0.575;
13pub const RRF_WEIGHT: f32 = 0.425;
14
15/// Neural reranker using cross-encoder model
16#[allow(dead_code)] // model_name used for diagnostics
17pub struct NeuralReranker {
18    reranker: TextRerank,
19    model_name: String,
20}
21
22impl NeuralReranker {
23    /// Create a new neural reranker with the default Jina model
24    pub fn new() -> Result<Self> {
25        Self::with_model(RerankerModel::JINARerankerV1TurboEn)
26    }
27
28    /// Create a neural reranker with a specific model
29    pub fn with_model(model: RerankerModel) -> Result<Self> {
30        let model_name = model.to_string();
31        info_print!("Loading reranker model: {}", model_name);
32
33        let mut options = RerankInitOptions::default();
34        options.model_name = model;
35        options.show_download_progress = false;
36
37        let reranker = TextRerank::try_new(options)?;
38
39        info_print!("Reranker model loaded successfully!");
40
41        Ok(Self {
42            reranker,
43            model_name,
44        })
45    }
46
47    /// Get the model name
48    #[allow(dead_code)] // Reserved for diagnostics
49    pub fn model_name(&self) -> &str {
50        &self.model_name
51    }
52
53    /// Rerank documents given a query
54    ///
55    /// Returns Vec of (original_index, rerank_score) sorted by score descending
56    pub fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
57        if documents.is_empty() {
58            return Ok(vec![]);
59        }
60
61        // Convert to &str references for fastembed API
62        let doc_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
63
64        // Rerank using the cross-encoder
65        let results = self.reranker.rerank(
66            query, doc_refs, false, // Don't return documents (we have them)
67            None,  // Use default batch size
68        )?;
69
70        // Convert to (index, score) pairs
71        Ok(results.into_iter().map(|r| (r.index, r.score)).collect())
72    }
73
74    /// Rerank and blend scores with existing RRF scores
75    ///
76    /// Uses weighted blending: final_score = RERANK_WEIGHT * rerank_score + RRF_WEIGHT * rrf_score
77    pub fn rerank_and_blend(
78        &mut self,
79        query: &str,
80        documents: &[String],
81        rrf_scores: &[f32],
82    ) -> Result<Vec<(usize, f32)>> {
83        if documents.is_empty() {
84            return Ok(vec![]);
85        }
86
87        assert_eq!(
88            documents.len(),
89            rrf_scores.len(),
90            "Documents and RRF scores must have same length"
91        );
92
93        // Get rerank scores
94        let rerank_results = self.rerank(query, documents)?;
95
96        // Normalize rerank scores to [0, 1] using sigmoid (scores can be negative)
97        let normalized: Vec<(usize, f32)> = rerank_results
98            .iter()
99            .map(|(idx, score)| (*idx, sigmoid(*score)))
100            .collect();
101
102        // Normalize RRF scores to [0, 1] (they're already positive, just need min-max)
103        let rrf_min = rrf_scores.iter().cloned().fold(f32::INFINITY, f32::min);
104        let rrf_max = rrf_scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
105        let rrf_range = (rrf_max - rrf_min).max(0.0001); // Avoid division by zero
106
107        // Blend scores
108        let mut blended: Vec<(usize, f32)> = normalized
109            .into_iter()
110            .map(|(idx, rerank_norm)| {
111                let rrf_norm = (rrf_scores[idx] - rrf_min) / rrf_range;
112                let blended_score = RERANK_WEIGHT * rerank_norm + RRF_WEIGHT * rrf_norm;
113                (idx, blended_score)
114            })
115            .collect();
116
117        // Sort by blended score descending
118        blended.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
119
120        Ok(blended)
121    }
122}
123
124/// Sigmoid function to normalize scores to [0, 1]
125fn sigmoid(x: f32) -> f32 {
126    1.0 / (1.0 + (-x).exp())
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_sigmoid() {
135        assert!((sigmoid(0.0) - 0.5).abs() < 0.0001);
136        assert!(sigmoid(10.0) > 0.99);
137        assert!(sigmoid(-10.0) < 0.01);
138    }
139
140    #[test]
141    #[ignore] // Requires model download
142    fn test_reranker_creation() {
143        let reranker = NeuralReranker::new();
144        assert!(reranker.is_ok());
145    }
146
147    #[test]
148    #[ignore] // Requires model download
149    fn test_rerank_basic() {
150        let mut reranker = NeuralReranker::new().unwrap();
151
152        let query = "How do I authenticate users?";
153        let documents = vec![
154            "fn authenticate(user: &str, password: &str) -> bool { ... }".to_string(),
155            "fn calculate_sum(a: i32, b: i32) -> i32 { a + b }".to_string(),
156            "impl UserAuth for App { fn login(&self, credentials: Credentials) -> Result<Token> }"
157                .to_string(),
158        ];
159
160        let results = reranker.rerank(query, &documents).unwrap();
161
162        // Should return all documents
163        assert_eq!(results.len(), 3);
164
165        // Results should be sorted by score descending
166        for i in 0..results.len() - 1 {
167            assert!(results[i].1 >= results[i + 1].1);
168        }
169    }
170}