use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use ndarray::Array2;
use once_cell::sync::OnceCell;
use ort::session::Session;
use crate::embedder::{create_session, pad_2d_i64, select_provider, ExecutionProvider};
use crate::store::SearchResult;
const DEFAULT_MODEL_REPO: &str = "cross-encoder/ms-marco-MiniLM-L-6-v2";
const MODEL_FILE: &str = "onnx/model.onnx";
const TOKENIZER_FILE: &str = "tokenizer.json";
const MODEL_BLAKE3: &str = "";
const TOKENIZER_BLAKE3: &str = "";
const DEFAULT_RERANKER_BATCH: usize = 32;
fn reranker_batch_size() -> usize {
std::env::var("CQS_RERANKER_BATCH")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(DEFAULT_RERANKER_BATCH)
}
fn model_repo() -> String {
match std::env::var("CQS_RERANKER_MODEL") {
Ok(repo) => {
tracing::info!(model = %repo, "Using custom reranker model");
repo
}
Err(_) => DEFAULT_MODEL_REPO.to_string(),
}
}
#[derive(Debug, thiserror::Error)]
pub enum RerankerError {
#[error("Model download failed: {0}")]
ModelDownload(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Inference error: {0}")]
Inference(String),
#[error("Checksum mismatch for {path}: expected {expected}, got {actual}")]
ChecksumMismatch {
path: String,
expected: String,
actual: String,
},
}
fn ort_err<T>(e: ort::Error<T>) -> RerankerError {
RerankerError::Inference(e.to_string())
}
pub struct Reranker {
session: Mutex<Option<Session>>,
tokenizer: Mutex<Option<Arc<tokenizers::Tokenizer>>>,
model_paths: OnceCell<(PathBuf, PathBuf)>,
provider: ExecutionProvider,
max_length: usize,
}
impl Reranker {
pub fn new() -> Result<Self, RerankerError> {
let provider = select_provider();
let max_length = match std::env::var("CQS_RERANKER_MAX_LENGTH") {
Ok(val) => match val.parse::<usize>() {
Ok(len) => {
tracing::info!(max_length = len, "Using custom reranker max_length");
len
}
Err(e) => {
tracing::warn!(
value = %val,
error = %e,
"Invalid CQS_RERANKER_MAX_LENGTH, using default 512"
);
512
}
},
Err(_) => 512,
};
Ok(Self {
session: Mutex::new(None),
tokenizer: Mutex::new(None),
model_paths: OnceCell::new(),
provider,
max_length,
})
}
pub fn rerank(
&self,
query: &str,
results: &mut Vec<SearchResult>,
limit: usize,
) -> Result<(), RerankerError> {
let scores = {
let passages: Vec<&str> = results.iter().map(|r| r.chunk.content.as_str()).collect();
self.compute_scores(query, &passages)?
};
apply_rerank_scores(results, scores, limit);
Ok(())
}
pub fn rerank_with_passages(
&self,
query: &str,
results: &mut Vec<SearchResult>,
passages: &[&str],
limit: usize,
) -> Result<(), RerankerError> {
let _span = tracing::info_span!(
"rerank",
count = results.len(),
limit,
query_len = query.len()
)
.entered();
if results.len() <= 1 {
return Ok(());
}
if results.len() != passages.len() {
return Err(RerankerError::Inference(format!(
"passages length ({}) must match results length ({})",
passages.len(),
results.len()
)));
}
let Some(scores) = self.compute_scores_opt(query, passages)? else {
return Ok(());
};
apply_rerank_scores(results, scores, limit);
Ok(())
}
fn compute_scores_opt(
&self,
query: &str,
passages: &[&str],
) -> Result<Option<Vec<f32>>, RerankerError> {
let tokenizer = self.tokenizer()?;
let encodings: Vec<tokenizers::Encoding> = passages
.iter()
.map(|passage| {
tokenizer
.encode((query, *passage), true)
.map_err(|e| RerankerError::Tokenizer(e.to_string()))
})
.collect::<Result<Vec<_>, _>>()?;
let overall_max = encodings
.iter()
.map(|e| e.get_ids().len())
.max()
.unwrap_or(0)
.min(self.max_length);
if overall_max == 0 {
return Ok(None); }
let batch_cap = reranker_batch_size();
let mut scores = Vec::with_capacity(passages.len());
for chunk in encodings.chunks(batch_cap) {
scores.extend(self.run_chunk(chunk)?);
}
Ok(Some(scores))
}
fn run_chunk(&self, chunk: &[tokenizers::Encoding]) -> Result<Vec<f32>, RerankerError> {
let batch_size = chunk.len();
debug_assert!(batch_size > 0, "run_chunk called with empty chunk");
let input_ids: Vec<Vec<i64>> = chunk
.iter()
.map(|e| e.get_ids().iter().map(|&id| id as i64).collect())
.collect();
let attention_mask: Vec<Vec<i64>> = chunk
.iter()
.map(|e| e.get_attention_mask().iter().map(|&m| m as i64).collect())
.collect();
let max_len = input_ids
.iter()
.map(|v| v.len())
.max()
.unwrap_or(0)
.min(self.max_length);
if max_len == 0 {
return Ok(vec![sigmoid(0.0); batch_size]);
}
let ids_arr = pad_2d_i64(&input_ids, max_len, 0);
let mask_arr = pad_2d_i64(&attention_mask, max_len, 0);
let type_arr = Array2::<i64>::zeros((batch_size, max_len));
use ort::value::Tensor;
let ids_tensor = Tensor::from_array(ids_arr).map_err(ort_err)?;
let mask_tensor = Tensor::from_array(mask_arr).map_err(ort_err)?;
let type_tensor = Tensor::from_array(type_arr).map_err(ort_err)?;
let mut session_guard = self.session()?;
let session = session_guard
.as_mut()
.expect("session() guarantees initialized after Ok return");
let outputs = session
.run(ort::inputs![
"input_ids" => ids_tensor,
"attention_mask" => mask_tensor,
"token_type_ids" => type_tensor,
])
.map_err(ort_err)?;
if outputs.len() == 0 {
return Err(RerankerError::Inference(
"ONNX model produced no outputs".to_string(),
));
}
let (shape, data) = outputs[0].try_extract_tensor::<f32>().map_err(ort_err)?;
let stride = if shape.len() == 2 {
shape[1] as usize
} else {
1
};
if stride == 0 {
return Err(RerankerError::Inference(
"Model returned zero-width output tensor".to_string(),
));
}
let expected_len = batch_size * stride;
if data.len() < expected_len {
return Err(RerankerError::Inference(format!(
"Model output too short: expected {} elements, got {}",
expected_len,
data.len()
)));
}
let scores: Vec<f32> = (0..batch_size).map(|i| sigmoid(data[i * stride])).collect();
Ok(scores)
}
fn compute_scores(&self, query: &str, passages: &[&str]) -> Result<Vec<f32>, RerankerError> {
if passages.len() <= 1 {
return Ok(Vec::new());
}
Ok(self
.compute_scores_opt(query, passages)?
.unwrap_or_default())
}
fn model_paths(&self) -> Result<&(PathBuf, PathBuf), RerankerError> {
self.model_paths.get_or_try_init(|| {
let _span = tracing::info_span!("reranker_model_download").entered();
use hf_hub::api::sync::Api;
let api = Api::new().map_err(|e| RerankerError::ModelDownload(e.to_string()))?;
let repo = api.model(model_repo());
let model_path = repo
.get(MODEL_FILE)
.map_err(|e| RerankerError::ModelDownload(e.to_string()))?;
let tokenizer_path = repo
.get(TOKENIZER_FILE)
.map_err(|e| RerankerError::ModelDownload(e.to_string()))?;
if !MODEL_BLAKE3.is_empty() || !TOKENIZER_BLAKE3.is_empty() {
let marker = model_path
.parent()
.unwrap_or(std::path::Path::new("."))
.join(".cqs_reranker_verified");
let expected_marker = format!("{}\n{}", MODEL_BLAKE3, TOKENIZER_BLAKE3);
let already_verified = std::fs::read_to_string(&marker)
.map(|s| s == expected_marker)
.unwrap_or(false);
if !already_verified {
if !MODEL_BLAKE3.is_empty() {
verify_checksum(&model_path, MODEL_BLAKE3)?;
}
if !TOKENIZER_BLAKE3.is_empty() {
verify_checksum(&tokenizer_path, TOKENIZER_BLAKE3)?;
}
let _ = std::fs::write(&marker, &expected_marker);
}
}
tracing::info!(model = %model_path.display(), "Reranker model ready");
Ok((model_path, tokenizer_path))
})
}
fn session(&self) -> Result<std::sync::MutexGuard<'_, Option<Session>>, RerankerError> {
let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
if guard.is_none() {
let _span = tracing::info_span!("reranker_session_init").entered();
let (model_path, _) = self.model_paths()?;
*guard = Some(
create_session(model_path, self.provider)
.map_err(|e| RerankerError::Inference(e.to_string()))?,
);
tracing::info!("Reranker session initialized");
}
Ok(guard)
}
pub fn clear_session(&self) {
let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
*guard = None;
let mut tok = self.tokenizer.lock().unwrap_or_else(|p| p.into_inner());
*tok = None;
tracing::info!("Reranker session and tokenizer cleared");
}
fn tokenizer(&self) -> Result<Arc<tokenizers::Tokenizer>, RerankerError> {
{
let guard = self.tokenizer.lock().unwrap_or_else(|p| p.into_inner());
if let Some(t) = guard.as_ref() {
return Ok(Arc::clone(t));
}
}
let (_, tokenizer_path) = self.model_paths()?;
let _span = tracing::info_span!("reranker_tokenizer_init").entered();
let loaded = Arc::new(
tokenizers::Tokenizer::from_file(tokenizer_path)
.map_err(|e| RerankerError::Tokenizer(e.to_string()))?,
);
let mut guard = self.tokenizer.lock().unwrap_or_else(|p| p.into_inner());
if let Some(existing) = guard.as_ref() {
return Ok(Arc::clone(existing));
}
*guard = Some(Arc::clone(&loaded));
Ok(loaded)
}
}
fn verify_checksum(path: &std::path::Path, expected: &str) -> Result<(), RerankerError> {
let mut file = std::fs::File::open(path).map_err(|e| {
RerankerError::ModelDownload(format!("Cannot open {}: {}", path.display(), e))
})?;
let mut hasher = blake3::Hasher::new();
std::io::copy(&mut file, &mut hasher).map_err(|e| {
RerankerError::ModelDownload(format!("Read error on {}: {}", path.display(), e))
})?;
let actual = hasher.finalize().to_hex().to_string();
if actual != expected {
return Err(RerankerError::ChecksumMismatch {
path: path.display().to_string(),
expected: expected.to_string(),
actual,
});
}
Ok(())
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn apply_rerank_scores(results: &mut Vec<SearchResult>, scores: Vec<f32>, limit: usize) {
if scores.is_empty() {
return;
}
let n = scores.len().min(results.len());
for (i, score) in scores.into_iter().take(n).enumerate() {
results[i].score = score;
}
let batch_size = results.len();
results.sort_by(|a, b| {
b.score
.total_cmp(&a.score)
.then(a.chunk.id.cmp(&b.chunk.id))
});
results.truncate(limit);
tracing::info!(reranked = results.len(), batch_size, "Re-ranking complete");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid_zero() {
let result = sigmoid(0.0);
assert!((result - 0.5).abs() < 1e-6);
}
#[test]
fn test_sigmoid_large_positive() {
let result = sigmoid(10.0);
assert!(result > 0.999);
}
#[test]
fn test_sigmoid_large_negative() {
let result = sigmoid(-10.0);
assert!(result < 0.001);
}
#[test]
fn test_sigmoid_extreme_negative() {
let result = sigmoid(-100.0);
assert!(result >= 0.0 && result.is_finite());
}
#[test]
fn test_sigmoid_nan_does_not_panic() {
let result = sigmoid(f32::NAN);
assert!(result.is_nan(), "sigmoid(NaN) should be NaN, got {result}");
}
#[test]
fn test_sigmoid_infinity_does_not_panic() {
let pos = sigmoid(f32::INFINITY);
assert!(
pos.is_finite() || pos.is_nan(),
"sigmoid(+inf) should not panic"
);
let neg = sigmoid(f32::NEG_INFINITY);
assert!(
neg.is_finite() || neg.is_nan(),
"sigmoid(-inf) should not panic"
);
}
#[test]
fn test_reranker_new() {
let reranker = Reranker::new();
assert!(reranker.is_ok());
}
#[test]
fn test_rerank_empty_results() {
let reranker = Reranker::new().unwrap();
let mut results = Vec::new();
let result = reranker.rerank("test query", &mut results, 10);
assert!(result.is_ok());
assert!(results.is_empty());
}
}