#[cfg(all(feature = "onnx", feature = "onnx-bundled"))]
compile_error!(
"mnem-embed-providers: enable exactly one of `onnx` or `onnx-bundled` (mutually exclusive)"
);
use std::collections::HashSet;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::sync::OnceLock;
use ndarray::{Array2, ArrayViewD};
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use crate::embedder::Embedder;
use crate::error::EmbedError;
use crate::manifest::EmbedderManifest;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ModelKind {
BgeLargeEnV15,
BgeBaseEnV15,
BgeSmallEnV15,
AllMiniLmL6V2,
}
impl ModelKind {
fn repo_id(self) -> &'static str {
match self {
Self::BgeLargeEnV15 => "Xenova/bge-large-en-v1.5",
Self::BgeBaseEnV15 => "Xenova/bge-base-en-v1.5",
Self::BgeSmallEnV15 => "Xenova/bge-small-en-v1.5",
Self::AllMiniLmL6V2 => "Xenova/all-MiniLM-L6-v2",
}
}
fn onnx_path(self) -> &'static str {
"onnx/model.onnx"
}
fn wire_id(self) -> &'static str {
match self {
Self::BgeLargeEnV15 => "onnx:bge-large-en-v1.5",
Self::BgeBaseEnV15 => "onnx:bge-base-en-v1.5",
Self::BgeSmallEnV15 => "onnx:bge-small-en-v1.5",
Self::AllMiniLmL6V2 => "onnx:all-MiniLM-L6-v2",
}
}
#[must_use]
pub const fn dim(self) -> u32 {
match self {
Self::BgeLargeEnV15 => 1024,
Self::BgeBaseEnV15 => 768,
Self::BgeSmallEnV15 | Self::AllMiniLmL6V2 => 384,
}
}
#[must_use]
pub const fn default_max_length(self) -> usize {
match self {
Self::BgeLargeEnV15 | Self::BgeBaseEnV15 | Self::BgeSmallEnV15 => 512,
Self::AllMiniLmL6V2 => 256,
}
}
#[must_use]
pub const fn positional_limit(self) -> usize {
512
}
#[must_use]
pub const fn noise_floor(self) -> f32 {
match self {
Self::AllMiniLmL6V2 => 0.22,
Self::BgeLargeEnV15 | Self::BgeBaseEnV15 | Self::BgeSmallEnV15 => 0.31,
}
}
}
struct ModelFiles {
model_onnx: PathBuf,
tokenizer_json: PathBuf,
}
fn hf_cache_root() -> PathBuf {
if let Ok(v) = std::env::var("HF_HOME") {
return PathBuf::from(v).join("hub");
}
let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"));
if let Ok(h) = home {
return PathBuf::from(h)
.join(".cache")
.join("huggingface")
.join("hub");
}
PathBuf::from(".mnem-hf-cache")
}
fn fetch_to_cache(repo: &str, revision: &str, file: &str) -> Result<PathBuf, EmbedError> {
let base = hf_cache_root();
let repo_slug = format!("models--{}", repo.replace('/', "--"));
let target = base
.join(&repo_slug)
.join("resolve")
.join(revision)
.join(file);
if target.is_file() {
if let Ok(md) = fs::metadata(&target) {
if md.len() > 0 {
return Ok(target);
}
}
}
if let Some(parent) = target.parent() {
fs::create_dir_all(parent).map_err(|e| {
EmbedError::Config(format!("create cache dir {}: {e}", parent.display()))
})?;
}
let url = format!("https://huggingface.co/{repo}/resolve/{revision}/{file}");
eprintln!("mnem-embed: downloading {} -> {}", url, target.display());
let resp = ureq::get(&url)
.call()
.map_err(|e| EmbedError::Config(format!("hf download {url}: {e}")))?;
let status = resp.status();
if status != 200 {
return Err(EmbedError::Config(format!(
"hf download {url}: status {status}"
)));
}
let tmp = target.with_extension("download-partial");
{
let mut reader = resp.into_reader();
let mut out = fs::File::create(&tmp)
.map_err(|e| EmbedError::Config(format!("create {}: {e}", tmp.display())))?;
io::copy(&mut reader, &mut out)
.map_err(|e| EmbedError::Config(format!("download {}: {e}", target.display())))?;
}
fs::rename(&tmp, &target).map_err(|e| {
EmbedError::Config(format!(
"rename {} -> {}: {e}",
tmp.display(),
target.display()
))
})?;
Ok(target)
}
fn fetch_files(kind: ModelKind) -> Result<ModelFiles, EmbedError> {
let revision = "main";
let model_onnx = fetch_to_cache(kind.repo_id(), revision, kind.onnx_path())?;
let tokenizer_json = fetch_to_cache(kind.repo_id(), revision, "tokenizer.json")?;
Ok(ModelFiles {
model_onnx,
tokenizer_json,
})
}
fn load_tokenizer(path: &Path, max_len: usize) -> Result<Tokenizer, EmbedError> {
let mut tok = Tokenizer::from_file(path)
.map_err(|e| EmbedError::Config(format!("tokenizer.json load: {e}")))?;
tok.with_truncation(Some(TruncationParams {
max_length: max_len,
..Default::default()
}))
.map_err(|e| EmbedError::Config(format!("tokenizer truncation: {e}")))?;
tok.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
..Default::default()
}));
Ok(tok)
}
const ENV_EMBED_MAX_LEN: &str = "MNEM_ONNX_EMBED_MAX_LEN";
fn resolve_max_length(kind: ModelKind, override_: Option<usize>) -> usize {
let ceiling = kind.positional_limit();
let requested = override_
.or_else(|| {
std::env::var(ENV_EMBED_MAX_LEN)
.ok()
.and_then(|s| s.parse::<usize>().ok())
})
.unwrap_or_else(|| kind.default_max_length());
if requested == 0 {
eprintln!(
"mnem-embed: requested max_length=0 for {}; snapping to default {}",
kind.wire_id(),
kind.default_max_length()
);
return kind.default_max_length();
}
if requested > ceiling {
eprintln!(
"mnem-embed: requested max_length={requested} exceeds {}'s positional limit {ceiling}; clamping",
kind.wire_id()
);
return ceiling;
}
requested
}
struct OnnxSession {
session: Session,
needs_token_type_ids: bool,
}
impl OnnxSession {
fn open(model_path: &Path) -> Result<Self, EmbedError> {
#[cfg(feature = "onnx-bundled")]
let default_threads: usize = std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(1);
#[cfg(not(feature = "onnx-bundled"))]
let default_threads: usize = 1;
let threads: usize = std::env::var("MNEM_ORT_INTRA_THREADS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(default_threads);
#[allow(unused_mut)]
let mut builder = Session::builder()
.map_err(|e| EmbedError::Config(format!("ort session builder: {e}")))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| EmbedError::Config(format!("ort opt level: {e}")))?
.with_intra_threads(threads)
.map_err(|e| EmbedError::Config(format!("ort intra threads: {e}")))?;
#[cfg(any(feature = "onnx-bundled-cuda", feature = "onnx-bundled-directml"))]
{
use ort::execution_providers::ExecutionProviderDispatch;
#[allow(unused_mut)]
let mut providers: Vec<ExecutionProviderDispatch> = Vec::new();
#[cfg(feature = "onnx-bundled-cuda")]
providers.push(ort::execution_providers::CUDAExecutionProvider::default().build());
#[cfg(feature = "onnx-bundled-directml")]
providers.push(ort::execution_providers::DirectMLExecutionProvider::default().build());
builder = builder
.with_execution_providers(providers)
.map_err(|e| EmbedError::Config(format!("ort execution providers: {e}")))?;
}
let session = builder
.commit_from_file(model_path)
.map_err(|e| EmbedError::Config(format!("ort commit {}: {e}", model_path.display())))?;
let needs_token_type_ids = session
.inputs()
.iter()
.any(|i| i.name() == "token_type_ids");
Ok(Self {
session,
needs_token_type_ids,
})
}
}
pub struct OnnxEmbedder {
kind: ModelKind,
tokenizer: Tokenizer,
session: Mutex<OnnxSession>,
model_fq: String,
max_len: usize,
}
static TOKENIZER_TRUNCATE_WARNED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
fn warn_truncation_once(provider: &str, model: &str, max_len: usize, positional_limit: usize) {
let key = format!("{provider}:{model}");
let set = TOKENIZER_TRUNCATE_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
let mut guard = match set.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
if guard.insert(key) {
eprintln!(
"{provider}: input filled max_length={max_len} on {model}; tail truncated. \
Raise via MNEM_ONNX_EMBED_MAX_LEN (<= {positional_limit}) or chunk upstream."
);
}
}
impl std::fmt::Debug for OnnxEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxEmbedder")
.field("kind", &self.kind)
.field("model_fq", &self.model_fq)
.field("dim", &self.kind.dim())
.field("max_len", &self.max_len)
.finish()
}
}
static ORT_INIT: OnceLock<()> = OnceLock::new();
fn ensure_ort_init() {
ORT_INIT.get_or_init(|| {
});
}
impl OnnxEmbedder {
pub fn new(kind: ModelKind) -> Result<Self, EmbedError> {
Self::with_max_length(kind, None)
}
pub fn with_max_length(kind: ModelKind, max_length: Option<usize>) -> Result<Self, EmbedError> {
ensure_ort_init();
let max_len = resolve_max_length(kind, max_length);
let files = fetch_files(kind)?;
let tokenizer = load_tokenizer(&files.tokenizer_json, max_len)?;
let session = OnnxSession::open(&files.model_onnx)?;
Ok(Self {
kind,
tokenizer,
session: Mutex::new(session),
model_fq: kind.wire_id().to_string(),
max_len,
})
}
#[must_use]
pub fn max_length(&self) -> usize {
self.max_len
}
fn forward_single(&self, text: &str) -> Result<Vec<f32>, EmbedError> {
let encoded = self
.tokenizer
.encode(text, true)
.map_err(|e| EmbedError::Decode(format!("tokenize: {e}")))?;
let seq_len = encoded.get_ids().len();
if seq_len >= self.max_len {
warn_truncation_once(
"mnem-embed",
self.kind.wire_id(),
self.max_len,
self.kind.positional_limit(),
);
}
let ids: Vec<i64> = encoded.get_ids().iter().map(|&x| i64::from(x)).collect();
let mask: Vec<i64> = encoded
.get_attention_mask()
.iter()
.map(|&x| i64::from(x))
.collect();
let mask_for_pool: Vec<f32> = mask.iter().map(|&x| x as f32).collect();
let ids_arr = Array2::from_shape_vec((1, seq_len), ids)
.map_err(|e| EmbedError::Decode(format!("ids reshape: {e}")))?;
let mask_arr = Array2::from_shape_vec((1, seq_len), mask)
.map_err(|e| EmbedError::Decode(format!("mask reshape: {e}")))?;
let mut session = self
.session
.lock()
.map_err(|_| EmbedError::Decode("session mutex poisoned".into()))?;
let mut inputs: Vec<(&'static str, Value)> = Vec::with_capacity(3);
inputs.push((
"input_ids",
Value::from_array(ids_arr)
.map_err(|e| EmbedError::Decode(format!("ids tensor: {e}")))?
.into_dyn(),
));
inputs.push((
"attention_mask",
Value::from_array(mask_arr)
.map_err(|e| EmbedError::Decode(format!("mask tensor: {e}")))?
.into_dyn(),
));
if session.needs_token_type_ids {
let type_arr: Array2<i64> = Array2::zeros((1, seq_len));
inputs.push((
"token_type_ids",
Value::from_array(type_arr)
.map_err(|e| EmbedError::Decode(format!("type_ids tensor: {e}")))?
.into_dyn(),
));
}
let outputs = session
.session
.run(inputs)
.map_err(|e| EmbedError::Decode(format!("ort run: {e}")))?;
let value = outputs
.iter()
.find(|(name, _)| *name == "last_hidden_state" || *name == "token_embeddings")
.map(|(_, v)| v)
.or_else(|| outputs.iter().next().map(|(_, v)| v))
.ok_or_else(|| EmbedError::Decode("no hidden-state output".into()))?;
let view: ArrayViewD<'_, f32> = value
.try_extract_array::<f32>()
.map_err(|e| EmbedError::Decode(format!("extract hidden state: {e}")))?;
let shape = view.shape().to_vec();
if shape.len() != 3 || shape[0] != 1 {
return Err(EmbedError::Decode(format!(
"expected (1, seq, hidden) hidden state, got {shape:?}"
)));
}
let seq = shape[1];
let hidden = shape[2];
let buffer: Vec<f32> = view.iter().copied().collect();
drop(outputs);
drop(session);
let mut pooled = vec![0.0_f32; hidden];
let mut denom = 0.0_f32;
for s in 0..seq {
let m = mask_for_pool.get(s).copied().unwrap_or(0.0);
if m == 0.0 {
continue;
}
denom += m;
let row = &buffer[s * hidden..(s + 1) * hidden];
for (i, v) in row.iter().enumerate() {
pooled[i] += m * v;
}
}
if denom > 0.0 {
let inv = 1.0_f32 / denom;
for v in &mut pooled {
*v *= inv;
}
}
let norm: f32 = pooled.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
let inv = 1.0_f32 / norm;
for v in &mut pooled {
*v *= inv;
}
}
let expected = self.kind.dim() as usize;
if pooled.len() != expected {
return Err(EmbedError::DimMismatch {
expected: self.kind.dim(),
got: u32::try_from(pooled.len()).unwrap_or(u32::MAX),
});
}
Ok(pooled)
}
}
impl Embedder for OnnxEmbedder {
fn model(&self) -> &str {
&self.model_fq
}
fn dim(&self) -> u32 {
self.kind.dim()
}
fn manifest(&self) -> EmbedderManifest {
EmbedderManifest::new(
self.model_fq.clone(),
self.kind.dim(),
self.kind.noise_floor(),
)
}
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedError> {
self.forward_single(text)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError> {
if texts.is_empty() {
return Ok(Vec::new());
}
if texts.len() == 1 {
return Ok(vec![self.forward_single(texts[0])?]);
}
self.forward_batch(texts)
}
}
impl OnnxEmbedder {
fn forward_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError> {
let inputs_vec: Vec<tokenizers::EncodeInput<'_>> = texts
.iter()
.map(|t| tokenizers::EncodeInput::Single((*t).into()))
.collect();
let encoded = self
.tokenizer
.encode_batch(inputs_vec, true)
.map_err(|e| EmbedError::Decode(format!("tokenize batch: {e}")))?;
let batch = encoded.len();
let seq_len = encoded.first().map(|e| e.get_ids().len()).unwrap_or(0);
if seq_len == 0 {
return texts.iter().map(|t| self.forward_single(t)).collect();
}
if encoded.iter().any(|e| e.get_ids().len() != seq_len) {
return texts.iter().map(|t| self.forward_single(t)).collect();
}
if seq_len >= self.max_len {
warn_truncation_once(
"mnem-embed",
self.kind.wire_id(),
self.max_len,
self.kind.positional_limit(),
);
}
let total = batch * seq_len;
let mut ids: Vec<i64> = Vec::with_capacity(total);
let mut mask: Vec<i64> = Vec::with_capacity(total);
for e in &encoded {
ids.extend(e.get_ids().iter().map(|&x| i64::from(x)));
mask.extend(e.get_attention_mask().iter().map(|&x| i64::from(x)));
}
let mask_for_pool: Vec<f32> = mask.iter().map(|&x| x as f32).collect();
let ids_arr = Array2::from_shape_vec((batch, seq_len), ids)
.map_err(|e| EmbedError::Decode(format!("ids reshape: {e}")))?;
let mask_arr = Array2::from_shape_vec((batch, seq_len), mask)
.map_err(|e| EmbedError::Decode(format!("mask reshape: {e}")))?;
let mut session = self
.session
.lock()
.map_err(|_| EmbedError::Decode("session mutex poisoned".into()))?;
let mut inputs: Vec<(&'static str, Value)> = Vec::with_capacity(3);
inputs.push((
"input_ids",
Value::from_array(ids_arr)
.map_err(|e| EmbedError::Decode(format!("ids tensor: {e}")))?
.into_dyn(),
));
inputs.push((
"attention_mask",
Value::from_array(mask_arr)
.map_err(|e| EmbedError::Decode(format!("mask tensor: {e}")))?
.into_dyn(),
));
if session.needs_token_type_ids {
let type_arr: Array2<i64> = Array2::zeros((batch, seq_len));
inputs.push((
"token_type_ids",
Value::from_array(type_arr)
.map_err(|e| EmbedError::Decode(format!("type_ids tensor: {e}")))?
.into_dyn(),
));
}
let outputs = session
.session
.run(inputs)
.map_err(|e| EmbedError::Decode(format!("ort run: {e}")))?;
let value = outputs
.iter()
.find(|(name, _)| *name == "last_hidden_state" || *name == "token_embeddings")
.map(|(_, v)| v)
.or_else(|| outputs.iter().next().map(|(_, v)| v))
.ok_or_else(|| EmbedError::Decode("no hidden-state output".into()))?;
let view: ArrayViewD<'_, f32> = value
.try_extract_array::<f32>()
.map_err(|e| EmbedError::Decode(format!("extract hidden state: {e}")))?;
let shape = view.shape().to_vec();
if shape.len() != 3 || shape[0] != batch || shape[1] != seq_len {
return Err(EmbedError::Decode(format!(
"expected ({batch}, {seq_len}, hidden) hidden state, got {shape:?}"
)));
}
let hidden = shape[2];
let buffer: Vec<f32> = view.iter().copied().collect();
drop(outputs);
drop(session);
let expected = self.kind.dim() as usize;
let mut out: Vec<Vec<f32>> = Vec::with_capacity(batch);
let row_stride = seq_len * hidden;
for b in 0..batch {
let mut pooled = vec![0.0_f32; hidden];
let mut denom = 0.0_f32;
let row_base = b * row_stride;
let mask_base = b * seq_len;
for s in 0..seq_len {
let m = mask_for_pool[mask_base + s];
if m == 0.0 {
continue;
}
denom += m;
let tok_base = row_base + s * hidden;
let row = &buffer[tok_base..tok_base + hidden];
for (i, v) in row.iter().enumerate() {
pooled[i] += m * v;
}
}
if denom > 0.0 {
let inv = 1.0_f32 / denom;
for v in &mut pooled {
*v *= inv;
}
}
let norm: f32 = pooled.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 0.0 {
let inv = 1.0_f32 / norm;
for v in &mut pooled {
*v *= inv;
}
}
if pooled.len() != expected {
return Err(EmbedError::DimMismatch {
expected: self.kind.dim(),
got: u32::try_from(pooled.len()).unwrap_or(u32::MAX),
});
}
out.push(pooled);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_max_length_uses_default_when_none() {
let n = resolve_max_length(ModelKind::BgeLargeEnV15, None);
assert_eq!(n, 512);
}
#[test]
fn resolve_max_length_clamps_above_positional_limit() {
let n = resolve_max_length(ModelKind::BgeLargeEnV15, Some(8192));
assert_eq!(n, 512);
}
#[test]
fn resolve_max_length_zero_snaps_to_default() {
let n = resolve_max_length(ModelKind::BgeBaseEnV15, Some(0));
assert_eq!(n, 512);
}
#[test]
fn dims_match_published_sizes() {
assert_eq!(ModelKind::BgeLargeEnV15.dim(), 1024);
assert_eq!(ModelKind::BgeBaseEnV15.dim(), 768);
assert_eq!(ModelKind::BgeSmallEnV15.dim(), 384);
assert_eq!(ModelKind::AllMiniLmL6V2.dim(), 384);
}
#[test]
fn wire_ids_are_stable_and_namespaced() {
assert_eq!(ModelKind::BgeLargeEnV15.wire_id(), "onnx:bge-large-en-v1.5");
assert_eq!(ModelKind::BgeBaseEnV15.wire_id(), "onnx:bge-base-en-v1.5");
assert_eq!(ModelKind::BgeSmallEnV15.wire_id(), "onnx:bge-small-en-v1.5");
assert_eq!(ModelKind::AllMiniLmL6V2.wire_id(), "onnx:all-MiniLM-L6-v2");
}
#[test]
fn minilm_default_max_length_is_256() {
assert_eq!(ModelKind::AllMiniLmL6V2.default_max_length(), 256);
}
}