1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47pub struct EmbeddingEngine {
54 sessions: Vec<Arc<Mutex<Session>>>,
58 next_session: AtomicUsize,
60 processor: Arc<BatchProcessor>,
62 config: ModelConfig,
64 dimension: usize,
66}
67
68impl EmbeddingEngine {
69 #[instrument(skip_all, fields(model = %config.model))]
73 pub async fn new(config: ModelConfig) -> Result<Self> {
74 info!(
75 "Initializing ONNX embedding engine with model: {}",
76 config.model
77 );
78
79 let (tokenizer_path, onnx_path) = Self::download_model_files(&config).await?;
81
82 info!("Loading tokenizer from {:?}", tokenizer_path);
84 let tokenizer = Tokenizer::from_file(&tokenizer_path)
85 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
86
87 info!("Loading ONNX model from {:?}", onnx_path);
90 let num_threads = config.num_threads.unwrap_or(4);
91 let pool_size = config.session_pool_size.max(1);
92 let onnx_path_clone = onnx_path.clone();
93 let sessions: Vec<Arc<Mutex<Session>>> =
94 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
95 (0..pool_size)
96 .map(|_| {
97 let s = Session::builder()
98 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
99 .with_optimization_level(GraphOptimizationLevel::Level3)
100 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
101 .with_intra_threads(num_threads)
102 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
103 .commit_from_file(&onnx_path_clone)
104 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
105 Ok(Arc::new(Mutex::new(s)))
106 })
107 .collect()
108 })
109 .await
110 .map_err(|e| {
111 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
112 })??;
113
114 let dimension = config.model.dimension();
115 let processor = Arc::new(BatchProcessor::new(
116 tokenizer,
117 config.model,
118 config.max_batch_size,
119 ));
120
121 info!(
122 "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
123 config.model, dimension, num_threads, pool_size
124 );
125
126 Ok(Self {
127 sessions,
128 next_session: AtomicUsize::new(0),
129 processor,
130 config,
131 dimension,
132 })
133 }
134
135 #[instrument(skip_all, fields(model = %config.model))]
140 async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf)> {
141 let model_id = config.model.model_id();
142 let onnx_repo_id = config.model.onnx_repo_id();
143 let onnx_filename = config.model.onnx_filename();
144
145 info!(
146 "Resolving model files: tokenizer={}, onnx={}@{}",
147 model_id, onnx_filename, onnx_repo_id
148 );
149
150 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
151 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
152
153 let onnx_subdir = onnx_cache_dir.join("onnx");
155 std::fs::create_dir_all(&onnx_subdir)?;
156
157 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
158 let onnx_basename = Path::new(onnx_filename)
160 .file_name()
161 .and_then(|s| s.to_str())
162 .unwrap_or("model_quantized.onnx");
163 let local_onnx = onnx_subdir.join(onnx_basename);
164
165 let tokenizer_needs_download = !local_tokenizer.exists();
167 let onnx_needs_download = !local_onnx.exists();
168
169 if tokenizer_needs_download || onnx_needs_download {
170 let model_id_owned = model_id.to_string();
171 let onnx_repo_id_owned = onnx_repo_id.to_string();
172 let onnx_filename_owned = onnx_filename.to_string();
173 let tokenizer_cache = tokenizer_cache_dir.clone();
174 let onnx_cache = onnx_cache_dir.clone();
175
176 tokio::task::spawn_blocking(move || {
177 if !tokenizer_cache.join("tokenizer.json").exists() {
178 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
179 .map_err(|e| {
180 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
181 })?;
182 }
183 if !onnx_cache.join(&onnx_filename_owned).exists() {
184 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
185 .map_err(|e| {
186 InferenceError::HubError(format!(
187 "Failed to download ONNX model: {}",
188 e
189 ))
190 })?;
191 }
192 Ok::<_, InferenceError>(())
193 })
194 .await
195 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
196 } else {
197 info!("All model files found in local cache");
198 }
199
200 let final_onnx = onnx_cache_dir.join(onnx_filename);
202
203 info!(
204 "Model files ready: tokenizer={:?}, onnx={:?}",
205 local_tokenizer, final_onnx
206 );
207 Ok((local_tokenizer, final_onnx))
208 }
209
210 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
212 let base = std::env::var("HF_HOME")
213 .map(PathBuf::from)
214 .unwrap_or_else(|_| {
215 let home = std::env::var("HOME").unwrap_or_else(|_| {
216 warn!("HOME environment variable not set, using /tmp for model cache");
217 "/tmp".to_string()
218 });
219 PathBuf::from(home).join(".cache").join("huggingface")
220 });
221 let dir = base.join("dakera").join(model_id.replace('/', "--"));
222 std::fs::create_dir_all(&dir)?;
223 Ok(dir)
224 }
225
226 pub fn download_hf_file_pub(
232 model_id: &str,
233 filename: &str,
234 cache_dir: &Path,
235 ) -> std::result::Result<PathBuf, String> {
236 Self::download_hf_file(model_id, filename, cache_dir)
237 }
238
239 fn download_hf_file(
240 model_id: &str,
241 filename: &str,
242 cache_dir: &Path,
243 ) -> std::result::Result<PathBuf, String> {
244 let file_path = cache_dir.join(filename);
246 if file_path.exists() {
247 info!("Cached: {}/{}", model_id, filename);
248 return Ok(file_path);
249 }
250
251 if let Some(parent) = file_path.parent() {
253 std::fs::create_dir_all(parent)
254 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
255 }
256
257 let url = format!(
258 "https://huggingface.co/{}/resolve/main/{}",
259 model_id, filename
260 );
261 info!("Downloading: {}", url);
262
263 let agent = ureq::AgentBuilder::new()
265 .redirects(0)
266 .timeout(std::time::Duration::from_secs(300))
267 .build();
268
269 let mut current_url = url.clone();
270 let mut redirects = 0;
271 let max_redirects = 10;
272
273 let response = loop {
274 let resp = agent.get(¤t_url).call();
275
276 let r = match resp {
277 Ok(r) => r,
278 Err(ureq::Error::Status(_status, r)) => r,
279 Err(e) => return Err(format!("{}: {}", filename, e)),
280 };
281
282 let status = r.status();
283 if (200..300).contains(&status) {
284 break r;
285 } else if (300..400).contains(&status) {
286 redirects += 1;
287 if redirects > max_redirects {
288 return Err(format!("{}: too many redirects", filename));
289 }
290 let location = r
291 .header("location")
292 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
293 .to_string();
294
295 current_url = if location.starts_with('/') {
297 let parsed = url::Url::parse(¤t_url)
298 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
299 let host = parsed.host_str().ok_or_else(|| {
300 format!("{}: redirect URL missing host: {}", filename, current_url)
301 })?;
302 format!("{}://{}{}", parsed.scheme(), host, location)
303 } else {
304 location
305 };
306 info!("Redirect {} → {}", redirects, current_url);
307 } else {
308 return Err(format!("{}: HTTP {}", filename, status));
309 }
310 };
311
312 let mut bytes = Vec::new();
313 response
314 .into_reader()
315 .take(500_000_000) .read_to_end(&mut bytes)
317 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
318
319 std::fs::write(&file_path, &bytes)
320 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
321
322 info!("Downloaded {} ({} bytes)", filename, bytes.len());
323 Ok(file_path)
324 }
325
326 pub fn dimension(&self) -> usize {
328 self.dimension
329 }
330
331 pub fn model(&self) -> EmbeddingModel {
333 self.config.model
334 }
335
336 pub fn pool_size(&self) -> usize {
338 self.sessions.len()
339 }
340
341 #[instrument(skip(self, text), fields(text_len = text.len()))]
345 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
346 let texts = vec![text.to_string()];
347 let prepared = self.processor.prepare_texts(&texts, true);
348 let embeddings = self.embed_batch_internal(&prepared).await?;
349 embeddings.into_iter().next().ok_or_else(|| {
350 InferenceError::InferenceError("No embedding returned for query".to_string())
351 })
352 }
353
354 #[instrument(skip(self, texts), fields(count = texts.len()))]
358 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
359 let prepared = self.processor.prepare_texts(texts, true);
360 self.embed_batch_internal(&prepared).await
361 }
362
363 #[instrument(skip(self, text), fields(text_len = text.len()))]
367 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
368 let texts = vec![text.to_string()];
369 let prepared = self.processor.prepare_texts(&texts, false);
370 let embeddings = self.embed_batch_internal(&prepared).await?;
371 embeddings.into_iter().next().ok_or_else(|| {
372 InferenceError::InferenceError("No embedding returned for document".to_string())
373 })
374 }
375
376 #[instrument(skip(self, texts), fields(count = texts.len()))]
380 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
381 let prepared = self.processor.prepare_texts(texts, false);
382 self.embed_batch_internal(&prepared).await
383 }
384
385 #[instrument(skip(self, texts), fields(count = texts.len()))]
387 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
388 self.embed_batch_internal(texts).await
389 }
390
391 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
399 if texts.is_empty() {
400 return Ok(vec![]);
401 }
402
403 let batches: Vec<Vec<String>> = self
404 .processor
405 .split_into_batches(texts)
406 .into_iter()
407 .map(|b| b.to_vec())
408 .collect();
409
410 let pool_len = self.sessions.len();
411 let normalize = self.config.model.normalize_embeddings();
412 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
415
416 let mut handles = Vec::with_capacity(batches.len());
418 for (i, batch_owned) in batches.into_iter().enumerate() {
419 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
420 let processor = Arc::clone(&self.processor);
421 handles.push(tokio::task::spawn_blocking(move || {
422 let mut session_guard = session.lock();
423 Self::process_batch_blocking(
424 &batch_owned,
425 &mut session_guard,
426 &processor,
427 normalize,
428 )
429 }));
430 }
431
432 let mut all_embeddings = Vec::with_capacity(texts.len());
433 for handle in handles {
434 let batch_embeddings = handle.await.map_err(|e| {
435 InferenceError::InferenceError(format!("Inference task panicked: {}", e))
436 })??;
437 all_embeddings.extend(batch_embeddings);
438 }
439
440 Ok(all_embeddings)
441 }
442
443 fn process_batch_blocking(
447 texts: &[String],
448 session: &mut Session,
449 processor: &BatchProcessor,
450 normalize: bool,
451 ) -> Result<Vec<Vec<f32>>> {
452 let prepared = processor.tokenize_batch(texts)?;
454 let batch_size = prepared.batch_size;
455 let seq_len = prepared.seq_len;
456
457 let attention_mask_flat = prepared.attention_mask.clone();
459
460 let input_ids_tensor =
462 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
463 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
464 let attention_mask_tensor =
465 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
466 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
467 let token_type_ids_tensor =
468 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
469 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
470
471 let outputs = session
473 .run(inputs![
474 "input_ids" => input_ids_tensor,
475 "attention_mask" => attention_mask_tensor,
476 "token_type_ids" => token_type_ids_tensor
477 ])
478 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
479
480 let (ort_shape, lhs_slice) = outputs[0]
484 .try_extract_tensor::<f32>()
485 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
486
487 if ort_shape.len() != 3 {
488 return Err(InferenceError::InferenceError(format!(
489 "Expected 3D last_hidden_state, got {} dims",
490 ort_shape.len()
491 )));
492 }
493 let hidden_size = ort_shape[2] as usize;
494
495 let mut embeddings = mean_pooling(
497 lhs_slice,
498 batch_size,
499 seq_len,
500 hidden_size,
501 &attention_mask_flat,
502 );
503
504 if normalize {
506 normalize_embeddings(&mut embeddings);
507 }
508
509 debug!(
510 "Generated {} embeddings of dimension {}",
511 embeddings.len(),
512 embeddings.first().map(|e| e.len()).unwrap_or(0)
513 );
514
515 Ok(embeddings)
516 }
517
518 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
520 let tokens_per_text =
522 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
523 let total_tokens = tokens_per_text * text_count as f64;
524 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
525 (total_tokens / tokens_per_second) * 1000.0
526 }
527}
528
529impl std::fmt::Debug for EmbeddingEngine {
530 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
531 f.debug_struct("EmbeddingEngine")
532 .field("model", &self.config.model)
533 .field("dimension", &self.dimension)
534 .field("max_batch_size", &self.config.max_batch_size)
535 .field("session_pool_size", &self.sessions.len())
536 .finish()
537 }
538}
539
540pub struct EmbeddingEngineBuilder {
542 config: ModelConfig,
543}
544
545impl EmbeddingEngineBuilder {
546 pub fn new() -> Self {
548 Self {
549 config: ModelConfig::default(),
550 }
551 }
552
553 pub fn model(mut self, model: EmbeddingModel) -> Self {
555 self.config.model = model;
556 self
557 }
558
559 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
561 self.config.cache_dir = Some(dir.into());
562 self
563 }
564
565 pub fn max_batch_size(mut self, size: usize) -> Self {
567 self.config.max_batch_size = size;
568 self
569 }
570
571 pub fn use_gpu(mut self, enable: bool) -> Self {
573 self.config.use_gpu = enable;
574 self
575 }
576
577 pub fn num_threads(mut self, threads: usize) -> Self {
579 self.config.num_threads = Some(threads);
580 self
581 }
582
583 pub fn session_pool_size(mut self, size: usize) -> Self {
585 self.config.session_pool_size = size.max(1);
586 self
587 }
588
589 pub async fn build(self) -> Result<EmbeddingEngine> {
591 EmbeddingEngine::new(self.config).await
592 }
593}
594
595impl Default for EmbeddingEngineBuilder {
596 fn default() -> Self {
597 Self::new()
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn test_estimate_time() {
607 let config = ModelConfig::new(EmbeddingModel::MiniLM);
608 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
609 assert!(tokens_per_second > 0.0);
610 }
611
612 #[test]
613 fn test_builder() {
614 let builder = EmbeddingEngineBuilder::new()
615 .model(EmbeddingModel::BgeSmall)
616 .max_batch_size(64)
617 .use_gpu(false);
618
619 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
620 assert_eq!(builder.config.max_batch_size, 64);
621 assert!(!builder.config.use_gpu);
622 }
623
624 #[test]
630 fn test_model_cache_dir_with_hf_home() {
631 use std::sync::Mutex;
632 static ENV_LOCK: Mutex<()> = Mutex::new(());
633 let _guard = ENV_LOCK.lock().unwrap();
634
635 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
636 std::env::set_var("HF_HOME", &tmp);
637 let result = EmbeddingEngine::model_cache_dir("org/my-model");
638 std::env::remove_var("HF_HOME");
639
640 let path = result.unwrap();
641 assert!(
642 path.starts_with(&tmp),
643 "expected path under {tmp:?}, got {path:?}"
644 );
645 assert!(
646 path.to_str().unwrap().contains("org--my-model"),
647 "model_id separator not applied: {path:?}"
648 );
649 }
650
651 #[test]
652 fn test_model_cache_dir_contains_dakera_subdir() {
653 let path =
654 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
655 let s = path.to_str().unwrap();
656 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
657 assert!(
658 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
659 "expected transformed model id in path: {s}"
660 );
661 }
662
663 #[test]
664 fn test_model_cache_dir_creates_directory() {
665 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
666 assert!(dir.exists(), "model_cache_dir should create the directory");
667 }
668
669 #[test]
672 fn test_download_hf_file_returns_path_when_already_cached() {
673 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
674 std::fs::create_dir_all(&tmp).unwrap();
675 let file_path = tmp.join("config.json");
676 std::fs::write(&file_path, b"{}").unwrap();
677
678 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
679 assert!(result.is_ok());
680 assert_eq!(result.unwrap(), file_path);
681 }
682
683 #[test]
684 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
685 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
686 let onnx_dir = tmp.join("onnx");
687 std::fs::create_dir_all(&onnx_dir).unwrap();
688 let onnx_path = onnx_dir.join("model_quantized.onnx");
689 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
690
691 let result = EmbeddingEngine::download_hf_file(
693 "Xenova/all-MiniLM-L6-v2",
694 "onnx/model_quantized.onnx",
695 &tmp,
696 );
697 assert!(result.is_ok());
698 assert_eq!(result.unwrap(), onnx_path);
699 }
700
701 #[test]
704 fn test_builder_default_impl() {
705 let b1 = EmbeddingEngineBuilder::new();
706 let b2 = EmbeddingEngineBuilder::default();
707 assert_eq!(b1.config.model, b2.config.model);
708 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
709 }
710
711 #[test]
712 fn test_builder_model_field() {
713 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
714 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
715 }
716
717 #[test]
718 fn test_builder_cache_dir() {
719 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
720 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
721 }
722
723 #[test]
724 fn test_builder_max_batch_size() {
725 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
726 assert_eq!(builder.config.max_batch_size, 128);
727 }
728
729 #[test]
730 fn test_builder_use_gpu_true() {
731 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
732 assert!(builder.config.use_gpu);
733 }
734
735 #[test]
736 fn test_builder_use_gpu_false() {
737 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
738 assert!(!builder.config.use_gpu);
739 }
740
741 #[test]
742 fn test_builder_num_threads() {
743 let builder = EmbeddingEngineBuilder::new().num_threads(4);
744 assert_eq!(builder.config.num_threads, Some(4));
745 }
746
747 #[test]
748 fn test_builder_chain_all_fields() {
749 let builder = EmbeddingEngineBuilder::new()
750 .model(EmbeddingModel::BgeSmall)
751 .cache_dir("/cache")
752 .max_batch_size(16)
753 .use_gpu(false)
754 .num_threads(2);
755
756 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
757 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
758 assert_eq!(builder.config.max_batch_size, 16);
759 assert!(!builder.config.use_gpu);
760 assert_eq!(builder.config.num_threads, Some(2));
761 }
762
763 #[test]
766 fn test_estimate_time_zero_count() {
767 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
768 let estimate = (0.0 / tps) * 1000.0;
769 assert_eq!(estimate, 0.0);
770 }
771
772 #[test]
773 fn test_estimate_time_formula_cpu() {
774 let model = EmbeddingModel::MiniLM;
777 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
778 let total_tokens = tokens_per_text * 10.0;
779 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
780 assert!(
781 (estimate - 50.0).abs() < 1e-6,
782 "expected 50.0ms, got {estimate}"
783 );
784 }
785
786 #[test]
787 fn test_estimate_time_capped_at_max_seq_length() {
788 let model = EmbeddingModel::MiniLM;
789 let avg_len = 100_000;
790 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
791 assert_eq!(tokens_per_text, 256.0);
792 }
793
794 #[test]
797 fn test_model_config_new() {
798 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
799 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
800 assert_eq!(cfg.max_batch_size, 1);
801 assert!(!cfg.use_gpu);
802 assert!(cfg.cache_dir.is_none());
803 assert!(cfg.num_threads.is_none());
804 }
805
806 #[test]
807 fn test_model_config_default() {
808 let cfg = ModelConfig::default();
809 assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
810 assert_eq!(cfg.max_batch_size, 1);
811 assert!(!cfg.use_gpu);
812 }
813
814 #[test]
815 fn test_model_config_with_cache_dir() {
816 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
817 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
818 }
819
820 #[test]
821 fn test_model_config_with_max_batch_size() {
822 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
823 assert_eq!(cfg.max_batch_size, 64);
824 }
825
826 #[test]
827 fn test_model_config_with_gpu() {
828 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
829 assert!(cfg.use_gpu);
830 }
831
832 #[test]
833 fn test_model_config_with_num_threads() {
834 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
835 assert_eq!(cfg.num_threads, Some(8));
836 }
837
838 #[test]
839 fn test_model_config_chained_builder() {
840 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
841 .with_cache_dir("/cache")
842 .with_max_batch_size(16)
843 .with_gpu(false)
844 .with_num_threads(4);
845 assert_eq!(cfg.model, EmbeddingModel::E5Small);
846 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
847 assert_eq!(cfg.max_batch_size, 16);
848 assert!(!cfg.use_gpu);
849 assert_eq!(cfg.num_threads, Some(4));
850 }
851
852 #[test]
856 fn test_model_cache_dir_no_home_fallback() {
857 use std::sync::Mutex;
858 static ENV_LOCK: Mutex<()> = Mutex::new(());
859 let _guard = ENV_LOCK.lock().unwrap();
860
861 let saved_home = std::env::var("HOME").ok();
863 let saved_hf = std::env::var("HF_HOME").ok();
864 unsafe {
865 std::env::remove_var("HOME");
866 std::env::remove_var("HF_HOME");
867 }
868
869 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
870
871 if let Some(h) = saved_home {
873 unsafe { std::env::set_var("HOME", h) };
874 }
875 if let Some(h) = saved_hf {
876 unsafe { std::env::set_var("HF_HOME", h) };
877 }
878
879 let path = result.unwrap();
880 assert!(
882 path.starts_with("/tmp"),
883 "expected path under /tmp, got {path:?}"
884 );
885 }
886
887 #[test]
888 fn test_model_cache_dir_deep_model_id() {
889 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
890 let s = path.to_str().unwrap();
891 assert!(
893 s.contains("org--sub--model-name-with-dashes"),
894 "expected transformed path, got: {s}"
895 );
896 }
897
898 #[test]
899 fn test_model_cache_dir_minilm_model_id() {
900 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
901 let s = path.to_str().unwrap();
902 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
903 }
904
905 #[test]
906 fn test_model_cache_dir_bge_model_id() {
907 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
908 let s = path.to_str().unwrap();
909 assert!(s.contains("BAAI--bge-small-en-v1.5"));
910 }
911
912 #[test]
913 fn test_model_cache_dir_e5_model_id() {
914 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
915 let s = path.to_str().unwrap();
916 assert!(s.contains("intfloat--e5-small-v2"));
917 }
918
919 #[test]
922 fn test_download_hf_file_pytorch_bin_cached() {
923 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
924 std::fs::create_dir_all(&tmp).unwrap();
925 let model_path = tmp.join("pytorch_model.bin");
926 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
927
928 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
929 assert!(result.is_ok());
930 assert_eq!(result.unwrap(), model_path);
931 }
932
933 #[test]
934 fn test_download_hf_file_tokenizer_cached() {
935 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
936 std::fs::create_dir_all(&tmp).unwrap();
937 let tok_path = tmp.join("tokenizer.json");
938 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
939
940 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
941 assert!(result.is_ok());
942 assert_eq!(result.unwrap(), tok_path);
943 }
944
945 #[test]
946 fn test_download_hf_file_config_json_cached() {
947 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
948 std::fs::create_dir_all(&tmp).unwrap();
949 let cfg_path = tmp.join("config.json");
950 std::fs::write(&cfg_path, b"{}").unwrap();
951
952 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
953 assert!(result.is_ok());
954 assert_eq!(result.unwrap(), cfg_path);
955 }
956
957 #[tokio::test]
963 #[allow(clippy::await_holding_lock)]
964 async fn test_new_fails_with_invalid_tokenizer_json() {
965 use std::sync::Mutex;
966 static ENV_LOCK: Mutex<()> = Mutex::new(());
967 let _guard = ENV_LOCK.lock().unwrap();
968
969 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
971 let model_dir = tmp
972 .join("dakera")
973 .join("sentence-transformers--all-MiniLM-L6-v2");
974 std::fs::create_dir_all(&model_dir).unwrap();
975 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
977 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
979 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
980
981 unsafe { std::env::set_var("HF_HOME", &tmp) };
982
983 let config = ModelConfig::new(EmbeddingModel::MiniLM);
984 let result = EmbeddingEngine::new(config).await;
985
986 unsafe { std::env::remove_var("HF_HOME") };
987
988 assert!(
990 result.is_err(),
991 "expected Err from new() with invalid tokenizer, got Ok"
992 );
993 }
994
995 #[test]
998 fn test_builder_with_all_models() {
999 for model in [
1000 EmbeddingModel::MiniLM,
1001 EmbeddingModel::BgeSmall,
1002 EmbeddingModel::E5Small,
1003 ] {
1004 let builder = EmbeddingEngineBuilder::new().model(model);
1005 assert_eq!(builder.config.model, model);
1006 }
1007 }
1008
1009 #[test]
1010 fn test_builder_max_batch_size_one() {
1011 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1012 assert_eq!(builder.config.max_batch_size, 1);
1013 }
1014
1015 #[test]
1016 fn test_builder_num_threads_zero() {
1017 let builder = EmbeddingEngineBuilder::new().num_threads(0);
1018 assert_eq!(builder.config.num_threads, Some(0));
1019 }
1020
1021 #[tokio::test]
1027 async fn test_engine_getters_when_model_cached() {
1028 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1029 match EmbeddingEngine::new(config).await {
1030 Ok(engine) => {
1031 assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1032 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1033 let _ = format!("{:?}", engine);
1036 let ms = engine.estimate_time_ms(10, 50);
1038 assert!(ms >= 0.0);
1039 }
1040 Err(_) => {
1041 }
1043 }
1044 }
1045
1046 #[tokio::test]
1049 async fn test_engine_embed_empty_batch_when_cached() {
1050 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1051 if let Ok(engine) = EmbeddingEngine::new(config).await {
1052 let result = engine.embed_raw(&[]).await;
1053 assert!(result.is_ok());
1054 assert!(result.unwrap().is_empty());
1055 }
1056 }
1057
1058 #[test]
1061 fn test_session_pool_default_is_4() {
1062 let config = ModelConfig::default();
1064 let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1065 .ok()
1066 .and_then(|v| v.parse::<usize>().ok())
1067 .filter(|&n| n >= 1)
1068 .unwrap_or(4);
1069 assert_eq!(config.session_pool_size, expected);
1070 }
1071
1072 #[test]
1073 fn test_session_pool_size_builder_roundtrip() {
1074 let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1075 assert_eq!(builder.config.session_pool_size, 8);
1076 }
1077
1078 #[test]
1079 fn test_session_pool_size_min_enforced() {
1080 let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1081 assert_eq!(
1082 builder.config.session_pool_size, 1,
1083 "pool size 0 must clamp to 1"
1084 );
1085 }
1086
1087 #[test]
1088 fn test_model_config_with_session_pool_size() {
1089 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1090 assert_eq!(cfg.session_pool_size, 2);
1091 }
1092
1093 #[tokio::test]
1095 async fn test_engine_pool_size_matches_config_when_cached() {
1096 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1097 if let Ok(engine) = EmbeddingEngine::new(config).await {
1098 assert_eq!(
1099 engine.pool_size(),
1100 2,
1101 "engine should hold exactly 2 sessions"
1102 );
1103 }
1104 }
1105
1106 #[test]
1110 fn test_round_robin_index_stays_in_bounds() {
1111 let pool_len = 4_usize;
1112 let counter = AtomicUsize::new(0);
1113 for expected_idx in 0..100_usize {
1114 let start = counter.fetch_add(1, Ordering::Relaxed);
1115 let slot = start % pool_len;
1116 assert!(slot < pool_len);
1117 assert_eq!(slot, expected_idx % pool_len);
1118 }
1119 }
1120
1121 #[test]
1123 fn test_round_robin_pool_size_one() {
1124 let pool_len = 1_usize;
1125 let counter = AtomicUsize::new(0);
1126 for _ in 0..50 {
1127 let start = counter.fetch_add(1, Ordering::Relaxed);
1128 assert_eq!(start % pool_len, 0);
1129 }
1130 }
1131
1132 #[tokio::test]
1134 async fn test_embed_empty_does_not_advance_pool_counter() {
1135 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1136 if let Ok(engine) = EmbeddingEngine::new(config).await {
1137 let result = engine.embed_raw(&[]).await;
1138 assert!(result.is_ok());
1139 assert!(result.unwrap().is_empty());
1140 assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1142 }
1143 }
1144}