use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering};
use hf_hub::api::sync::{Api, ApiBuilder};
use ndarray::{Array2, ArrayViewD};
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use mnem_core::rerank::{RerankError, Reranker};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RerankerModel {
MsMarcoMiniLmL6V2,
BgeRerankerV2M3,
BgeRerankerBase,
}
impl RerankerModel {
fn repo_id(self) -> &'static str {
match self {
Self::MsMarcoMiniLmL6V2 => "Xenova/ms-marco-MiniLM-L-6-v2",
Self::BgeRerankerV2M3 => "BAAI/bge-reranker-v2-m3",
Self::BgeRerankerBase => "BAAI/bge-reranker-base",
}
}
fn onnx_path(self) -> &'static str {
match self {
Self::MsMarcoMiniLmL6V2 => "onnx/model.onnx",
Self::BgeRerankerV2M3 => "onnx/model.onnx",
Self::BgeRerankerBase => "onnx/model.onnx",
}
}
fn wire_id(self) -> &'static str {
match self {
Self::MsMarcoMiniLmL6V2 => "onnx:ms-marco-MiniLM-L-6-v2",
Self::BgeRerankerV2M3 => "onnx:bge-reranker-v2-m3",
Self::BgeRerankerBase => "onnx:bge-reranker-base",
}
}
pub const fn default_max_length(self) -> usize {
512
}
pub const fn positional_limit(self) -> usize {
match self {
Self::MsMarcoMiniLmL6V2 => 512,
Self::BgeRerankerV2M3 => 8192,
Self::BgeRerankerBase => 512,
}
}
}
struct ModelFiles {
model_onnx: PathBuf,
tokenizer_json: PathBuf,
}
fn hf_api() -> Result<Api, RerankError> {
ApiBuilder::new()
.build()
.map_err(|e| RerankError::Config(format!("hf-hub init: {e}")))
}
fn fetch_files(kind: RerankerModel) -> Result<ModelFiles, RerankError> {
let api = hf_api()?;
let repo = api.model(kind.repo_id().to_string());
let model_onnx = repo
.get(kind.onnx_path())
.or_else(|_| repo.get("model.onnx"))
.map_err(|e| RerankError::Config(format!("download {} onnx: {e}", kind.repo_id())))?;
let tokenizer_json = repo
.get("tokenizer.json")
.map_err(|e| RerankError::Config(format!("download tokenizer.json: {e}")))?;
Ok(ModelFiles {
model_onnx,
tokenizer_json,
})
}
fn load_tokenizer(path: &Path, max_len: usize) -> Result<Tokenizer, RerankError> {
let mut tok = Tokenizer::from_file(path)
.map_err(|e| RerankError::Config(format!("tokenizer.json load: {e}")))?;
tok.with_truncation(Some(TruncationParams {
max_length: max_len,
..Default::default()
}))
.map_err(|e| RerankError::Config(format!("tokenizer truncation: {e}")))?;
tok.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
..Default::default()
}));
Ok(tok)
}
const ENV_RERANK_MAX_LEN: &str = "MNEM_ONNX_RERANK_MAX_LEN";
fn resolve_max_length(kind: RerankerModel, override_: Option<usize>) -> usize {
let ceiling = kind.positional_limit();
let requested = override_
.or_else(|| {
std::env::var(ENV_RERANK_MAX_LEN)
.ok()
.and_then(|s| s.parse::<usize>().ok())
})
.unwrap_or_else(|| kind.default_max_length());
if requested == 0 {
eprintln!(
"mnem-rerank: requested max_length=0 for {}; snapping to default {}",
kind.wire_id(),
kind.default_max_length()
);
return kind.default_max_length();
}
if requested > ceiling {
eprintln!(
"mnem-rerank: requested max_length={requested} exceeds {}'s positional limit {ceiling}; clamping",
kind.wire_id()
);
return ceiling;
}
requested
}
struct OnnxSession {
session: Session,
needs_token_type_ids: bool,
}
impl OnnxSession {
fn open(model_path: &Path) -> Result<Self, RerankError> {
let threads: usize = std::env::var("MNEM_ORT_INTRA_THREADS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1);
let session = Session::builder()
.map_err(|e| RerankError::Config(format!("ort session builder: {e}")))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| RerankError::Config(format!("ort opt level: {e}")))?
.with_intra_threads(threads)
.map_err(|e| RerankError::Config(format!("ort intra threads: {e}")))?
.commit_from_file(model_path)
.map_err(|e| {
RerankError::Config(format!("ort commit {}: {e}", model_path.display()))
})?;
let needs_token_type_ids = session
.inputs()
.iter()
.any(|i| i.name() == "token_type_ids");
Ok(Self {
session,
needs_token_type_ids,
})
}
}
pub struct OnnxReranker {
kind: RerankerModel,
tokenizer: Tokenizer,
session: Mutex<OnnxSession>,
model_fq: String,
max_len: usize,
warned_truncation: AtomicBool,
}
impl std::fmt::Debug for OnnxReranker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxReranker")
.field("kind", &self.kind)
.field("model_fq", &self.model_fq)
.field("max_len", &self.max_len)
.finish()
}
}
static ORT_INIT: OnceLock<()> = OnceLock::new();
fn ensure_ort_init() {
ORT_INIT.get_or_init(|| {
});
}
impl OnnxReranker {
pub fn new(kind: RerankerModel) -> Result<Self, RerankError> {
Self::with_max_length(kind, None)
}
pub fn with_max_length(
kind: RerankerModel,
max_length: Option<usize>,
) -> Result<Self, RerankError> {
ensure_ort_init();
let max_len = resolve_max_length(kind, max_length);
let files = fetch_files(kind)?;
let tokenizer = load_tokenizer(&files.tokenizer_json, max_len)?;
let session = OnnxSession::open(&files.model_onnx)?;
let model_fq = kind.wire_id().to_string();
Ok(Self {
kind,
tokenizer,
session: Mutex::new(session),
model_fq,
max_len,
warned_truncation: AtomicBool::new(false),
})
}
pub fn max_length(&self) -> usize {
self.max_len
}
}
impl Reranker for OnnxReranker {
fn model(&self) -> &str {
&self.model_fq
}
fn rerank(&self, query: &str, candidates: &[&str]) -> Result<Vec<f32>, RerankError> {
if candidates.is_empty() {
return Ok(Vec::new());
}
let pairs: Vec<(&str, &str)> = candidates.iter().map(|c| (query, *c)).collect();
let encodings = self
.tokenizer
.encode_batch(pairs, true)
.map_err(|e| RerankError::Inference(format!("tokenize batch: {e}")))?;
let batch = encodings.len();
let seq_len = encodings
.first()
.map(|e| e.get_ids().len())
.ok_or_else(|| RerankError::Inference("empty encoding batch".into()))?;
if seq_len >= self.max_len && !self.warned_truncation.swap(true, Ordering::Relaxed) {
eprintln!(
"mnem-rerank: batch filled max_length={} on {}; pair tail truncated. \
Raise via MNEM_ONNX_RERANK_MAX_LEN (<= {}) or chunk upstream.",
self.max_len,
self.kind.wire_id(),
self.kind.positional_limit()
);
}
let mut ids_flat: Vec<i64> = Vec::with_capacity(batch * seq_len);
let mut mask_flat: Vec<i64> = Vec::with_capacity(batch * seq_len);
let mut type_flat: Vec<i64> = Vec::with_capacity(batch * seq_len);
for enc in &encodings {
ids_flat.extend(enc.get_ids().iter().map(|&x| x as i64));
mask_flat.extend(enc.get_attention_mask().iter().map(|&x| x as i64));
type_flat.extend(enc.get_type_ids().iter().map(|&x| x as i64));
}
let ids_arr = Array2::from_shape_vec((batch, seq_len), ids_flat)
.map_err(|e| RerankError::Inference(format!("ids reshape: {e}")))?;
let mask_arr = Array2::from_shape_vec((batch, seq_len), mask_flat)
.map_err(|e| RerankError::Inference(format!("mask reshape: {e}")))?;
let mut session = self
.session
.lock()
.map_err(|_| RerankError::Inference("session mutex poisoned".into()))?;
let mut inputs: Vec<(&'static str, Value)> = Vec::with_capacity(3);
inputs.push((
"input_ids",
Value::from_array(ids_arr)
.map_err(|e| RerankError::Inference(format!("ids tensor: {e}")))?
.into_dyn(),
));
inputs.push((
"attention_mask",
Value::from_array(mask_arr)
.map_err(|e| RerankError::Inference(format!("mask tensor: {e}")))?
.into_dyn(),
));
if session.needs_token_type_ids {
let type_arr = Array2::from_shape_vec((batch, seq_len), type_flat)
.map_err(|e| RerankError::Inference(format!("type_ids reshape: {e}")))?;
inputs.push((
"token_type_ids",
Value::from_array(type_arr)
.map_err(|e| RerankError::Inference(format!("type_ids tensor: {e}")))?
.into_dyn(),
));
}
let outputs = session
.session
.run(inputs)
.map_err(|e| RerankError::Inference(format!("ort run: {e}")))?;
let value = outputs
.iter()
.find(|(name, _)| *name == "logits")
.map(|(_, v)| v)
.or_else(|| outputs.iter().next().map(|(_, v)| v))
.ok_or_else(|| RerankError::Decode("no logits output".into()))?;
let view: ArrayViewD<'_, f32> = value
.try_extract_array::<f32>()
.map_err(|e| RerankError::Decode(format!("extract logits: {e}")))?;
let shape = view.shape().to_vec();
let buffer: Vec<f32> = view.iter().copied().collect();
let scores = extract_pair_scores(&buffer, &shape, batch)?;
Ok(scores)
}
}
fn extract_pair_scores(
buffer: &[f32],
shape: &[usize],
batch: usize,
) -> Result<Vec<f32>, RerankError> {
match shape {
[n] if *n == batch => Ok(buffer.to_vec()),
[n, 1] if *n == batch => Ok(buffer.to_vec()),
[n, labels] if *n == batch && *labels >= 2 => {
let mut out = Vec::with_capacity(batch);
let stride = *labels;
for n in 0..batch {
out.push(buffer[n * stride + (stride - 1)]);
}
Ok(out)
}
_ => Err(RerankError::Decode(format!(
"unexpected logits shape {shape:?} for batch={batch}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_rank_1_batch_scores() {
let buf = vec![0.1, 0.2, 0.3];
let out = extract_pair_scores(&buf, &[3], 3).unwrap();
assert_eq!(out, vec![0.1, 0.2, 0.3]);
}
#[test]
fn extract_rank_2_single_label() {
let buf = vec![0.1, 0.2, 0.3];
let out = extract_pair_scores(&buf, &[3, 1], 3).unwrap();
assert_eq!(out, vec![0.1, 0.2, 0.3]);
}
#[test]
fn extract_rank_2_two_class_takes_positive() {
let buf = vec![0.9, 0.1, 0.2, 0.8];
let out = extract_pair_scores(&buf, &[2, 2], 2).unwrap();
assert_eq!(out, vec![0.1, 0.8]);
}
#[test]
fn extract_rejects_mismatched_batch() {
let buf = vec![0.1, 0.2];
let err = extract_pair_scores(&buf, &[3], 2).unwrap_err();
assert!(matches!(err, RerankError::Decode(_)));
}
#[test]
fn extract_rejects_unknown_shape() {
let buf = vec![0.1, 0.2, 0.3, 0.4];
let err = extract_pair_scores(&buf, &[1, 2, 2], 1).unwrap_err();
assert!(matches!(err, RerankError::Decode(_)));
}
#[test]
fn resolve_max_length_uses_default_when_none() {
let n = resolve_max_length(RerankerModel::MsMarcoMiniLmL6V2, None);
assert_eq!(n, 512);
}
#[test]
fn resolve_max_length_passes_through_in_range() {
let n = resolve_max_length(RerankerModel::BgeRerankerV2M3, Some(2048));
assert_eq!(n, 2048);
}
#[test]
fn resolve_max_length_clamps_above_positional_limit() {
let n = resolve_max_length(RerankerModel::MsMarcoMiniLmL6V2, Some(8192));
assert_eq!(n, 512);
}
#[test]
fn resolve_max_length_v2_m3_can_unlock_full_window() {
let n = resolve_max_length(RerankerModel::BgeRerankerV2M3, Some(8192));
assert_eq!(n, 8192);
}
#[test]
fn resolve_max_length_zero_snaps_to_default() {
let n = resolve_max_length(RerankerModel::BgeRerankerBase, Some(0));
assert_eq!(n, 512);
}
#[test]
fn positional_limits_match_published_windows() {
assert_eq!(RerankerModel::MsMarcoMiniLmL6V2.positional_limit(), 512);
assert_eq!(RerankerModel::BgeRerankerV2M3.positional_limit(), 8192);
assert_eq!(RerankerModel::BgeRerankerBase.positional_limit(), 512);
}
}