1use async_trait::async_trait;
2use ort::session::{Session, builder::GraphOptimizationLevel};
3use ort::value::Tensor;
4use std::sync::{Arc, Mutex};
5use tokenizers::Tokenizer;
6use tracing::info;
7
8use crate::{
9 config::OnnxEmbeddingConfig,
10 download::{ModelUrls, ensure_model_exists, ensure_tokenizer_exists},
11 engine::EmbeddingEngine,
12 error::{EmbeddingError, EmbeddingResult},
13 utils::{l2_normalize, mean_pool},
14};
15type TokenizationBatch = (Vec<Vec<i64>>, Vec<Vec<i64>>);
17pub struct OnnxEmbeddingEngine {
22 session: Arc<Mutex<Session>>,
23 tokenizer: Arc<Mutex<Tokenizer>>,
24 config: OnnxEmbeddingConfig,
25}
26
27impl OnnxEmbeddingEngine {
28 pub fn new(config: OnnxEmbeddingConfig) -> EmbeddingResult<Self> {
47 ort::init().commit();
48
49 if !config.model_path.exists() {
50 return Err(EmbeddingError::ModelLoadError(format!(
51 "Model file not found: {:?}",
52 config.model_path
53 )));
54 }
55
56 info!("Loading tokenizer: {:?}", config.tokenizer_path);
57 let tokenizer = Tokenizer::from_file(&config.tokenizer_path).map_err(|e| {
58 EmbeddingError::TokenizerError(format!(
59 "Failed to load tokenizer from {:?}: {}. Please ensure tokenizer.json file exists.",
60 config.tokenizer_path, e
61 ))
62 })?;
63
64 info!("Loading ONNX model: {:?}", config.model_path);
65 let session = Session::builder()
66 .map_err(|e| EmbeddingError::ModelLoadError(e.to_string()))?
67 .with_optimization_level(GraphOptimizationLevel::Level3)
68 .map_err(|e| EmbeddingError::ModelLoadError(e.to_string()))?
69 .commit_from_file(&config.model_path)
70 .map_err(|e| EmbeddingError::ModelLoadError(e.to_string()))?;
71
72 info!(
73 "✓ Loaded {} (dims: {}, max_seq_len: {})",
74 config.model_name, config.dimensions, config.max_sequence_length
75 );
76
77 Ok(Self {
78 session: Arc::new(Mutex::new(session)),
79 tokenizer: Arc::new(Mutex::new(tokenizer)),
80 config,
81 })
82 }
83
84 pub async fn with_auto_download(config: OnnxEmbeddingConfig) -> EmbeddingResult<Self> {
104 let (model_url, tokenizer_url) = match config.model_name.to_lowercase().as_str() {
105 "bge-small-en-v1.5" | "bge-small-v1.5" => (
106 ModelUrls::BGE_SMALL.model_url,
107 ModelUrls::BGE_SMALL.tokenizer_url,
108 ),
109 "all-minilm-l6-v2" => (
110 ModelUrls::MINILM_L6.model_url,
111 ModelUrls::MINILM_L6.tokenizer_url,
112 ),
113 _ => {
114 return Err(EmbeddingError::ModelLoadError(format!(
115 "Unknown model name '{}'. Supported: 'bge-small-en-v1.5', 'all-MiniLM-L6-v2'",
116 config.model_name
117 )));
118 }
119 };
120
121 let model_downloaded = ensure_model_exists(&config.model_path, model_url).await?;
122 if model_downloaded {
123 info!("✓ Downloaded model to {:?}", config.model_path);
124 }
125
126 let tokenizer_downloaded =
127 ensure_tokenizer_exists(&config.tokenizer_path, tokenizer_url).await?;
128 if tokenizer_downloaded {
129 info!("✓ Downloaded tokenizer to {:?}", config.tokenizer_path);
130 }
131
132 Self::new(config)
133 }
134
135 fn tokenize_batch(&self, texts: &[&str]) -> EmbeddingResult<TokenizationBatch> {
145 #[allow(clippy::unwrap_used, reason = "lock poison is unrecoverable")]
146 let tokenizer = self.tokenizer.lock().unwrap(); let max_len = self.config.max_sequence_length;
148
149 let mut input_ids_batch = Vec::new();
150 let mut attention_mask_batch = Vec::new();
151
152 for text in texts {
153 let encoding = tokenizer
154 .encode(*text, true)
155 .map_err(|e| EmbeddingError::TokenizerError(e.to_string()))?;
156
157 let mut ids = encoding
158 .get_ids()
159 .iter()
160 .map(|&id| id as i64)
161 .collect::<Vec<_>>();
162 let mut mask = encoding
163 .get_attention_mask()
164 .iter()
165 .map(|&m| m as i64)
166 .collect::<Vec<_>>();
167
168 if ids.len() > max_len {
169 ids.truncate(max_len);
170 mask.truncate(max_len);
171 }
172
173 while ids.len() < max_len {
174 ids.push(0); mask.push(0); }
177
178 input_ids_batch.push(ids);
179 attention_mask_batch.push(mask);
180 }
181
182 Ok((input_ids_batch, attention_mask_batch))
183 }
184
185 fn extract_embedding(
189 &self,
190 output_data: &[f32],
191 output_shape: &[usize],
192 attention_mask: &[i64],
193 ) -> EmbeddingResult<Vec<f32>> {
194 let output_dim = self.config.dimensions;
195
196 if output_shape.len() == 3 {
197 let seq_len = output_shape[1];
198 let hidden_dim = output_shape[2];
199
200 let pooled = mean_pool(output_data, seq_len, hidden_dim, attention_mask, output_dim);
201 Ok(l2_normalize(&pooled))
202 } else if output_shape.len() == 2 {
203 let embedding: Vec<f32> = output_data.iter().take(output_dim).copied().collect();
204 Ok(l2_normalize(&embedding))
205 } else {
206 Err(EmbeddingError::InferenceError(format!(
207 "Unexpected output shape: {output_shape:?}"
208 )))
209 }
210 }
211}
212
213impl OnnxEmbeddingEngine {
214 async fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
222 if texts.is_empty() {
223 return Ok(Vec::new());
224 }
225
226 let batch_size = texts.len();
227 let seq_len = self.config.max_sequence_length;
228
229 let (input_ids_batch, attention_mask_batch) = self.tokenize_batch(texts)?;
230
231 let input_ids_flat: Vec<i64> = input_ids_batch.iter().flatten().copied().collect();
232 let attention_mask_flat: Vec<i64> =
233 attention_mask_batch.iter().flatten().copied().collect();
234
235 let input_ids_tensor = Tensor::from_array((vec![batch_size, seq_len], input_ids_flat))
236 .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?;
237 let attention_mask_tensor =
238 Tensor::from_array((vec![batch_size, seq_len], attention_mask_flat))
239 .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?;
240 let token_type_ids_tensor =
241 Tensor::from_array((vec![batch_size, seq_len], vec![0i64; batch_size * seq_len]))
242 .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?;
243
244 let session = Arc::clone(&self.session);
245 let attention_masks = attention_mask_batch.clone();
246
247 let (output_shape, output_data) = tokio::task::spawn_blocking(move || {
248 #[allow(clippy::unwrap_used, reason = "lock poison is unrecoverable")]
249 let mut session = session.lock().unwrap(); let outputs = session.run(ort::inputs! {
251 "input_ids" => input_ids_tensor,
252 "attention_mask" => attention_mask_tensor,
253 "token_type_ids" => token_type_ids_tensor,
254 })?;
255
256 let (shape, data) = outputs[0].try_extract_tensor::<f32>()?;
257 let shape_usize: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
258 Ok::<_, Box<dyn std::error::Error + Send + Sync>>((shape_usize, data.to_vec()))
259 })
260 .await
261 .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?
262 .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?;
263
264 let mut embeddings = Vec::with_capacity(batch_size);
265
266 if output_shape.len() == 3 {
267 let seq_len = output_shape[1];
268 let hidden_dim = output_shape[2];
269 let sample_size = seq_len * hidden_dim;
270
271 for (i, mask) in attention_masks.iter().enumerate().take(batch_size) {
272 let start = i * sample_size;
273 let end = start + sample_size;
274 let sample_data = &output_data[start..end];
275
276 let embedding =
277 self.extract_embedding(sample_data, &[1, seq_len, hidden_dim], mask)?;
278
279 embeddings.push(embedding);
280 }
281 } else if output_shape.len() == 2 && batch_size == 1 {
282 let embedding =
283 self.extract_embedding(&output_data, &output_shape, &attention_masks[0])?;
284 embeddings.push(embedding);
285 } else {
286 return Err(EmbeddingError::InferenceError(format!(
287 "Unexpected output shape: {output_shape:?} for batch_size {batch_size}"
288 )));
289 }
290
291 Ok(embeddings)
292 }
293}
294
295#[async_trait]
296impl EmbeddingEngine for OnnxEmbeddingEngine {
297 async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
304 if texts.is_empty() {
305 return Ok(Vec::new());
306 }
307 let batch = self.config.batch_size.max(1);
308 if texts.len() <= batch {
309 return self.embed_batch(texts).await;
310 }
311 let mut embeddings = Vec::with_capacity(texts.len());
312 for chunk in texts.chunks(batch) {
313 embeddings.extend(self.embed_batch(chunk).await?);
314 }
315 Ok(embeddings)
316 }
317
318 fn dimension(&self) -> usize {
319 self.config.dimensions
320 }
321
322 fn batch_size(&self) -> usize {
323 self.config.batch_size
324 }
325
326 fn max_sequence_length(&self) -> usize {
327 self.config.max_sequence_length
328 }
329}
330
331#[cfg(test)]
332#[allow(
333 clippy::unwrap_used,
334 clippy::expect_used,
335 reason = "test code — panics are acceptable failures"
336)]
337mod tests {
338 use super::*;
339
340 #[tokio::test]
341 async fn test_tokenization() {
342 let tokenizer_path = "../../target/models/bge-small-tokenizer.json";
345 if std::path::Path::new(tokenizer_path).exists() {
346 let tokenizer = Tokenizer::from_file(tokenizer_path).expect("Failed to load tokenizer");
347
348 let encoding = tokenizer.encode("Hello world", true).unwrap();
349 let ids = encoding.get_ids();
350
351 assert!(!ids.is_empty());
352 assert_eq!(ids[0], 101); }
354 }
355
356 #[test]
357 fn test_l2_normalization() {
358 use crate::utils::{compute_norm, l2_normalize};
359
360 let vec = vec![3.0, 4.0];
361 let normalized = l2_normalize(&vec);
362 let norm = compute_norm(&normalized);
363
364 assert!((norm - 1.0).abs() < 0.001);
365 }
366
367 #[tokio::test]
368 async fn test_engine_creation() {
369 let config = OnnxEmbeddingConfig::default();
370 let result = OnnxEmbeddingEngine::new(config);
372
373 if let Err(e) = result {
375 assert!(
376 e.to_string().contains("Model file not found")
377 || e.to_string().contains("tokenizer")
378 );
379 }
380 }
381
382 #[tokio::test]
388 async fn embed_sub_batches_large_inputs() {
389 let model = "../../target/models/BGE-Small-v1.5-model_quantized.onnx";
390 let tok = "../../target/models/bge-small-tokenizer.json";
391 if !std::path::Path::new(model).exists() || !std::path::Path::new(tok).exists() {
392 return; }
394
395 let config = OnnxEmbeddingConfig {
396 model_path: model.into(),
397 tokenizer_path: tok.into(),
398 batch_size: 4, ..Default::default()
400 };
401
402 let engine = OnnxEmbeddingEngine::new(config).expect("engine creation");
403
404 let texts: Vec<String> = (0..10).map(|i| format!("sentence number {i}")).collect();
406 let refs: Vec<&str> = texts.iter().map(String::as_str).collect();
407
408 let chunked = engine.embed(&refs).await.expect("embed");
409 assert_eq!(
410 chunked.len(),
411 10,
412 "one embedding per input across sub-batches"
413 );
414 assert_eq!(chunked[0].len(), engine.dimension());
415
416 let single = engine.embed_batch(&refs).await.expect("embed_batch");
421 assert_eq!(single.len(), chunked.len());
422 for (a, b) in chunked.iter().zip(single.iter()) {
423 assert_eq!(a.len(), b.len());
424 let cos: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
425 assert!(cos > 0.999, "chunked vs single-batch diverged: cos={cos}");
426 }
427 }
428}