1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47pub struct EmbeddingEngine {
54 sessions: Vec<Arc<Mutex<Session>>>,
58 next_session: AtomicUsize,
60 processor: Arc<BatchProcessor>,
62 config: ModelConfig,
64 dimension: usize,
66}
67
68impl EmbeddingEngine {
69 #[instrument(skip_all, fields(model = %config.model))]
73 pub async fn new(config: ModelConfig) -> Result<Self> {
74 info!(
75 "Initializing ONNX embedding engine with model: {}",
76 config.model
77 );
78
79 let (tokenizer_path, onnx_path) = Self::download_model_files(&config).await?;
81
82 info!("Loading tokenizer from {:?}", tokenizer_path);
84 let tokenizer = Tokenizer::from_file(&tokenizer_path)
85 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
86
87 info!("Loading ONNX model from {:?}", onnx_path);
90 let num_threads = config.num_threads.unwrap_or(4);
91 let pool_size = config.session_pool_size.max(1);
92 let onnx_path_clone = onnx_path.clone();
93 let sessions: Vec<Arc<Mutex<Session>>> =
94 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
95 (0..pool_size)
96 .map(|_| {
97 let s = Session::builder()
98 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
99 .with_optimization_level(GraphOptimizationLevel::Level3)
100 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
101 .with_intra_threads(num_threads)
102 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
103 .commit_from_file(&onnx_path_clone)
104 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
105 Ok(Arc::new(Mutex::new(s)))
106 })
107 .collect()
108 })
109 .await
110 .map_err(|e| {
111 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
112 })??;
113
114 let dimension = config.model.dimension();
115 let processor = Arc::new(BatchProcessor::new(
116 tokenizer,
117 config.model,
118 config.max_batch_size,
119 ));
120
121 info!(
122 "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
123 config.model, dimension, num_threads, pool_size
124 );
125
126 Ok(Self {
127 sessions,
128 next_session: AtomicUsize::new(0),
129 processor,
130 config,
131 dimension,
132 })
133 }
134
135 #[instrument(skip_all, fields(model = %config.model))]
140 async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf)> {
141 let model_id = config.model.model_id();
142 let onnx_repo_id = config.model.onnx_repo_id();
143 let onnx_filename = config.model.onnx_filename();
144
145 info!(
146 "Resolving model files: tokenizer={}, onnx={}@{}",
147 model_id, onnx_filename, onnx_repo_id
148 );
149
150 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
151 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
152
153 let onnx_subdir = onnx_cache_dir.join("onnx");
155 std::fs::create_dir_all(&onnx_subdir)?;
156
157 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
158 let onnx_basename = Path::new(onnx_filename)
160 .file_name()
161 .and_then(|s| s.to_str())
162 .unwrap_or("model_quantized.onnx");
163 let local_onnx = onnx_subdir.join(onnx_basename);
164
165 let tokenizer_needs_download = !local_tokenizer.exists();
167 let onnx_needs_download = !local_onnx.exists();
168
169 if tokenizer_needs_download || onnx_needs_download {
170 let model_id_owned = model_id.to_string();
171 let onnx_repo_id_owned = onnx_repo_id.to_string();
172 let onnx_filename_owned = onnx_filename.to_string();
173 let tokenizer_cache = tokenizer_cache_dir.clone();
174 let onnx_cache = onnx_cache_dir.clone();
175
176 tokio::task::spawn_blocking(move || {
177 if !tokenizer_cache.join("tokenizer.json").exists() {
178 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
179 .map_err(|e| {
180 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
181 })?;
182 }
183 if !onnx_cache.join(&onnx_filename_owned).exists() {
184 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
185 .map_err(|e| {
186 InferenceError::HubError(format!(
187 "Failed to download ONNX model: {}",
188 e
189 ))
190 })?;
191 }
192 Ok::<_, InferenceError>(())
193 })
194 .await
195 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
196 } else {
197 info!("All model files found in local cache");
198 }
199
200 let final_onnx = onnx_cache_dir.join(onnx_filename);
202
203 info!(
204 "Model files ready: tokenizer={:?}, onnx={:?}",
205 local_tokenizer, final_onnx
206 );
207 Ok((local_tokenizer, final_onnx))
208 }
209
210 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
212 let base = std::env::var("HF_HOME")
213 .map(PathBuf::from)
214 .unwrap_or_else(|_| {
215 let home = std::env::var("HOME").unwrap_or_else(|_| {
216 warn!("HOME environment variable not set, using /tmp for model cache");
217 "/tmp".to_string()
218 });
219 PathBuf::from(home).join(".cache").join("huggingface")
220 });
221 let dir = base.join("dakera").join(model_id.replace('/', "--"));
222 std::fs::create_dir_all(&dir)?;
223 Ok(dir)
224 }
225
226 pub fn download_hf_file_pub(
232 model_id: &str,
233 filename: &str,
234 cache_dir: &Path,
235 ) -> std::result::Result<PathBuf, String> {
236 Self::download_hf_file(model_id, filename, cache_dir)
237 }
238
239 fn download_hf_file(
240 model_id: &str,
241 filename: &str,
242 cache_dir: &Path,
243 ) -> std::result::Result<PathBuf, String> {
244 let file_path = cache_dir.join(filename);
246 if file_path.exists() {
247 info!("Cached: {}/{}", model_id, filename);
248 return Ok(file_path);
249 }
250
251 if let Some(parent) = file_path.parent() {
253 std::fs::create_dir_all(parent)
254 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
255 }
256
257 let url = format!(
258 "https://huggingface.co/{}/resolve/main/{}",
259 model_id, filename
260 );
261 info!("Downloading: {}", url);
262
263 let hf_token = std::env::var("HF_TOKEN")
265 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
266 .ok();
267 if hf_token.is_some() {
268 info!("Using HuggingFace auth token for download");
269 }
270
271 let agent = ureq::AgentBuilder::new()
273 .redirects(0)
274 .timeout(std::time::Duration::from_secs(300))
275 .build();
276
277 let mut current_url = url.clone();
278 let mut redirects = 0;
279 let max_redirects = 10;
280
281 let response = loop {
282 let mut req = agent.get(¤t_url);
283 if let Some(ref token) = hf_token {
284 req = req.set("Authorization", &format!("Bearer {}", token));
285 }
286 let resp = req.call();
287
288 let r = match resp {
289 Ok(r) => r,
290 Err(ureq::Error::Status(_status, r)) => r,
291 Err(e) => return Err(format!("{}: {}", filename, e)),
292 };
293
294 let status = r.status();
295 if (200..300).contains(&status) {
296 break r;
297 } else if (300..400).contains(&status) {
298 redirects += 1;
299 if redirects > max_redirects {
300 return Err(format!("{}: too many redirects", filename));
301 }
302 let location = r
303 .header("location")
304 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
305 .to_string();
306
307 current_url = if location.starts_with('/') {
309 let parsed = url::Url::parse(¤t_url)
310 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
311 let host = parsed.host_str().ok_or_else(|| {
312 format!("{}: redirect URL missing host: {}", filename, current_url)
313 })?;
314 format!("{}://{}{}", parsed.scheme(), host, location)
315 } else {
316 location
317 };
318 info!("Redirect {} → {}", redirects, current_url);
319 } else {
320 return Err(format!("{}: HTTP {}", filename, status));
321 }
322 };
323
324 let mut bytes = Vec::new();
325 response
326 .into_reader()
327 .take(500_000_000) .read_to_end(&mut bytes)
329 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
330
331 std::fs::write(&file_path, &bytes)
332 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
333
334 info!("Downloaded {} ({} bytes)", filename, bytes.len());
335 Ok(file_path)
336 }
337
338 pub fn dimension(&self) -> usize {
340 self.dimension
341 }
342
343 pub fn model(&self) -> EmbeddingModel {
345 self.config.model
346 }
347
348 pub fn pool_size(&self) -> usize {
350 self.sessions.len()
351 }
352
353 #[instrument(skip(self, text), fields(text_len = text.len()))]
357 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
358 let texts = vec![text.to_string()];
359 let prepared = self.processor.prepare_texts(&texts, true);
360 let embeddings = self.embed_batch_internal(&prepared).await?;
361 embeddings.into_iter().next().ok_or_else(|| {
362 InferenceError::InferenceError("No embedding returned for query".to_string())
363 })
364 }
365
366 #[instrument(skip(self, texts), fields(count = texts.len()))]
370 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
371 let prepared = self.processor.prepare_texts(texts, true);
372 self.embed_batch_internal(&prepared).await
373 }
374
375 #[instrument(skip(self, text), fields(text_len = text.len()))]
379 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
380 let texts = vec![text.to_string()];
381 let prepared = self.processor.prepare_texts(&texts, false);
382 let embeddings = self.embed_batch_internal(&prepared).await?;
383 embeddings.into_iter().next().ok_or_else(|| {
384 InferenceError::InferenceError("No embedding returned for document".to_string())
385 })
386 }
387
388 #[instrument(skip(self, texts), fields(count = texts.len()))]
392 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
393 let prepared = self.processor.prepare_texts(texts, false);
394 self.embed_batch_internal(&prepared).await
395 }
396
397 #[instrument(skip(self, texts), fields(count = texts.len()))]
399 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
400 self.embed_batch_internal(texts).await
401 }
402
403 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
411 if texts.is_empty() {
412 return Ok(vec![]);
413 }
414
415 let batches: Vec<Vec<String>> = self
416 .processor
417 .split_into_batches(texts)
418 .into_iter()
419 .map(|b| b.to_vec())
420 .collect();
421
422 let pool_len = self.sessions.len();
423 let normalize = self.config.model.normalize_embeddings();
424 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
427
428 let mut handles = Vec::with_capacity(batches.len());
430 for (i, batch_owned) in batches.into_iter().enumerate() {
431 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
432 let processor = Arc::clone(&self.processor);
433 handles.push(tokio::task::spawn_blocking(move || {
434 let mut session_guard = session.lock();
435 Self::process_batch_blocking(
436 &batch_owned,
437 &mut session_guard,
438 &processor,
439 normalize,
440 )
441 }));
442 }
443
444 let mut all_embeddings = Vec::with_capacity(texts.len());
445 for handle in handles {
446 let batch_embeddings = handle.await.map_err(|e| {
447 InferenceError::InferenceError(format!("Inference task panicked: {}", e))
448 })??;
449 all_embeddings.extend(batch_embeddings);
450 }
451
452 Ok(all_embeddings)
453 }
454
455 fn process_batch_blocking(
459 texts: &[String],
460 session: &mut Session,
461 processor: &BatchProcessor,
462 normalize: bool,
463 ) -> Result<Vec<Vec<f32>>> {
464 let prepared = processor.tokenize_batch(texts)?;
466 let batch_size = prepared.batch_size;
467 let seq_len = prepared.seq_len;
468
469 let attention_mask_flat = prepared.attention_mask.clone();
471
472 let input_ids_tensor =
474 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
475 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
476 let attention_mask_tensor =
477 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
478 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
479 let token_type_ids_tensor =
480 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
481 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
482
483 let outputs = session
485 .run(inputs![
486 "input_ids" => input_ids_tensor,
487 "attention_mask" => attention_mask_tensor,
488 "token_type_ids" => token_type_ids_tensor
489 ])
490 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
491
492 let (ort_shape, lhs_slice) = outputs[0]
496 .try_extract_tensor::<f32>()
497 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
498
499 if ort_shape.len() != 3 {
500 return Err(InferenceError::InferenceError(format!(
501 "Expected 3D last_hidden_state, got {} dims",
502 ort_shape.len()
503 )));
504 }
505 let hidden_size = ort_shape[2] as usize;
506
507 let mut embeddings = mean_pooling(
509 lhs_slice,
510 batch_size,
511 seq_len,
512 hidden_size,
513 &attention_mask_flat,
514 );
515
516 if normalize {
518 normalize_embeddings(&mut embeddings);
519 }
520
521 debug!(
522 "Generated {} embeddings of dimension {}",
523 embeddings.len(),
524 embeddings.first().map(|e| e.len()).unwrap_or(0)
525 );
526
527 Ok(embeddings)
528 }
529
530 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
532 let tokens_per_text =
534 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
535 let total_tokens = tokens_per_text * text_count as f64;
536 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
537 (total_tokens / tokens_per_second) * 1000.0
538 }
539}
540
541impl std::fmt::Debug for EmbeddingEngine {
542 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543 f.debug_struct("EmbeddingEngine")
544 .field("model", &self.config.model)
545 .field("dimension", &self.dimension)
546 .field("max_batch_size", &self.config.max_batch_size)
547 .field("session_pool_size", &self.sessions.len())
548 .finish()
549 }
550}
551
552pub struct EmbeddingEngineBuilder {
554 config: ModelConfig,
555}
556
557impl EmbeddingEngineBuilder {
558 pub fn new() -> Self {
560 Self {
561 config: ModelConfig::default(),
562 }
563 }
564
565 pub fn model(mut self, model: EmbeddingModel) -> Self {
567 self.config.model = model;
568 self
569 }
570
571 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
573 self.config.cache_dir = Some(dir.into());
574 self
575 }
576
577 pub fn max_batch_size(mut self, size: usize) -> Self {
579 self.config.max_batch_size = size;
580 self
581 }
582
583 pub fn use_gpu(mut self, enable: bool) -> Self {
585 self.config.use_gpu = enable;
586 self
587 }
588
589 pub fn num_threads(mut self, threads: usize) -> Self {
591 self.config.num_threads = Some(threads);
592 self
593 }
594
595 pub fn session_pool_size(mut self, size: usize) -> Self {
597 self.config.session_pool_size = size.max(1);
598 self
599 }
600
601 pub async fn build(self) -> Result<EmbeddingEngine> {
603 EmbeddingEngine::new(self.config).await
604 }
605}
606
607impl Default for EmbeddingEngineBuilder {
608 fn default() -> Self {
609 Self::new()
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_estimate_time() {
619 let config = ModelConfig::new(EmbeddingModel::MiniLM);
620 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
621 assert!(tokens_per_second > 0.0);
622 }
623
624 #[test]
625 fn test_builder() {
626 let builder = EmbeddingEngineBuilder::new()
627 .model(EmbeddingModel::BgeSmall)
628 .max_batch_size(64)
629 .use_gpu(false);
630
631 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
632 assert_eq!(builder.config.max_batch_size, 64);
633 assert!(!builder.config.use_gpu);
634 }
635
636 #[test]
642 fn test_model_cache_dir_with_hf_home() {
643 use std::sync::Mutex;
644 static ENV_LOCK: Mutex<()> = Mutex::new(());
645 let _guard = ENV_LOCK.lock().unwrap();
646
647 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
648 std::env::set_var("HF_HOME", &tmp);
649 let result = EmbeddingEngine::model_cache_dir("org/my-model");
650 std::env::remove_var("HF_HOME");
651
652 let path = result.unwrap();
653 assert!(
654 path.starts_with(&tmp),
655 "expected path under {tmp:?}, got {path:?}"
656 );
657 assert!(
658 path.to_str().unwrap().contains("org--my-model"),
659 "model_id separator not applied: {path:?}"
660 );
661 }
662
663 #[test]
664 fn test_model_cache_dir_contains_dakera_subdir() {
665 let path =
666 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
667 let s = path.to_str().unwrap();
668 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
669 assert!(
670 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
671 "expected transformed model id in path: {s}"
672 );
673 }
674
675 #[test]
676 fn test_model_cache_dir_creates_directory() {
677 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
678 assert!(dir.exists(), "model_cache_dir should create the directory");
679 }
680
681 #[test]
684 fn test_download_hf_file_returns_path_when_already_cached() {
685 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
686 std::fs::create_dir_all(&tmp).unwrap();
687 let file_path = tmp.join("config.json");
688 std::fs::write(&file_path, b"{}").unwrap();
689
690 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
691 assert!(result.is_ok());
692 assert_eq!(result.unwrap(), file_path);
693 }
694
695 #[test]
696 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
697 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
698 let onnx_dir = tmp.join("onnx");
699 std::fs::create_dir_all(&onnx_dir).unwrap();
700 let onnx_path = onnx_dir.join("model_quantized.onnx");
701 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
702
703 let result = EmbeddingEngine::download_hf_file(
705 "Xenova/all-MiniLM-L6-v2",
706 "onnx/model_quantized.onnx",
707 &tmp,
708 );
709 assert!(result.is_ok());
710 assert_eq!(result.unwrap(), onnx_path);
711 }
712
713 #[test]
716 fn test_builder_default_impl() {
717 let b1 = EmbeddingEngineBuilder::new();
718 let b2 = EmbeddingEngineBuilder::default();
719 assert_eq!(b1.config.model, b2.config.model);
720 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
721 }
722
723 #[test]
724 fn test_builder_model_field() {
725 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
726 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
727 }
728
729 #[test]
730 fn test_builder_cache_dir() {
731 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
732 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
733 }
734
735 #[test]
736 fn test_builder_max_batch_size() {
737 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
738 assert_eq!(builder.config.max_batch_size, 128);
739 }
740
741 #[test]
742 fn test_builder_use_gpu_true() {
743 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
744 assert!(builder.config.use_gpu);
745 }
746
747 #[test]
748 fn test_builder_use_gpu_false() {
749 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
750 assert!(!builder.config.use_gpu);
751 }
752
753 #[test]
754 fn test_builder_num_threads() {
755 let builder = EmbeddingEngineBuilder::new().num_threads(4);
756 assert_eq!(builder.config.num_threads, Some(4));
757 }
758
759 #[test]
760 fn test_builder_chain_all_fields() {
761 let builder = EmbeddingEngineBuilder::new()
762 .model(EmbeddingModel::BgeSmall)
763 .cache_dir("/cache")
764 .max_batch_size(16)
765 .use_gpu(false)
766 .num_threads(2);
767
768 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
769 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
770 assert_eq!(builder.config.max_batch_size, 16);
771 assert!(!builder.config.use_gpu);
772 assert_eq!(builder.config.num_threads, Some(2));
773 }
774
775 #[test]
778 fn test_estimate_time_zero_count() {
779 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
780 let estimate = (0.0 / tps) * 1000.0;
781 assert_eq!(estimate, 0.0);
782 }
783
784 #[test]
785 fn test_estimate_time_formula_cpu() {
786 let model = EmbeddingModel::MiniLM;
789 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
790 let total_tokens = tokens_per_text * 10.0;
791 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
792 assert!(
793 (estimate - 50.0).abs() < 1e-6,
794 "expected 50.0ms, got {estimate}"
795 );
796 }
797
798 #[test]
799 fn test_estimate_time_capped_at_max_seq_length() {
800 let model = EmbeddingModel::MiniLM;
801 let avg_len = 100_000;
802 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
803 assert_eq!(tokens_per_text, 256.0);
804 }
805
806 #[test]
809 fn test_model_config_new() {
810 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
811 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
812 assert_eq!(cfg.max_batch_size, 8);
813 assert!(!cfg.use_gpu);
814 assert!(cfg.cache_dir.is_none());
815 assert!(cfg.num_threads.is_none());
816 }
817
818 #[test]
819 fn test_model_config_default() {
820 let cfg = ModelConfig::default();
821 assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
822 assert_eq!(cfg.max_batch_size, 8);
823 assert!(!cfg.use_gpu);
824 }
825
826 #[test]
827 fn test_model_config_with_cache_dir() {
828 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
829 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
830 }
831
832 #[test]
833 fn test_model_config_with_max_batch_size() {
834 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
835 assert_eq!(cfg.max_batch_size, 64);
836 }
837
838 #[test]
839 fn test_model_config_with_gpu() {
840 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
841 assert!(cfg.use_gpu);
842 }
843
844 #[test]
845 fn test_model_config_with_num_threads() {
846 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
847 assert_eq!(cfg.num_threads, Some(8));
848 }
849
850 #[test]
851 fn test_model_config_chained_builder() {
852 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
853 .with_cache_dir("/cache")
854 .with_max_batch_size(16)
855 .with_gpu(false)
856 .with_num_threads(4);
857 assert_eq!(cfg.model, EmbeddingModel::E5Small);
858 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
859 assert_eq!(cfg.max_batch_size, 16);
860 assert!(!cfg.use_gpu);
861 assert_eq!(cfg.num_threads, Some(4));
862 }
863
864 #[test]
868 fn test_model_cache_dir_no_home_fallback() {
869 use std::sync::Mutex;
870 static ENV_LOCK: Mutex<()> = Mutex::new(());
871 let _guard = ENV_LOCK.lock().unwrap();
872
873 let saved_home = std::env::var("HOME").ok();
875 let saved_hf = std::env::var("HF_HOME").ok();
876 unsafe {
877 std::env::remove_var("HOME");
878 std::env::remove_var("HF_HOME");
879 }
880
881 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
882
883 if let Some(h) = saved_home {
885 unsafe { std::env::set_var("HOME", h) };
886 }
887 if let Some(h) = saved_hf {
888 unsafe { std::env::set_var("HF_HOME", h) };
889 }
890
891 let path = result.unwrap();
892 assert!(
894 path.starts_with("/tmp"),
895 "expected path under /tmp, got {path:?}"
896 );
897 }
898
899 #[test]
900 fn test_model_cache_dir_deep_model_id() {
901 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
902 let s = path.to_str().unwrap();
903 assert!(
905 s.contains("org--sub--model-name-with-dashes"),
906 "expected transformed path, got: {s}"
907 );
908 }
909
910 #[test]
911 fn test_model_cache_dir_minilm_model_id() {
912 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
913 let s = path.to_str().unwrap();
914 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
915 }
916
917 #[test]
918 fn test_model_cache_dir_bge_model_id() {
919 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
920 let s = path.to_str().unwrap();
921 assert!(s.contains("BAAI--bge-small-en-v1.5"));
922 }
923
924 #[test]
925 fn test_model_cache_dir_e5_model_id() {
926 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
927 let s = path.to_str().unwrap();
928 assert!(s.contains("intfloat--e5-small-v2"));
929 }
930
931 #[test]
934 fn test_download_hf_file_pytorch_bin_cached() {
935 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
936 std::fs::create_dir_all(&tmp).unwrap();
937 let model_path = tmp.join("pytorch_model.bin");
938 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
939
940 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
941 assert!(result.is_ok());
942 assert_eq!(result.unwrap(), model_path);
943 }
944
945 #[test]
946 fn test_download_hf_file_tokenizer_cached() {
947 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
948 std::fs::create_dir_all(&tmp).unwrap();
949 let tok_path = tmp.join("tokenizer.json");
950 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
951
952 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
953 assert!(result.is_ok());
954 assert_eq!(result.unwrap(), tok_path);
955 }
956
957 #[test]
958 fn test_download_hf_file_config_json_cached() {
959 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
960 std::fs::create_dir_all(&tmp).unwrap();
961 let cfg_path = tmp.join("config.json");
962 std::fs::write(&cfg_path, b"{}").unwrap();
963
964 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
965 assert!(result.is_ok());
966 assert_eq!(result.unwrap(), cfg_path);
967 }
968
969 #[tokio::test]
975 #[allow(clippy::await_holding_lock)]
976 async fn test_new_fails_with_invalid_tokenizer_json() {
977 use std::sync::Mutex;
978 static ENV_LOCK: Mutex<()> = Mutex::new(());
979 let _guard = ENV_LOCK.lock().unwrap();
980
981 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
983 let model_dir = tmp
984 .join("dakera")
985 .join("sentence-transformers--all-MiniLM-L6-v2");
986 std::fs::create_dir_all(&model_dir).unwrap();
987 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
989 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
991 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
992
993 unsafe { std::env::set_var("HF_HOME", &tmp) };
994
995 let config = ModelConfig::new(EmbeddingModel::MiniLM);
996 let result = EmbeddingEngine::new(config).await;
997
998 unsafe { std::env::remove_var("HF_HOME") };
999
1000 assert!(
1002 result.is_err(),
1003 "expected Err from new() with invalid tokenizer, got Ok"
1004 );
1005 }
1006
1007 #[test]
1010 fn test_builder_with_all_models() {
1011 for model in [
1012 EmbeddingModel::MiniLM,
1013 EmbeddingModel::BgeSmall,
1014 EmbeddingModel::E5Small,
1015 ] {
1016 let builder = EmbeddingEngineBuilder::new().model(model);
1017 assert_eq!(builder.config.model, model);
1018 }
1019 }
1020
1021 #[test]
1022 fn test_builder_max_batch_size_one() {
1023 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1024 assert_eq!(builder.config.max_batch_size, 1);
1025 }
1026
1027 #[test]
1028 fn test_builder_num_threads_zero() {
1029 let builder = EmbeddingEngineBuilder::new().num_threads(0);
1030 assert_eq!(builder.config.num_threads, Some(0));
1031 }
1032
1033 #[tokio::test]
1039 async fn test_engine_getters_when_model_cached() {
1040 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1041 match EmbeddingEngine::new(config).await {
1042 Ok(engine) => {
1043 assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1044 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1045 let _ = format!("{:?}", engine);
1048 let ms = engine.estimate_time_ms(10, 50);
1050 assert!(ms >= 0.0);
1051 }
1052 Err(_) => {
1053 }
1055 }
1056 }
1057
1058 #[tokio::test]
1061 async fn test_engine_embed_empty_batch_when_cached() {
1062 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1063 if let Ok(engine) = EmbeddingEngine::new(config).await {
1064 let result = engine.embed_raw(&[]).await;
1065 assert!(result.is_ok());
1066 assert!(result.unwrap().is_empty());
1067 }
1068 }
1069
1070 #[test]
1073 fn test_session_pool_default_is_4() {
1074 let config = ModelConfig::default();
1077 let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1078 .ok()
1079 .and_then(|v| v.parse::<usize>().ok())
1080 .filter(|&n| n >= 1)
1081 .unwrap_or(4);
1082 assert_eq!(config.session_pool_size, expected);
1083 }
1084
1085 #[test]
1086 fn test_session_pool_size_builder_roundtrip() {
1087 let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1088 assert_eq!(builder.config.session_pool_size, 8);
1089 }
1090
1091 #[test]
1092 fn test_session_pool_size_min_enforced() {
1093 let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1094 assert_eq!(
1095 builder.config.session_pool_size, 1,
1096 "pool size 0 must clamp to 1"
1097 );
1098 }
1099
1100 #[test]
1101 fn test_model_config_with_session_pool_size() {
1102 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1103 assert_eq!(cfg.session_pool_size, 2);
1104 }
1105
1106 #[tokio::test]
1108 async fn test_engine_pool_size_matches_config_when_cached() {
1109 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1110 if let Ok(engine) = EmbeddingEngine::new(config).await {
1111 assert_eq!(
1112 engine.pool_size(),
1113 2,
1114 "engine should hold exactly 2 sessions"
1115 );
1116 }
1117 }
1118
1119 #[test]
1123 fn test_round_robin_index_stays_in_bounds() {
1124 let pool_len = 4_usize;
1125 let counter = AtomicUsize::new(0);
1126 for expected_idx in 0..100_usize {
1127 let start = counter.fetch_add(1, Ordering::Relaxed);
1128 let slot = start % pool_len;
1129 assert!(slot < pool_len);
1130 assert_eq!(slot, expected_idx % pool_len);
1131 }
1132 }
1133
1134 #[test]
1136 fn test_round_robin_pool_size_one() {
1137 let pool_len = 1_usize;
1138 let counter = AtomicUsize::new(0);
1139 for _ in 0..50 {
1140 let start = counter.fetch_add(1, Ordering::Relaxed);
1141 assert_eq!(start % pool_len, 0);
1142 }
1143 }
1144
1145 #[tokio::test]
1147 async fn test_embed_empty_does_not_advance_pool_counter() {
1148 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1149 if let Ok(engine) = EmbeddingEngine::new(config).await {
1150 let result = engine.embed_raw(&[]).await;
1151 assert!(result.is_ok());
1152 assert!(result.unwrap().is_empty());
1153 assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1155 }
1156 }
1157}