1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49pub struct EmbeddingEngine {
56 sessions: Vec<Arc<Mutex<Session>>>,
60 next_session: AtomicUsize,
62 processor: Arc<BatchProcessor>,
64 config: ModelConfig,
66 dimension: usize,
68}
69
70impl EmbeddingEngine {
71 #[instrument(skip_all, fields(model = %config.model))]
75 pub async fn new(config: ModelConfig) -> Result<Self> {
76 info!(
77 "Initializing ONNX embedding engine with model: {}",
78 config.model
79 );
80
81 let (tokenizer_path, onnx_path) = Self::download_model_files(&config).await?;
83
84 info!("Loading tokenizer from {:?}", tokenizer_path);
86 let tokenizer = Tokenizer::from_file(&tokenizer_path)
87 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
88
89 info!("Loading ONNX model from {:?}", onnx_path);
92 let num_threads = config.num_threads.unwrap_or(4);
93 let pool_size = config.session_pool_size.max(1);
94 let onnx_path_clone = onnx_path.clone();
95
96 let use_gpu = std::env::var("DAKERA_USE_GPU")
97 .map(|v| v == "1")
98 .unwrap_or(false);
99 if use_gpu {
100 info!("CUDA execution provider enabled (DAKERA_USE_GPU=1)");
101 }
102
103 let sessions: Vec<Arc<Mutex<Session>>> =
104 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
105 (0..pool_size)
106 .map(|_| {
107 let builder = Session::builder()
108 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
109 .with_optimization_level(GraphOptimizationLevel::Level3)
110 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
111 .with_intra_threads(num_threads)
112 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
113
114 let mut builder = if use_gpu {
115 builder
116 .with_execution_providers(
117 [CUDAExecutionProvider::default().build()],
118 )
119 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
120 } else {
121 builder
122 };
123
124 let s = builder
125 .commit_from_file(&onnx_path_clone)
126 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
127 Ok(Arc::new(Mutex::new(s)))
128 })
129 .collect()
130 })
131 .await
132 .map_err(|e| {
133 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
134 })??;
135
136 let dimension = config.model.dimension();
137 let processor = Arc::new(BatchProcessor::new(
138 tokenizer,
139 config.model,
140 config.max_batch_size,
141 ));
142
143 info!(
144 "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
145 config.model, dimension, num_threads, pool_size
146 );
147
148 Ok(Self {
149 sessions,
150 next_session: AtomicUsize::new(0),
151 processor,
152 config,
153 dimension,
154 })
155 }
156
157 #[instrument(skip_all, fields(model = %config.model))]
162 async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf)> {
163 let model_id = config.model.model_id();
164 let onnx_repo_id = config.model.onnx_repo_id();
165 let onnx_filename = config.model.onnx_filename();
166
167 info!(
168 "Resolving model files: tokenizer={}, onnx={}@{}",
169 model_id, onnx_filename, onnx_repo_id
170 );
171
172 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
173 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
174
175 let onnx_subdir = onnx_cache_dir.join("onnx");
177 std::fs::create_dir_all(&onnx_subdir)?;
178
179 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
180 let onnx_basename = Path::new(onnx_filename)
182 .file_name()
183 .and_then(|s| s.to_str())
184 .unwrap_or("model_quantized.onnx");
185 let local_onnx = onnx_subdir.join(onnx_basename);
186
187 let tokenizer_needs_download = !local_tokenizer.exists();
189 let onnx_needs_download = !local_onnx.exists();
190
191 if tokenizer_needs_download || onnx_needs_download {
192 let model_id_owned = model_id.to_string();
193 let onnx_repo_id_owned = onnx_repo_id.to_string();
194 let onnx_filename_owned = onnx_filename.to_string();
195 let tokenizer_cache = tokenizer_cache_dir.clone();
196 let onnx_cache = onnx_cache_dir.clone();
197
198 tokio::task::spawn_blocking(move || {
199 if !tokenizer_cache.join("tokenizer.json").exists() {
200 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
201 .map_err(|e| {
202 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
203 })?;
204 }
205 if !onnx_cache.join(&onnx_filename_owned).exists() {
206 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
207 .map_err(|e| {
208 InferenceError::HubError(format!(
209 "Failed to download ONNX model: {}",
210 e
211 ))
212 })?;
213 }
214 Ok::<_, InferenceError>(())
215 })
216 .await
217 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
218 } else {
219 info!("All model files found in local cache");
220 }
221
222 let final_onnx = onnx_cache_dir.join(onnx_filename);
224
225 info!(
226 "Model files ready: tokenizer={:?}, onnx={:?}",
227 local_tokenizer, final_onnx
228 );
229 Ok((local_tokenizer, final_onnx))
230 }
231
232 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
234 let base = std::env::var("HF_HOME")
235 .map(PathBuf::from)
236 .unwrap_or_else(|_| {
237 let home = std::env::var("HOME").unwrap_or_else(|_| {
238 warn!("HOME environment variable not set, using /tmp for model cache");
239 "/tmp".to_string()
240 });
241 PathBuf::from(home).join(".cache").join("huggingface")
242 });
243 let dir = base.join("dakera").join(model_id.replace('/', "--"));
244 std::fs::create_dir_all(&dir)?;
245 Ok(dir)
246 }
247
248 pub fn download_hf_file_pub(
254 model_id: &str,
255 filename: &str,
256 cache_dir: &Path,
257 ) -> std::result::Result<PathBuf, String> {
258 Self::download_hf_file(model_id, filename, cache_dir)
259 }
260
261 fn download_hf_file(
262 model_id: &str,
263 filename: &str,
264 cache_dir: &Path,
265 ) -> std::result::Result<PathBuf, String> {
266 let file_path = cache_dir.join(filename);
268 if file_path.exists() {
269 info!("Cached: {}/{}", model_id, filename);
270 return Ok(file_path);
271 }
272
273 if let Some(parent) = file_path.parent() {
275 std::fs::create_dir_all(parent)
276 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
277 }
278
279 let url = format!(
280 "https://huggingface.co/{}/resolve/main/{}",
281 model_id, filename
282 );
283 info!("Downloading: {}", url);
284
285 let hf_token = std::env::var("HF_TOKEN")
287 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
288 .ok();
289 if hf_token.is_some() {
290 info!("Using HuggingFace auth token for download");
291 }
292
293 let agent = ureq::AgentBuilder::new()
295 .redirects(0)
296 .timeout(std::time::Duration::from_secs(300))
297 .build();
298
299 let mut current_url = url.clone();
300 let mut redirects = 0;
301 let max_redirects = 10;
302
303 let response = loop {
304 let mut req = agent.get(¤t_url);
305 if let Some(ref token) = hf_token {
306 req = req.set("Authorization", &format!("Bearer {}", token));
307 }
308 let resp = req.call();
309
310 let r = match resp {
311 Ok(r) => r,
312 Err(ureq::Error::Status(_status, r)) => r,
313 Err(e) => return Err(format!("{}: {}", filename, e)),
314 };
315
316 let status = r.status();
317 if (200..300).contains(&status) {
318 break r;
319 } else if (300..400).contains(&status) {
320 redirects += 1;
321 if redirects > max_redirects {
322 return Err(format!("{}: too many redirects", filename));
323 }
324 let location = r
325 .header("location")
326 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
327 .to_string();
328
329 current_url = if location.starts_with('/') {
331 let parsed = url::Url::parse(¤t_url)
332 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
333 let host = parsed.host_str().ok_or_else(|| {
334 format!("{}: redirect URL missing host: {}", filename, current_url)
335 })?;
336 format!("{}://{}{}", parsed.scheme(), host, location)
337 } else {
338 location
339 };
340 info!("Redirect {} → {}", redirects, current_url);
341 } else {
342 return Err(format!("{}: HTTP {}", filename, status));
343 }
344 };
345
346 let mut bytes = Vec::new();
347 response
348 .into_reader()
349 .take(500_000_000) .read_to_end(&mut bytes)
351 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
352
353 std::fs::write(&file_path, &bytes)
354 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
355
356 info!("Downloaded {} ({} bytes)", filename, bytes.len());
357 Ok(file_path)
358 }
359
360 pub fn dimension(&self) -> usize {
362 self.dimension
363 }
364
365 pub fn model(&self) -> EmbeddingModel {
367 self.config.model
368 }
369
370 pub fn pool_size(&self) -> usize {
372 self.sessions.len()
373 }
374
375 #[instrument(skip(self, text), fields(text_len = text.len()))]
379 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
380 let texts = vec![text.to_string()];
381 let prepared = self.processor.prepare_texts(&texts, true);
382 let embeddings = self.embed_batch_internal(&prepared).await?;
383 embeddings.into_iter().next().ok_or_else(|| {
384 InferenceError::InferenceError("No embedding returned for query".to_string())
385 })
386 }
387
388 #[instrument(skip(self, texts), fields(count = texts.len()))]
392 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
393 let prepared = self.processor.prepare_texts(texts, true);
394 self.embed_batch_internal(&prepared).await
395 }
396
397 #[instrument(skip(self, text), fields(text_len = text.len()))]
401 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
402 let texts = vec![text.to_string()];
403 let prepared = self.processor.prepare_texts(&texts, false);
404 let embeddings = self.embed_batch_internal(&prepared).await?;
405 embeddings.into_iter().next().ok_or_else(|| {
406 InferenceError::InferenceError("No embedding returned for document".to_string())
407 })
408 }
409
410 #[instrument(skip(self, texts), fields(count = texts.len()))]
414 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
415 let prepared = self.processor.prepare_texts(texts, false);
416 self.embed_batch_internal(&prepared).await
417 }
418
419 #[instrument(skip(self, texts), fields(count = texts.len()))]
421 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
422 self.embed_batch_internal(texts).await
423 }
424
425 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
433 if texts.is_empty() {
434 return Ok(vec![]);
435 }
436
437 let batches: Vec<Vec<String>> = self
438 .processor
439 .split_into_batches(texts)
440 .into_iter()
441 .map(|b| b.to_vec())
442 .collect();
443
444 let pool_len = self.sessions.len();
445 let normalize = self.config.model.normalize_embeddings();
446 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
449
450 let mut handles = Vec::with_capacity(batches.len());
452 for (i, batch_owned) in batches.into_iter().enumerate() {
453 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
454 let processor = Arc::clone(&self.processor);
455 handles.push(tokio::task::spawn_blocking(move || {
456 let mut session_guard = session.lock();
457 Self::process_batch_blocking(
458 &batch_owned,
459 &mut session_guard,
460 &processor,
461 normalize,
462 )
463 }));
464 }
465
466 let mut all_embeddings = Vec::with_capacity(texts.len());
467 for handle in handles {
468 let batch_embeddings = handle.await.map_err(|e| {
469 InferenceError::InferenceError(format!("Inference task panicked: {}", e))
470 })??;
471 all_embeddings.extend(batch_embeddings);
472 }
473
474 Ok(all_embeddings)
475 }
476
477 fn process_batch_blocking(
481 texts: &[String],
482 session: &mut Session,
483 processor: &BatchProcessor,
484 normalize: bool,
485 ) -> Result<Vec<Vec<f32>>> {
486 let prepared = processor.tokenize_batch(texts)?;
488 let batch_size = prepared.batch_size;
489 let seq_len = prepared.seq_len;
490
491 let attention_mask_flat = prepared.attention_mask.clone();
493
494 let input_ids_tensor =
496 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
497 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
498 let attention_mask_tensor =
499 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
500 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
501 let token_type_ids_tensor =
502 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
503 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
504
505 let outputs = session
507 .run(inputs![
508 "input_ids" => input_ids_tensor,
509 "attention_mask" => attention_mask_tensor,
510 "token_type_ids" => token_type_ids_tensor
511 ])
512 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
513
514 let (ort_shape, lhs_slice) = outputs[0]
518 .try_extract_tensor::<f32>()
519 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
520
521 if ort_shape.len() != 3 {
522 return Err(InferenceError::InferenceError(format!(
523 "Expected 3D last_hidden_state, got {} dims",
524 ort_shape.len()
525 )));
526 }
527 let hidden_size = ort_shape[2] as usize;
528
529 let mut embeddings = mean_pooling(
531 lhs_slice,
532 batch_size,
533 seq_len,
534 hidden_size,
535 &attention_mask_flat,
536 );
537
538 if normalize {
540 normalize_embeddings(&mut embeddings);
541 }
542
543 debug!(
544 "Generated {} embeddings of dimension {}",
545 embeddings.len(),
546 embeddings.first().map(|e| e.len()).unwrap_or(0)
547 );
548
549 Ok(embeddings)
550 }
551
552 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
554 let tokens_per_text =
556 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
557 let total_tokens = tokens_per_text * text_count as f64;
558 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
559 (total_tokens / tokens_per_second) * 1000.0
560 }
561}
562
563impl std::fmt::Debug for EmbeddingEngine {
564 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565 f.debug_struct("EmbeddingEngine")
566 .field("model", &self.config.model)
567 .field("dimension", &self.dimension)
568 .field("max_batch_size", &self.config.max_batch_size)
569 .field("session_pool_size", &self.sessions.len())
570 .finish()
571 }
572}
573
574pub struct EmbeddingEngineBuilder {
576 config: ModelConfig,
577}
578
579impl EmbeddingEngineBuilder {
580 pub fn new() -> Self {
582 Self {
583 config: ModelConfig::default(),
584 }
585 }
586
587 pub fn model(mut self, model: EmbeddingModel) -> Self {
589 self.config.model = model;
590 self
591 }
592
593 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
595 self.config.cache_dir = Some(dir.into());
596 self
597 }
598
599 pub fn max_batch_size(mut self, size: usize) -> Self {
601 self.config.max_batch_size = size;
602 self
603 }
604
605 pub fn use_gpu(mut self, enable: bool) -> Self {
607 self.config.use_gpu = enable;
608 self
609 }
610
611 pub fn num_threads(mut self, threads: usize) -> Self {
613 self.config.num_threads = Some(threads);
614 self
615 }
616
617 pub fn session_pool_size(mut self, size: usize) -> Self {
619 self.config.session_pool_size = size.max(1);
620 self
621 }
622
623 pub async fn build(self) -> Result<EmbeddingEngine> {
625 EmbeddingEngine::new(self.config).await
626 }
627}
628
629impl Default for EmbeddingEngineBuilder {
630 fn default() -> Self {
631 Self::new()
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638
639 #[test]
640 fn test_estimate_time() {
641 let config = ModelConfig::new(EmbeddingModel::MiniLM);
642 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
643 assert!(tokens_per_second > 0.0);
644 }
645
646 #[test]
647 fn test_builder() {
648 let builder = EmbeddingEngineBuilder::new()
649 .model(EmbeddingModel::BgeSmall)
650 .max_batch_size(64)
651 .use_gpu(false);
652
653 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
654 assert_eq!(builder.config.max_batch_size, 64);
655 assert!(!builder.config.use_gpu);
656 }
657
658 #[test]
664 fn test_model_cache_dir_with_hf_home() {
665 use std::sync::Mutex;
666 static ENV_LOCK: Mutex<()> = Mutex::new(());
667 let _guard = ENV_LOCK.lock().unwrap();
668
669 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
670 std::env::set_var("HF_HOME", &tmp);
671 let result = EmbeddingEngine::model_cache_dir("org/my-model");
672 std::env::remove_var("HF_HOME");
673
674 let path = result.unwrap();
675 assert!(
676 path.starts_with(&tmp),
677 "expected path under {tmp:?}, got {path:?}"
678 );
679 assert!(
680 path.to_str().unwrap().contains("org--my-model"),
681 "model_id separator not applied: {path:?}"
682 );
683 }
684
685 #[test]
686 fn test_model_cache_dir_contains_dakera_subdir() {
687 let path =
688 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
689 let s = path.to_str().unwrap();
690 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
691 assert!(
692 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
693 "expected transformed model id in path: {s}"
694 );
695 }
696
697 #[test]
698 fn test_model_cache_dir_creates_directory() {
699 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
700 assert!(dir.exists(), "model_cache_dir should create the directory");
701 }
702
703 #[test]
706 fn test_download_hf_file_returns_path_when_already_cached() {
707 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
708 std::fs::create_dir_all(&tmp).unwrap();
709 let file_path = tmp.join("config.json");
710 std::fs::write(&file_path, b"{}").unwrap();
711
712 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
713 assert!(result.is_ok());
714 assert_eq!(result.unwrap(), file_path);
715 }
716
717 #[test]
718 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
719 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
720 let onnx_dir = tmp.join("onnx");
721 std::fs::create_dir_all(&onnx_dir).unwrap();
722 let onnx_path = onnx_dir.join("model_quantized.onnx");
723 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
724
725 let result = EmbeddingEngine::download_hf_file(
727 "Xenova/all-MiniLM-L6-v2",
728 "onnx/model_quantized.onnx",
729 &tmp,
730 );
731 assert!(result.is_ok());
732 assert_eq!(result.unwrap(), onnx_path);
733 }
734
735 #[test]
738 fn test_builder_default_impl() {
739 let b1 = EmbeddingEngineBuilder::new();
740 let b2 = EmbeddingEngineBuilder::default();
741 assert_eq!(b1.config.model, b2.config.model);
742 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
743 }
744
745 #[test]
746 fn test_builder_model_field() {
747 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
748 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
749 }
750
751 #[test]
752 fn test_builder_cache_dir() {
753 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
754 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
755 }
756
757 #[test]
758 fn test_builder_max_batch_size() {
759 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
760 assert_eq!(builder.config.max_batch_size, 128);
761 }
762
763 #[test]
764 fn test_builder_use_gpu_true() {
765 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
766 assert!(builder.config.use_gpu);
767 }
768
769 #[test]
770 fn test_builder_use_gpu_false() {
771 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
772 assert!(!builder.config.use_gpu);
773 }
774
775 #[test]
776 fn test_builder_num_threads() {
777 let builder = EmbeddingEngineBuilder::new().num_threads(4);
778 assert_eq!(builder.config.num_threads, Some(4));
779 }
780
781 #[test]
782 fn test_builder_chain_all_fields() {
783 let builder = EmbeddingEngineBuilder::new()
784 .model(EmbeddingModel::BgeSmall)
785 .cache_dir("/cache")
786 .max_batch_size(16)
787 .use_gpu(false)
788 .num_threads(2);
789
790 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
791 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
792 assert_eq!(builder.config.max_batch_size, 16);
793 assert!(!builder.config.use_gpu);
794 assert_eq!(builder.config.num_threads, Some(2));
795 }
796
797 #[test]
800 fn test_estimate_time_zero_count() {
801 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
802 let estimate = (0.0 / tps) * 1000.0;
803 assert_eq!(estimate, 0.0);
804 }
805
806 #[test]
807 fn test_estimate_time_formula_cpu() {
808 let model = EmbeddingModel::MiniLM;
811 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
812 let total_tokens = tokens_per_text * 10.0;
813 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
814 assert!(
815 (estimate - 50.0).abs() < 1e-6,
816 "expected 50.0ms, got {estimate}"
817 );
818 }
819
820 #[test]
821 fn test_estimate_time_capped_at_max_seq_length() {
822 let model = EmbeddingModel::MiniLM;
823 let avg_len = 100_000;
824 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
825 assert_eq!(tokens_per_text, 256.0);
826 }
827
828 #[test]
831 fn test_model_config_new() {
832 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
833 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
834 assert_eq!(cfg.max_batch_size, 8);
835 assert!(!cfg.use_gpu);
836 assert!(cfg.cache_dir.is_none());
837 assert!(cfg.num_threads.is_none());
838 }
839
840 #[test]
841 fn test_model_config_default() {
842 let cfg = ModelConfig::default();
843 assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
844 assert_eq!(cfg.max_batch_size, 8);
845 assert!(!cfg.use_gpu);
846 }
847
848 #[test]
849 fn test_model_config_with_cache_dir() {
850 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
851 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
852 }
853
854 #[test]
855 fn test_model_config_with_max_batch_size() {
856 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
857 assert_eq!(cfg.max_batch_size, 64);
858 }
859
860 #[test]
861 fn test_model_config_with_gpu() {
862 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
863 assert!(cfg.use_gpu);
864 }
865
866 #[test]
867 fn test_model_config_with_num_threads() {
868 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
869 assert_eq!(cfg.num_threads, Some(8));
870 }
871
872 #[test]
873 fn test_model_config_chained_builder() {
874 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
875 .with_cache_dir("/cache")
876 .with_max_batch_size(16)
877 .with_gpu(false)
878 .with_num_threads(4);
879 assert_eq!(cfg.model, EmbeddingModel::E5Small);
880 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
881 assert_eq!(cfg.max_batch_size, 16);
882 assert!(!cfg.use_gpu);
883 assert_eq!(cfg.num_threads, Some(4));
884 }
885
886 #[test]
890 fn test_model_cache_dir_no_home_fallback() {
891 use std::sync::Mutex;
892 static ENV_LOCK: Mutex<()> = Mutex::new(());
893 let _guard = ENV_LOCK.lock().unwrap();
894
895 let saved_home = std::env::var("HOME").ok();
897 let saved_hf = std::env::var("HF_HOME").ok();
898 unsafe {
899 std::env::remove_var("HOME");
900 std::env::remove_var("HF_HOME");
901 }
902
903 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
904
905 if let Some(h) = saved_home {
907 unsafe { std::env::set_var("HOME", h) };
908 }
909 if let Some(h) = saved_hf {
910 unsafe { std::env::set_var("HF_HOME", h) };
911 }
912
913 let path = result.unwrap();
914 assert!(
916 path.starts_with("/tmp"),
917 "expected path under /tmp, got {path:?}"
918 );
919 }
920
921 #[test]
922 fn test_model_cache_dir_deep_model_id() {
923 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
924 let s = path.to_str().unwrap();
925 assert!(
927 s.contains("org--sub--model-name-with-dashes"),
928 "expected transformed path, got: {s}"
929 );
930 }
931
932 #[test]
933 fn test_model_cache_dir_minilm_model_id() {
934 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
935 let s = path.to_str().unwrap();
936 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
937 }
938
939 #[test]
940 fn test_model_cache_dir_bge_model_id() {
941 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
942 let s = path.to_str().unwrap();
943 assert!(s.contains("BAAI--bge-small-en-v1.5"));
944 }
945
946 #[test]
947 fn test_model_cache_dir_e5_model_id() {
948 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
949 let s = path.to_str().unwrap();
950 assert!(s.contains("intfloat--e5-small-v2"));
951 }
952
953 #[test]
956 fn test_download_hf_file_pytorch_bin_cached() {
957 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
958 std::fs::create_dir_all(&tmp).unwrap();
959 let model_path = tmp.join("pytorch_model.bin");
960 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
961
962 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
963 assert!(result.is_ok());
964 assert_eq!(result.unwrap(), model_path);
965 }
966
967 #[test]
968 fn test_download_hf_file_tokenizer_cached() {
969 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
970 std::fs::create_dir_all(&tmp).unwrap();
971 let tok_path = tmp.join("tokenizer.json");
972 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
973
974 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
975 assert!(result.is_ok());
976 assert_eq!(result.unwrap(), tok_path);
977 }
978
979 #[test]
980 fn test_download_hf_file_config_json_cached() {
981 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
982 std::fs::create_dir_all(&tmp).unwrap();
983 let cfg_path = tmp.join("config.json");
984 std::fs::write(&cfg_path, b"{}").unwrap();
985
986 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
987 assert!(result.is_ok());
988 assert_eq!(result.unwrap(), cfg_path);
989 }
990
991 #[tokio::test]
997 #[allow(clippy::await_holding_lock)]
998 async fn test_new_fails_with_invalid_tokenizer_json() {
999 use std::sync::Mutex;
1000 static ENV_LOCK: Mutex<()> = Mutex::new(());
1001 let _guard = ENV_LOCK.lock().unwrap();
1002
1003 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1005 let model_dir = tmp
1006 .join("dakera")
1007 .join("sentence-transformers--all-MiniLM-L6-v2");
1008 std::fs::create_dir_all(&model_dir).unwrap();
1009 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1011 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1013 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1014
1015 unsafe { std::env::set_var("HF_HOME", &tmp) };
1016
1017 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1018 let result = EmbeddingEngine::new(config).await;
1019
1020 unsafe { std::env::remove_var("HF_HOME") };
1021
1022 assert!(
1024 result.is_err(),
1025 "expected Err from new() with invalid tokenizer, got Ok"
1026 );
1027 }
1028
1029 #[test]
1032 fn test_builder_with_all_models() {
1033 for model in [
1034 EmbeddingModel::MiniLM,
1035 EmbeddingModel::BgeSmall,
1036 EmbeddingModel::E5Small,
1037 ] {
1038 let builder = EmbeddingEngineBuilder::new().model(model);
1039 assert_eq!(builder.config.model, model);
1040 }
1041 }
1042
1043 #[test]
1044 fn test_builder_max_batch_size_one() {
1045 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1046 assert_eq!(builder.config.max_batch_size, 1);
1047 }
1048
1049 #[test]
1050 fn test_builder_num_threads_zero() {
1051 let builder = EmbeddingEngineBuilder::new().num_threads(0);
1052 assert_eq!(builder.config.num_threads, Some(0));
1053 }
1054
1055 #[tokio::test]
1061 async fn test_engine_getters_when_model_cached() {
1062 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1063 match EmbeddingEngine::new(config).await {
1064 Ok(engine) => {
1065 assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1066 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1067 let _ = format!("{:?}", engine);
1070 let ms = engine.estimate_time_ms(10, 50);
1072 assert!(ms >= 0.0);
1073 }
1074 Err(_) => {
1075 }
1077 }
1078 }
1079
1080 #[tokio::test]
1083 async fn test_engine_embed_empty_batch_when_cached() {
1084 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1085 if let Ok(engine) = EmbeddingEngine::new(config).await {
1086 let result = engine.embed_raw(&[]).await;
1087 assert!(result.is_ok());
1088 assert!(result.unwrap().is_empty());
1089 }
1090 }
1091
1092 #[test]
1095 fn test_session_pool_default_is_4() {
1096 let config = ModelConfig::default();
1099 let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1100 .ok()
1101 .and_then(|v| v.parse::<usize>().ok())
1102 .filter(|&n| n >= 1)
1103 .unwrap_or(4);
1104 assert_eq!(config.session_pool_size, expected);
1105 }
1106
1107 #[test]
1108 fn test_session_pool_size_builder_roundtrip() {
1109 let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1110 assert_eq!(builder.config.session_pool_size, 8);
1111 }
1112
1113 #[test]
1114 fn test_session_pool_size_min_enforced() {
1115 let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1116 assert_eq!(
1117 builder.config.session_pool_size, 1,
1118 "pool size 0 must clamp to 1"
1119 );
1120 }
1121
1122 #[test]
1123 fn test_model_config_with_session_pool_size() {
1124 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1125 assert_eq!(cfg.session_pool_size, 2);
1126 }
1127
1128 #[tokio::test]
1130 async fn test_engine_pool_size_matches_config_when_cached() {
1131 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1132 if let Ok(engine) = EmbeddingEngine::new(config).await {
1133 assert_eq!(
1134 engine.pool_size(),
1135 2,
1136 "engine should hold exactly 2 sessions"
1137 );
1138 }
1139 }
1140
1141 #[test]
1145 fn test_round_robin_index_stays_in_bounds() {
1146 let pool_len = 4_usize;
1147 let counter = AtomicUsize::new(0);
1148 for expected_idx in 0..100_usize {
1149 let start = counter.fetch_add(1, Ordering::Relaxed);
1150 let slot = start % pool_len;
1151 assert!(slot < pool_len);
1152 assert_eq!(slot, expected_idx % pool_len);
1153 }
1154 }
1155
1156 #[test]
1158 fn test_round_robin_pool_size_one() {
1159 let pool_len = 1_usize;
1160 let counter = AtomicUsize::new(0);
1161 for _ in 0..50 {
1162 let start = counter.fetch_add(1, Ordering::Relaxed);
1163 assert_eq!(start % pool_len, 0);
1164 }
1165 }
1166
1167 #[tokio::test]
1169 async fn test_embed_empty_does_not_advance_pool_counter() {
1170 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1171 if let Ok(engine) = EmbeddingEngine::new(config).await {
1172 let result = engine.embed_raw(&[]).await;
1173 assert!(result.is_ok());
1174 assert!(result.unwrap().is_empty());
1175 assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1177 }
1178 }
1179}