1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::Arc;
43use tokenizers::Tokenizer;
44use tracing::{debug, info, instrument, warn};
45
46pub struct EmbeddingEngine {
53 session: Arc<Mutex<Session>>,
55 processor: Arc<BatchProcessor>,
57 config: ModelConfig,
59 dimension: usize,
61}
62
63impl EmbeddingEngine {
64 #[instrument(skip_all, fields(model = %config.model))]
68 pub async fn new(config: ModelConfig) -> Result<Self> {
69 info!(
70 "Initializing ONNX embedding engine with model: {}",
71 config.model
72 );
73
74 let (tokenizer_path, onnx_path) = Self::download_model_files(&config).await?;
76
77 info!("Loading tokenizer from {:?}", tokenizer_path);
79 let tokenizer = Tokenizer::from_file(&tokenizer_path)
80 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
81
82 info!("Loading ONNX model from {:?}", onnx_path);
84 let num_threads = config.num_threads.unwrap_or(1);
85 let session = Session::builder()
86 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
87 .with_optimization_level(GraphOptimizationLevel::Level3)
88 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
89 .with_intra_threads(num_threads)
90 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
91 .commit_from_file(&onnx_path)
92 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
93
94 let dimension = config.model.dimension();
95 let processor = Arc::new(BatchProcessor::new(
96 tokenizer,
97 config.model,
98 config.max_batch_size,
99 ));
100
101 info!(
102 "ONNX embedding engine ready: model={}, dimension={}, threads={}",
103 config.model, dimension, num_threads
104 );
105
106 Ok(Self {
107 session: Arc::new(Mutex::new(session)),
108 processor,
109 config,
110 dimension,
111 })
112 }
113
114 #[instrument(skip_all, fields(model = %config.model))]
119 async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf)> {
120 let model_id = config.model.model_id();
121 let onnx_repo_id = config.model.onnx_repo_id();
122 let onnx_filename = config.model.onnx_filename();
123
124 info!(
125 "Resolving model files: tokenizer={}, onnx={}@{}",
126 model_id, onnx_filename, onnx_repo_id
127 );
128
129 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
130 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
131
132 let onnx_subdir = onnx_cache_dir.join("onnx");
134 std::fs::create_dir_all(&onnx_subdir)?;
135
136 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
137 let onnx_basename = Path::new(onnx_filename)
139 .file_name()
140 .and_then(|s| s.to_str())
141 .unwrap_or("model_quantized.onnx");
142 let local_onnx = onnx_subdir.join(onnx_basename);
143
144 let tokenizer_needs_download = !local_tokenizer.exists();
146 let onnx_needs_download = !local_onnx.exists();
147
148 if tokenizer_needs_download || onnx_needs_download {
149 let model_id_owned = model_id.to_string();
150 let onnx_repo_id_owned = onnx_repo_id.to_string();
151 let onnx_filename_owned = onnx_filename.to_string();
152 let tokenizer_cache = tokenizer_cache_dir.clone();
153 let onnx_cache = onnx_cache_dir.clone();
154
155 tokio::task::spawn_blocking(move || {
156 if !tokenizer_cache.join("tokenizer.json").exists() {
157 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
158 .map_err(|e| {
159 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
160 })?;
161 }
162 if !onnx_cache.join(&onnx_filename_owned).exists() {
163 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
164 .map_err(|e| {
165 InferenceError::HubError(format!(
166 "Failed to download ONNX model: {}",
167 e
168 ))
169 })?;
170 }
171 Ok::<_, InferenceError>(())
172 })
173 .await
174 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
175 } else {
176 info!("All model files found in local cache");
177 }
178
179 let final_onnx = onnx_cache_dir.join(onnx_filename);
181
182 info!(
183 "Model files ready: tokenizer={:?}, onnx={:?}",
184 local_tokenizer, final_onnx
185 );
186 Ok((local_tokenizer, final_onnx))
187 }
188
189 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
191 let base = std::env::var("HF_HOME")
192 .map(PathBuf::from)
193 .unwrap_or_else(|_| {
194 let home = std::env::var("HOME").unwrap_or_else(|_| {
195 warn!("HOME environment variable not set, using /tmp for model cache");
196 "/tmp".to_string()
197 });
198 PathBuf::from(home).join(".cache").join("huggingface")
199 });
200 let dir = base.join("dakera").join(model_id.replace('/', "--"));
201 std::fs::create_dir_all(&dir)?;
202 Ok(dir)
203 }
204
205 pub fn download_hf_file_pub(
211 model_id: &str,
212 filename: &str,
213 cache_dir: &Path,
214 ) -> std::result::Result<PathBuf, String> {
215 Self::download_hf_file(model_id, filename, cache_dir)
216 }
217
218 fn download_hf_file(
219 model_id: &str,
220 filename: &str,
221 cache_dir: &Path,
222 ) -> std::result::Result<PathBuf, String> {
223 let file_path = cache_dir.join(filename);
225 if file_path.exists() {
226 info!("Cached: {}/{}", model_id, filename);
227 return Ok(file_path);
228 }
229
230 if let Some(parent) = file_path.parent() {
232 std::fs::create_dir_all(parent)
233 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
234 }
235
236 let url = format!(
237 "https://huggingface.co/{}/resolve/main/{}",
238 model_id, filename
239 );
240 info!("Downloading: {}", url);
241
242 let agent = ureq::AgentBuilder::new()
244 .redirects(0)
245 .timeout(std::time::Duration::from_secs(300))
246 .build();
247
248 let mut current_url = url.clone();
249 let mut redirects = 0;
250 let max_redirects = 10;
251
252 let response = loop {
253 let resp = agent.get(¤t_url).call();
254
255 let r = match resp {
256 Ok(r) => r,
257 Err(ureq::Error::Status(_status, r)) => r,
258 Err(e) => return Err(format!("{}: {}", filename, e)),
259 };
260
261 let status = r.status();
262 if (200..300).contains(&status) {
263 break r;
264 } else if (300..400).contains(&status) {
265 redirects += 1;
266 if redirects > max_redirects {
267 return Err(format!("{}: too many redirects", filename));
268 }
269 let location = r
270 .header("location")
271 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
272 .to_string();
273
274 current_url = if location.starts_with('/') {
276 let parsed = url::Url::parse(¤t_url)
277 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
278 let host = parsed.host_str().ok_or_else(|| {
279 format!("{}: redirect URL missing host: {}", filename, current_url)
280 })?;
281 format!("{}://{}{}", parsed.scheme(), host, location)
282 } else {
283 location
284 };
285 info!("Redirect {} → {}", redirects, current_url);
286 } else {
287 return Err(format!("{}: HTTP {}", filename, status));
288 }
289 };
290
291 let mut bytes = Vec::new();
292 response
293 .into_reader()
294 .take(500_000_000) .read_to_end(&mut bytes)
296 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
297
298 std::fs::write(&file_path, &bytes)
299 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
300
301 info!("Downloaded {} ({} bytes)", filename, bytes.len());
302 Ok(file_path)
303 }
304
305 pub fn dimension(&self) -> usize {
307 self.dimension
308 }
309
310 pub fn model(&self) -> EmbeddingModel {
312 self.config.model
313 }
314
315 #[instrument(skip(self, text), fields(text_len = text.len()))]
319 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
320 let texts = vec![text.to_string()];
321 let prepared = self.processor.prepare_texts(&texts, true);
322 let embeddings = self.embed_batch_internal(&prepared).await?;
323 embeddings.into_iter().next().ok_or_else(|| {
324 InferenceError::InferenceError("No embedding returned for query".to_string())
325 })
326 }
327
328 #[instrument(skip(self, texts), fields(count = texts.len()))]
332 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
333 let prepared = self.processor.prepare_texts(texts, true);
334 self.embed_batch_internal(&prepared).await
335 }
336
337 #[instrument(skip(self, text), fields(text_len = text.len()))]
341 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
342 let texts = vec![text.to_string()];
343 let prepared = self.processor.prepare_texts(&texts, false);
344 let embeddings = self.embed_batch_internal(&prepared).await?;
345 embeddings.into_iter().next().ok_or_else(|| {
346 InferenceError::InferenceError("No embedding returned for document".to_string())
347 })
348 }
349
350 #[instrument(skip(self, texts), fields(count = texts.len()))]
354 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
355 let prepared = self.processor.prepare_texts(texts, false);
356 self.embed_batch_internal(&prepared).await
357 }
358
359 #[instrument(skip(self, texts), fields(count = texts.len()))]
361 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
362 self.embed_batch_internal(texts).await
363 }
364
365 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
370 if texts.is_empty() {
371 return Ok(vec![]);
372 }
373
374 let batches = self.processor.split_into_batches(texts);
375 let mut all_embeddings = Vec::with_capacity(texts.len());
376
377 for batch in batches {
378 let batch_owned: Vec<String> = batch.to_vec();
379 let session = Arc::clone(&self.session);
380 let processor = Arc::clone(&self.processor);
381 let normalize = self.config.model.normalize_embeddings();
382
383 let batch_embeddings = tokio::task::spawn_blocking(move || {
384 let mut session_guard = session.lock();
385 Self::process_batch_blocking(
386 &batch_owned,
387 &mut session_guard,
388 &processor,
389 normalize,
390 )
391 })
392 .await
393 .map_err(|e| {
394 InferenceError::InferenceError(format!("Inference task panicked: {}", e))
395 })??;
396
397 all_embeddings.extend(batch_embeddings);
398 }
399
400 Ok(all_embeddings)
401 }
402
403 fn process_batch_blocking(
407 texts: &[String],
408 session: &mut Session,
409 processor: &BatchProcessor,
410 normalize: bool,
411 ) -> Result<Vec<Vec<f32>>> {
412 let prepared = processor.tokenize_batch(texts)?;
414 let batch_size = prepared.batch_size;
415 let seq_len = prepared.seq_len;
416
417 let attention_mask_flat = prepared.attention_mask.clone();
419
420 let input_ids_tensor =
422 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
423 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
424 let attention_mask_tensor =
425 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
426 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
427 let token_type_ids_tensor =
428 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
429 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
430
431 let outputs = session
433 .run(inputs![
434 "input_ids" => input_ids_tensor,
435 "attention_mask" => attention_mask_tensor,
436 "token_type_ids" => token_type_ids_tensor
437 ])
438 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
439
440 let (ort_shape, lhs_slice) = outputs[0]
444 .try_extract_tensor::<f32>()
445 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
446
447 if ort_shape.len() != 3 {
448 return Err(InferenceError::InferenceError(format!(
449 "Expected 3D last_hidden_state, got {} dims",
450 ort_shape.len()
451 )));
452 }
453 let hidden_size = ort_shape[2] as usize;
454
455 let mut embeddings = mean_pooling(
457 lhs_slice,
458 batch_size,
459 seq_len,
460 hidden_size,
461 &attention_mask_flat,
462 );
463
464 if normalize {
466 normalize_embeddings(&mut embeddings);
467 }
468
469 debug!(
470 "Generated {} embeddings of dimension {}",
471 embeddings.len(),
472 embeddings.first().map(|e| e.len()).unwrap_or(0)
473 );
474
475 Ok(embeddings)
476 }
477
478 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
480 let tokens_per_text =
482 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
483 let total_tokens = tokens_per_text * text_count as f64;
484 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
485 (total_tokens / tokens_per_second) * 1000.0
486 }
487}
488
489impl std::fmt::Debug for EmbeddingEngine {
490 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491 f.debug_struct("EmbeddingEngine")
492 .field("model", &self.config.model)
493 .field("dimension", &self.dimension)
494 .field("max_batch_size", &self.config.max_batch_size)
495 .finish()
496 }
497}
498
499pub struct EmbeddingEngineBuilder {
501 config: ModelConfig,
502}
503
504impl EmbeddingEngineBuilder {
505 pub fn new() -> Self {
507 Self {
508 config: ModelConfig::default(),
509 }
510 }
511
512 pub fn model(mut self, model: EmbeddingModel) -> Self {
514 self.config.model = model;
515 self
516 }
517
518 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
520 self.config.cache_dir = Some(dir.into());
521 self
522 }
523
524 pub fn max_batch_size(mut self, size: usize) -> Self {
526 self.config.max_batch_size = size;
527 self
528 }
529
530 pub fn use_gpu(mut self, enable: bool) -> Self {
532 self.config.use_gpu = enable;
533 self
534 }
535
536 pub fn num_threads(mut self, threads: usize) -> Self {
538 self.config.num_threads = Some(threads);
539 self
540 }
541
542 pub async fn build(self) -> Result<EmbeddingEngine> {
544 EmbeddingEngine::new(self.config).await
545 }
546}
547
548impl Default for EmbeddingEngineBuilder {
549 fn default() -> Self {
550 Self::new()
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_estimate_time() {
560 let config = ModelConfig::new(EmbeddingModel::MiniLM);
561 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
562 assert!(tokens_per_second > 0.0);
563 }
564
565 #[test]
566 fn test_builder() {
567 let builder = EmbeddingEngineBuilder::new()
568 .model(EmbeddingModel::BgeSmall)
569 .max_batch_size(64)
570 .use_gpu(false);
571
572 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
573 assert_eq!(builder.config.max_batch_size, 64);
574 assert!(!builder.config.use_gpu);
575 }
576
577 #[test]
583 fn test_model_cache_dir_with_hf_home() {
584 use std::sync::Mutex;
585 static ENV_LOCK: Mutex<()> = Mutex::new(());
586 let _guard = ENV_LOCK.lock().unwrap();
587
588 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
589 std::env::set_var("HF_HOME", &tmp);
590 let result = EmbeddingEngine::model_cache_dir("org/my-model");
591 std::env::remove_var("HF_HOME");
592
593 let path = result.unwrap();
594 assert!(
595 path.starts_with(&tmp),
596 "expected path under {tmp:?}, got {path:?}"
597 );
598 assert!(
599 path.to_str().unwrap().contains("org--my-model"),
600 "model_id separator not applied: {path:?}"
601 );
602 }
603
604 #[test]
605 fn test_model_cache_dir_contains_dakera_subdir() {
606 let path =
607 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
608 let s = path.to_str().unwrap();
609 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
610 assert!(
611 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
612 "expected transformed model id in path: {s}"
613 );
614 }
615
616 #[test]
617 fn test_model_cache_dir_creates_directory() {
618 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
619 assert!(dir.exists(), "model_cache_dir should create the directory");
620 }
621
622 #[test]
625 fn test_download_hf_file_returns_path_when_already_cached() {
626 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
627 std::fs::create_dir_all(&tmp).unwrap();
628 let file_path = tmp.join("config.json");
629 std::fs::write(&file_path, b"{}").unwrap();
630
631 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
632 assert!(result.is_ok());
633 assert_eq!(result.unwrap(), file_path);
634 }
635
636 #[test]
637 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
638 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
639 let onnx_dir = tmp.join("onnx");
640 std::fs::create_dir_all(&onnx_dir).unwrap();
641 let onnx_path = onnx_dir.join("model_quantized.onnx");
642 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
643
644 let result = EmbeddingEngine::download_hf_file(
646 "Xenova/all-MiniLM-L6-v2",
647 "onnx/model_quantized.onnx",
648 &tmp,
649 );
650 assert!(result.is_ok());
651 assert_eq!(result.unwrap(), onnx_path);
652 }
653
654 #[test]
657 fn test_builder_default_impl() {
658 let b1 = EmbeddingEngineBuilder::new();
659 let b2 = EmbeddingEngineBuilder::default();
660 assert_eq!(b1.config.model, b2.config.model);
661 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
662 }
663
664 #[test]
665 fn test_builder_model_field() {
666 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
667 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
668 }
669
670 #[test]
671 fn test_builder_cache_dir() {
672 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
673 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
674 }
675
676 #[test]
677 fn test_builder_max_batch_size() {
678 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
679 assert_eq!(builder.config.max_batch_size, 128);
680 }
681
682 #[test]
683 fn test_builder_use_gpu_true() {
684 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
685 assert!(builder.config.use_gpu);
686 }
687
688 #[test]
689 fn test_builder_use_gpu_false() {
690 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
691 assert!(!builder.config.use_gpu);
692 }
693
694 #[test]
695 fn test_builder_num_threads() {
696 let builder = EmbeddingEngineBuilder::new().num_threads(4);
697 assert_eq!(builder.config.num_threads, Some(4));
698 }
699
700 #[test]
701 fn test_builder_chain_all_fields() {
702 let builder = EmbeddingEngineBuilder::new()
703 .model(EmbeddingModel::BgeSmall)
704 .cache_dir("/cache")
705 .max_batch_size(16)
706 .use_gpu(false)
707 .num_threads(2);
708
709 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
710 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
711 assert_eq!(builder.config.max_batch_size, 16);
712 assert!(!builder.config.use_gpu);
713 assert_eq!(builder.config.num_threads, Some(2));
714 }
715
716 #[test]
719 fn test_estimate_time_zero_count() {
720 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
721 let estimate = (0.0 / tps) * 1000.0;
722 assert_eq!(estimate, 0.0);
723 }
724
725 #[test]
726 fn test_estimate_time_formula_cpu() {
727 let model = EmbeddingModel::MiniLM;
730 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
731 let total_tokens = tokens_per_text * 10.0;
732 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
733 assert!(
734 (estimate - 50.0).abs() < 1e-6,
735 "expected 50.0ms, got {estimate}"
736 );
737 }
738
739 #[test]
740 fn test_estimate_time_capped_at_max_seq_length() {
741 let model = EmbeddingModel::MiniLM;
742 let avg_len = 100_000;
743 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
744 assert_eq!(tokens_per_text, 256.0);
745 }
746
747 #[test]
750 fn test_model_config_new() {
751 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
752 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
753 assert_eq!(cfg.max_batch_size, 32);
754 assert!(!cfg.use_gpu);
755 assert!(cfg.cache_dir.is_none());
756 assert!(cfg.num_threads.is_none());
757 }
758
759 #[test]
760 fn test_model_config_default() {
761 let cfg = ModelConfig::default();
762 assert_eq!(cfg.model, EmbeddingModel::MiniLM);
763 assert_eq!(cfg.max_batch_size, 32);
764 assert!(!cfg.use_gpu);
765 }
766
767 #[test]
768 fn test_model_config_with_cache_dir() {
769 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
770 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
771 }
772
773 #[test]
774 fn test_model_config_with_max_batch_size() {
775 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
776 assert_eq!(cfg.max_batch_size, 64);
777 }
778
779 #[test]
780 fn test_model_config_with_gpu() {
781 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
782 assert!(cfg.use_gpu);
783 }
784
785 #[test]
786 fn test_model_config_with_num_threads() {
787 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
788 assert_eq!(cfg.num_threads, Some(8));
789 }
790
791 #[test]
792 fn test_model_config_chained_builder() {
793 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
794 .with_cache_dir("/cache")
795 .with_max_batch_size(16)
796 .with_gpu(false)
797 .with_num_threads(4);
798 assert_eq!(cfg.model, EmbeddingModel::E5Small);
799 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
800 assert_eq!(cfg.max_batch_size, 16);
801 assert!(!cfg.use_gpu);
802 assert_eq!(cfg.num_threads, Some(4));
803 }
804
805 #[test]
809 fn test_model_cache_dir_no_home_fallback() {
810 use std::sync::Mutex;
811 static ENV_LOCK: Mutex<()> = Mutex::new(());
812 let _guard = ENV_LOCK.lock().unwrap();
813
814 let saved_home = std::env::var("HOME").ok();
816 let saved_hf = std::env::var("HF_HOME").ok();
817 unsafe {
818 std::env::remove_var("HOME");
819 std::env::remove_var("HF_HOME");
820 }
821
822 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
823
824 if let Some(h) = saved_home {
826 unsafe { std::env::set_var("HOME", h) };
827 }
828 if let Some(h) = saved_hf {
829 unsafe { std::env::set_var("HF_HOME", h) };
830 }
831
832 let path = result.unwrap();
833 assert!(
835 path.starts_with("/tmp"),
836 "expected path under /tmp, got {path:?}"
837 );
838 }
839
840 #[test]
841 fn test_model_cache_dir_deep_model_id() {
842 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
843 let s = path.to_str().unwrap();
844 assert!(
846 s.contains("org--sub--model-name-with-dashes"),
847 "expected transformed path, got: {s}"
848 );
849 }
850
851 #[test]
852 fn test_model_cache_dir_minilm_model_id() {
853 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
854 let s = path.to_str().unwrap();
855 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
856 }
857
858 #[test]
859 fn test_model_cache_dir_bge_model_id() {
860 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
861 let s = path.to_str().unwrap();
862 assert!(s.contains("BAAI--bge-small-en-v1.5"));
863 }
864
865 #[test]
866 fn test_model_cache_dir_e5_model_id() {
867 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
868 let s = path.to_str().unwrap();
869 assert!(s.contains("intfloat--e5-small-v2"));
870 }
871
872 #[test]
875 fn test_download_hf_file_pytorch_bin_cached() {
876 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
877 std::fs::create_dir_all(&tmp).unwrap();
878 let model_path = tmp.join("pytorch_model.bin");
879 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
880
881 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
882 assert!(result.is_ok());
883 assert_eq!(result.unwrap(), model_path);
884 }
885
886 #[test]
887 fn test_download_hf_file_tokenizer_cached() {
888 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
889 std::fs::create_dir_all(&tmp).unwrap();
890 let tok_path = tmp.join("tokenizer.json");
891 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
892
893 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
894 assert!(result.is_ok());
895 assert_eq!(result.unwrap(), tok_path);
896 }
897
898 #[test]
899 fn test_download_hf_file_config_json_cached() {
900 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
901 std::fs::create_dir_all(&tmp).unwrap();
902 let cfg_path = tmp.join("config.json");
903 std::fs::write(&cfg_path, b"{}").unwrap();
904
905 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
906 assert!(result.is_ok());
907 assert_eq!(result.unwrap(), cfg_path);
908 }
909
910 #[tokio::test]
916 async fn test_new_fails_with_invalid_tokenizer_json() {
917 use std::sync::Mutex;
918 static ENV_LOCK: Mutex<()> = Mutex::new(());
919 let _guard = ENV_LOCK.lock().unwrap();
920
921 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
923 let model_dir = tmp
924 .join("dakera")
925 .join("sentence-transformers--all-MiniLM-L6-v2");
926 std::fs::create_dir_all(&model_dir).unwrap();
927 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
929 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
931 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
932
933 unsafe { std::env::set_var("HF_HOME", &tmp) };
934
935 let config = ModelConfig::new(EmbeddingModel::MiniLM);
936 let result = EmbeddingEngine::new(config).await;
937
938 unsafe { std::env::remove_var("HF_HOME") };
939
940 assert!(
942 result.is_err(),
943 "expected Err from new() with invalid tokenizer, got Ok"
944 );
945 }
946
947 #[test]
950 fn test_builder_with_all_models() {
951 for model in [
952 EmbeddingModel::MiniLM,
953 EmbeddingModel::BgeSmall,
954 EmbeddingModel::E5Small,
955 ] {
956 let builder = EmbeddingEngineBuilder::new().model(model);
957 assert_eq!(builder.config.model, model);
958 }
959 }
960
961 #[test]
962 fn test_builder_max_batch_size_one() {
963 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
964 assert_eq!(builder.config.max_batch_size, 1);
965 }
966
967 #[test]
968 fn test_builder_num_threads_zero() {
969 let builder = EmbeddingEngineBuilder::new().num_threads(0);
970 assert_eq!(builder.config.num_threads, Some(0));
971 }
972
973 #[tokio::test]
979 async fn test_engine_getters_when_model_cached() {
980 let config = ModelConfig::new(EmbeddingModel::MiniLM);
981 match EmbeddingEngine::new(config).await {
982 Ok(engine) => {
983 assert_eq!(engine.dimension(), 384);
984 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
985 let _ = format!("{:?}", engine);
988 let ms = engine.estimate_time_ms(10, 50);
990 assert!(ms >= 0.0);
991 }
992 Err(_) => {
993 }
995 }
996 }
997
998 #[tokio::test]
1001 async fn test_engine_embed_empty_batch_when_cached() {
1002 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1003 match EmbeddingEngine::new(config).await {
1004 Ok(engine) => {
1005 let result = engine.embed_raw(&[]).await;
1006 assert!(result.is_ok());
1007 assert!(result.unwrap().is_empty());
1008 }
1009 Err(_) => {}
1010 }
1011 }
1012}