1use crate::engine::EmbeddingEngine;
26use crate::error::{InferenceError, Result};
27use ort::inputs;
28use ort::session::builder::GraphOptimizationLevel;
29use ort::session::Session;
30use ort::value::Tensor;
31use parking_lot::Mutex;
32use std::path::PathBuf;
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::sync::Arc;
35use tokenizers::{
36 EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
37};
38use tracing::{info, instrument, warn};
39
40const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
42const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
44const MAX_SEQ_LENGTH: usize = 512;
46const RERANKER_POOL_SIZE: usize = 2;
52const RERANKER_CHUNK_SIZE: usize = 32;
60
61pub struct CrossEncoderEngine {
66 sessions: Vec<Arc<Mutex<Session>>>,
68 tokenizer: Arc<Tokenizer>,
69 has_token_type_ids: bool,
73 next_session: AtomicUsize,
75}
76
77impl CrossEncoderEngine {
78 #[instrument(skip_all)]
83 pub async fn new(cache_dir: Option<String>) -> Result<Self> {
84 info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
85
86 let (tokenizer_path, onnx_path) =
87 tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
88 .await
89 .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
90 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
91
92 info!("Loading reranker tokenizer from {:?}", tokenizer_path);
93 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
94 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
95
96 let padding = PaddingParams {
98 strategy: PaddingStrategy::BatchLongest,
99 pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
100 pad_token: tokenizer
101 .get_padding()
102 .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
103 ..Default::default()
104 };
105 tokenizer.with_padding(Some(padding));
106 let truncation = TruncationParams {
107 max_length: MAX_SEQ_LENGTH,
108 ..Default::default()
109 };
110 let _ = tokenizer.with_truncation(Some(truncation));
111
112 info!(
113 "Loading reranker ONNX model from {:?} (pool_size={})",
114 onnx_path, RERANKER_POOL_SIZE
115 );
116
117 let (sessions, has_token_type_ids) =
120 tokio::task::spawn_blocking(move || -> Result<(Vec<Arc<Mutex<Session>>>, bool)> {
121 let raw: Result<Vec<Session>> = (0..RERANKER_POOL_SIZE)
122 .map(|_| {
123 Session::builder()
124 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
125 .with_optimization_level(GraphOptimizationLevel::Level3)
126 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
127 .with_intra_threads(4)
128 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
129 .commit_from_file(&onnx_path)
130 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))
131 })
132 .collect();
133 let raw = raw?;
134 let has_tti = raw[0].inputs().iter().any(|i| i.name() == "token_type_ids");
136 let sessions: Vec<Arc<Mutex<Session>>> =
137 raw.into_iter().map(|s| Arc::new(Mutex::new(s))).collect();
138 Ok((sessions, has_tti))
139 })
140 .await
141 .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
142 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
143
144 info!(
145 has_token_type_ids,
146 pool_size = sessions.len(),
147 "Cross-encoder reranker loaded successfully"
148 );
149
150 Ok(Self {
151 sessions,
152 tokenizer: Arc::new(tokenizer),
153 has_token_type_ids,
154 next_session: AtomicUsize::new(0),
155 })
156 }
157
158 #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
167 pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
168 if passages.is_empty() {
169 return Ok(Vec::new());
170 }
171
172 let pool_len = self.sessions.len();
173 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
176 let tokenizer = Arc::clone(&self.tokenizer);
177 let has_token_type_ids = self.has_token_type_ids;
178 let query_str = query.to_string();
179
180 let chunks: Vec<Vec<String>> = passages
182 .chunks(RERANKER_CHUNK_SIZE)
183 .map(<[String]>::to_vec)
184 .collect();
185
186 let mut handles = Vec::with_capacity(chunks.len());
188 for (i, chunk) in chunks.into_iter().enumerate() {
189 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
190 let tok = Arc::clone(&tokenizer);
191 let q = query_str.clone();
192 handles.push(tokio::task::spawn_blocking(move || {
193 score_pairs_blocking(&session, &tok, &q, &chunk, has_token_type_ids)
194 }));
195 }
196
197 let mut scores = Vec::with_capacity(passages.len());
199 for handle in handles {
200 let chunk_scores = handle
201 .await
202 .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))??;
203 scores.extend(chunk_scores);
204 }
205
206 Ok(scores)
207 }
208
209 pub fn pool_size(&self) -> usize {
211 self.sessions.len()
212 }
213}
214
215fn score_pairs_blocking(
217 session: &Arc<Mutex<Session>>,
218 tokenizer: &Tokenizer,
219 query: &str,
220 passages: &[String],
221 has_token_type_ids: bool,
222) -> Result<Vec<f32>> {
223 let batch_size = passages.len();
224
225 let inputs: Vec<EncodeInput> = passages
227 .iter()
228 .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
229 .collect();
230
231 let encodings = tokenizer
232 .encode_batch(inputs, true)
233 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
234
235 let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
236 if seq_len == 0 {
237 return Ok(vec![0.5; batch_size]);
238 }
239
240 let mut input_ids = Vec::with_capacity(batch_size * seq_len);
242 let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
243 let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
244
245 for enc in &encodings {
246 input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
247 attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
248 let type_ids = enc.get_type_ids();
249 if type_ids.is_empty() {
250 token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
251 } else {
252 token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
253 }
254 }
255
256 let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
258 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
259 let attention_mask_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
260 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
261 let token_type_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
262 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
263
264 let scores: Vec<f32> = {
267 let mut sess = session.lock();
268 let outputs = if has_token_type_ids {
269 sess.run(inputs![
270 "input_ids" => input_ids_tensor,
271 "attention_mask" => attention_mask_tensor,
272 "token_type_ids" => token_type_ids_tensor
273 ])
274 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
275 } else {
276 sess.run(inputs![
277 "input_ids" => input_ids_tensor,
278 "attention_mask" => attention_mask_tensor
279 ])
280 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
281 };
282
283 let (out_shape, logits_slice) = outputs[0]
285 .try_extract_tensor::<f32>()
286 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
287
288 if out_shape.is_empty() || out_shape[0] as usize != batch_size {
289 warn!(
290 "Reranker output shape mismatch: expected [{}, 1], got {:?}",
291 batch_size, out_shape
292 );
293 }
294
295 logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
297 };
299
300 if scores.len() != batch_size {
301 warn!(
302 "Reranker score count mismatch: expected {}, got {}",
303 batch_size,
304 scores.len()
305 );
306 let mut padded = scores;
307 padded.resize(batch_size, 0.5);
308 return Ok(padded);
309 }
310
311 Ok(scores)
312}
313
314#[inline]
316fn sigmoid(x: f32) -> f32 {
317 1.0 / (1.0 + (-x).exp())
318}
319
320fn download_reranker_files(
323 cache_dir: Option<String>,
324) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
325 let cache = match cache_dir {
326 Some(dir) => {
327 let p = PathBuf::from(dir);
328 std::fs::create_dir_all(&p)
329 .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
330 p
331 }
332 None => {
333 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
334 PathBuf::from(home)
335 .join(".cache")
336 .join("huggingface")
337 .join("dakera")
338 .join(RERANKER_REPO_ID.replace('/', "--"))
339 }
340 };
341
342 std::fs::create_dir_all(&cache)
343 .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
344
345 let files = [
346 "tokenizer.json",
347 "tokenizer_config.json",
348 "special_tokens_map.json",
349 RERANKER_ONNX_FILE,
350 ];
351
352 for filename in &files {
353 EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
354 .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
355 }
356
357 let tokenizer_path = cache.join("tokenizer.json");
358 let onnx_path = cache.join(RERANKER_ONNX_FILE);
359 Ok((tokenizer_path, onnx_path))
360}
361
362impl std::fmt::Debug for CrossEncoderEngine {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.debug_struct("CrossEncoderEngine")
365 .field("model", &RERANKER_REPO_ID)
366 .field("pool_size", &self.sessions.len())
367 .finish()
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_sigmoid() {
377 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
378 assert!(sigmoid(10.0) > 0.99);
379 assert!(sigmoid(-10.0) < 0.01);
380 }
381
382 #[test]
383 fn test_chunk_count_exact() {
384 let passages: Vec<String> = (0..64).map(|i| format!("passage {i}")).collect();
386 let chunks: Vec<Vec<String>> = passages
387 .chunks(RERANKER_CHUNK_SIZE)
388 .map(<[String]>::to_vec)
389 .collect();
390 assert_eq!(chunks.len(), 2);
391 assert_eq!(chunks[0].len(), 32);
392 assert_eq!(chunks[1].len(), 32);
393 }
394
395 #[test]
396 fn test_chunk_count_remainder() {
397 let passages: Vec<String> = (0..50).map(|i| format!("passage {i}")).collect();
399 let chunks: Vec<Vec<String>> = passages
400 .chunks(RERANKER_CHUNK_SIZE)
401 .map(<[String]>::to_vec)
402 .collect();
403 assert_eq!(chunks.len(), 2);
404 assert_eq!(chunks[0].len(), 32);
405 assert_eq!(chunks[1].len(), 18);
406 }
407
408 #[test]
409 fn test_chunk_count_small_batch() {
410 let passages: Vec<String> = (0..10).map(|i| format!("passage {i}")).collect();
412 let chunks: Vec<Vec<String>> = passages
413 .chunks(RERANKER_CHUNK_SIZE)
414 .map(<[String]>::to_vec)
415 .collect();
416 assert_eq!(chunks.len(), 1);
417 assert_eq!(chunks[0].len(), 10);
418 }
419
420 #[test]
421 fn test_chunk_order_preserved() {
422 let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
424 let reassembled: Vec<String> = passages
425 .chunks(RERANKER_CHUNK_SIZE)
426 .flat_map(<[String]>::to_vec)
427 .collect();
428 assert_eq!(passages, reassembled);
429 }
430
431 #[test]
432 fn test_pool_size_constant() {
433 const { assert!(RERANKER_POOL_SIZE >= 1) };
434 const { assert!(RERANKER_CHUNK_SIZE >= 1) };
435 }
436
437 #[test]
438 fn test_round_robin_wraps() {
439 let pool_len = RERANKER_POOL_SIZE;
440 for start in 0usize..10 {
443 let idx = start % pool_len;
444 assert!(idx < pool_len);
445 }
446 }
447}