use crate::engine::EmbeddingEngine;
use crate::error::{InferenceError, Result};
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::Tensor;
use parking_lot::Mutex;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokenizers::{
EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
};
use tracing::{info, instrument, warn};
const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
const MAX_SEQ_LENGTH: usize = 512;
const RERANKER_POOL_SIZE: usize = 2;
const RERANKER_CHUNK_SIZE: usize = 32;
const RERANKER_ONNX_BATCH_SIZE: usize = 16;
pub struct CrossEncoderEngine {
sessions: Vec<Arc<Mutex<Session>>>,
tokenizer: Arc<Tokenizer>,
has_token_type_ids: bool,
next_session: AtomicUsize,
}
impl CrossEncoderEngine {
#[instrument(skip_all)]
pub async fn new(cache_dir: Option<String>) -> Result<Self> {
info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
let (tokenizer_path, onnx_path) =
tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
.await
.map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
info!("Loading reranker tokenizer from {:?}", tokenizer_path);
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
pad_token: tokenizer
.get_padding()
.map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
..Default::default()
};
tokenizer.with_padding(Some(padding));
let truncation = TruncationParams {
max_length: MAX_SEQ_LENGTH,
..Default::default()
};
let _ = tokenizer.with_truncation(Some(truncation));
info!(
"Loading reranker ONNX model from {:?} (pool_size={}, onnx_batch_size={})",
onnx_path, RERANKER_POOL_SIZE, RERANKER_ONNX_BATCH_SIZE
);
let (sessions, has_token_type_ids) =
tokio::task::spawn_blocking(move || -> Result<(Vec<Arc<Mutex<Session>>>, bool)> {
let raw: Result<Vec<Session>> = (0..RERANKER_POOL_SIZE)
.map(|_| {
Session::builder()
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_intra_threads(4)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.commit_from_file(&onnx_path)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))
})
.collect();
let raw = raw?;
let has_tti = raw[0].inputs().iter().any(|i| i.name() == "token_type_ids");
let sessions: Vec<Arc<Mutex<Session>>> =
raw.into_iter().map(|s| Arc::new(Mutex::new(s))).collect();
Ok((sessions, has_tti))
})
.await
.map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
info!(
has_token_type_ids,
pool_size = sessions.len(),
onnx_batch_size = RERANKER_ONNX_BATCH_SIZE,
"Cross-encoder reranker loaded successfully"
);
Ok(Self {
sessions,
tokenizer: Arc::new(tokenizer),
has_token_type_ids,
next_session: AtomicUsize::new(0),
})
}
#[instrument(skip(self, passages), fields(n_passages = passages.len()))]
pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
if passages.is_empty() {
return Ok(Vec::new());
}
let pool_len = self.sessions.len();
let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
let tokenizer = Arc::clone(&self.tokenizer);
let has_token_type_ids = self.has_token_type_ids;
let query_str = query.to_string();
let chunks: Vec<Vec<String>> = passages
.chunks(RERANKER_CHUNK_SIZE)
.map(<[String]>::to_vec)
.collect();
let mut handles = Vec::with_capacity(chunks.len());
for (i, chunk) in chunks.into_iter().enumerate() {
let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
let tok = Arc::clone(&tokenizer);
let q = query_str.clone();
handles.push(tokio::task::spawn_blocking(move || {
score_pairs_blocking(&session, &tok, &q, &chunk, has_token_type_ids)
}));
}
let mut scores = Vec::with_capacity(passages.len());
for handle in handles {
let chunk_scores = handle
.await
.map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))??;
scores.extend(chunk_scores);
}
Ok(scores)
}
pub fn pool_size(&self) -> usize {
self.sessions.len()
}
pub fn onnx_batch_size(&self) -> usize {
RERANKER_ONNX_BATCH_SIZE
}
}
fn score_pairs_blocking(
session: &Arc<Mutex<Session>>,
tokenizer: &Tokenizer,
query: &str,
passages: &[String],
has_token_type_ids: bool,
) -> Result<Vec<f32>> {
let total = passages.len();
if total == 0 {
return Ok(Vec::new());
}
let mut all_scores = Vec::with_capacity(total);
let mut sess = session.lock();
for mini_batch in passages.chunks(RERANKER_ONNX_BATCH_SIZE) {
let batch_size = mini_batch.len();
let inputs: Vec<EncodeInput> = mini_batch
.iter()
.map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
.collect();
let encodings = tokenizer
.encode_batch(inputs, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
if seq_len == 0 {
all_scores.extend(std::iter::repeat_n(0.5f32, batch_size));
continue;
}
let mut input_ids = Vec::with_capacity(batch_size * seq_len);
let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
for enc in &encodings {
input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
let type_ids = enc.get_type_ids();
if type_ids.is_empty() {
token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
} else {
token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
}
}
let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let attention_mask_tensor =
Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let token_type_ids_tensor =
Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mini_scores: Vec<f32> = {
let outputs = if has_token_type_ids {
sess.run(inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor
])
.map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
} else {
sess.run(inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor
])
.map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
};
let (out_shape, logits_slice) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
if out_shape.is_empty() || out_shape[0] as usize != batch_size {
warn!(
"Reranker output shape mismatch: expected [{}, 1], got {:?}",
batch_size, out_shape
);
}
logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
};
let n_scores = mini_scores.len();
if n_scores != batch_size {
warn!(
"Reranker score count mismatch: expected {}, got {}",
batch_size, n_scores
);
let mut padded = mini_scores;
padded.resize(batch_size, 0.5);
all_scores.extend(padded);
} else {
all_scores.extend(mini_scores);
}
}
Ok(all_scores)
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn download_reranker_files(
cache_dir: Option<String>,
) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
let cache = match cache_dir {
Some(dir) => {
let p = PathBuf::from(dir);
std::fs::create_dir_all(&p)
.map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
p
}
None => {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("dakera")
.join(RERANKER_REPO_ID.replace('/', "--"))
}
};
std::fs::create_dir_all(&cache)
.map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
let files = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
RERANKER_ONNX_FILE,
];
for filename in &files {
EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
.map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
}
let tokenizer_path = cache.join("tokenizer.json");
let onnx_path = cache.join(RERANKER_ONNX_FILE);
Ok((tokenizer_path, onnx_path))
}
impl std::fmt::Debug for CrossEncoderEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CrossEncoderEngine")
.field("model", &RERANKER_REPO_ID)
.field("pool_size", &self.sessions.len())
.field("onnx_batch_size", &RERANKER_ONNX_BATCH_SIZE)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
}
#[test]
fn test_chunk_count_exact() {
let passages: Vec<String> = (0..64).map(|i| format!("passage {i}")).collect();
let chunks: Vec<Vec<String>> = passages
.chunks(RERANKER_CHUNK_SIZE)
.map(<[String]>::to_vec)
.collect();
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].len(), 32);
assert_eq!(chunks[1].len(), 32);
}
#[test]
fn test_chunk_count_remainder() {
let passages: Vec<String> = (0..50).map(|i| format!("passage {i}")).collect();
let chunks: Vec<Vec<String>> = passages
.chunks(RERANKER_CHUNK_SIZE)
.map(<[String]>::to_vec)
.collect();
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].len(), 32);
assert_eq!(chunks[1].len(), 18);
}
#[test]
fn test_chunk_count_small_batch() {
let passages: Vec<String> = (0..10).map(|i| format!("passage {i}")).collect();
let chunks: Vec<Vec<String>> = passages
.chunks(RERANKER_CHUNK_SIZE)
.map(<[String]>::to_vec)
.collect();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].len(), 10);
}
#[test]
fn test_chunk_order_preserved() {
let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
let reassembled: Vec<String> = passages
.chunks(RERANKER_CHUNK_SIZE)
.flat_map(<[String]>::to_vec)
.collect();
assert_eq!(passages, reassembled);
}
#[test]
fn test_pool_size_constant() {
const { assert!(RERANKER_POOL_SIZE >= 1) };
const { assert!(RERANKER_CHUNK_SIZE >= 1) };
}
#[test]
fn test_round_robin_wraps() {
let pool_len = RERANKER_POOL_SIZE;
for start in 0usize..10 {
let idx = start % pool_len;
assert!(idx < pool_len);
}
}
#[test]
fn test_onnx_batch_size_constant_invariants() {
const { assert!(RERANKER_ONNX_BATCH_SIZE >= 1) };
const { assert!(RERANKER_ONNX_BATCH_SIZE <= RERANKER_CHUNK_SIZE) };
}
#[test]
fn test_onnx_mini_batch_count_full_chunk() {
let passages: Vec<String> = (0..RERANKER_CHUNK_SIZE).map(|i| format!("p{i}")).collect();
let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
let expected = RERANKER_CHUNK_SIZE.div_ceil(RERANKER_ONNX_BATCH_SIZE);
assert_eq!(mini_batches.len(), expected);
for mb in &mini_batches[..mini_batches.len() - 1] {
assert_eq!(mb.len(), RERANKER_ONNX_BATCH_SIZE);
}
}
#[test]
fn test_onnx_mini_batch_count_partial_chunk() {
let n = RERANKER_ONNX_BATCH_SIZE + 1;
let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
assert_eq!(mini_batches.len(), 2);
assert_eq!(mini_batches[0].len(), RERANKER_ONNX_BATCH_SIZE);
assert_eq!(mini_batches[1].len(), 1);
}
#[test]
fn test_onnx_mini_batch_count_smaller_than_batch_size() {
let n = RERANKER_ONNX_BATCH_SIZE / 2;
let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
let mini_batches: Vec<&[String]> = passages.chunks(RERANKER_ONNX_BATCH_SIZE).collect();
assert_eq!(mini_batches.len(), 1);
assert_eq!(mini_batches[0].len(), n);
}
#[test]
fn test_onnx_mini_batch_order_preserved() {
let passages: Vec<String> = (0..70).map(|i| format!("p{i:03}")).collect();
let reassembled: Vec<String> = passages
.chunks(RERANKER_ONNX_BATCH_SIZE)
.flat_map(|mb| mb.to_vec())
.collect();
assert_eq!(passages, reassembled);
}
#[test]
fn test_onnx_mini_batch_total_score_count_matches_input() {
for n in [1, 8, 15, 16, 17, 32, 33, 47, 64] {
let passages: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
let total: usize = passages
.chunks(RERANKER_ONNX_BATCH_SIZE)
.map(|mb| mb.len())
.sum();
assert_eq!(total, n, "score count mismatch for n={n}");
}
}
#[test]
fn test_onnx_batch_size_accessor() {
assert_eq!(RERANKER_ONNX_BATCH_SIZE, 16);
}
}