use std::path::PathBuf;
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
use ort::value::Tensor;
use tokenizers::Tokenizer;
use crate::semantic_index::{format_embedding_init_error, pre_validate_onnx_runtime};
use crate::slog_info;
const MINILM_REPO: &str = "Qdrant/all-MiniLM-L6-v2-onnx";
const MINILM_MODEL_FILE: &str = "model.onnx";
const MINILM_TOKENIZER_FILE: &str = "tokenizer.json";
const MINILM_MAX_LENGTH: usize = 512;
const MAX_BATCH_ATTENTION_UNITS: usize = 4_000_000;
fn intra_thread_cap() -> usize {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1)
.div_ceil(2)
.max(1)
}
pub struct LocalEmbedder {
session: Session,
tokenizer: Tokenizer,
wants_token_type_ids: bool,
}
impl LocalEmbedder {
pub fn new(model: &str) -> Result<Self, String> {
match model {
"all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {}
other => {
return Err(format!(
"unsupported local embedding model '{other}'. Supported: all-MiniLM-L6-v2"
))
}
}
pre_validate_onnx_runtime()?;
let (model_path, tokenizer_path) = resolve_model_files()?;
let threads = intra_thread_cap();
let session = Session::builder()
.map_err(|e| format!("failed to create ONNX session builder: {e}"))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| format!("failed to set ONNX optimization level: {e}"))?
.with_intra_threads(threads)
.map_err(|e| format!("failed to set ONNX intra-op threads: {e}"))?
.commit_from_file(&model_path)
.map_err(format_embedding_init_error)?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| format!("failed to load tokenizer {}: {e}", tokenizer_path.display()))?;
tokenizer
.with_truncation(Some(tokenizers::TruncationParams {
max_length: MINILM_MAX_LENGTH,
..Default::default()
}))
.map_err(|e| format!("failed to set tokenizer truncation: {e}"))?;
let wants_token_type_ids = session
.inputs()
.iter()
.any(|input| input.name() == "token_type_ids");
slog_info!(
"local embedder ready: model=all-MiniLM-L6-v2 intra_threads={} token_type_ids={}",
threads,
wants_token_type_ids
);
Ok(Self {
session,
tokenizer,
wants_token_type_ids,
})
}
pub fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>, String> {
if texts.is_empty() {
return Ok(Vec::new());
}
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| format!("tokenize batch: {e}"))?;
let mut result = Vec::with_capacity(encodings.len());
let mut batch_start = 0usize;
let mut batch_max = 0usize;
for (i, enc) in encodings.iter().enumerate() {
let len = enc.get_ids().len().max(1);
let count = i - batch_start; let candidate_max = batch_max.max(len);
let cost = (count + 1)
.saturating_mul(candidate_max)
.saturating_mul(candidate_max);
if count > 0 && cost > MAX_BATCH_ATTENTION_UNITS {
let vecs = self.run_inference(&encodings[batch_start..i])?;
result.extend(vecs);
batch_start = i;
batch_max = len;
} else {
batch_max = candidate_max;
}
}
let vecs = self.run_inference(&encodings[batch_start..])?;
result.extend(vecs);
Ok(result)
}
fn run_inference(
&mut self,
encodings: &[tokenizers::Encoding],
) -> Result<Vec<Vec<f32>>, String> {
if encodings.is_empty() {
return Ok(Vec::new());
}
let batch = encodings.len();
let max_len = encodings
.iter()
.map(|e| e.get_ids().len())
.max()
.unwrap_or(1)
.max(1);
let mut ids = vec![0i64; batch * max_len];
let mut mask = vec![0i64; batch * max_len];
for (row, enc) in encodings.iter().enumerate() {
let row_ids = enc.get_ids();
let row_mask = enc.get_attention_mask();
let base = row * max_len;
for col in 0..row_ids.len() {
ids[base + col] = row_ids[col] as i64;
mask[base + col] = row_mask[col] as i64;
}
}
let input_ids = ndarray::Array2::<i64>::from_shape_vec((batch, max_len), ids)
.map_err(|e| format!("build input_ids tensor: {e}"))?;
let attention_mask = ndarray::Array2::<i64>::from_shape_vec((batch, max_len), mask)
.map_err(|e| format!("build attention_mask tensor: {e}"))?;
let mut inputs = ort::inputs![
"input_ids" => Tensor::from_array(input_ids).map_err(|e| format!("input_ids: {e}"))?,
"attention_mask" => Tensor::from_array(attention_mask.clone())
.map_err(|e| format!("attention_mask: {e}"))?,
];
if self.wants_token_type_ids {
let token_type_ids = ndarray::Array2::<i64>::zeros((batch, max_len));
inputs.push((
"token_type_ids".into(),
Tensor::from_array(token_type_ids)
.map_err(|e| format!("token_type_ids: {e}"))?
.into(),
));
}
let outputs = self
.session
.run(inputs)
.map_err(|e| format!("ONNX inference failed: {e}"))?;
let output = outputs
.values()
.next()
.ok_or_else(|| "ONNX model produced no output".to_string())?;
let (shape, data): (Vec<i64>, Vec<f32>) = match output.try_extract_tensor::<f32>() {
Ok((s, d)) => (s.to_vec(), d.to_vec()),
Err(_) => {
let (s, d) = output
.try_extract_tensor::<half::f16>()
.map_err(|e| format!("extract output tensor: {e}"))?;
(s.to_vec(), d.iter().map(|h| h.to_f32()).collect())
}
};
if shape.len() != 3 {
return Err(format!(
"unexpected ONNX output rank {} (expected 3: [batch, seq, dim])",
shape.len()
));
}
let seq = shape[1] as usize;
let dim = shape[2] as usize;
let mut result = Vec::with_capacity(batch);
for row in 0..batch {
let mut emb = vec![0.0f32; dim];
let mut valid = 0.0f32;
for col in 0..seq {
if mask_at(&attention_mask, row, col) == 1 {
valid += 1.0;
let base = (row * seq + col) * dim;
for (d, slot) in emb.iter_mut().enumerate() {
*slot += data[base + d];
}
}
}
let denom = if valid == 0.0 { 1.0 } else { valid };
for slot in &mut emb {
*slot /= denom;
}
let norm = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
for slot in &mut emb {
*slot /= norm + 1e-12;
}
result.push(emb);
}
Ok(result)
}
}
#[inline]
fn mask_at(mask: &ndarray::Array2<i64>, row: usize, col: usize) -> i64 {
mask[[row, col]]
}
fn resolve_model_files() -> Result<(PathBuf, PathBuf), String> {
let cache_dir = embedding_cache_dir();
if let Some(found) = scan_local_snapshot(&cache_dir) {
return Ok(found);
}
download_via_hf_hub(&cache_dir)
}
fn embedding_cache_dir() -> PathBuf {
if let Some(dir) = std::env::var_os("FASTEMBED_CACHE_DIR") {
return PathBuf::from(dir);
}
let home = std::env::var_os("HOME")
.or_else(|| std::env::var_os("USERPROFILE"))
.map(PathBuf::from)
.unwrap_or_else(std::env::temp_dir);
home.join(".cache").join("fastembed")
}
fn scan_local_snapshot(cache_dir: &std::path::Path) -> Option<(PathBuf, PathBuf)> {
let repo_dir = cache_dir.join("models--Qdrant--all-MiniLM-L6-v2-onnx");
let snapshots = repo_dir.join("snapshots");
let mut candidates: Vec<PathBuf> = std::fs::read_dir(&snapshots)
.ok()?
.filter_map(|entry| entry.ok().map(|e| e.path()))
.filter(|p| p.is_dir())
.collect();
candidates.sort_by_key(|p| {
std::fs::metadata(p)
.and_then(|m| m.modified())
.unwrap_or(std::time::UNIX_EPOCH)
});
candidates.reverse();
for snap in candidates {
let model = snap.join(MINILM_MODEL_FILE);
let tokenizer = snap.join(MINILM_TOKENIZER_FILE);
if model.is_file() && tokenizer.is_file() {
return Some((model, tokenizer));
}
}
None
}
fn download_via_hf_hub(cache_dir: &std::path::Path) -> Result<(PathBuf, PathBuf), String> {
use hf_hub::api::sync::ApiBuilder;
slog_info!(
"downloading all-MiniLM-L6-v2 ({}) to {}",
MINILM_REPO,
cache_dir.display()
);
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(cache_dir.to_path_buf())
.build()
.map_err(|e| format!("failed to init hf-hub api: {e}"))?;
let repo = api.model(MINILM_REPO.to_string());
let model = repo
.get(MINILM_MODEL_FILE)
.map_err(|e| format!("failed to download {MINILM_MODEL_FILE}: {e}"))?;
let tokenizer = repo
.get(MINILM_TOKENIZER_FILE)
.map_err(|e| format!("failed to download {MINILM_TOKENIZER_FILE}: {e}"))?;
Ok((model, tokenizer))
}