use crate::embedding::EmbeddingProvider;
use crate::error::{Error, Result};
#[cfg(feature = "onnx")]
mod inner {
use super::*;
use ndarray::Array2;
use ort::Session;
use std::path::Path;
use std::sync::Arc;
use tokenizers::Tokenizer;
pub struct OnnxEmbedding {
dimensions: usize,
model_path: String,
session: Arc<Session>,
tokenizer: Arc<Tokenizer>,
}
impl std::fmt::Debug for OnnxEmbedding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxEmbedding")
.field("dimensions", &self.dimensions)
.field("model_path", &self.model_path)
.finish_non_exhaustive()
}
}
impl OnnxEmbedding {
pub fn new(model_path: &str, dimensions: usize) -> Result<Self> {
let model = Path::new(model_path);
if !model.exists() {
return Err(Error::Validation(format!(
"ONNX model not found at: {model_path}"
)));
}
let tokenizer_path = model
.parent()
.map(|p| p.join("tokenizer.json"))
.unwrap_or_else(|| Path::new("tokenizer.json").to_path_buf());
if !tokenizer_path.exists() {
return Err(Error::Embedding(format!(
"tokenizer.json not found next to ONNX model (expected at {})",
tokenizer_path.display()
)));
}
let session = Session::builder()
.map_err(|e| {
Error::Embedding(format!("failed to create ONNX session builder: {e}"))
})?
.with_intra_threads(4)
.map_err(|e| Error::Embedding(format!("failed to set intra threads: {e}")))?
.commit_from_file(model_path)
.map_err(|e| Error::Embedding(format!("failed to load ONNX model: {e}")))?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| Error::Embedding(format!("failed to load tokenizer: {e}")))?;
Ok(Self {
dimensions,
model_path: model_path.to_string(),
session: Arc::new(session),
tokenizer: Arc::new(tokenizer),
})
}
#[must_use]
pub fn model_path(&self) -> &str {
&self.model_path
}
fn tokenize_batch(
tokenizer: &Tokenizer,
texts: &[&str],
) -> Result<(Array2<i64>, Array2<i64>, Array2<i64>)> {
let encodings = tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| Error::Embedding(format!("tokenization failed: {e}")))?;
let batch_size = encodings.len();
let max_len = encodings
.iter()
.map(|e| e.get_ids().len())
.max()
.unwrap_or(0);
let mut input_ids = Array2::<i64>::zeros((batch_size, max_len));
let mut attention_mask = Array2::<i64>::zeros((batch_size, max_len));
let mut token_type_ids = Array2::<i64>::zeros((batch_size, max_len));
for (i, enc) in encodings.iter().enumerate() {
for (j, &id) in enc.get_ids().iter().enumerate() {
input_ids[[i, j]] = i64::from(id);
}
for (j, &mask) in enc.get_attention_mask().iter().enumerate() {
attention_mask[[i, j]] = i64::from(mask);
}
for (j, &tid) in enc.get_type_ids().iter().enumerate() {
token_type_ids[[i, j]] = i64::from(tid);
}
}
Ok((input_ids, attention_mask, token_type_ids))
}
fn mean_pool_and_normalize(
hidden: &Array2<f32>,
mask: &Array2<i64>,
batch_size: usize,
seq_len: usize,
hidden_dim: usize,
) -> Vec<Vec<f32>> {
let _ = seq_len;
let mut results = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let mut pooled = vec![0.0f32; hidden_dim];
let mut count = 0.0f32;
for j in 0..seq_len {
let m = mask[[i, j]] as f32;
if m > 0.0 {
for k in 0..hidden_dim {
pooled[k] += hidden[[i * seq_len + j, k]] * m;
}
count += m;
}
}
if count > 0.0 {
for v in &mut pooled {
*v /= count;
}
}
let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut pooled {
*v /= norm;
}
}
results.push(pooled);
}
results
}
async fn run_inference(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let session = Arc::clone(&self.session);
let tokenizer = Arc::clone(&self.tokenizer);
let dims = self.dimensions;
let owned_texts: Vec<String> = texts.iter().map(|t| (*t).to_string()).collect();
let result = tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
let text_refs: Vec<&str> = owned_texts.iter().map(String::as_str).collect();
let (input_ids, attention_mask, token_type_ids) =
Self::tokenize_batch(&tokenizer, &text_refs)?;
let batch_size = input_ids.nrows();
let seq_len = input_ids.ncols();
let outputs = session
.run(ort::inputs![
"input_ids" => input_ids.view(),
"attention_mask" => attention_mask.view(),
"token_type_ids" => token_type_ids.view(),
].map_err(|e| Error::Embedding(format!("failed to create inputs: {e}")))?)
.map_err(|e| Error::Embedding(format!("ONNX inference failed: {e}")))?;
let output_tensor = outputs
.get("last_hidden_state")
.or_else(|| outputs.iter().next().map(|(_, v)| v))
.ok_or_else(|| Error::Embedding("no output tensor from ONNX model".to_string()))?;
let output_array = output_tensor
.try_extract_tensor::<f32>()
.map_err(|e| Error::Embedding(format!("failed to extract output tensor: {e}")))?;
let shape = output_array.shape();
if shape.len() == 3 {
let hidden_dim = shape[2];
if hidden_dim != dims {
return Err(Error::Embedding(format!(
"model hidden dim ({hidden_dim}) does not match configured dimensions ({dims})"
)));
}
let flat = output_array
.to_shape((batch_size * seq_len, hidden_dim))
.map_err(|e| Error::Embedding(format!("reshape failed: {e}")))?;
let flat_owned: Array2<f32> = flat.to_owned();
Ok(Self::mean_pool_and_normalize(
&flat_owned,
&attention_mask,
batch_size,
seq_len,
hidden_dim,
))
} else if shape.len() == 2 {
let hidden_dim = shape[1];
if hidden_dim != dims {
return Err(Error::Embedding(format!(
"model hidden dim ({hidden_dim}) does not match configured dimensions ({dims})"
)));
}
let mut results = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let mut vec = Vec::with_capacity(hidden_dim);
for j in 0..hidden_dim {
vec.push(output_array[[i, j]]);
}
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut vec {
*v /= norm;
}
}
results.push(vec);
}
Ok(results)
} else {
Err(Error::Embedding(format!(
"unexpected output tensor shape: {shape:?}"
)))
}
})
.await
.map_err(|e| Error::Embedding(format!("inference task panicked: {e}")))?;
result
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for OnnxEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut results = self.run_inference(&[text]).await?;
results
.pop()
.ok_or_else(|| Error::Embedding("empty inference result".to_string()))
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.run_inference(texts).await
}
fn dimensions(&self) -> usize {
self.dimensions
}
}
}
#[cfg(not(feature = "onnx"))]
mod inner {
use super::*;
#[derive(Debug)]
pub struct OnnxEmbedding {
dimensions: usize,
model_path: String,
}
impl OnnxEmbedding {
pub fn new(model_path: &str, dimensions: usize) -> Result<Self> {
if !std::path::Path::new(model_path).exists() {
return Err(Error::Validation(format!(
"ONNX model not found at: {model_path}"
)));
}
Ok(Self {
dimensions,
model_path: model_path.to_string(),
})
}
#[must_use]
pub fn model_path(&self) -> &str {
&self.model_path
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for OnnxEmbedding {
async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Err(Error::Embedding(
"ONNX Runtime not available: compile with full onnx dependencies \
(ort, tokenizers, ndarray) to enable local inference"
.to_string(),
))
}
async fn embed_batch(&self, _texts: &[&str]) -> Result<Vec<Vec<f32>>> {
Err(Error::Embedding(
"ONNX Runtime not available: compile with full onnx dependencies \
(ort, tokenizers, ndarray) to enable local inference"
.to_string(),
))
}
fn dimensions(&self) -> usize {
self.dimensions
}
}
}
pub use inner::OnnxEmbedding;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_onnx_missing_model() {
let result = OnnxEmbedding::new("/nonexistent/path/model.onnx", 384);
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("ONNX model not found"),
"unexpected error message: {msg}"
);
}
#[test]
fn test_onnx_dimensions() {
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
#[cfg(not(feature = "onnx"))]
{
let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
assert_eq!(provider.dimensions(), 384);
}
#[cfg(feature = "onnx")]
{
let result = OnnxEmbedding::new(path, 384);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("tokenizer.json"),
"expected tokenizer.json error, got: {msg}"
);
}
}
#[test]
fn test_onnx_model_path() {
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
#[cfg(not(feature = "onnx"))]
{
let provider = OnnxEmbedding::new(path, 768).expect("file should exist");
assert_eq!(provider.model_path(), path);
}
#[cfg(feature = "onnx")]
{
let result = OnnxEmbedding::new(path, 768);
assert!(result.is_err());
}
}
#[cfg(not(feature = "onnx"))]
#[tokio::test]
async fn test_onnx_embed_returns_error_without_runtime() {
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
let result = provider.embed("hello world").await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("ONNX Runtime not available"),
"unexpected error: {msg}"
);
}
#[cfg(not(feature = "onnx"))]
#[tokio::test]
async fn test_onnx_embed_batch_returns_error_without_runtime() {
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
let result = provider.embed_batch(&["a", "b"]).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("ONNX Runtime not available"),
"unexpected error: {msg}"
);
}
}