1use std::path::{Path, PathBuf};
41use std::sync::Mutex;
42
43use ndarray::Array2;
44use ort::session::builder::GraphOptimizationLevel;
45use ort::session::Session;
46use tokenizers::Tokenizer;
47use tracing::{debug, info};
48
49use crate::embedding::EmbeddingService;
50use crate::error::{PulseDBError, Result};
51use crate::types::Embedding;
52
53const DEFAULT_MODEL_NAME: &str = "all-MiniLM-L6-v2";
59const DEFAULT_DIMENSION: usize = 384;
60const DEFAULT_MAX_LENGTH: usize = 256;
61
62const BGE_MODEL_NAME: &str = "bge-base-en-v1.5";
64const BGE_MAX_LENGTH: usize = 512;
65
66const MODEL_FILENAME: &str = "model.onnx";
68const TOKENIZER_FILENAME: &str = "tokenizer.json";
69
70pub struct OnnxEmbedding {
86 session: Mutex<Session>,
90
91 tokenizer: Tokenizer,
94
95 dimension: usize,
97
98 max_length: usize,
100}
101
102impl OnnxEmbedding {
103 pub fn new(model_path: Option<PathBuf>) -> Result<Self> {
130 Self::with_dimension(model_path, DEFAULT_DIMENSION)
131 }
132
133 pub fn with_dimension(model_path: Option<PathBuf>, dimension: usize) -> Result<Self> {
145 let max_length = match dimension {
146 DEFAULT_DIMENSION => DEFAULT_MAX_LENGTH,
147 768 => BGE_MAX_LENGTH,
148 _ => DEFAULT_MAX_LENGTH,
149 };
150
151 let model_dir = resolve_model_dir(model_path.as_deref(), dimension)?;
152
153 info!(
154 model_dir = %model_dir.display(),
155 dimension,
156 max_length,
157 "Loading ONNX embedding model"
158 );
159
160 Self::load_from_dir(&model_dir, dimension, max_length)
161 }
162
163 pub fn download_default_model(dimension: usize) -> Result<PathBuf> {
178 let (model_name, model_url, tokenizer_url) = match dimension {
179 DEFAULT_DIMENSION => (
180 DEFAULT_MODEL_NAME,
181 "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx",
182 "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json",
183 ),
184 768 => (
185 BGE_MODEL_NAME,
186 "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx",
187 "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json",
188 ),
189 _ => {
190 return Err(PulseDBError::embedding(format!(
191 "No default model for dimension {dimension}. \
192 Supported: 384 (all-MiniLM-L6-v2), 768 (bge-base-en-v1.5)"
193 )));
194 }
195 };
196
197 let cache_dir = default_cache_dir(model_name);
198
199 std::fs::create_dir_all(&cache_dir).map_err(|e| {
201 PulseDBError::embedding(format!(
202 "Failed to create model cache directory {}: {e}",
203 cache_dir.display()
204 ))
205 })?;
206
207 let lock_path = cache_dir.join(".download.lock");
210 let lock_file = std::fs::File::create(&lock_path)
211 .map_err(|e| PulseDBError::embedding(format!("Failed to create download lock: {e}")))?;
212 use fs2::FileExt;
213 lock_file.lock_exclusive().map_err(|e| {
214 PulseDBError::embedding(format!("Failed to acquire download lock: {e}"))
215 })?;
216
217 let model_path = cache_dir.join(MODEL_FILENAME);
218 let tokenizer_path = cache_dir.join(TOKENIZER_FILENAME);
219
220 if model_path.exists() && tokenizer_path.exists() {
222 info!(dir = %cache_dir.display(), "Model files already downloaded by another process");
223 return Ok(cache_dir);
224 }
225
226 if !model_path.exists() {
228 info!(url = model_url, dest = %model_path.display(), "Downloading ONNX model");
229 download_file(model_url, &model_path)?;
230 }
231
232 if !tokenizer_path.exists() {
234 info!(url = tokenizer_url, dest = %tokenizer_path.display(), "Downloading tokenizer");
235 download_file(tokenizer_url, &tokenizer_path)?;
236 }
237
238 info!(dir = %cache_dir.display(), "Model files ready");
239 Ok(cache_dir)
240 }
241
242 fn load_from_dir(model_dir: &Path, dimension: usize, max_length: usize) -> Result<Self> {
244 let model_path = model_dir.join(MODEL_FILENAME);
245 let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
246
247 if !model_path.exists() {
249 return Err(PulseDBError::embedding(format!(
250 "Model file not found: {}. \
251 Download with OnnxEmbedding::download_default_model({dimension}) \
252 or provide a directory containing '{MODEL_FILENAME}'",
253 model_path.display()
254 )));
255 }
256 if !tokenizer_path.exists() {
257 return Err(PulseDBError::embedding(format!(
258 "Tokenizer file not found: {}. \
259 The model directory must contain '{TOKENIZER_FILENAME}'",
260 tokenizer_path.display()
261 )));
262 }
263
264 let session = create_session(&model_path)?;
265 let tokenizer = load_tokenizer(&tokenizer_path, max_length)?;
266
267 debug!(dimension, max_length, "ONNX embedding model loaded");
268
269 Ok(Self {
270 session: Mutex::new(session),
271 tokenizer,
272 dimension,
273 max_length,
274 })
275 }
276}
277
278impl EmbeddingService for OnnxEmbedding {
279 fn embed(&self, text: &str) -> Result<Embedding> {
280 if text.is_empty() {
281 return Err(PulseDBError::embedding("Cannot embed empty text"));
282 }
283
284 let encoding = self
286 .tokenizer
287 .encode(text, true)
288 .map_err(|e| PulseDBError::embedding(format!("Tokenization failed: {e}")))?;
289
290 let ids = encoding.get_ids();
291 let mask = encoding.get_attention_mask();
292
293 let len = ids.len().min(self.max_length);
295
296 let input_ids: Vec<i64> = ids[..len].iter().map(|&x| x as i64).collect();
298 let attention_mask: Vec<i64> = mask[..len].iter().map(|&x| x as i64).collect();
299 let token_type_ids: Vec<i64> = vec![0i64; len];
300
301 let ids_array = Array2::from_shape_vec((1, len), input_ids)
302 .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
303 let mask_array = Array2::from_shape_vec((1, len), attention_mask.clone())
304 .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
305 let type_array = Array2::from_shape_vec((1, len), token_type_ids)
306 .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
307
308 let ids_tensor = ort::value::Tensor::from_array(ids_array)
310 .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
311 let mask_tensor = ort::value::Tensor::from_array(mask_array)
312 .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
313 let type_tensor = ort::value::Tensor::from_array(type_array)
314 .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
315
316 let mut session = self
318 .session
319 .lock()
320 .map_err(|e| PulseDBError::embedding(format!("Session lock poisoned: {e}")))?;
321 let outputs = session
322 .run(ort::inputs![
323 "input_ids" => ids_tensor,
324 "attention_mask" => mask_tensor,
325 "token_type_ids" => type_tensor,
326 ])
327 .map_err(|e| PulseDBError::embedding(format!("ONNX inference failed: {e}")))?;
328
329 let token_embeddings = outputs[0]
331 .try_extract_tensor::<f32>()
332 .map_err(|e| PulseDBError::embedding(format!("Output extraction failed: {e}")))?;
333
334 let mask_u32: Vec<u32> = attention_mask.iter().map(|&x| x as u32).collect();
336
337 let pooled = mean_pool_raw(token_embeddings.1, &mask_u32, self.dimension, len);
339 Ok(l2_normalize(&pooled))
340 }
341
342 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
343 if texts.is_empty() {
344 return Ok(vec![]);
345 }
346 if texts.len() == 1 {
347 return Ok(vec![self.embed(texts[0])?]);
348 }
349
350 let encodings: Vec<_> = texts
352 .iter()
353 .map(|t| self.tokenizer.encode(*t, true))
354 .collect::<std::result::Result<Vec<_>, _>>()
355 .map_err(|e| PulseDBError::embedding(format!("Batch tokenization failed: {e}")))?;
356
357 let max_len = encodings
359 .iter()
360 .map(|enc| enc.get_ids().len().min(self.max_length))
361 .max()
362 .unwrap_or(0);
363
364 let batch_size = texts.len();
365
366 let mut input_ids = vec![0i64; batch_size * max_len];
368 let mut attention_mask = vec![0i64; batch_size * max_len];
369 let token_type_ids = vec![0i64; batch_size * max_len];
370
371 for (i, enc) in encodings.iter().enumerate() {
372 let ids = enc.get_ids();
373 let mask = enc.get_attention_mask();
374 let len = ids.len().min(self.max_length);
375
376 for j in 0..len {
377 input_ids[i * max_len + j] = ids[j] as i64;
378 attention_mask[i * max_len + j] = mask[j] as i64;
379 }
380 }
381
382 let ids_array = Array2::from_shape_vec((batch_size, max_len), input_ids)
383 .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
384 let mask_array = Array2::from_shape_vec((batch_size, max_len), attention_mask.clone())
385 .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
386 let type_array = Array2::from_shape_vec((batch_size, max_len), token_type_ids)
387 .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
388
389 let ids_tensor = ort::value::Tensor::from_array(ids_array)
391 .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
392 let mask_tensor = ort::value::Tensor::from_array(mask_array)
393 .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
394 let type_tensor = ort::value::Tensor::from_array(type_array)
395 .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
396
397 let mut session = self
399 .session
400 .lock()
401 .map_err(|e| PulseDBError::embedding(format!("Session lock poisoned: {e}")))?;
402 let outputs = session
403 .run(ort::inputs![
404 "input_ids" => ids_tensor,
405 "attention_mask" => mask_tensor,
406 "token_type_ids" => type_tensor,
407 ])
408 .map_err(|e| PulseDBError::embedding(format!("ONNX inference failed: {e}")))?;
409
410 let token_embeddings = outputs[0]
412 .try_extract_tensor::<f32>()
413 .map_err(|e| PulseDBError::embedding(format!("Output extraction failed: {e}")))?;
414
415 let (_shape, data) = token_embeddings;
416
417 let mut results = Vec::with_capacity(batch_size);
419 for i in 0..batch_size {
420 let text_mask: Vec<u32> = (0..max_len)
421 .map(|j| attention_mask[i * max_len + j] as u32)
422 .collect();
423
424 let offset = i * max_len * self.dimension;
426 let text_data = &data[offset..offset + max_len * self.dimension];
427
428 let pooled = mean_pool_raw(text_data, &text_mask, self.dimension, max_len);
429 results.push(l2_normalize(&pooled));
430 }
431
432 Ok(results)
433 }
434
435 fn dimension(&self) -> usize {
436 self.dimension
437 }
438}
439
440fn create_session(model_path: &Path) -> Result<Session> {
446 Session::builder()
447 .map_err(|e| PulseDBError::embedding(format!("Failed to create session builder: {e}")))?
448 .with_optimization_level(GraphOptimizationLevel::Level3)
450 .map_err(|e| PulseDBError::embedding(format!("Failed to set optimization level: {e}")))?
451 .commit_from_file(model_path)
452 .map_err(|e| {
453 PulseDBError::embedding(format!(
454 "Failed to load ONNX model from {}: {e}",
455 model_path.display()
456 ))
457 })
458}
459
460fn load_tokenizer(tokenizer_path: &Path, max_length: usize) -> Result<Tokenizer> {
462 let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
463 PulseDBError::embedding(format!(
464 "Failed to load tokenizer from {}: {e}",
465 tokenizer_path.display()
466 ))
467 })?;
468
469 tokenizer
471 .with_truncation(Some(tokenizers::TruncationParams {
472 max_length,
473 strategy: tokenizers::TruncationStrategy::LongestFirst,
474 ..Default::default()
475 }))
476 .map_err(|e| PulseDBError::embedding(format!("Failed to set truncation: {e}")))?;
477
478 tokenizer.with_padding(None);
481
482 Ok(tokenizer)
483}
484
485fn resolve_model_dir(model_path: Option<&Path>, dimension: usize) -> Result<PathBuf> {
487 match model_path {
488 Some(path) => {
489 if !path.exists() {
490 return Err(PulseDBError::embedding(format!(
491 "Model directory not found: {}",
492 path.display()
493 )));
494 }
495 Ok(path.to_path_buf())
496 }
497 None => {
498 let model_name = match dimension {
500 DEFAULT_DIMENSION => DEFAULT_MODEL_NAME,
501 768 => BGE_MODEL_NAME,
502 _ => {
503 return Err(PulseDBError::embedding(format!(
504 "No default model for dimension {dimension}. \
505 Provide a model_path for custom dimensions, \
506 or use 384 (all-MiniLM-L6-v2) or 768 (bge-base-en-v1.5)"
507 )));
508 }
509 };
510
511 let cache_dir = default_cache_dir(model_name);
512
513 if !cache_dir.join(MODEL_FILENAME).exists() {
514 return Err(PulseDBError::embedding(format!(
515 "Model not found at {}. \
516 Download with: OnnxEmbedding::download_default_model({dimension})",
517 cache_dir.display()
518 )));
519 }
520
521 Ok(cache_dir)
522 }
523 }
524}
525
526fn default_cache_dir(model_name: &str) -> PathBuf {
533 dirs::cache_dir()
534 .unwrap_or_else(|| PathBuf::from(".cache"))
535 .join("pulsedb")
536 .join("models")
537 .join(model_name)
538}
539
540fn mean_pool_raw(data: &[f32], attention_mask: &[u32], dim: usize, seq_len: usize) -> Vec<f32> {
555 let mut pooled = vec![0.0f32; dim];
556 let mut mask_sum = 0.0f32;
557
558 for (t, &mask_val) in attention_mask.iter().enumerate().take(seq_len) {
559 let weight = mask_val as f32;
560 mask_sum += weight;
561 let offset = t * dim;
562 for d in 0..dim {
563 pooled[d] += data[offset + d] * weight;
564 }
565 }
566
567 if mask_sum > 0.0 {
569 for val in &mut pooled {
570 *val /= mask_sum;
571 }
572 }
573
574 pooled
575}
576
577fn l2_normalize(v: &[f32]) -> Vec<f32> {
583 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
584 if norm > 0.0 {
585 v.iter().map(|x| x / norm).collect()
586 } else {
587 v.to_vec()
588 }
589}
590
591fn download_file(url: &str, dest: &Path) -> Result<()> {
596 let response = ureq::get(url)
597 .call()
598 .map_err(|e| PulseDBError::embedding(format!("Download failed for {url}: {e}")))?;
599
600 let temp = dest.with_extension("tmp");
602 let mut reader = response.into_body().into_reader();
603 let mut file = std::fs::File::create(&temp).map_err(|e| {
604 PulseDBError::embedding(format!("Failed to create file {}: {e}", temp.display()))
605 })?;
606
607 if let Err(e) = std::io::copy(&mut reader, &mut file) {
608 let _ = std::fs::remove_file(&temp);
609 return Err(PulseDBError::embedding(format!(
610 "Failed to write to {}: {e}",
611 dest.display()
612 )));
613 }
614
615 std::fs::rename(&temp, dest).map_err(|e| {
617 let _ = std::fs::remove_file(&temp);
618 PulseDBError::embedding(format!(
619 "Failed to finalize download {}: {e}",
620 dest.display()
621 ))
622 })?;
623
624 Ok(())
625}
626
627#[cfg(test)]
632mod tests {
633 use super::*;
634
635 #[test]
638 fn test_l2_normalize_basic() {
639 let v = vec![3.0, 4.0];
640 let normalized = l2_normalize(&v);
641 assert!((normalized[0] - 0.6).abs() < 1e-6);
643 assert!((normalized[1] - 0.8).abs() < 1e-6);
644
645 let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
647 assert!((norm - 1.0).abs() < 1e-6);
648 }
649
650 #[test]
651 fn test_l2_normalize_zero_vector() {
652 let v = vec![0.0, 0.0, 0.0];
653 let normalized = l2_normalize(&v);
654 assert_eq!(normalized, vec![0.0, 0.0, 0.0]);
656 }
657
658 #[test]
659 fn test_l2_normalize_already_unit() {
660 let v = vec![1.0, 0.0, 0.0];
661 let normalized = l2_normalize(&v);
662 assert!((normalized[0] - 1.0).abs() < 1e-6);
663 assert!((normalized[1] - 0.0).abs() < 1e-6);
664 }
665
666 #[test]
669 fn test_mean_pool_uniform_mask() {
670 let data = vec![
673 1.0, 2.0, 3.0, 5.0, 6.0, 7.0, ];
676 let mask = vec![1u32, 1];
677
678 let pooled = mean_pool_raw(&data, &mask, 3, 2);
679 assert!((pooled[0] - 3.0).abs() < 1e-6);
681 assert!((pooled[1] - 4.0).abs() < 1e-6);
682 assert!((pooled[2] - 5.0).abs() < 1e-6);
683 }
684
685 #[test]
686 fn test_mean_pool_partial_mask() {
687 let data = vec![
689 1.0, 2.0, 3.0, 99.0, 99.0, 99.0, ];
692 let mask = vec![1u32, 0]; let pooled = mean_pool_raw(&data, &mask, 3, 2);
695 assert!((pooled[0] - 1.0).abs() < 1e-6);
697 assert!((pooled[1] - 2.0).abs() < 1e-6);
698 assert!((pooled[2] - 3.0).abs() < 1e-6);
699 }
700
701 #[test]
702 fn test_mean_pool_zero_mask() {
703 let data = vec![99.0, 99.0, 99.0];
705 let mask = vec![0u32];
706
707 let pooled = mean_pool_raw(&data, &mask, 3, 1);
708 assert_eq!(pooled, vec![0.0, 0.0, 0.0]);
710 }
711
712 #[test]
715 fn test_resolve_model_dir_custom_path_missing() {
716 let result = resolve_model_dir(Some(Path::new("/nonexistent/path")), 384);
717 assert!(result.is_err());
718 let err = result.unwrap_err().to_string();
719 assert!(err.contains("not found"), "Error: {err}");
720 }
721
722 #[test]
723 fn test_resolve_model_dir_unsupported_dimension() {
724 let result = resolve_model_dir(None, 999);
725 assert!(result.is_err());
726 let err = result.unwrap_err().to_string();
727 assert!(err.contains("No default model"), "Error: {err}");
728 }
729
730 #[test]
731 fn test_default_cache_dir_format() {
732 let dir = default_cache_dir("test-model");
733 let path_str = dir.to_string_lossy();
735 assert!(path_str.contains("pulsedb"), "Path: {path_str}");
736 assert!(path_str.contains("models"), "Path: {path_str}");
737 assert!(path_str.contains("test-model"), "Path: {path_str}");
738 }
739
740 #[test]
743 fn test_onnx_embedding_is_send_sync() {
744 fn assert_send_sync<T: Send + Sync>() {}
745 assert_send_sync::<OnnxEmbedding>();
746 }
747}