1pub mod config;
5pub mod error;
7
8#[cfg(test)]
9mod tests;
10
11pub use config::{DEFAULT_THRESHOLD, MAX_SEQ_LEN, RerankerConfig};
12pub use error::RerankerError;
13
14use crate::embedding::bert::BertClassifier;
15use candle_core::Tensor;
16use tokenizers::Tokenizer;
17use tracing::{debug, info};
18
19use crate::embedding::device::select_device;
20use crate::embedding::utils::load_tokenizer_with_truncation;
21
22pub struct Reranker {
24 device: candle_core::Device,
25 config: RerankerConfig,
26 model_loaded: bool,
27 model: Option<BertClassifier>,
28 tokenizer: Option<Tokenizer>,
29}
30
31impl std::fmt::Debug for Reranker {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("Reranker")
34 .field("device", &format!("{:?}", self.device))
35 .field("config", &self.config)
36 .field("model_loaded", &self.model_loaded)
37 .finish()
38 }
39}
40
41impl Reranker {
42 pub fn load(config: RerankerConfig) -> Result<Self, RerankerError> {
44 if let Err(msg) = config.validate() {
45 return Err(RerankerError::InvalidConfig { reason: msg });
46 }
47
48 let device = select_device()?;
49 debug!(?device, "Selected compute device for reranker");
50
51 if let Some(ref model_path) = config.model_path {
52 if !model_path.exists() {
53 return Err(RerankerError::ModelLoadFailed {
54 reason: format!("Reranker model path not found: {}", model_path.display()),
55 });
56 }
57
58 let config_path = model_path.join("config.json");
59 if !config_path.exists() {
60 return Err(RerankerError::ModelLoadFailed {
61 reason: format!("Missing config.json in {}", model_path.display()),
62 });
63 }
64
65 let weights_path = model_path.join("model.safetensors");
66 if !weights_path.exists() {
67 return Err(RerankerError::ModelLoadFailed {
68 reason: format!("Missing model.safetensors in {}", model_path.display()),
69 });
70 }
71
72 info!(
73 model_path = %model_path.display(),
74 threshold = config.threshold,
75 "Loading reranker model"
76 );
77
78 let model = BertClassifier::load(model_path, &device).map_err(|e| {
79 RerankerError::ModelLoadFailed {
80 reason: format!("Failed to load BERT model: {}", e),
81 }
82 })?;
83
84 let tokenizer =
85 load_tokenizer_with_truncation(model_path, MAX_SEQ_LEN).map_err(|e| {
86 RerankerError::ModelLoadFailed {
87 reason: format!("Failed to load tokenizer: {}", e),
88 }
89 })?;
90
91 info!(
92 threshold = config.threshold,
93 "Reranker model loaded successfully"
94 );
95
96 Ok(Self {
97 device,
98 config,
99 model_loaded: true,
100 model: Some(model),
101 tokenizer: Some(tokenizer),
102 })
103 } else {
104 info!("No reranker model path configured, operating in stub mode");
105 Ok(Self::create_stub(device, config))
106 }
107 }
108
109 pub fn stub() -> Result<Self, RerankerError> {
111 Self::load(RerankerConfig::stub())
112 }
113
114 fn create_stub(device: candle_core::Device, config: RerankerConfig) -> Self {
115 Self {
116 device,
117 config,
118 model_loaded: false,
119 model: None,
120 tokenizer: None,
121 }
122 }
123
124 pub fn score(&self, query: &str, candidate: &str) -> Result<f32, RerankerError> {
126 debug!(
127 query_len = query.len(),
128 candidate_len = candidate.len(),
129 model_loaded = self.model_loaded,
130 "Scoring query-candidate pair"
131 );
132
133 if let (Some(model), Some(tokenizer)) = (&self.model, &self.tokenizer) {
134 let tokens = tokenizer.encode((query, candidate), true).map_err(|e| {
135 RerankerError::TokenizationFailed {
136 reason: e.to_string(),
137 }
138 })?;
139
140 let token_ids = tokens.get_ids();
141 let token_ids = Tensor::new(token_ids, &self.device)
142 .map_err(RerankerError::from)?
143 .unsqueeze(0)
144 .map_err(RerankerError::from)?;
145
146 let type_ids = tokens.get_type_ids();
147 let type_ids = Tensor::new(type_ids, &self.device)
148 .map_err(RerankerError::from)?
149 .unsqueeze(0)
150 .map_err(RerankerError::from)?;
151
152 let attention_mask_data = tokens.get_attention_mask();
155 let attention_mask = Tensor::new(attention_mask_data, &self.device)
156 .map_err(RerankerError::from)?
157 .unsqueeze(0)
158 .map_err(RerankerError::from)?;
159
160 let logits = model
161 .forward(&token_ids, &type_ids, Some(&attention_mask))
162 .map_err(|e| RerankerError::InferenceFailed {
163 reason: e.to_string(),
164 })?;
165
166 let score = logits
167 .flatten_all()
168 .map_err(RerankerError::from)?
169 .to_vec1::<f32>()
170 .map_err(RerankerError::from)?[0];
171 return Ok(score);
172 }
173
174 let score = self.compute_placeholder_score(query, candidate);
175
176 debug!(score = score, "Computed score (stub)");
177
178 Ok(score)
179 }
180
181 pub fn rerank(
183 &self,
184 query: &str,
185 candidates: &[&str],
186 ) -> Result<Vec<(usize, f32)>, RerankerError> {
187 debug!(
188 query_len = query.len(),
189 num_candidates = candidates.len(),
190 "Reranking candidates"
191 );
192
193 let mut scored: Vec<(usize, f32)> = candidates
194 .iter()
195 .enumerate()
196 .map(|(idx, candidate)| {
197 let score = self.score(query, candidate)?;
198 Ok((idx, score))
199 })
200 .collect::<Result<Vec<_>, RerankerError>>()?;
201
202 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
203
204 debug!(
205 top_score = scored.first().map(|(_, s)| *s),
206 "Reranking complete"
207 );
208
209 Ok(scored)
210 }
211
212 pub fn rerank_with_threshold(
214 &self,
215 query: &str,
216 candidates: &[&str],
217 ) -> Result<Vec<(usize, f32)>, RerankerError> {
218 let ranked = self.rerank(query, candidates)?;
219 let threshold = self.config.threshold;
220
221 let filtered: Vec<_> = ranked
222 .into_iter()
223 .filter(|(_, score)| *score > threshold)
224 .collect();
225
226 debug!(
227 threshold = threshold,
228 hits = filtered.len(),
229 total = candidates.len(),
230 "Filtered by threshold"
231 );
232
233 Ok(filtered)
234 }
235
236 pub fn is_model_loaded(&self) -> bool {
238 self.model_loaded
239 }
240
241 pub fn threshold(&self) -> f32 {
243 self.config.threshold
244 }
245
246 pub fn config(&self) -> &RerankerConfig {
248 &self.config
249 }
250
251 pub fn device(&self) -> &candle_core::Device {
253 &self.device
254 }
255
256 pub fn is_hit(&self, score: f32) -> bool {
258 score > self.config.threshold
259 }
260
261 fn compute_placeholder_score(&self, query: &str, candidate: &str) -> f32 {
262 use std::collections::HashSet;
263
264 let stop_words: HashSet<&str> = [
265 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
266 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
267 "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
268 "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
269 "below", "between", "under", "again", "further", "then", "once", "here", "there",
270 "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
271 "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just",
272 "and", "but", "if", "or", "because", "until", "while", "what", "which", "who", "whom",
273 "this", "that", "these", "those", "am", "it", "its",
274 ]
275 .into_iter()
276 .collect();
277
278 let query_lower = query.to_lowercase();
279 let query_words: HashSet<&str> = query_lower
280 .split(|c: char| !c.is_alphanumeric())
281 .filter(|w| !w.is_empty() && !stop_words.contains(w))
282 .collect();
283
284 let candidate_lower = candidate.to_lowercase();
285 let candidate_words: HashSet<&str> = candidate_lower
286 .split(|c: char| !c.is_alphanumeric())
287 .filter(|w| !w.is_empty() && !stop_words.contains(w))
288 .collect();
289
290 if query_words.is_empty() {
291 let len_ratio = (query.len().min(candidate.len()) as f32)
292 / (query.len().max(candidate.len()).max(1) as f32);
293 return len_ratio * 0.3;
294 }
295
296 let matches = query_words.intersection(&candidate_words).count();
297 let recall = matches as f32 / query_words.len() as f32;
298
299 let union = query_words.union(&candidate_words).count();
300 let jaccard = if union > 0 {
301 matches as f32 / union as f32
302 } else {
303 0.0
304 };
305
306 let base_score = 0.6 * recall + 0.4 * jaccard;
307
308 let normalized = 1.0 / (1.0 + (-8.0 * (base_score - 0.5)).exp());
309
310 normalized.clamp(0.0, 1.0)
311 }
312}