use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use hf_hub::api::sync::{Api, ApiBuilder};
use ndarray::{Array2, ArrayViewD};
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use mnem_core::sparse::{SparseEmbed, SparseEncoder, SparseError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelKind {
OpensearchDocV3Distill,
OpensearchBiV2Distill,
}
impl ModelKind {
pub fn repo_id(self) -> &'static str {
match self {
Self::OpensearchDocV3Distill => {
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill"
}
Self::OpensearchBiV2Distill => {
"opensearch-project/opensearch-neural-sparse-encoding-v2-distill"
}
}
}
pub fn vocab_id(self) -> &'static str {
match self {
Self::OpensearchDocV3Distill => "opensearch-doc-v3-distill",
Self::OpensearchBiV2Distill => "opensearch-bi-v2-distill",
}
}
pub fn query_is_inference_free(self) -> bool {
matches!(self, Self::OpensearchDocV3Distill)
}
fn activation(self) -> Activation {
match self {
Self::OpensearchDocV3Distill => Activation::DoubleLog,
Self::OpensearchBiV2Distill => Activation::SingleLog,
}
}
pub const fn positional_limit(self) -> usize {
match self {
Self::OpensearchDocV3Distill | Self::OpensearchBiV2Distill => 512,
}
}
pub const fn default_max_length(self) -> usize {
self.positional_limit()
}
}
#[derive(Debug, Clone, Copy)]
enum Activation {
SingleLog,
DoubleLog,
}
impl Activation {
fn apply(self, x: f32) -> f32 {
let relu = x.max(0.0);
match self {
Self::SingleLog => relu.ln_1p(),
Self::DoubleLog => relu.ln_1p().ln_1p(),
}
}
}
#[derive(Debug, Clone)]
struct ModelFiles {
model_onnx: PathBuf,
tokenizer_json: PathBuf,
idf_json: Option<PathBuf>,
}
fn hf_api() -> Result<Api, SparseError> {
ApiBuilder::new()
.build()
.map_err(|e| SparseError::Config(format!("hf-hub init: {e}")))
}
fn fetch_files(kind: ModelKind) -> Result<ModelFiles, SparseError> {
let api = hf_api()?;
let repo = api.model(kind.repo_id().to_string());
let model_onnx = repo
.get("onnx/model.onnx")
.or_else(|_| repo.get("model.onnx"))
.map_err(|e| {
SparseError::Config(format!("download model.onnx from {}: {e}", kind.repo_id()))
})?;
let tokenizer_json = repo
.get("tokenizer.json")
.map_err(|e| SparseError::Config(format!("download tokenizer.json: {e}")))?;
let idf_json = if kind.query_is_inference_free() {
Some(
repo.get("idf.json")
.map_err(|e| SparseError::Config(format!("download idf.json: {e}")))?,
)
} else {
None
};
Ok(ModelFiles {
model_onnx,
tokenizer_json,
idf_json,
})
}
const ENV_SPARSE_MAX_LEN: &str = "MNEM_ONNX_SPARSE_MAX_LEN";
fn resolve_max_length(kind: ModelKind, override_: Option<usize>) -> usize {
let ceiling = kind.positional_limit();
let requested = override_
.or_else(|| {
std::env::var(ENV_SPARSE_MAX_LEN)
.ok()
.and_then(|s| s.parse::<usize>().ok())
})
.unwrap_or_else(|| kind.default_max_length());
if requested == 0 {
eprintln!(
"mnem-sparse: requested max_length=0 for {}; snapping to default {}",
kind.vocab_id(),
kind.default_max_length()
);
return kind.default_max_length();
}
if requested > ceiling {
eprintln!(
"mnem-sparse: requested max_length={requested} exceeds {}'s positional limit {ceiling}; clamping",
kind.vocab_id()
);
return ceiling;
}
requested
}
fn load_tokenizer(path: &Path, max_len: usize) -> Result<Tokenizer, SparseError> {
let mut tok = Tokenizer::from_file(path)
.map_err(|e| SparseError::Config(format!("tokenizer.json load: {e}")))?;
tok.with_truncation(Some(TruncationParams {
max_length: max_len,
..Default::default()
}))
.map_err(|e| SparseError::Config(format!("tokenizer truncation: {e}")))?;
tok.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
..Default::default()
}));
Ok(tok)
}
#[derive(Debug, Clone)]
struct IdfTable {
weights: Vec<f32>,
}
impl IdfTable {
fn load(path: &Path, tokenizer: &Tokenizer) -> Result<Self, SparseError> {
let raw = std::fs::read_to_string(path)
.map_err(|e| SparseError::Config(format!("read idf.json: {e}")))?;
let map: HashMap<String, f32> = serde_json::from_str(&raw)
.map_err(|e| SparseError::Config(format!("parse idf.json: {e}")))?;
let vocab_size = tokenizer.get_vocab_size(true);
let mut weights = vec![0.0_f32; vocab_size];
for (tok_str, idf) in map {
if let Some(id) = tokenizer.token_to_id(&tok_str) {
let idx = id as usize;
if idx < weights.len() {
weights[idx] = idf;
}
}
}
Ok(Self { weights })
}
fn encode_query(
&self,
tokenizer: &Tokenizer,
text: &str,
special_ids: &[u32],
) -> Result<Vec<(u32, f32)>, SparseError> {
let encoded = tokenizer
.encode(text, true)
.map_err(|e| SparseError::Inference(format!("tokenize query: {e}")))?;
let ids = encoded.get_ids();
let mut by_id: HashMap<u32, f32> = HashMap::with_capacity(ids.len());
for &id in ids {
if special_ids.contains(&id) {
continue;
}
let idx = id as usize;
if idx >= self.weights.len() {
continue;
}
let w = self.weights[idx];
if w > 0.0 {
let slot = by_id.entry(id).or_insert(0.0);
if w > *slot {
*slot = w;
}
}
}
let mut out: Vec<(u32, f32)> = by_id.into_iter().collect();
out.sort_by_key(|&(id, _)| id);
Ok(out)
}
}
fn collect_special_ids(tokenizer: &Tokenizer) -> Vec<u32> {
let surfaces = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]"];
surfaces
.iter()
.filter_map(|s| tokenizer.token_to_id(s))
.collect()
}
struct OnnxSession {
session: Session,
needs_token_type_ids: bool,
}
impl OnnxSession {
fn open(model_path: &Path) -> Result<Self, SparseError> {
let threads: usize = std::env::var("MNEM_ORT_INTRA_THREADS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1);
let session = Session::builder()
.map_err(|e| SparseError::Config(format!("ort session builder: {e}")))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| SparseError::Config(format!("ort opt level: {e}")))?
.with_intra_threads(threads)
.map_err(|e| SparseError::Config(format!("ort intra threads: {e}")))?
.commit_from_file(model_path)
.map_err(|e| {
SparseError::Config(format!("ort commit {}: {e}", model_path.display()))
})?;
let needs_token_type_ids = session
.inputs()
.iter()
.any(|i| i.name() == "token_type_ids");
Ok(Self {
session,
needs_token_type_ids,
})
}
fn forward_single(
&mut self,
encoded: &tokenizers::Encoding,
) -> Result<(Vec<f32>, usize, usize), SparseError> {
let seq_len = encoded.get_ids().len();
let ids: Vec<i64> = encoded.get_ids().iter().map(|&x| x as i64).collect();
let mask: Vec<i64> = encoded
.get_attention_mask()
.iter()
.map(|&x| x as i64)
.collect();
let ids_arr: Array2<i64> = Array2::from_shape_vec((1, seq_len), ids)
.map_err(|e| SparseError::Inference(format!("ids reshape: {e}")))?;
let mask_arr: Array2<i64> = Array2::from_shape_vec((1, seq_len), mask)
.map_err(|e| SparseError::Inference(format!("mask reshape: {e}")))?;
let mut inputs: Vec<(&'static str, Value)> = Vec::with_capacity(3);
inputs.push((
"input_ids",
Value::from_array(ids_arr)
.map_err(|e| SparseError::Inference(format!("ids tensor: {e}")))?
.into_dyn(),
));
inputs.push((
"attention_mask",
Value::from_array(mask_arr)
.map_err(|e| SparseError::Inference(format!("mask tensor: {e}")))?
.into_dyn(),
));
if self.needs_token_type_ids {
let type_ids_arr: Array2<i64> = Array2::zeros((1, seq_len));
inputs.push((
"token_type_ids",
Value::from_array(type_ids_arr)
.map_err(|e| SparseError::Inference(format!("type_ids tensor: {e}")))?
.into_dyn(),
));
}
let outputs = self
.session
.run(inputs)
.map_err(|e| SparseError::Inference(format!("ort run: {e}")))?;
let value = outputs
.iter()
.find(|(name, _)| *name == "logits")
.map(|(_, v)| v)
.or_else(|| outputs.iter().next().map(|(_, v)| v))
.ok_or_else(|| SparseError::Inference("no logits output".into()))?;
let view: ArrayViewD<'_, f32> = value
.try_extract_array::<f32>()
.map_err(|e| SparseError::Inference(format!("extract logits: {e}")))?;
let shape = view.shape().to_vec();
let buffer: Vec<f32> = view.iter().copied().collect();
if shape.len() != 3 {
return Err(SparseError::Inference(format!(
"expected rank-3 logits, got shape {:?}",
shape
)));
}
let seq = shape[1];
let vocab = shape[2];
Ok((buffer, seq, vocab))
}
}
fn reduce_doc_logits(
logits: &[f32],
seq_len: usize,
vocab_size: usize,
attention_mask: &[u32],
activation: Activation,
special_ids: &[u32],
) -> Vec<(u32, f32)> {
let mut scores = vec![0.0_f32; vocab_size];
for s in 0..seq_len {
let m = attention_mask.get(s).copied().unwrap_or(0);
if m == 0 {
continue;
}
let row_start = s * vocab_size;
let row = &logits[row_start..row_start + vocab_size];
for v in 0..vocab_size {
let a = activation.apply(row[v]);
if a > scores[v] {
scores[v] = a;
}
}
}
for &id in special_ids {
let idx = id as usize;
if idx < scores.len() {
scores[idx] = 0.0;
}
}
let mut out: Vec<(u32, f32)> = scores
.into_iter()
.enumerate()
.filter_map(|(i, w)| if w > 0.0 { Some((i as u32, w)) } else { None })
.collect();
out.sort_by_key(|&(id, _)| id);
out
}
pub struct OnnxSparseEncoder {
kind: ModelKind,
tokenizer: Tokenizer,
session: std::sync::Mutex<OnnxSession>,
idf: Option<Arc<IdfTable>>,
special_ids: Vec<u32>,
model_fq: String,
max_len: usize,
}
static TOKENIZER_TRUNCATE_WARNED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
fn warn_truncation_once(provider: &str, model: &str, max_len: usize, positional_limit: usize) {
let key = format!("{provider}:{model}");
let set = TOKENIZER_TRUNCATE_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
let mut guard = match set.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
if guard.insert(key) {
eprintln!(
"{provider}: input filled max_length={max_len} on {model}; tail truncated. \
Raise via MNEM_ONNX_SPARSE_MAX_LEN (<= {positional_limit}) or chunk upstream."
);
}
}
impl std::fmt::Debug for OnnxSparseEncoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxSparseEncoder")
.field("kind", &self.kind)
.field("model_fq", &self.model_fq)
.field("has_idf", &self.idf.is_some())
.field("max_len", &self.max_len)
.finish()
}
}
static ORT_INIT: OnceLock<()> = OnceLock::new();
fn ensure_ort_init() {
ORT_INIT.get_or_init(|| {
});
}
impl OnnxSparseEncoder {
pub fn new(kind: ModelKind) -> Result<Self, SparseError> {
Self::with_max_length(kind, None)
}
pub fn with_max_length(
kind: ModelKind,
max_length: Option<usize>,
) -> Result<Self, SparseError> {
ensure_ort_init();
let max_len = resolve_max_length(kind, max_length);
let files = fetch_files(kind)?;
let tokenizer = load_tokenizer(&files.tokenizer_json, max_len)?;
let special_ids = collect_special_ids(&tokenizer);
let idf = match &files.idf_json {
Some(p) => Some(Arc::new(IdfTable::load(p, &tokenizer)?)),
None => None,
};
let session = OnnxSession::open(&files.model_onnx)?;
let model_fq = format!("onnx:{}", kind.vocab_id());
Ok(Self {
kind,
tokenizer,
session: std::sync::Mutex::new(session),
idf,
special_ids,
model_fq,
max_len,
})
}
pub fn max_length(&self) -> usize {
self.max_len
}
pub fn encode_document(&self, text: &str) -> Result<SparseEmbed, SparseError> {
let encoded = self
.tokenizer
.encode(text, true)
.map_err(|e| SparseError::Inference(format!("tokenize doc: {e}")))?;
if encoded.get_ids().len() >= self.max_len {
warn_truncation_once(
"mnem-sparse",
self.kind.vocab_id(),
self.max_len,
self.kind.positional_limit(),
);
}
let mask = encoded.get_attention_mask().to_vec();
let mut session = self
.session
.lock()
.map_err(|_| SparseError::Inference("session mutex poisoned".into()))?;
let (logits, seq_len, vocab_size) = session.forward_single(&encoded)?;
drop(session);
let pairs = reduce_doc_logits(
&logits,
seq_len,
vocab_size,
&mask,
self.kind.activation(),
&self.special_ids,
);
pairs_to_sparse(pairs, self.kind.vocab_id())
}
pub fn encode_query(&self, text: &str) -> Result<SparseEmbed, SparseError> {
if let Some(idf) = &self.idf {
let pairs = idf.encode_query(&self.tokenizer, text, &self.special_ids)?;
return pairs_to_sparse(pairs, self.kind.vocab_id());
}
self.encode_document(text)
}
}
fn pairs_to_sparse(mut pairs: Vec<(u32, f32)>, vocab_id: &str) -> Result<SparseEmbed, SparseError> {
pairs.sort_by_key(|&(id, _)| id);
let mut indices: Vec<u32> = Vec::with_capacity(pairs.len());
let mut values: Vec<f32> = Vec::with_capacity(pairs.len());
for (id, w) in pairs {
indices.push(id);
values.push(w);
}
SparseEmbed::new(indices, values, vocab_id.to_string())
}
impl SparseEncoder for OnnxSparseEncoder {
fn model(&self) -> &str {
&self.model_fq
}
fn vocab_id(&self) -> &str {
self.kind.vocab_id()
}
fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError> {
self.encode_document(text)
}
fn encode_query(&self, text: &str) -> Result<SparseEmbed, SparseError> {
Self::encode_query(self, text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_max_length_uses_default_when_none() {
let n = resolve_max_length(ModelKind::OpensearchDocV3Distill, None);
assert_eq!(n, 512);
}
#[test]
fn resolve_max_length_passes_through_in_range() {
let n = resolve_max_length(ModelKind::OpensearchDocV3Distill, Some(256));
assert_eq!(n, 256);
}
#[test]
fn resolve_max_length_clamps_to_positional_limit() {
let n = resolve_max_length(ModelKind::OpensearchDocV3Distill, Some(8192));
assert_eq!(n, 512);
}
#[test]
fn resolve_max_length_zero_snaps_to_default() {
let n = resolve_max_length(ModelKind::OpensearchBiV2Distill, Some(0));
assert_eq!(n, 512);
}
#[test]
fn positional_limit_and_default_coincide_for_distilbert() {
assert_eq!(
ModelKind::OpensearchDocV3Distill.positional_limit(),
ModelKind::OpensearchDocV3Distill.default_max_length()
);
assert_eq!(
ModelKind::OpensearchBiV2Distill.positional_limit(),
ModelKind::OpensearchBiV2Distill.default_max_length()
);
}
}