1use crate::engine::EmbeddingEngine;
34use crate::error::{InferenceError, Result};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::path::PathBuf;
41use std::sync::atomic::{AtomicUsize, Ordering};
42use std::sync::Arc;
43use tokenizers::{
44 EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
45};
46use tracing::{info, instrument, warn};
47
48const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
50const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
52const MAX_SEQ_LENGTH: usize = 512;
54const RERANKER_POOL_SIZE: usize = 2;
60const RERANKER_CHUNK_SIZE: usize = 32;
68const RERANKER_ONNX_BATCH_SIZE: usize = 16;
82
83pub struct CrossEncoderEngine {
88 sessions: Vec<Arc<Mutex<Session>>>,
90 tokenizer: Arc<Tokenizer>,
91 has_token_type_ids: bool,
95 next_session: AtomicUsize,
97}
98
99impl CrossEncoderEngine {
100 #[instrument(skip_all)]
105 pub async fn new(cache_dir: Option<String>) -> Result<Self> {
106 info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
107
108 let (tokenizer_path, onnx_path) =
109 tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
110 .await
111 .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
112 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
113
114 info!("Loading reranker tokenizer from {:?}", tokenizer_path);
115 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
116 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
117
118 let padding = PaddingParams {
120 strategy: PaddingStrategy::BatchLongest,
121 pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
122 pad_token: tokenizer
123 .get_padding()
124 .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
125 ..Default::default()
126 };
127 tokenizer.with_padding(Some(padding));
128 let truncation = TruncationParams {
129 max_length: MAX_SEQ_LENGTH,
130 ..Default::default()
131 };
132 let _ = tokenizer.with_truncation(Some(truncation));
133
134 info!(
135 "Loading reranker ONNX model from {:?} (pool_size={}, onnx_batch_size={})",
136 onnx_path, RERANKER_POOL_SIZE, RERANKER_ONNX_BATCH_SIZE
137 );
138
139 let (sessions, has_token_type_ids) =
142 tokio::task::spawn_blocking(move || -> Result<(Vec<Arc<Mutex<Session>>>, bool)> {
143 let raw: Result<Vec<Session>> = (0..RERANKER_POOL_SIZE)
144 .map(|_| {
145 Session::builder()
146 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
147 .with_optimization_level(GraphOptimizationLevel::Level3)
148 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
149 .with_intra_threads(4)
150 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
151 .commit_from_file(&onnx_path)
152 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))
153 })
154 .collect();
155 let raw = raw?;
156 let has_tti = raw[0].inputs().iter().any(|i| i.name() == "token_type_ids");
158 let sessions: Vec<Arc<Mutex<Session>>> =
159 raw.into_iter().map(|s| Arc::new(Mutex::new(s))).collect();
160 Ok((sessions, has_tti))
161 })
162 .await
163 .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
164 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
165
166 info!(
167 has_token_type_ids,
168 pool_size = sessions.len(),
169 onnx_batch_size = RERANKER_ONNX_BATCH_SIZE,
170 "Cross-encoder reranker loaded successfully"
171 );
172
173 Ok(Self {
174 sessions,
175 tokenizer: Arc::new(tokenizer),
176 has_token_type_ids,
177 next_session: AtomicUsize::new(0),
178 })
179 }
180
181 #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
192 pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
193 if passages.is_empty() {
194 return Ok(Vec::new());
195 }
196
197 let pool_len = self.sessions.len();
198 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
201 let tokenizer = Arc::clone(&self.tokenizer);
202 let has_token_type_ids = self.has_token_type_ids;
203 let query_str = query.to_string();
204
205 let chunks: Vec<Vec<String>> = passages
207 .chunks(RERANKER_CHUNK_SIZE)
208 .map(<[String]>::to_vec)
209 .collect();
210
211 let mut handles = Vec::with_capacity(chunks.len());
213 for (i, chunk) in chunks.into_iter().enumerate() {
214 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
215 let tok = Arc::clone(&tokenizer);
216 let q = query_str.clone();
217 handles.push(tokio::task::spawn_blocking(move || {
218 score_pairs_blocking(&session, &tok, &q, &chunk, has_token_type_ids)
219 }));
220 }
221
222 let mut scores = Vec::with_capacity(passages.len());
224 for handle in handles {
225 let chunk_scores = handle
226 .await
227 .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))??;
228 scores.extend(chunk_scores);
229 }
230
231 Ok(scores)
232 }
233
234 pub fn pool_size(&self) -> usize {
236 self.sessions.len()
237 }
238
239 pub fn onnx_batch_size(&self) -> usize {
241 RERANKER_ONNX_BATCH_SIZE
242 }
243}
244
245fn score_pairs_blocking(
253 session: &Arc<Mutex<Session>>,
254 tokenizer: &Tokenizer,
255 query: &str,
256 passages: &[String],
257 has_token_type_ids: bool,
258) -> Result<Vec<f32>> {
259 let total = passages.len();
260 if total == 0 {
261 return Ok(Vec::new());
262 }
263
264 let mut all_scores = Vec::with_capacity(total);
265 let mut sess = session.lock();
269
270 for mini_batch in passages.chunks(RERANKER_ONNX_BATCH_SIZE) {
271 let batch_size = mini_batch.len();
272
273 let inputs: Vec<EncodeInput> = mini_batch
275 .iter()
276 .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
277 .collect();
278
279 let encodings = tokenizer
280 .encode_batch(inputs, true)
281 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
282
283 let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
284 if seq_len == 0 {
285 all_scores.extend(std::iter::repeat_n(0.5f32, batch_size));
286 continue;
287 }
288
289 let mut input_ids = Vec::with_capacity(batch_size * seq_len);
291 let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
292 let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
293
294 for enc in &encodings {
295 input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
296 attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
297 let type_ids = enc.get_type_ids();
298 if type_ids.is_empty() {
299 token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
300 } else {
301 token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
302 }
303 }
304
305 let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
307 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
308 let attention_mask_tensor =
309 Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
310 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
311 let token_type_ids_tensor =
312 Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
313 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
314
315 let mini_scores: Vec<f32> = {
319 let outputs = if has_token_type_ids {
320 sess.run(inputs![
321 "input_ids" => input_ids_tensor,
322 "attention_mask" => attention_mask_tensor,
323 "token_type_ids" => token_type_ids_tensor
324 ])
325 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
326 } else {
327 sess.run(inputs![
328 "input_ids" => input_ids_tensor,
329 "attention_mask" => attention_mask_tensor
330 ])
331 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
332 };
333
334 let (out_shape, logits_slice) = outputs[0]
336 .try_extract_tensor::<f32>()
337 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
338
339 if out_shape.is_empty() || out_shape[0] as usize != batch_size {
340 warn!(
341 "Reranker output shape mismatch: expected [{}, 1], got {:?}",
342 batch_size, out_shape
343 );
344 }
345
346 logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
348 };
350
351 let n_scores = mini_scores.len();
352 if n_scores != batch_size {
353 warn!(
354 "Reranker score count mismatch: expected {}, got {}",
355 batch_size, n_scores
356 );
357 let mut padded = mini_scores;
358 padded.resize(batch_size, 0.5);
359 all_scores.extend(padded);
360 } else {
361 all_scores.extend(mini_scores);
362 }
363 }
364 Ok(all_scores)
367}
368
369#[inline]
371fn sigmoid(x: f32) -> f32 {
372 1.0 / (1.0 + (-x).exp())
373}
374
375fn download_reranker_files(
378 cache_dir: Option<String>,
379) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
380 let cache = match cache_dir {
381 Some(dir) => {
382 let p = PathBuf::from(dir);
383 std::fs::create_dir_all(&p)
384 .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
385 p
386 }
387 None => {
388 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
389 PathBuf::from(home)
390 .join(".cache")
391 .join("huggingface")
392 .join("dakera")
393 .join(RERANKER_REPO_ID.replace('/', "--"))
394 }
395 };
396
397 std::fs::create_dir_all(&cache)
398 .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
399
400 let files = [
401 "tokenizer.json",
402 "tokenizer_config.json",
403 "special_tokens_map.json",
404 RERANKER_ONNX_FILE,
405 ];
406
407 for filename in &files {
408 EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
409 .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
410 }
411
412 let tokenizer_path = cache.join("tokenizer.json");
413 let onnx_path = cache.join(RERANKER_ONNX_FILE);
414 Ok((tokenizer_path, onnx_path))
415}
416
417impl std::fmt::Debug for CrossEncoderEngine {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 f.debug_struct("CrossEncoderEngine")
420 .field("model", &RERANKER_REPO_ID)
421 .field("pool_size", &self.sessions.len())
422 .field("onnx_batch_size", &RERANKER_ONNX_BATCH_SIZE)
423 .finish()
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_sigmoid() {
433 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
434 assert!(sigmoid(10.0) > 0.99);
435 assert!(sigmoid(-10.0) < 0.01);
436 }
437
438 #[test]
439 fn test_chunk_count_exact() {
440 let passages: Vec<String> = (0..64).map(|i| format!("passage {i}")).collect();
442 let chunks: Vec<Vec<String>> = passages
443 .chunks(RERANKER_CHUNK_SIZE)
444 .map(<[String]>::to_vec)
445 .collect();
446 assert_eq!(chunks.len(), 2);
447 assert_eq!(chunks[0].len(), 32);
448 assert_eq!(chunks[1].len(), 32);
449 }
450
451 #[test]
452 fn test_chunk_count_remainder() {
453 let passages: Vec<String> = (0..50).map(|i| format!("passage {i}")).collect();
455 let chunks: Vec<Vec<String>> = passages
456 .chunks(RERANKER_CHUNK_SIZE)
457 .map(<[String]>::to_vec)
458 .collect();
459 assert_eq!(chunks.len(), 2);
460 assert_eq!(chunks[0].len(), 32);
461 assert_eq!(chunks[1].len(), 18);
462 }
463
464 #[test]
465 fn test_chunk_count_small_batch() {
466 let passages: Vec<String> = (0..10).map(|i| format!("passage {i}")).collect();
468 let chunks: Vec<Vec<String>> = passages
469 .chunks(RERANKER_CHUNK_SIZE)
470 .map(<[String]>::to_vec)
471 .collect();
472 assert_eq!(chunks.len(), 1);
473 assert_eq!(chunks[0].len(), 10);
474 }
475
476 #[test]
477 fn test_chunk_order_preserved() {
478 let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
480 let reassembled: Vec<String> = passages
481 .chunks(RERANKER_CHUNK_SIZE)
482 .flat_map(<[String]>::to_vec)
483 .collect();
484 assert_eq!(passages, reassembled);
485 }
486
487 #[test]
488 fn test_pool_size_constant() {
489 const { assert!(RERANKER_POOL_SIZE >= 1) };
490 const { assert!(RERANKER_CHUNK_SIZE >= 1) };
491 }
492
493 #[test]
494 fn test_round_robin_wraps() {
495 let pool_len = RERANKER_POOL_SIZE;
496 for start in 0usize..10 {
499 let idx = start % pool_len;
500 assert!(idx < pool_len);
501 }
502 }
503
504 #[test]
507 fn test_onnx_batch_size_constant_invariants() {
508 const { assert!(RERANKER_ONNX_BATCH_SIZE >= 1) };
513 const { assert!(RERANKER_ONNX_BATCH_SIZE <= RERANKER_CHUNK_SIZE) };
514 }
515
516 #[test]
517 fn test_onnx_mini_batch_count_full_chunk() {
518 let passages: Vec<String> = (0..RERANKER_CHUNK_SIZE).map(|i| format!("p{i}")).collect();
520 let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
521 let expected = RERANKER_CHUNK_SIZE.div_ceil(RERANKER_ONNX_BATCH_SIZE);
522 assert_eq!(mini_batches.len(), expected);
523 for mb in &mini_batches[..mini_batches.len() - 1] {
525 assert_eq!(mb.len(), RERANKER_ONNX_BATCH_SIZE);
526 }
527 }
528
529 #[test]
530 fn test_onnx_mini_batch_count_partial_chunk() {
531 let n = RERANKER_ONNX_BATCH_SIZE + 1;
533 let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
534 let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
535 assert_eq!(mini_batches.len(), 2);
536 assert_eq!(mini_batches[0].len(), RERANKER_ONNX_BATCH_SIZE);
537 assert_eq!(mini_batches[1].len(), 1);
538 }
539
540 #[test]
541 fn test_onnx_mini_batch_count_smaller_than_batch_size() {
542 let n = RERANKER_ONNX_BATCH_SIZE / 2;
544 let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
545 let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
546 assert_eq!(mini_batches.len(), 1);
547 assert_eq!(mini_batches[0].len(), n);
548 }
549
550 #[test]
551 fn test_onnx_mini_batch_order_preserved() {
552 let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
555 let reassembled: Vec<String> = passages
556 .chunks(RERANKER_ONNX_BATCH_SIZE)
557 .flat_map(|mb| mb.to_vec())
558 .collect();
559 assert_eq!(passages, reassembled);
560 }
561
562 #[test]
563 fn test_onnx_mini_batch_total_score_count_matches_input() {
564 for n in [1, 8, 15, 16, 17, 32, 33, 47, 64] {
567 let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
568 let total: usize = passages
569 .chunks(RERANKER_ONNX_BATCH_SIZE)
570 .map(|mb| mb.len())
571 .sum();
572 assert_eq!(total, n, "score count mismatch for n={n}");
573 }
574 }
575
576 #[test]
577 fn test_onnx_batch_size_accessor() {
578 assert_eq!(RERANKER_ONNX_BATCH_SIZE, 16);
583 }
584}