use std::sync::{Mutex, OnceLock};
use ort::{ep, session::Session, value::Tensor};
use serde::Serialize;
use tokenizers::Tokenizer;
const RERANKER_MODEL: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/reranker.onnx"));
const RERANKER_TOKENIZER: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/reranker_tokenizer.json"));
#[derive(Debug, Serialize)]
pub struct RerankResult {
pub index: usize,
pub score: f64,
pub path: String,
}
struct RerankState {
session: Mutex<Session>,
tokenizer: Tokenizer,
}
static STATE: OnceLock<RerankState> = OnceLock::new();
fn state() -> &'static RerankState {
STATE.get_or_init(|| {
let session = Session::builder()
.expect("session builder")
.with_execution_providers([ep::CPU::default().build()])
.expect("CPU EP")
.commit_from_memory(RERANKER_MODEL)
.expect("load reranker model");
let mut tokenizer =
Tokenizer::from_bytes(RERANKER_TOKENIZER).expect("load reranker tokenizer");
tokenizer
.with_truncation(Some(tokenizers::TruncationParams {
max_length: 512,
..Default::default()
}))
.expect("set truncation");
tokenizer.with_padding(None);
RerankState {
session: Mutex::new(session),
tokenizer,
}
})
}
#[derive(Debug, Serialize)]
pub struct RerankFailure {
pub index: usize,
pub path: String,
pub reason: String,
}
#[derive(Debug, Serialize)]
pub struct RerankReport {
pub scored: Vec<RerankResult>,
pub failed: Vec<RerankFailure>,
}
pub fn rerank(query: &str, documents: &[(String, String)], top_n: usize) -> Vec<RerankResult> {
rerank_with_report(query, documents, top_n).scored
}
pub fn rerank_with_report(
query: &str,
documents: &[(String, String)],
top_n: usize,
) -> RerankReport {
let st = state();
let mut scored: Vec<RerankResult> = Vec::with_capacity(documents.len());
let mut failed: Vec<RerankFailure> = Vec::new();
for (i, (path, text)) in documents.iter().enumerate() {
match score_pair_typed(st, query, text) {
Ok(score) => scored.push(RerankResult {
index: i,
score,
path: path.clone(),
}),
Err(reason) => failed.push(RerankFailure {
index: i,
path: path.clone(),
reason,
}),
}
}
scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_n);
RerankReport { scored, failed }
}
fn score_pair_typed(st: &RerankState, query: &str, document: &str) -> Result<f64, String> {
let encoding = st
.tokenizer
.encode((query, document), true)
.map_err(|e| format!("tokenization failed: {e}"))?;
let ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&m| m as i64)
.collect();
let type_ids: Vec<i64> = encoding
.get_type_ids()
.iter()
.map(|&t| t as i64)
.collect();
let len = ids.len() as i64;
let shape = vec![1i64, len];
let input_ids = Tensor::from_array((shape.clone(), ids.into_boxed_slice()))
.map_err(|e| format!("tensor build failed: {e}"))?;
let attention_mask = Tensor::from_array((shape.clone(), mask.into_boxed_slice()))
.map_err(|e| format!("tensor build failed: {e}"))?;
let token_type_ids = Tensor::from_array((shape, type_ids.into_boxed_slice()))
.map_err(|e| format!("tensor build failed: {e}"))?;
let inputs = ort::inputs! {
"input_ids" => input_ids,
"attention_mask" => attention_mask,
"token_type_ids" => token_type_ids,
};
let mut session = st
.session
.lock()
.map_err(|_| "reranker session mutex poisoned".to_string())?;
let outputs = session
.run(inputs)
.map_err(|e| format!("inference failed: {e}"))?;
let (_, data) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| format!("output extract failed: {e}"))?;
Ok(data[0] as f64)
}