use super::{
normalize_vector, EmbeddingCache, EmbeddingConfig, EmbeddingProvider, EmbeddingResult,
};
use crate::{Error, Result};
use async_trait::async_trait;
use ort::{
session::{builder::GraphOptimizationLevel, Session},
value::Tensor,
};
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use tokenizers::Tokenizer;
pub struct LocalONNXEmbedding {
config: EmbeddingConfig,
session: Mutex<Session>,
tokenizer: Tokenizer,
cache: Option<Arc<EmbeddingCache>>,
}
impl LocalONNXEmbedding {
pub fn new(
model_path: impl Into<PathBuf>,
tokenizer_path: impl Into<PathBuf>,
config: EmbeddingConfig,
) -> Result<Self> {
let model_path = model_path.into();
let tokenizer_path = tokenizer_path.into();
let session = Session::builder()
.map_err(|e| Error::embedding(format!("Failed to create session builder: {}", e)))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| Error::embedding(format!("Failed to set optimization level: {}", e)))?
.with_intra_threads(4)
.map_err(|e| Error::embedding(format!("Failed to set intra threads: {}", e)))?
.commit_from_file(&model_path)
.map_err(|e| {
Error::embedding(format!(
"Failed to load ONNX model from {:?}: {}",
model_path, e
))
})?;
let tokenizer = Tokenizer::from_file(&tokenizer_path).map_err(|e| {
Error::embedding(format!(
"Failed to load tokenizer from {:?}: {}",
tokenizer_path, e
))
})?;
let cache = if config.enable_cache {
Some(Arc::new(EmbeddingCache::new(10000, config.cache_ttl_secs)))
} else {
None
};
Ok(Self {
config,
session: Mutex::new(session),
tokenizer,
cache,
})
}
pub fn bge_m3(models_dir: impl Into<PathBuf>) -> Result<Self> {
let models_dir = models_dir.into();
Self::new(
models_dir.join("bge-m3.onnx"),
models_dir.join("bge-m3-tokenizer.json"),
EmbeddingConfig::bge_m3(),
)
}
pub fn e5_small(models_dir: impl Into<PathBuf>) -> Result<Self> {
let models_dir = models_dir.into();
Self::new(
models_dir.join("e5-small-v2.onnx"),
models_dir.join("e5-small-v2-tokenizer.json"),
EmbeddingConfig::e5_small(),
)
}
fn cache_key(&self, text: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(self.config.model.as_bytes());
hasher.update(b":");
hasher.update(text.as_bytes());
format!("{:x}", hasher.finalize())
}
fn infer(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| Error::embedding(format!("Tokenization failed: {}", e)))?;
let batch_size = texts.len();
let max_len = encodings.iter().map(|e| e.len()).max().unwrap_or(0);
let mut input_ids = vec![0i64; batch_size * max_len];
let mut attention_mask = vec![0i64; batch_size * max_len];
for (i, encoding) in encodings.iter().enumerate() {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
for (j, &id) in ids.iter().enumerate() {
input_ids[i * max_len + j] = id as i64;
attention_mask[i * max_len + j] = mask[j] as i64;
}
}
let input_ids_array = ndarray::Array2::from_shape_vec((batch_size, max_len), input_ids)
.map_err(|e| Error::embedding(format!("Failed to create input_ids array: {}", e)))?;
let attention_mask_array =
ndarray::Array2::from_shape_vec((batch_size, max_len), attention_mask).map_err(
|e| Error::embedding(format!("Failed to create attention_mask array: {}", e)),
)?;
let input_ids_tensor = Tensor::from_array(input_ids_array)
.map_err(|e| Error::embedding(format!("Failed to create input_ids tensor: {}", e)))?;
let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
Error::embedding(format!("Failed to create attention_mask tensor: {}", e))
})?;
let mut session = self
.session
.lock()
.map_err(|_| Error::embedding("Failed to lock ONNX session"))?;
let outputs = session
.run(ort::inputs![input_ids_tensor, attention_mask_tensor])
.map_err(|e| Error::embedding(format!("ONNX inference failed: {}", e)))?;
let embeddings_tensor = outputs
.get("last_hidden_state")
.or_else(|| outputs.get("output"))
.or_else(|| outputs.get("embeddings"))
.or_else(|| outputs.get("sentence_embedding"))
.or_else(|| {
if outputs.len() > 0 {
Some(&outputs[0])
} else {
None
}
})
.ok_or_else(|| Error::embedding("No output from ONNX model"))?;
let embeddings_array = embeddings_tensor
.try_extract_array::<f32>()
.map_err(|e| Error::embedding(format!("Failed to extract embeddings: {}", e)))?;
let mut results = Vec::with_capacity(batch_size);
match embeddings_array.ndim() {
2 => {
if embeddings_array.shape()[0] != batch_size {
return Err(Error::embedding(format!(
"Unexpected embedding batch size: expected {}, got {}",
batch_size,
embeddings_array.shape()[0]
)));
}
for i in 0..batch_size {
let embedding = embeddings_array.slice(ndarray::s![i, ..]).to_vec();
let embedding = if self.config.normalize {
normalize_vector(&embedding)
} else {
embedding
};
results.push(embedding);
}
}
3 => {
if embeddings_array.shape()[0] != batch_size {
return Err(Error::embedding(format!(
"Unexpected embedding batch size: expected {}, got {}",
batch_size,
embeddings_array.shape()[0]
)));
}
let token_embeddings: ndarray::ArrayView3<'_, f32> =
embeddings_array.into_dimensionality().map_err(|e| {
Error::embedding(format!("Wrong embedding tensor shape: {}", e))
})?;
for i in 0..batch_size {
let tokens = token_embeddings.slice(ndarray::s![i, .., ..]); let pooled = tokens.mean_axis(ndarray::Axis(0)).ok_or_else(|| {
Error::embedding("Failed to pool embeddings: empty sequence")
})?;
let embedding = pooled.to_vec();
let embedding = if self.config.normalize {
normalize_vector(&embedding)
} else {
embedding
};
results.push(embedding);
}
}
other => {
return Err(Error::embedding(format!(
"Unexpected embedding tensor dimensionality: {}",
other
)));
}
}
Ok(results)
}
}
#[async_trait]
impl EmbeddingProvider for LocalONNXEmbedding {
fn dimension(&self) -> usize {
self.config.dimension
}
fn model_name(&self) -> &str {
&self.config.model
}
async fn embed(&self, text: &str) -> Result<EmbeddingResult> {
if let Some(ref cache) = self.cache {
let key = self.cache_key(text);
if let Some(cached) = cache.get(&key) {
return Ok(EmbeddingResult {
dense: Some(cached),
sparse: None,
token_count: text.split_whitespace().count(),
});
}
}
let embeddings = self.infer(&[text])?;
let embedding = embeddings
.into_iter()
.next()
.ok_or_else(|| Error::embedding("No embedding returned"))?;
if let Some(ref cache) = self.cache {
let key = self.cache_key(text);
cache.put(key, embedding.clone());
}
Ok(EmbeddingResult {
dense: Some(embedding),
sparse: None,
token_count: text.split_whitespace().count(),
})
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<EmbeddingResult>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(texts.len());
let mut uncached_indices = Vec::new();
let mut uncached_texts = Vec::new();
if let Some(ref cache) = self.cache {
for (i, text) in texts.iter().enumerate() {
let key = self.cache_key(text);
if let Some(cached) = cache.get(&key) {
results.push(EmbeddingResult {
dense: Some(cached),
sparse: None,
token_count: text.split_whitespace().count(),
});
} else {
uncached_indices.push(i);
uncached_texts.push(*text);
}
}
} else {
uncached_indices.extend(0..texts.len());
uncached_texts.extend(texts.iter());
}
if uncached_texts.is_empty() {
return Ok(results);
}
let embeddings = self.infer(&uncached_texts)?;
let mut new_results = Vec::with_capacity(uncached_texts.len());
for (i, embedding) in embeddings.into_iter().enumerate() {
if let Some(ref cache) = self.cache {
let key = self.cache_key(uncached_texts[i]);
cache.put(key, embedding.clone());
}
new_results.push(EmbeddingResult {
dense: Some(embedding),
sparse: None,
token_count: uncached_texts[i].split_whitespace().count(),
});
}
if self.cache.is_some() {
let mut final_results = Vec::with_capacity(texts.len());
let mut new_idx = 0;
for i in 0..texts.len() {
if uncached_indices.contains(&i) {
final_results.push(new_results[new_idx].clone());
new_idx += 1;
} else {
final_results.push(results.remove(0));
}
}
Ok(final_results)
} else {
Ok(new_results)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_onnx_provider_creation() {
let result = LocalONNXEmbedding::new(
PathBuf::from("models/bge-m3.onnx"),
PathBuf::from("models/bge-m3-tokenizer.json"),
EmbeddingConfig::bge_m3(),
);
assert!(result.is_err());
}
}