use std::path::{Path, PathBuf};
use std::sync::Mutex;
use ort::session::Session;
use ort::session::builder::GraphOptimizationLevel;
use ort::value::TensorRef;
use crate::embedder::traits::Embedder;
use crate::error::EmbedderError;
const MODEL_URL: &str = "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model_quantized.onnx";
const MODEL_FILENAME: &str = "all-MiniLM-L6-v2-quantized.onnx";
const TOKENIZER_URL: &str =
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json";
const TOKENIZER_FILENAME: &str = "tokenizer.json";
const EMBEDDING_DIM: usize = 384;
const MAX_SEQ_LENGTH: usize = 256;
pub struct OnnxEmbedder {
session: Mutex<Session>,
tokenizer: tokenizers::Tokenizer,
model_dir: PathBuf,
}
impl OnnxEmbedder {
pub fn new(model_dir: &Path) -> Result<Self, EmbedderError> {
std::fs::create_dir_all(model_dir).map_err(EmbedderError::Io)?;
let model_path = model_dir.join(MODEL_FILENAME);
if !model_path.exists() {
tracing::info!("Downloading ONNX model to {}...", model_path.display());
download_file(MODEL_URL, &model_path)?;
tracing::info!("Model downloaded successfully.");
}
let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
if !tokenizer_path.exists() {
tracing::info!("Downloading tokenizer...");
download_file(TOKENIZER_URL, &tokenizer_path)?;
tracing::info!("Tokenizer downloaded successfully.");
}
let session = Session::builder()
.map_err(|e| EmbedderError::OnnxError(e.to_string()))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.unwrap_or_else(|e| e.recover())
.with_intra_threads(4)
.unwrap_or_else(|e| e.recover())
.commit_from_file(&model_path)
.map_err(|e| EmbedderError::OnnxError(format!("Failed to load model: {}", e)))?;
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| EmbedderError::OnnxError(format!("Failed to load tokenizer: {}", e)))?;
Ok(Self {
session: Mutex::new(session),
tokenizer,
model_dir: model_dir.to_path_buf(),
})
}
pub fn model_dir(&self) -> &Path {
&self.model_dir
}
fn tokenize(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
let encoding = self.tokenizer.encode(text, true).unwrap_or_else(|_| {
self.tokenizer.encode("", true).unwrap()
});
let mut input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let mut attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&m| m as i64)
.collect();
if input_ids.len() > MAX_SEQ_LENGTH {
input_ids.truncate(MAX_SEQ_LENGTH);
attention_mask.truncate(MAX_SEQ_LENGTH);
if let Some(last) = input_ids.last_mut() {
*last = 102;
}
}
while input_ids.len() < MAX_SEQ_LENGTH {
input_ids.push(0);
attention_mask.push(0);
}
(input_ids, attention_mask)
}
fn run_inference(
&self,
input_ids: &[i64],
attention_mask: &[i64],
) -> Result<Vec<f32>, EmbedderError> {
let seq_len = input_ids.len();
let input_ids_array = ndarray::Array2::from_shape_vec((1, seq_len), input_ids.to_vec())
.map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
let attention_mask_array =
ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.to_vec())
.map_err(|e| EmbedderError::OnnxError(format!("Shape error: {}", e)))?;
let token_type_ids_array = ndarray::Array2::<i64>::zeros((1, seq_len));
let input_ids_tensor = TensorRef::from_array_view(&input_ids_array)
.map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
let attention_mask_tensor = TensorRef::from_array_view(&attention_mask_array)
.map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
let token_type_ids_tensor = TensorRef::from_array_view(&token_type_ids_array)
.map_err(|e| EmbedderError::OnnxError(format!("Tensor creation error: {}", e)))?;
let mut session = self
.session
.lock()
.map_err(|e| EmbedderError::OnnxError(format!("Session lock poisoned: {}", e)))?;
let outputs = session
.run(ort::inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor
])
.map_err(|e| EmbedderError::OnnxError(format!("Inference error: {}", e)))?;
let output = if outputs.contains_key("last_hidden_state") {
&outputs["last_hidden_state"]
} else if outputs.contains_key("token_embeddings") {
&outputs["token_embeddings"]
} else {
&outputs[0]
};
let tensor = output
.try_extract_array::<f32>()
.map_err(|e| EmbedderError::OnnxError(format!("Extract error: {}", e)))?;
let shape = tensor.shape();
if shape.len() != 3 {
return Err(EmbedderError::OnnxError(format!(
"Unexpected output shape: {:?}",
shape
)));
}
let hidden_size = shape[2];
let mut pooled = vec![0.0f32; hidden_size];
let active_tokens: f32 = attention_mask.iter().map(|&m| m as f32).sum();
if active_tokens > 0.0 {
for seq_idx in 0..shape[1] {
let mask = attention_mask.get(seq_idx).copied().unwrap_or(0) as f32;
if mask > 0.0 {
for dim in 0..hidden_size {
pooled[dim] += tensor[[0, seq_idx, dim]];
}
}
}
for val in &mut pooled {
*val /= active_tokens;
}
}
let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut pooled {
*x /= norm;
}
}
Ok(pooled)
}
}
impl Embedder for OnnxEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
let (input_ids, attention_mask) = self.tokenize(text);
self.run_inference(&input_ids, &attention_mask)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
texts.iter().map(|text| self.embed(text)).collect()
}
fn dimension(&self) -> usize {
EMBEDDING_DIM
}
}
fn download_file(url: &str, dest: &Path) -> Result<(), EmbedderError> {
let response = reqwest::blocking::get(url)
.map_err(|e| EmbedderError::DownloadFailed(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
return Err(EmbedderError::DownloadFailed(format!(
"HTTP {} for {}",
response.status(),
url
)));
}
let bytes = response
.bytes()
.map_err(|e| EmbedderError::DownloadFailed(format!("Failed to read response: {}", e)))?;
if bytes.is_empty() {
return Err(EmbedderError::DownloadFailed(
"Downloaded file is empty".to_string(),
));
}
let tmp_path = dest.with_extension("tmp");
std::fs::write(&tmp_path, &bytes).map_err(EmbedderError::Io)?;
std::fs::rename(&tmp_path, dest).map_err(EmbedderError::Io)?;
tracing::info!("Downloaded {} bytes to {}", bytes.len(), dest.display());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenize_output_length() {
let model_dir = std::env::temp_dir().join("seekr_test_tokenizer");
if let Ok(embedder) = OnnxEmbedder::new(&model_dir) {
let (ids, mask) = embedder.tokenize("hello world");
assert_eq!(ids.len(), MAX_SEQ_LENGTH);
assert_eq!(mask.len(), MAX_SEQ_LENGTH);
assert_eq!(ids[0], 101);
let active: i64 = mask.iter().sum();
assert!(active > 0, "Should have at least some active tokens");
}
}
#[test]
fn test_embedding_dimension() {
assert_eq!(EMBEDDING_DIM, 384);
}
#[test]
fn test_max_seq_length() {
assert_eq!(MAX_SEQ_LENGTH, 256);
}
}