1use crate::engine::EmbeddingEngine;
34use crate::error::{InferenceError, Result};
35use ort::execution_providers::CUDAExecutionProvider;
36use ort::inputs;
37use ort::session::builder::GraphOptimizationLevel;
38use ort::session::Session;
39use ort::value::Tensor;
40use parking_lot::Mutex;
41use std::path::PathBuf;
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::{
45 EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
46};
47use tracing::{info, instrument, warn};
48
49const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
51const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
53const MAX_SEQ_LENGTH: usize = 512;
55const RERANKER_POOL_SIZE: usize = 2;
61
62const RERANKER_MAX_CONCURRENT: usize = RERANKER_POOL_SIZE * 3;
70
71const RERANKER_CHUNK_SIZE: usize = 32;
79const RERANKER_ONNX_BATCH_SIZE: usize = 16;
93
94struct ActiveGuard(Arc<AtomicUsize>);
96
97impl Drop for ActiveGuard {
98 fn drop(&mut self) {
99 self.0.fetch_sub(1, Ordering::SeqCst);
100 }
101}
102
103pub struct CrossEncoderEngine {
108 sessions: Vec<Arc<Mutex<Session>>>,
110 tokenizer: Arc<Tokenizer>,
111 has_token_type_ids: bool,
115 next_session: AtomicUsize,
117 active_requests: Arc<AtomicUsize>,
122}
123
124impl CrossEncoderEngine {
125 #[instrument(skip_all)]
130 pub async fn new(cache_dir: Option<String>) -> Result<Self> {
131 info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
132
133 let (tokenizer_path, onnx_path) =
134 tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
135 .await
136 .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
137 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
138
139 info!("Loading reranker tokenizer from {:?}", tokenizer_path);
140 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
141 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
142
143 let padding = PaddingParams {
145 strategy: PaddingStrategy::BatchLongest,
146 pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
147 pad_token: tokenizer
148 .get_padding()
149 .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
150 ..Default::default()
151 };
152 tokenizer.with_padding(Some(padding));
153 let truncation = TruncationParams {
154 max_length: MAX_SEQ_LENGTH,
155 ..Default::default()
156 };
157 let _ = tokenizer.with_truncation(Some(truncation));
158
159 info!(
160 "Loading reranker ONNX model from {:?} (pool_size={}, onnx_batch_size={})",
161 onnx_path, RERANKER_POOL_SIZE, RERANKER_ONNX_BATCH_SIZE
162 );
163
164 let use_gpu = std::env::var("DAKERA_USE_GPU")
165 .map(|v| v == "1")
166 .unwrap_or(false);
167 if use_gpu {
168 info!("CUDA execution provider enabled for reranker (DAKERA_USE_GPU=1)");
169 }
170
171 let (sessions, has_token_type_ids) =
174 tokio::task::spawn_blocking(move || -> Result<(Vec<Arc<Mutex<Session>>>, bool)> {
175 let raw: Result<Vec<Session>> = (0..RERANKER_POOL_SIZE)
176 .map(|_| {
177 let builder = Session::builder()
178 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
179 .with_optimization_level(GraphOptimizationLevel::Level3)
180 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
181 .with_intra_threads(4)
182 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
183
184 let mut builder = if use_gpu {
185 builder
186 .with_execution_providers(
187 [CUDAExecutionProvider::default().build()],
188 )
189 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
190 } else {
191 builder
192 };
193
194 builder
195 .commit_from_file(&onnx_path)
196 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))
197 })
198 .collect();
199 let raw = raw?;
200 let has_tti = raw[0].inputs().iter().any(|i| i.name() == "token_type_ids");
202 let sessions: Vec<Arc<Mutex<Session>>> =
203 raw.into_iter().map(|s| Arc::new(Mutex::new(s))).collect();
204 Ok((sessions, has_tti))
205 })
206 .await
207 .map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
208 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
209
210 info!(
211 has_token_type_ids,
212 pool_size = sessions.len(),
213 onnx_batch_size = RERANKER_ONNX_BATCH_SIZE,
214 "Cross-encoder reranker loaded successfully"
215 );
216
217 Ok(Self {
218 sessions,
219 tokenizer: Arc::new(tokenizer),
220 has_token_type_ids,
221 next_session: AtomicUsize::new(0),
222 active_requests: Arc::new(AtomicUsize::new(0)),
223 })
224 }
225
226 #[instrument(skip(self, passages), fields(n_passages = passages.len()))]
241 pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
242 if passages.is_empty() {
243 return Ok(Vec::new());
244 }
245
246 let prev = self.active_requests.fetch_add(1, Ordering::SeqCst);
249 if prev >= RERANKER_MAX_CONCURRENT {
250 self.active_requests.fetch_sub(1, Ordering::SeqCst);
251 warn!(
252 active = prev,
253 max = RERANKER_MAX_CONCURRENT,
254 "Cross-encoder at capacity — returning Overloaded (API will use unranked results)"
255 );
256 return Err(InferenceError::Overloaded {
257 active: prev,
258 max: RERANKER_MAX_CONCURRENT,
259 });
260 }
261 let _guard = ActiveGuard(Arc::clone(&self.active_requests));
263
264 let pool_len = self.sessions.len();
265 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
268 let tokenizer = Arc::clone(&self.tokenizer);
269 let has_token_type_ids = self.has_token_type_ids;
270 let query_str = query.to_string();
271
272 let chunks: Vec<Vec<String>> = passages
274 .chunks(RERANKER_CHUNK_SIZE)
275 .map(<[String]>::to_vec)
276 .collect();
277
278 let mut handles = Vec::with_capacity(chunks.len());
280 for (i, chunk) in chunks.into_iter().enumerate() {
281 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
282 let tok = Arc::clone(&tokenizer);
283 let q = query_str.clone();
284 handles.push(tokio::task::spawn_blocking(move || {
285 score_pairs_blocking(&session, &tok, &q, &chunk, has_token_type_ids)
286 }));
287 }
288
289 let mut scores = Vec::with_capacity(passages.len());
291 for handle in handles {
292 let chunk_scores = handle
293 .await
294 .map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))??;
295 scores.extend(chunk_scores);
296 }
297
298 Ok(scores)
299 }
300
301 pub fn pool_size(&self) -> usize {
303 self.sessions.len()
304 }
305
306 pub fn onnx_batch_size(&self) -> usize {
308 RERANKER_ONNX_BATCH_SIZE
309 }
310
311 pub fn active_requests_count(&self) -> usize {
314 self.active_requests.load(Ordering::Relaxed)
315 }
316
317 pub fn max_concurrent(&self) -> usize {
319 RERANKER_MAX_CONCURRENT
320 }
321}
322
323fn score_pairs_blocking(
331 session: &Arc<Mutex<Session>>,
332 tokenizer: &Tokenizer,
333 query: &str,
334 passages: &[String],
335 has_token_type_ids: bool,
336) -> Result<Vec<f32>> {
337 let total = passages.len();
338 if total == 0 {
339 return Ok(Vec::new());
340 }
341
342 let mut all_scores = Vec::with_capacity(total);
343 let mut sess = session.lock();
347
348 for mini_batch in passages.chunks(RERANKER_ONNX_BATCH_SIZE) {
349 let batch_size = mini_batch.len();
350
351 let inputs: Vec<EncodeInput> = mini_batch
353 .iter()
354 .map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
355 .collect();
356
357 let encodings = tokenizer
358 .encode_batch(inputs, true)
359 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
360
361 let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
362 if seq_len == 0 {
363 all_scores.extend(std::iter::repeat_n(0.5f32, batch_size));
364 continue;
365 }
366
367 let mut input_ids = Vec::with_capacity(batch_size * seq_len);
369 let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
370 let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
371
372 for enc in &encodings {
373 input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
374 attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
375 let type_ids = enc.get_type_ids();
376 if type_ids.is_empty() {
377 token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
378 } else {
379 token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
380 }
381 }
382
383 let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
385 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
386 let attention_mask_tensor =
387 Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
388 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
389 let token_type_ids_tensor =
390 Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
391 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
392
393 let mini_scores: Vec<f32> = {
397 let outputs = if has_token_type_ids {
398 sess.run(inputs![
399 "input_ids" => input_ids_tensor,
400 "attention_mask" => attention_mask_tensor,
401 "token_type_ids" => token_type_ids_tensor
402 ])
403 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
404 } else {
405 sess.run(inputs![
406 "input_ids" => input_ids_tensor,
407 "attention_mask" => attention_mask_tensor
408 ])
409 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
410 };
411
412 let (out_shape, logits_slice) = outputs[0]
414 .try_extract_tensor::<f32>()
415 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
416
417 if out_shape.is_empty() || out_shape[0] as usize != batch_size {
418 warn!(
419 "Reranker output shape mismatch: expected [{}, 1], got {:?}",
420 batch_size, out_shape
421 );
422 }
423
424 logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
426 };
428
429 let n_scores = mini_scores.len();
430 if n_scores != batch_size {
431 warn!(
432 "Reranker score count mismatch: expected {}, got {}",
433 batch_size, n_scores
434 );
435 let mut padded = mini_scores;
436 padded.resize(batch_size, 0.5);
437 all_scores.extend(padded);
438 } else {
439 all_scores.extend(mini_scores);
440 }
441 }
442 Ok(all_scores)
445}
446
447#[inline]
449fn sigmoid(x: f32) -> f32 {
450 1.0 / (1.0 + (-x).exp())
451}
452
453fn download_reranker_files(
456 cache_dir: Option<String>,
457) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
458 let cache = match cache_dir {
459 Some(dir) => {
460 let p = PathBuf::from(dir);
461 std::fs::create_dir_all(&p)
462 .map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
463 p
464 }
465 None => {
466 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
467 PathBuf::from(home)
468 .join(".cache")
469 .join("huggingface")
470 .join("dakera")
471 .join(RERANKER_REPO_ID.replace('/', "--"))
472 }
473 };
474
475 std::fs::create_dir_all(&cache)
476 .map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
477
478 let files = [
479 "tokenizer.json",
480 "tokenizer_config.json",
481 "special_tokens_map.json",
482 RERANKER_ONNX_FILE,
483 ];
484
485 for filename in &files {
486 EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
487 .map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
488 }
489
490 let tokenizer_path = cache.join("tokenizer.json");
491 let onnx_path = cache.join(RERANKER_ONNX_FILE);
492 Ok((tokenizer_path, onnx_path))
493}
494
495impl std::fmt::Debug for CrossEncoderEngine {
496 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
497 f.debug_struct("CrossEncoderEngine")
498 .field("model", &RERANKER_REPO_ID)
499 .field("pool_size", &self.sessions.len())
500 .field("onnx_batch_size", &RERANKER_ONNX_BATCH_SIZE)
501 .field(
502 "active_requests",
503 &self.active_requests.load(Ordering::Relaxed),
504 )
505 .field("max_concurrent", &RERANKER_MAX_CONCURRENT)
506 .finish()
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
517 fn test_use_gpu_default_is_false() {
518 use std::sync::Mutex;
520 static ENV_LOCK: Mutex<()> = Mutex::new(());
521 let _guard = ENV_LOCK.lock().unwrap();
522 unsafe { std::env::remove_var("DAKERA_USE_GPU") };
523 let use_gpu = std::env::var("DAKERA_USE_GPU")
524 .map(|v| v == "1")
525 .unwrap_or(false);
526 assert!(
527 !use_gpu,
528 "expected CPU default when DAKERA_USE_GPU is unset"
529 );
530 }
531
532 #[test]
533 fn test_use_gpu_enabled_when_env_var_is_1() {
534 use std::sync::Mutex;
535 static ENV_LOCK: Mutex<()> = Mutex::new(());
536 let _guard = ENV_LOCK.lock().unwrap();
537 unsafe { std::env::set_var("DAKERA_USE_GPU", "1") };
538 let use_gpu = std::env::var("DAKERA_USE_GPU")
539 .map(|v| v == "1")
540 .unwrap_or(false);
541 unsafe { std::env::remove_var("DAKERA_USE_GPU") };
542 assert!(use_gpu, "expected GPU mode when DAKERA_USE_GPU=1");
543 }
544
545 #[test]
546 fn test_use_gpu_not_enabled_for_other_values() {
547 use std::sync::Mutex;
548 static ENV_LOCK: Mutex<()> = Mutex::new(());
549 let _guard = ENV_LOCK.lock().unwrap();
550 for val in ["0", "true", "yes", "gpu", ""] {
551 unsafe { std::env::set_var("DAKERA_USE_GPU", val) };
552 let use_gpu = std::env::var("DAKERA_USE_GPU")
553 .map(|v| v == "1")
554 .unwrap_or(false);
555 unsafe { std::env::remove_var("DAKERA_USE_GPU") };
556 assert!(
557 !use_gpu,
558 "expected CPU when DAKERA_USE_GPU={val:?} (only '1' enables GPU)"
559 );
560 }
561 }
562
563 #[test]
564 fn test_sigmoid() {
565 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
566 assert!(sigmoid(10.0) > 0.99);
567 assert!(sigmoid(-10.0) < 0.01);
568 }
569
570 #[test]
571 fn test_chunk_count_exact() {
572 let passages: Vec<String> = (0..64).map(|i| format!("passage {i}")).collect();
574 let chunks: Vec<Vec<String>> = passages
575 .chunks(RERANKER_CHUNK_SIZE)
576 .map(<[String]>::to_vec)
577 .collect();
578 assert_eq!(chunks.len(), 2);
579 assert_eq!(chunks[0].len(), 32);
580 assert_eq!(chunks[1].len(), 32);
581 }
582
583 #[test]
584 fn test_chunk_count_remainder() {
585 let passages: Vec<String> = (0..50).map(|i| format!("passage {i}")).collect();
587 let chunks: Vec<Vec<String>> = passages
588 .chunks(RERANKER_CHUNK_SIZE)
589 .map(<[String]>::to_vec)
590 .collect();
591 assert_eq!(chunks.len(), 2);
592 assert_eq!(chunks[0].len(), 32);
593 assert_eq!(chunks[1].len(), 18);
594 }
595
596 #[test]
597 fn test_chunk_count_small_batch() {
598 let passages: Vec<String> = (0..10).map(|i| format!("passage {i}")).collect();
600 let chunks: Vec<Vec<String>> = passages
601 .chunks(RERANKER_CHUNK_SIZE)
602 .map(<[String]>::to_vec)
603 .collect();
604 assert_eq!(chunks.len(), 1);
605 assert_eq!(chunks[0].len(), 10);
606 }
607
608 #[test]
609 fn test_chunk_order_preserved() {
610 let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
612 let reassembled: Vec<String> = passages
613 .chunks(RERANKER_CHUNK_SIZE)
614 .flat_map(<[String]>::to_vec)
615 .collect();
616 assert_eq!(passages, reassembled);
617 }
618
619 #[test]
620 fn test_pool_size_constant() {
621 const { assert!(RERANKER_POOL_SIZE >= 1) };
622 const { assert!(RERANKER_CHUNK_SIZE >= 1) };
623 }
624
625 #[test]
626 fn test_max_concurrent_exceeds_pool_size() {
627 const { assert!(RERANKER_MAX_CONCURRENT > RERANKER_POOL_SIZE) };
630 const { assert!(RERANKER_MAX_CONCURRENT < 20) };
632 }
633
634 #[test]
635 fn test_active_guard_decrements() {
636 let counter = Arc::new(AtomicUsize::new(1));
637 {
638 let _g = ActiveGuard(Arc::clone(&counter));
639 assert_eq!(counter.load(Ordering::SeqCst), 1);
640 }
641 assert_eq!(counter.load(Ordering::SeqCst), 0);
642 }
643
644 #[test]
645 fn test_round_robin_wraps() {
646 let pool_len = RERANKER_POOL_SIZE;
647 for start in 0usize..10 {
650 let idx = start % pool_len;
651 assert!(idx < pool_len);
652 }
653 }
654
655 #[test]
658 fn test_onnx_batch_size_constant_invariants() {
659 const { assert!(RERANKER_ONNX_BATCH_SIZE >= 1) };
664 const { assert!(RERANKER_ONNX_BATCH_SIZE <= RERANKER_CHUNK_SIZE) };
665 }
666
667 #[test]
668 fn test_onnx_mini_batch_count_full_chunk() {
669 let passages: Vec<String> = (0..RERANKER_CHUNK_SIZE).map(|i| format!("p{i}")).collect();
671 let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
672 let expected = RERANKER_CHUNK_SIZE.div_ceil(RERANKER_ONNX_BATCH_SIZE);
673 assert_eq!(mini_batches.len(), expected);
674 for mb in &mini_batches[..mini_batches.len() - 1] {
676 assert_eq!(mb.len(), RERANKER_ONNX_BATCH_SIZE);
677 }
678 }
679
680 #[test]
681 fn test_onnx_mini_batch_count_partial_chunk() {
682 let n = RERANKER_ONNX_BATCH_SIZE + 1;
684 let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
685 let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
686 assert_eq!(mini_batches.len(), 2);
687 assert_eq!(mini_batches[0].len(), RERANKER_ONNX_BATCH_SIZE);
688 assert_eq!(mini_batches[1].len(), 1);
689 }
690
691 #[test]
692 fn test_onnx_mini_batch_count_smaller_than_batch_size() {
693 let n = RERANKER_ONNX_BATCH_SIZE / 2;
695 let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
696 let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
697 assert_eq!(mini_batches.len(), 1);
698 assert_eq!(mini_batches[0].len(), n);
699 }
700
701 #[test]
702 fn test_onnx_mini_batch_order_preserved() {
703 let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
706 let reassembled: Vec<String> = passages
707 .chunks(RERANKER_ONNX_BATCH_SIZE)
708 .flat_map(|mb| mb.to_vec())
709 .collect();
710 assert_eq!(passages, reassembled);
711 }
712
713 #[test]
714 fn test_onnx_mini_batch_total_score_count_matches_input() {
715 for n in [1, 8, 15, 16, 17, 32, 33, 47, 64] {
718 let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
719 let total: usize = passages
720 .chunks(RERANKER_ONNX_BATCH_SIZE)
721 .map(|mb| mb.len())
722 .sum();
723 assert_eq!(total, n, "score count mismatch for n={n}");
724 }
725 }
726
727 #[test]
728 fn test_onnx_batch_size_accessor() {
729 assert_eq!(RERANKER_ONNX_BATCH_SIZE, 16);
734 }
735}