use crate::engine::EmbeddingEngine;
use crate::error::{InferenceError, Result};
use ort::inputs;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::Tensor;
use parking_lot::Mutex;
use std::path::PathBuf;
use std::sync::Arc;
use tokenizers::{
EncodeInput, InputSequence, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
};
use tracing::{info, instrument, warn};
const RERANKER_REPO_ID: &str = "Xenova/bge-reranker-base";
const RERANKER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
const MAX_SEQ_LENGTH: usize = 512;
pub struct CrossEncoderEngine {
session: Arc<Mutex<Session>>,
tokenizer: Arc<Tokenizer>,
has_token_type_ids: bool,
}
impl CrossEncoderEngine {
#[instrument(skip_all)]
pub async fn new(cache_dir: Option<String>) -> Result<Self> {
info!("Initializing cross-encoder reranker: {}", RERANKER_REPO_ID);
let (tokenizer_path, onnx_path) =
tokio::task::spawn_blocking(move || download_reranker_files(cache_dir))
.await
.map_err(|e| InferenceError::ModelLoadError(format!("spawn_blocking: {e}")))?
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
info!("Loading reranker tokenizer from {:?}", tokenizer_path);
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
pad_token: tokenizer
.get_padding()
.map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
..Default::default()
};
tokenizer.with_padding(Some(padding));
let truncation = TruncationParams {
max_length: MAX_SEQ_LENGTH,
..Default::default()
};
let _ = tokenizer.with_truncation(Some(truncation));
info!("Loading reranker ONNX model from {:?}", onnx_path);
let session = Session::builder()
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.with_intra_threads(4)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
.commit_from_file(&onnx_path)
.map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
let has_token_type_ids = session
.inputs()
.iter()
.any(|i| i.name() == "token_type_ids");
info!(
has_token_type_ids,
"Cross-encoder reranker loaded successfully"
);
Ok(Self {
session: Arc::new(Mutex::new(session)),
tokenizer: Arc::new(tokenizer),
has_token_type_ids,
})
}
#[instrument(skip(self, passages), fields(n_passages = passages.len()))]
pub async fn score_pairs(&self, query: &str, passages: &[String]) -> Result<Vec<f32>> {
if passages.is_empty() {
return Ok(Vec::new());
}
let query = query.to_string();
let passages = passages.to_vec();
let tokenizer = Arc::clone(&self.tokenizer);
let session = Arc::clone(&self.session);
let has_token_type_ids = self.has_token_type_ids;
tokio::task::spawn_blocking(move || {
score_pairs_blocking(&session, &tokenizer, &query, &passages, has_token_type_ids)
})
.await
.map_err(|e| InferenceError::InferenceError(format!("spawn_blocking: {e}")))?
}
}
fn score_pairs_blocking(
session: &Arc<Mutex<Session>>,
tokenizer: &Tokenizer,
query: &str,
passages: &[String],
has_token_type_ids: bool,
) -> Result<Vec<f32>> {
let batch_size = passages.len();
let inputs: Vec<EncodeInput> = passages
.iter()
.map(|p| EncodeInput::Dual(InputSequence::from(query), InputSequence::from(p.as_str())))
.collect();
let encodings = tokenizer
.encode_batch(inputs, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
if seq_len == 0 {
return Ok(vec![0.5; batch_size]);
}
let mut input_ids = Vec::with_capacity(batch_size * seq_len);
let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
for enc in &encodings {
input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
let type_ids = enc.get_type_ids();
if type_ids.is_empty() {
token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
} else {
token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
}
}
let input_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], input_ids))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let attention_mask_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], attention_mask))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let token_type_ids_tensor = Tensor::<i64>::from_array(([batch_size, seq_len], token_type_ids))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let scores: Vec<f32> = {
let mut sess = session.lock();
let outputs = if has_token_type_ids {
sess.run(inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor
])
.map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
} else {
sess.run(inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor
])
.map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?
};
let (out_shape, logits_slice) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
if out_shape.is_empty() || out_shape[0] as usize != batch_size {
warn!(
"Reranker output shape mismatch: expected [{}, 1], got {:?}",
batch_size, out_shape
);
}
logits_slice.iter().map(|&logit| sigmoid(logit)).collect()
};
if scores.len() != batch_size {
warn!(
"Reranker score count mismatch: expected {}, got {}",
batch_size,
scores.len()
);
let mut padded = scores;
padded.resize(batch_size, 0.5);
return Ok(padded);
}
Ok(scores)
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn download_reranker_files(
cache_dir: Option<String>,
) -> std::result::Result<(PathBuf, PathBuf), InferenceError> {
let cache = match cache_dir {
Some(dir) => {
let p = PathBuf::from(dir);
std::fs::create_dir_all(&p)
.map_err(|e| InferenceError::ModelLoadError(format!("cache_dir create: {e}")))?;
p
}
None => {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("dakera")
.join(RERANKER_REPO_ID.replace('/', "--"))
}
};
std::fs::create_dir_all(&cache)
.map_err(|e| InferenceError::ModelLoadError(format!("create cache dir: {e}")))?;
let files = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
RERANKER_ONNX_FILE,
];
for filename in &files {
EmbeddingEngine::download_hf_file_pub(RERANKER_REPO_ID, filename, &cache)
.map_err(|e| InferenceError::HubError(format!("download {filename}: {e}")))?;
}
let tokenizer_path = cache.join("tokenizer.json");
let onnx_path = cache.join(RERANKER_ONNX_FILE);
Ok((tokenizer_path, onnx_path))
}
impl std::fmt::Debug for CrossEncoderEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CrossEncoderEngine")
.field("model", &RERANKER_REPO_ID)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
}
}