codesearch/rerank/
neural.rs1use crate::info_print;
7use anyhow::Result;
8use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
9
10pub const RERANK_WEIGHT: f32 = 0.575;
13pub const RRF_WEIGHT: f32 = 0.425;
14
15#[allow(dead_code)] pub struct NeuralReranker {
18 reranker: TextRerank,
19 model_name: String,
20}
21
22impl NeuralReranker {
23 pub fn new() -> Result<Self> {
25 Self::with_model(RerankerModel::JINARerankerV1TurboEn)
26 }
27
28 pub fn with_model(model: RerankerModel) -> Result<Self> {
30 let model_name = model.to_string();
31 info_print!("Loading reranker model: {}", model_name);
32
33 let mut options = RerankInitOptions::default();
34 options.model_name = model;
35 options.show_download_progress = false;
36
37 let reranker = TextRerank::try_new(options)?;
38
39 info_print!("Reranker model loaded successfully!");
40
41 Ok(Self {
42 reranker,
43 model_name,
44 })
45 }
46
47 #[allow(dead_code)] pub fn model_name(&self) -> &str {
50 &self.model_name
51 }
52
53 pub fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<(usize, f32)>> {
57 if documents.is_empty() {
58 return Ok(vec![]);
59 }
60
61 let doc_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
63
64 let results = self.reranker.rerank(
66 query, doc_refs, false, None, )?;
69
70 Ok(results.into_iter().map(|r| (r.index, r.score)).collect())
72 }
73
74 pub fn rerank_and_blend(
78 &mut self,
79 query: &str,
80 documents: &[String],
81 rrf_scores: &[f32],
82 ) -> Result<Vec<(usize, f32)>> {
83 if documents.is_empty() {
84 return Ok(vec![]);
85 }
86
87 assert_eq!(
88 documents.len(),
89 rrf_scores.len(),
90 "Documents and RRF scores must have same length"
91 );
92
93 let rerank_results = self.rerank(query, documents)?;
95
96 let normalized: Vec<(usize, f32)> = rerank_results
98 .iter()
99 .map(|(idx, score)| (*idx, sigmoid(*score)))
100 .collect();
101
102 let rrf_min = rrf_scores.iter().cloned().fold(f32::INFINITY, f32::min);
104 let rrf_max = rrf_scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
105 let rrf_range = (rrf_max - rrf_min).max(0.0001); let mut blended: Vec<(usize, f32)> = normalized
109 .into_iter()
110 .map(|(idx, rerank_norm)| {
111 let rrf_norm = (rrf_scores[idx] - rrf_min) / rrf_range;
112 let blended_score = RERANK_WEIGHT * rerank_norm + RRF_WEIGHT * rrf_norm;
113 (idx, blended_score)
114 })
115 .collect();
116
117 blended.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
119
120 Ok(blended)
121 }
122}
123
124fn sigmoid(x: f32) -> f32 {
126 1.0 / (1.0 + (-x).exp())
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_sigmoid() {
135 assert!((sigmoid(0.0) - 0.5).abs() < 0.0001);
136 assert!(sigmoid(10.0) > 0.99);
137 assert!(sigmoid(-10.0) < 0.01);
138 }
139
140 #[test]
141 #[ignore] fn test_reranker_creation() {
143 let reranker = NeuralReranker::new();
144 assert!(reranker.is_ok());
145 }
146
147 #[test]
148 #[ignore] fn test_rerank_basic() {
150 let mut reranker = NeuralReranker::new().unwrap();
151
152 let query = "How do I authenticate users?";
153 let documents = vec![
154 "fn authenticate(user: &str, password: &str) -> bool { ... }".to_string(),
155 "fn calculate_sum(a: i32, b: i32) -> i32 { a + b }".to_string(),
156 "impl UserAuth for App { fn login(&self, credentials: Credentials) -> Result<Token> }"
157 .to_string(),
158 ];
159
160 let results = reranker.rerank(query, &documents).unwrap();
161
162 assert_eq!(results.len(), 3);
164
165 for i in 0..results.len() - 1 {
167 assert!(results[i].1 >= results[i + 1].1);
168 }
169 }
170}