1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49pub struct EmbeddingEngine {
56 sessions: Vec<Arc<Mutex<Session>>>,
60 next_session: AtomicUsize,
62 processor: Arc<BatchProcessor>,
64 config: ModelConfig,
66 dimension: usize,
68}
69
70impl EmbeddingEngine {
71 #[instrument(skip_all, fields(model = %config.model))]
76 pub async fn new(config: ModelConfig) -> Result<Self> {
77 let use_gpu = std::env::var("DAKERA_USE_GPU")
80 .map(|v| v == "1")
81 .unwrap_or(config.use_gpu);
82 if use_gpu {
83 info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
84 }
85
86 info!(
87 "Initializing ONNX embedding engine with model: {}",
88 config.model
89 );
90
91 let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
93
94 info!("Loading tokenizer from {:?}", tokenizer_path);
96 let tokenizer = Tokenizer::from_file(&tokenizer_path)
97 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
98
99 info!("Loading ONNX model from {:?}", onnx_path);
102 let num_threads = config.num_threads.unwrap_or(4);
103 let pool_size = config.session_pool_size.max(1);
104 let onnx_path_clone = onnx_path.clone();
105
106 let sessions: Vec<Arc<Mutex<Session>>> =
107 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
108 (0..pool_size)
109 .map(|_| {
110 let builder = Session::builder()
111 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
112 .with_optimization_level(GraphOptimizationLevel::Level3)
113 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
114 .with_intra_threads(num_threads)
115 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
116
117 let mut builder = if use_gpu {
118 builder
119 .with_execution_providers(
120 [CUDAExecutionProvider::default().build()],
121 )
122 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
123 } else {
124 builder
125 };
126
127 let s = builder
128 .commit_from_file(&onnx_path_clone)
129 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
130 Ok(Arc::new(Mutex::new(s)))
131 })
132 .collect()
133 })
134 .await
135 .map_err(|e| {
136 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
137 })??;
138
139 let dimension = config.model.dimension();
140 let processor = Arc::new(BatchProcessor::new(
141 tokenizer,
142 config.model,
143 config.max_batch_size,
144 ));
145
146 info!(
147 "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
148 config.model, dimension, num_threads, pool_size
149 );
150
151 Ok(Self {
152 sessions,
153 next_session: AtomicUsize::new(0),
154 processor,
155 config,
156 dimension,
157 })
158 }
159
160 #[instrument(skip_all, fields(model = %config.model))]
166 async fn download_model_files(
167 config: &ModelConfig,
168 use_gpu: bool,
169 ) -> Result<(PathBuf, PathBuf)> {
170 let model_id = config.model.model_id();
171 let onnx_repo_id = config.model.onnx_repo_id();
172 let onnx_filename = if use_gpu {
173 config.model.onnx_filename_gpu()
174 } else {
175 config.model.onnx_filename()
176 };
177
178 info!(
179 "Resolving model files: tokenizer={}, onnx={}@{}",
180 model_id, onnx_filename, onnx_repo_id
181 );
182
183 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
184 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
185
186 let onnx_subdir = onnx_cache_dir.join("onnx");
188 std::fs::create_dir_all(&onnx_subdir)?;
189
190 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
191 let onnx_basename = Path::new(onnx_filename)
193 .file_name()
194 .and_then(|s| s.to_str())
195 .unwrap_or("model_quantized.onnx");
196 let local_onnx = onnx_subdir.join(onnx_basename);
197
198 let tokenizer_needs_download = !local_tokenizer.exists();
200 let onnx_needs_download = !local_onnx.exists();
201
202 if tokenizer_needs_download || onnx_needs_download {
203 let model_id_owned = model_id.to_string();
204 let onnx_repo_id_owned = onnx_repo_id.to_string();
205 let onnx_filename_owned = onnx_filename.to_string();
206 let tokenizer_cache = tokenizer_cache_dir.clone();
207 let onnx_cache = onnx_cache_dir.clone();
208
209 tokio::task::spawn_blocking(move || {
210 if !tokenizer_cache.join("tokenizer.json").exists() {
211 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
212 .map_err(|e| {
213 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
214 })?;
215 }
216 if !onnx_cache.join(&onnx_filename_owned).exists() {
217 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
218 .map_err(|e| {
219 InferenceError::HubError(format!(
220 "Failed to download ONNX model: {}",
221 e
222 ))
223 })?;
224 }
225 Ok::<_, InferenceError>(())
226 })
227 .await
228 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
229 } else {
230 info!("All model files found in local cache");
231 }
232
233 let final_onnx = onnx_cache_dir.join(onnx_filename);
235
236 info!(
237 "Model files ready: tokenizer={:?}, onnx={:?}",
238 local_tokenizer, final_onnx
239 );
240 Ok((local_tokenizer, final_onnx))
241 }
242
243 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
245 let base = std::env::var("HF_HOME")
246 .map(PathBuf::from)
247 .unwrap_or_else(|_| {
248 let home = std::env::var("HOME").unwrap_or_else(|_| {
249 warn!("HOME environment variable not set, using /tmp for model cache");
250 "/tmp".to_string()
251 });
252 PathBuf::from(home).join(".cache").join("huggingface")
253 });
254 let dir = base.join("dakera").join(model_id.replace('/', "--"));
255 std::fs::create_dir_all(&dir)?;
256 Ok(dir)
257 }
258
259 pub fn download_hf_file_pub(
265 model_id: &str,
266 filename: &str,
267 cache_dir: &Path,
268 ) -> std::result::Result<PathBuf, String> {
269 Self::download_hf_file(model_id, filename, cache_dir)
270 }
271
272 fn download_hf_file(
273 model_id: &str,
274 filename: &str,
275 cache_dir: &Path,
276 ) -> std::result::Result<PathBuf, String> {
277 let file_path = cache_dir.join(filename);
279 if file_path.exists() {
280 info!("Cached: {}/{}", model_id, filename);
281 return Ok(file_path);
282 }
283
284 if let Some(parent) = file_path.parent() {
286 std::fs::create_dir_all(parent)
287 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
288 }
289
290 let url = format!(
291 "https://huggingface.co/{}/resolve/main/{}",
292 model_id, filename
293 );
294 info!("Downloading: {}", url);
295
296 let hf_token = std::env::var("HF_TOKEN")
298 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
299 .ok();
300 if hf_token.is_some() {
301 info!("Using HuggingFace auth token for download");
302 }
303
304 let agent = ureq::AgentBuilder::new()
306 .redirects(0)
307 .timeout(std::time::Duration::from_secs(300))
308 .build();
309
310 let mut current_url = url.clone();
311 let mut redirects = 0;
312 let max_redirects = 10;
313
314 let response = loop {
315 let mut req = agent.get(¤t_url);
316 if let Some(ref token) = hf_token {
317 req = req.set("Authorization", &format!("Bearer {}", token));
318 }
319 let resp = req.call();
320
321 let r = match resp {
322 Ok(r) => r,
323 Err(ureq::Error::Status(_status, r)) => r,
324 Err(e) => return Err(format!("{}: {}", filename, e)),
325 };
326
327 let status = r.status();
328 if (200..300).contains(&status) {
329 break r;
330 } else if (300..400).contains(&status) {
331 redirects += 1;
332 if redirects > max_redirects {
333 return Err(format!("{}: too many redirects", filename));
334 }
335 let location = r
336 .header("location")
337 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
338 .to_string();
339
340 current_url = if location.starts_with('/') {
342 let parsed = url::Url::parse(¤t_url)
343 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
344 let host = parsed.host_str().ok_or_else(|| {
345 format!("{}: redirect URL missing host: {}", filename, current_url)
346 })?;
347 format!("{}://{}{}", parsed.scheme(), host, location)
348 } else {
349 location
350 };
351 info!("Redirect {} → {}", redirects, current_url);
352 } else {
353 return Err(format!("{}: HTTP {}", filename, status));
354 }
355 };
356
357 let mut bytes = Vec::new();
358 response
359 .into_reader()
360 .take(500_000_000) .read_to_end(&mut bytes)
362 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
363
364 std::fs::write(&file_path, &bytes)
365 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
366
367 info!("Downloaded {} ({} bytes)", filename, bytes.len());
368 Ok(file_path)
369 }
370
371 pub fn dimension(&self) -> usize {
373 self.dimension
374 }
375
376 pub fn model(&self) -> EmbeddingModel {
378 self.config.model
379 }
380
381 pub fn pool_size(&self) -> usize {
383 self.sessions.len()
384 }
385
386 #[instrument(skip(self, text), fields(text_len = text.len()))]
390 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
391 let texts = vec![text.to_string()];
392 let prepared = self.processor.prepare_texts(&texts, true);
393 let embeddings = self.embed_batch_internal(&prepared).await?;
394 embeddings.into_iter().next().ok_or_else(|| {
395 InferenceError::InferenceError("No embedding returned for query".to_string())
396 })
397 }
398
399 #[instrument(skip(self, texts), fields(count = texts.len()))]
403 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
404 let prepared = self.processor.prepare_texts(texts, true);
405 self.embed_batch_internal(&prepared).await
406 }
407
408 #[instrument(skip(self, text), fields(text_len = text.len()))]
412 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
413 let texts = vec![text.to_string()];
414 let prepared = self.processor.prepare_texts(&texts, false);
415 let embeddings = self.embed_batch_internal(&prepared).await?;
416 embeddings.into_iter().next().ok_or_else(|| {
417 InferenceError::InferenceError("No embedding returned for document".to_string())
418 })
419 }
420
421 #[instrument(skip(self, texts), fields(count = texts.len()))]
425 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
426 let prepared = self.processor.prepare_texts(texts, false);
427 self.embed_batch_internal(&prepared).await
428 }
429
430 #[instrument(skip(self, texts), fields(count = texts.len()))]
432 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
433 self.embed_batch_internal(texts).await
434 }
435
436 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
444 if texts.is_empty() {
445 return Ok(vec![]);
446 }
447
448 let batches: Vec<Vec<String>> = self
449 .processor
450 .split_into_batches(texts)
451 .into_iter()
452 .map(|b| b.to_vec())
453 .collect();
454
455 let pool_len = self.sessions.len();
456 let normalize = self.config.model.normalize_embeddings();
457 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
460
461 let mut handles = Vec::with_capacity(batches.len());
463 for (i, batch_owned) in batches.into_iter().enumerate() {
464 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
465 let processor = Arc::clone(&self.processor);
466 handles.push(tokio::task::spawn_blocking(move || {
467 let mut session_guard = session.lock();
468 Self::process_batch_blocking(
469 &batch_owned,
470 &mut session_guard,
471 &processor,
472 normalize,
473 )
474 }));
475 }
476
477 let mut all_embeddings = Vec::with_capacity(texts.len());
478 for handle in handles {
479 let batch_embeddings = handle.await.map_err(|e| {
480 InferenceError::InferenceError(format!("Inference task panicked: {}", e))
481 })??;
482 all_embeddings.extend(batch_embeddings);
483 }
484
485 Ok(all_embeddings)
486 }
487
488 fn process_batch_blocking(
492 texts: &[String],
493 session: &mut Session,
494 processor: &BatchProcessor,
495 normalize: bool,
496 ) -> Result<Vec<Vec<f32>>> {
497 let prepared = processor.tokenize_batch(texts)?;
499 let batch_size = prepared.batch_size;
500 let seq_len = prepared.seq_len;
501
502 let attention_mask_flat = prepared.attention_mask.clone();
504
505 let input_ids_tensor =
507 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
508 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
509 let attention_mask_tensor =
510 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
511 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
512 let token_type_ids_tensor =
513 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
514 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
515
516 let outputs = session
518 .run(inputs![
519 "input_ids" => input_ids_tensor,
520 "attention_mask" => attention_mask_tensor,
521 "token_type_ids" => token_type_ids_tensor
522 ])
523 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
524
525 let (ort_shape, lhs_slice) = outputs[0]
529 .try_extract_tensor::<f32>()
530 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
531
532 if ort_shape.len() != 3 {
533 return Err(InferenceError::InferenceError(format!(
534 "Expected 3D last_hidden_state, got {} dims",
535 ort_shape.len()
536 )));
537 }
538 let hidden_size = ort_shape[2] as usize;
539
540 let mut embeddings = mean_pooling(
542 lhs_slice,
543 batch_size,
544 seq_len,
545 hidden_size,
546 &attention_mask_flat,
547 );
548
549 if normalize {
551 normalize_embeddings(&mut embeddings);
552 }
553
554 debug!(
555 "Generated {} embeddings of dimension {}",
556 embeddings.len(),
557 embeddings.first().map(|e| e.len()).unwrap_or(0)
558 );
559
560 Ok(embeddings)
561 }
562
563 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
565 let tokens_per_text =
567 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
568 let total_tokens = tokens_per_text * text_count as f64;
569 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
570 (total_tokens / tokens_per_second) * 1000.0
571 }
572}
573
574impl std::fmt::Debug for EmbeddingEngine {
575 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
576 f.debug_struct("EmbeddingEngine")
577 .field("model", &self.config.model)
578 .field("dimension", &self.dimension)
579 .field("max_batch_size", &self.config.max_batch_size)
580 .field("session_pool_size", &self.sessions.len())
581 .finish()
582 }
583}
584
585pub struct EmbeddingEngineBuilder {
587 config: ModelConfig,
588}
589
590impl EmbeddingEngineBuilder {
591 pub fn new() -> Self {
593 Self {
594 config: ModelConfig::default(),
595 }
596 }
597
598 pub fn model(mut self, model: EmbeddingModel) -> Self {
600 self.config.model = model;
601 self
602 }
603
604 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
606 self.config.cache_dir = Some(dir.into());
607 self
608 }
609
610 pub fn max_batch_size(mut self, size: usize) -> Self {
612 self.config.max_batch_size = size;
613 self
614 }
615
616 pub fn use_gpu(mut self, enable: bool) -> Self {
618 self.config.use_gpu = enable;
619 self
620 }
621
622 pub fn num_threads(mut self, threads: usize) -> Self {
624 self.config.num_threads = Some(threads);
625 self
626 }
627
628 pub fn session_pool_size(mut self, size: usize) -> Self {
630 self.config.session_pool_size = size.max(1);
631 self
632 }
633
634 pub async fn build(self) -> Result<EmbeddingEngine> {
636 EmbeddingEngine::new(self.config).await
637 }
638}
639
640impl Default for EmbeddingEngineBuilder {
641 fn default() -> Self {
642 Self::new()
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn test_estimate_time() {
652 let config = ModelConfig::new(EmbeddingModel::MiniLM);
653 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
654 assert!(tokens_per_second > 0.0);
655 }
656
657 #[test]
658 fn test_builder() {
659 let builder = EmbeddingEngineBuilder::new()
660 .model(EmbeddingModel::BgeSmall)
661 .max_batch_size(64)
662 .use_gpu(false);
663
664 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
665 assert_eq!(builder.config.max_batch_size, 64);
666 assert!(!builder.config.use_gpu);
667 }
668
669 #[test]
675 fn test_model_cache_dir_with_hf_home() {
676 use std::sync::Mutex;
677 static ENV_LOCK: Mutex<()> = Mutex::new(());
678 let _guard = ENV_LOCK.lock().unwrap();
679
680 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
681 std::env::set_var("HF_HOME", &tmp);
682 let result = EmbeddingEngine::model_cache_dir("org/my-model");
683 std::env::remove_var("HF_HOME");
684
685 let path = result.unwrap();
686 assert!(
687 path.starts_with(&tmp),
688 "expected path under {tmp:?}, got {path:?}"
689 );
690 assert!(
691 path.to_str().unwrap().contains("org--my-model"),
692 "model_id separator not applied: {path:?}"
693 );
694 }
695
696 #[test]
697 fn test_model_cache_dir_contains_dakera_subdir() {
698 let path =
699 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
700 let s = path.to_str().unwrap();
701 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
702 assert!(
703 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
704 "expected transformed model id in path: {s}"
705 );
706 }
707
708 #[test]
709 fn test_model_cache_dir_creates_directory() {
710 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
711 assert!(dir.exists(), "model_cache_dir should create the directory");
712 }
713
714 #[test]
717 fn test_download_hf_file_returns_path_when_already_cached() {
718 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
719 std::fs::create_dir_all(&tmp).unwrap();
720 let file_path = tmp.join("config.json");
721 std::fs::write(&file_path, b"{}").unwrap();
722
723 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
724 assert!(result.is_ok());
725 assert_eq!(result.unwrap(), file_path);
726 }
727
728 #[test]
729 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
730 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
731 let onnx_dir = tmp.join("onnx");
732 std::fs::create_dir_all(&onnx_dir).unwrap();
733 let onnx_path = onnx_dir.join("model_quantized.onnx");
734 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
735
736 let result = EmbeddingEngine::download_hf_file(
738 "Xenova/all-MiniLM-L6-v2",
739 "onnx/model_quantized.onnx",
740 &tmp,
741 );
742 assert!(result.is_ok());
743 assert_eq!(result.unwrap(), onnx_path);
744 }
745
746 #[test]
749 fn test_builder_default_impl() {
750 let b1 = EmbeddingEngineBuilder::new();
751 let b2 = EmbeddingEngineBuilder::default();
752 assert_eq!(b1.config.model, b2.config.model);
753 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
754 }
755
756 #[test]
757 fn test_builder_model_field() {
758 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
759 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
760 }
761
762 #[test]
763 fn test_builder_cache_dir() {
764 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
765 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
766 }
767
768 #[test]
769 fn test_builder_max_batch_size() {
770 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
771 assert_eq!(builder.config.max_batch_size, 128);
772 }
773
774 #[test]
775 fn test_builder_use_gpu_true() {
776 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
777 assert!(builder.config.use_gpu);
778 }
779
780 #[test]
781 fn test_builder_use_gpu_false() {
782 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
783 assert!(!builder.config.use_gpu);
784 }
785
786 #[test]
787 fn test_builder_num_threads() {
788 let builder = EmbeddingEngineBuilder::new().num_threads(4);
789 assert_eq!(builder.config.num_threads, Some(4));
790 }
791
792 #[test]
793 fn test_builder_chain_all_fields() {
794 let builder = EmbeddingEngineBuilder::new()
795 .model(EmbeddingModel::BgeSmall)
796 .cache_dir("/cache")
797 .max_batch_size(16)
798 .use_gpu(false)
799 .num_threads(2);
800
801 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
802 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
803 assert_eq!(builder.config.max_batch_size, 16);
804 assert!(!builder.config.use_gpu);
805 assert_eq!(builder.config.num_threads, Some(2));
806 }
807
808 #[test]
811 fn test_estimate_time_zero_count() {
812 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
813 let estimate = (0.0 / tps) * 1000.0;
814 assert_eq!(estimate, 0.0);
815 }
816
817 #[test]
818 fn test_estimate_time_formula_cpu() {
819 let model = EmbeddingModel::MiniLM;
822 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
823 let total_tokens = tokens_per_text * 10.0;
824 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
825 assert!(
826 (estimate - 50.0).abs() < 1e-6,
827 "expected 50.0ms, got {estimate}"
828 );
829 }
830
831 #[test]
832 fn test_estimate_time_capped_at_max_seq_length() {
833 let model = EmbeddingModel::MiniLM;
834 let avg_len = 100_000;
835 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
836 assert_eq!(tokens_per_text, 256.0);
837 }
838
839 #[test]
842 fn test_model_config_new() {
843 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
844 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
845 assert_eq!(cfg.max_batch_size, 32);
846 assert!(!cfg.use_gpu);
847 assert!(cfg.cache_dir.is_none());
848 assert!(cfg.num_threads.is_none());
849 }
850
851 #[test]
852 fn test_model_config_default() {
853 let cfg = ModelConfig::default();
854 assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
855 assert_eq!(cfg.max_batch_size, 32);
856 assert!(!cfg.use_gpu);
857 }
858
859 #[test]
860 fn test_model_config_with_cache_dir() {
861 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
862 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
863 }
864
865 #[test]
866 fn test_model_config_with_max_batch_size() {
867 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
868 assert_eq!(cfg.max_batch_size, 64);
869 }
870
871 #[test]
872 fn test_model_config_with_gpu() {
873 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
874 assert!(cfg.use_gpu);
875 }
876
877 #[test]
878 fn test_model_config_with_num_threads() {
879 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
880 assert_eq!(cfg.num_threads, Some(8));
881 }
882
883 #[test]
884 fn test_model_config_chained_builder() {
885 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
886 .with_cache_dir("/cache")
887 .with_max_batch_size(16)
888 .with_gpu(false)
889 .with_num_threads(4);
890 assert_eq!(cfg.model, EmbeddingModel::E5Small);
891 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
892 assert_eq!(cfg.max_batch_size, 16);
893 assert!(!cfg.use_gpu);
894 assert_eq!(cfg.num_threads, Some(4));
895 }
896
897 #[test]
901 fn test_model_cache_dir_no_home_fallback() {
902 use std::sync::Mutex;
903 static ENV_LOCK: Mutex<()> = Mutex::new(());
904 let _guard = ENV_LOCK.lock().unwrap();
905
906 let saved_home = std::env::var("HOME").ok();
908 let saved_hf = std::env::var("HF_HOME").ok();
909 unsafe {
910 std::env::remove_var("HOME");
911 std::env::remove_var("HF_HOME");
912 }
913
914 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
915
916 if let Some(h) = saved_home {
918 unsafe { std::env::set_var("HOME", h) };
919 }
920 if let Some(h) = saved_hf {
921 unsafe { std::env::set_var("HF_HOME", h) };
922 }
923
924 let path = result.unwrap();
925 assert!(
927 path.starts_with("/tmp"),
928 "expected path under /tmp, got {path:?}"
929 );
930 }
931
932 #[test]
933 fn test_model_cache_dir_deep_model_id() {
934 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
935 let s = path.to_str().unwrap();
936 assert!(
938 s.contains("org--sub--model-name-with-dashes"),
939 "expected transformed path, got: {s}"
940 );
941 }
942
943 #[test]
944 fn test_model_cache_dir_minilm_model_id() {
945 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
946 let s = path.to_str().unwrap();
947 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
948 }
949
950 #[test]
951 fn test_model_cache_dir_bge_model_id() {
952 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
953 let s = path.to_str().unwrap();
954 assert!(s.contains("BAAI--bge-small-en-v1.5"));
955 }
956
957 #[test]
958 fn test_model_cache_dir_e5_model_id() {
959 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
960 let s = path.to_str().unwrap();
961 assert!(s.contains("intfloat--e5-small-v2"));
962 }
963
964 #[test]
967 fn test_download_hf_file_pytorch_bin_cached() {
968 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
969 std::fs::create_dir_all(&tmp).unwrap();
970 let model_path = tmp.join("pytorch_model.bin");
971 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
972
973 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
974 assert!(result.is_ok());
975 assert_eq!(result.unwrap(), model_path);
976 }
977
978 #[test]
979 fn test_download_hf_file_tokenizer_cached() {
980 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
981 std::fs::create_dir_all(&tmp).unwrap();
982 let tok_path = tmp.join("tokenizer.json");
983 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
984
985 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
986 assert!(result.is_ok());
987 assert_eq!(result.unwrap(), tok_path);
988 }
989
990 #[test]
991 fn test_download_hf_file_config_json_cached() {
992 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
993 std::fs::create_dir_all(&tmp).unwrap();
994 let cfg_path = tmp.join("config.json");
995 std::fs::write(&cfg_path, b"{}").unwrap();
996
997 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
998 assert!(result.is_ok());
999 assert_eq!(result.unwrap(), cfg_path);
1000 }
1001
1002 #[tokio::test]
1008 #[allow(clippy::await_holding_lock)]
1009 async fn test_new_fails_with_invalid_tokenizer_json() {
1010 use std::sync::Mutex;
1011 static ENV_LOCK: Mutex<()> = Mutex::new(());
1012 let _guard = ENV_LOCK.lock().unwrap();
1013
1014 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1016 let model_dir = tmp
1017 .join("dakera")
1018 .join("sentence-transformers--all-MiniLM-L6-v2");
1019 std::fs::create_dir_all(&model_dir).unwrap();
1020 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1022 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1024 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1025
1026 unsafe { std::env::set_var("HF_HOME", &tmp) };
1027
1028 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1029 let result = EmbeddingEngine::new(config).await;
1030
1031 unsafe { std::env::remove_var("HF_HOME") };
1032
1033 assert!(
1035 result.is_err(),
1036 "expected Err from new() with invalid tokenizer, got Ok"
1037 );
1038 }
1039
1040 #[test]
1043 fn test_builder_with_all_models() {
1044 for model in [
1045 EmbeddingModel::MiniLM,
1046 EmbeddingModel::BgeSmall,
1047 EmbeddingModel::E5Small,
1048 ] {
1049 let builder = EmbeddingEngineBuilder::new().model(model);
1050 assert_eq!(builder.config.model, model);
1051 }
1052 }
1053
1054 #[test]
1055 fn test_builder_max_batch_size_one() {
1056 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1057 assert_eq!(builder.config.max_batch_size, 1);
1058 }
1059
1060 #[test]
1061 fn test_builder_num_threads_zero() {
1062 let builder = EmbeddingEngineBuilder::new().num_threads(0);
1063 assert_eq!(builder.config.num_threads, Some(0));
1064 }
1065
1066 #[tokio::test]
1072 async fn test_engine_getters_when_model_cached() {
1073 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1074 match EmbeddingEngine::new(config).await {
1075 Ok(engine) => {
1076 assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1077 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1078 let _ = format!("{:?}", engine);
1081 let ms = engine.estimate_time_ms(10, 50);
1083 assert!(ms >= 0.0);
1084 }
1085 Err(_) => {
1086 }
1088 }
1089 }
1090
1091 #[tokio::test]
1094 async fn test_engine_embed_empty_batch_when_cached() {
1095 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1096 if let Ok(engine) = EmbeddingEngine::new(config).await {
1097 let result = engine.embed_raw(&[]).await;
1098 assert!(result.is_ok());
1099 assert!(result.unwrap().is_empty());
1100 }
1101 }
1102
1103 #[test]
1106 fn test_session_pool_default_is_4() {
1107 let config = ModelConfig::default();
1110 let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1111 .ok()
1112 .and_then(|v| v.parse::<usize>().ok())
1113 .filter(|&n| n >= 1)
1114 .unwrap_or(4);
1115 assert_eq!(config.session_pool_size, expected);
1116 }
1117
1118 #[test]
1119 fn test_session_pool_size_builder_roundtrip() {
1120 let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1121 assert_eq!(builder.config.session_pool_size, 8);
1122 }
1123
1124 #[test]
1125 fn test_session_pool_size_min_enforced() {
1126 let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1127 assert_eq!(
1128 builder.config.session_pool_size, 1,
1129 "pool size 0 must clamp to 1"
1130 );
1131 }
1132
1133 #[test]
1134 fn test_model_config_with_session_pool_size() {
1135 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1136 assert_eq!(cfg.session_pool_size, 2);
1137 }
1138
1139 #[tokio::test]
1141 async fn test_engine_pool_size_matches_config_when_cached() {
1142 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1143 if let Ok(engine) = EmbeddingEngine::new(config).await {
1144 assert_eq!(
1145 engine.pool_size(),
1146 2,
1147 "engine should hold exactly 2 sessions"
1148 );
1149 }
1150 }
1151
1152 #[test]
1156 fn test_round_robin_index_stays_in_bounds() {
1157 let pool_len = 4_usize;
1158 let counter = AtomicUsize::new(0);
1159 for expected_idx in 0..100_usize {
1160 let start = counter.fetch_add(1, Ordering::Relaxed);
1161 let slot = start % pool_len;
1162 assert!(slot < pool_len);
1163 assert_eq!(slot, expected_idx % pool_len);
1164 }
1165 }
1166
1167 #[test]
1169 fn test_round_robin_pool_size_one() {
1170 let pool_len = 1_usize;
1171 let counter = AtomicUsize::new(0);
1172 for _ in 0..50 {
1173 let start = counter.fetch_add(1, Ordering::Relaxed);
1174 assert_eq!(start % pool_len, 0);
1175 }
1176 }
1177
1178 #[tokio::test]
1180 async fn test_embed_empty_does_not_advance_pool_counter() {
1181 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1182 if let Ok(engine) = EmbeddingEngine::new(config).await {
1183 let result = engine.embed_raw(&[]).await;
1184 assert!(result.is_ok());
1185 assert!(result.unwrap().is_empty());
1186 assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1188 }
1189 }
1190}