1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49pub struct EmbeddingEngine {
56 sessions: Vec<Arc<Mutex<Session>>>,
60 next_session: AtomicUsize,
62 processor: Arc<BatchProcessor>,
64 config: ModelConfig,
66 dimension: usize,
68 use_gpu: bool,
72}
73
74impl EmbeddingEngine {
75 #[instrument(skip_all, fields(model = %config.model))]
80 pub async fn new(config: ModelConfig) -> Result<Self> {
81 let use_gpu = std::env::var("DAKERA_USE_GPU")
84 .map(|v| v == "1")
85 .unwrap_or(config.use_gpu);
86 if use_gpu {
87 info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
88 }
89
90 info!(
91 "Initializing ONNX embedding engine with model: {}",
92 config.model
93 );
94
95 let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
97
98 info!("Loading tokenizer from {:?}", tokenizer_path);
100 let tokenizer = Tokenizer::from_file(&tokenizer_path)
101 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
102
103 info!("Loading ONNX model from {:?}", onnx_path);
106 let num_threads = config.num_threads.unwrap_or(4);
107 let pool_size = config.session_pool_size.max(1);
108 let onnx_path_clone = onnx_path.clone();
109
110 let sessions: Vec<Arc<Mutex<Session>>> =
111 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
112 (0..pool_size)
113 .map(|_| {
114 let builder = Session::builder()
115 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
116 .with_optimization_level(GraphOptimizationLevel::Level3)
117 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
118 .with_intra_threads(num_threads)
119 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
120
121 let mut builder = if use_gpu {
122 builder
123 .with_execution_providers(
124 [CUDAExecutionProvider::default().build()],
125 )
126 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
127 } else {
128 builder
129 };
130
131 let s = builder
132 .commit_from_file(&onnx_path_clone)
133 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
134 Ok(Arc::new(Mutex::new(s)))
135 })
136 .collect()
137 })
138 .await
139 .map_err(|e| {
140 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
141 })??;
142
143 let dimension = config.model.dimension();
144 let processor = Arc::new(BatchProcessor::new(
145 tokenizer,
146 config.model,
147 config.max_batch_size,
148 ));
149
150 info!(
151 "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
152 config.model, dimension, num_threads, pool_size
153 );
154
155 Ok(Self {
156 sessions,
157 next_session: AtomicUsize::new(0),
158 processor,
159 config,
160 dimension,
161 use_gpu,
162 })
163 }
164
165 #[instrument(skip_all, fields(model = %config.model))]
171 async fn download_model_files(
172 config: &ModelConfig,
173 use_gpu: bool,
174 ) -> Result<(PathBuf, PathBuf)> {
175 let model_id = config.model.model_id();
176 let onnx_repo_id = config.model.onnx_repo_id();
177 let onnx_filename = if use_gpu {
178 config.model.onnx_filename_gpu()
179 } else {
180 config.model.onnx_filename()
181 };
182
183 info!(
184 "Resolving model files: tokenizer={}, onnx={}@{}",
185 model_id, onnx_filename, onnx_repo_id
186 );
187
188 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
189 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
190
191 let onnx_subdir = onnx_cache_dir.join("onnx");
193 std::fs::create_dir_all(&onnx_subdir)?;
194
195 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
196 let onnx_basename = Path::new(onnx_filename)
198 .file_name()
199 .and_then(|s| s.to_str())
200 .unwrap_or("model_quantized.onnx");
201 let local_onnx = onnx_subdir.join(onnx_basename);
202
203 let tokenizer_needs_download = !local_tokenizer.exists();
205
206 if use_gpu && local_onnx.exists() {
211 let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
212 if cached_size <= 500_000_000 {
213 warn!(
214 "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated by old \
215 download limit. Deleting for complete re-download.",
216 local_onnx, cached_size
217 );
218 let _ = std::fs::remove_file(&local_onnx);
219 }
220 }
221 let onnx_needs_download = !local_onnx.exists();
222
223 if tokenizer_needs_download || onnx_needs_download {
224 let model_id_owned = model_id.to_string();
225 let onnx_repo_id_owned = onnx_repo_id.to_string();
226 let onnx_filename_owned = onnx_filename.to_string();
227 let tokenizer_cache = tokenizer_cache_dir.clone();
228 let onnx_cache = onnx_cache_dir.clone();
229
230 tokio::task::spawn_blocking(move || {
231 if !tokenizer_cache.join("tokenizer.json").exists() {
232 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
233 .map_err(|e| {
234 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
235 })?;
236 }
237 if !onnx_cache.join(&onnx_filename_owned).exists() {
238 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
239 .map_err(|e| {
240 InferenceError::HubError(format!(
241 "Failed to download ONNX model: {}",
242 e
243 ))
244 })?;
245 }
246 Ok::<_, InferenceError>(())
247 })
248 .await
249 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
250 } else {
251 info!("All model files found in local cache");
252 }
253
254 let final_onnx = onnx_cache_dir.join(onnx_filename);
256
257 info!(
258 "Model files ready: tokenizer={:?}, onnx={:?}",
259 local_tokenizer, final_onnx
260 );
261 Ok((local_tokenizer, final_onnx))
262 }
263
264 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
266 let base = std::env::var("HF_HOME")
267 .map(PathBuf::from)
268 .unwrap_or_else(|_| {
269 let home = std::env::var("HOME").unwrap_or_else(|_| {
270 warn!("HOME environment variable not set, using /tmp for model cache");
271 "/tmp".to_string()
272 });
273 PathBuf::from(home).join(".cache").join("huggingface")
274 });
275 let dir = base.join("dakera").join(model_id.replace('/', "--"));
276 std::fs::create_dir_all(&dir)?;
277 Ok(dir)
278 }
279
280 pub fn download_hf_file_pub(
286 model_id: &str,
287 filename: &str,
288 cache_dir: &Path,
289 ) -> std::result::Result<PathBuf, String> {
290 Self::download_hf_file(model_id, filename, cache_dir)
291 }
292
293 fn download_hf_file(
294 model_id: &str,
295 filename: &str,
296 cache_dir: &Path,
297 ) -> std::result::Result<PathBuf, String> {
298 let file_path = cache_dir.join(filename);
300 if file_path.exists() {
301 info!("Cached: {}/{}", model_id, filename);
302 return Ok(file_path);
303 }
304
305 if let Some(parent) = file_path.parent() {
307 std::fs::create_dir_all(parent)
308 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
309 }
310
311 let url = format!(
312 "https://huggingface.co/{}/resolve/main/{}",
313 model_id, filename
314 );
315 info!("Downloading: {}", url);
316
317 let hf_token = std::env::var("HF_TOKEN")
319 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
320 .ok();
321 if hf_token.is_some() {
322 info!("Using HuggingFace auth token for download");
323 }
324
325 let agent = ureq::AgentBuilder::new()
327 .redirects(0)
328 .timeout(std::time::Duration::from_secs(300))
329 .build();
330
331 let mut current_url = url.clone();
332 let mut redirects = 0;
333 let max_redirects = 10;
334
335 let response = loop {
336 let mut req = agent.get(¤t_url);
337 if let Some(ref token) = hf_token {
338 req = req.set("Authorization", &format!("Bearer {}", token));
339 }
340 let resp = req.call();
341
342 let r = match resp {
343 Ok(r) => r,
344 Err(ureq::Error::Status(_status, r)) => r,
345 Err(e) => return Err(format!("{}: {}", filename, e)),
346 };
347
348 let status = r.status();
349 if (200..300).contains(&status) {
350 break r;
351 } else if (300..400).contains(&status) {
352 redirects += 1;
353 if redirects > max_redirects {
354 return Err(format!("{}: too many redirects", filename));
355 }
356 let location = r
357 .header("location")
358 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
359 .to_string();
360
361 current_url = if location.starts_with('/') {
363 let parsed = url::Url::parse(¤t_url)
364 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
365 let host = parsed.host_str().ok_or_else(|| {
366 format!("{}: redirect URL missing host: {}", filename, current_url)
367 })?;
368 format!("{}://{}{}", parsed.scheme(), host, location)
369 } else {
370 location
371 };
372 info!("Redirect {} → {}", redirects, current_url);
373 } else {
374 return Err(format!("{}: HTTP {}", filename, status));
375 }
376 };
377
378 let expected_bytes: Option<u64> = response
382 .header("x-linked-size")
383 .or_else(|| response.header("content-length"))
384 .and_then(|v| v.parse::<u64>().ok());
385
386 let mut bytes = Vec::new();
390 response
391 .into_reader()
392 .take(2_147_483_648)
393 .read_to_end(&mut bytes)
394 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
395
396 if let Some(expected) = expected_bytes {
400 let actual = bytes.len() as u64;
401 if actual < expected {
402 return Err(format!(
403 "{}: download incomplete — received {} of {} bytes. \
404 File may exceed 2 GiB or the connection was interrupted.",
405 filename, actual, expected
406 ));
407 }
408 }
409
410 std::fs::write(&file_path, &bytes)
411 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
412
413 info!("Downloaded {} ({} bytes)", filename, bytes.len());
414 Ok(file_path)
415 }
416
417 pub fn dimension(&self) -> usize {
419 self.dimension
420 }
421
422 pub fn model(&self) -> EmbeddingModel {
424 self.config.model
425 }
426
427 pub fn pool_size(&self) -> usize {
429 self.sessions.len()
430 }
431
432 #[instrument(skip(self, text), fields(text_len = text.len()))]
436 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
437 let texts = vec![text.to_string()];
438 let prepared = self.processor.prepare_texts(&texts, true);
439 let embeddings = self.embed_batch_internal(&prepared).await?;
440 embeddings.into_iter().next().ok_or_else(|| {
441 InferenceError::InferenceError("No embedding returned for query".to_string())
442 })
443 }
444
445 #[instrument(skip(self, texts), fields(count = texts.len()))]
449 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
450 let prepared = self.processor.prepare_texts(texts, true);
451 self.embed_batch_internal(&prepared).await
452 }
453
454 #[instrument(skip(self, text), fields(text_len = text.len()))]
458 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
459 let texts = vec![text.to_string()];
460 let prepared = self.processor.prepare_texts(&texts, false);
461 let embeddings = self.embed_batch_internal(&prepared).await?;
462 embeddings.into_iter().next().ok_or_else(|| {
463 InferenceError::InferenceError("No embedding returned for document".to_string())
464 })
465 }
466
467 #[instrument(skip(self, texts), fields(count = texts.len()))]
471 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
472 let prepared = self.processor.prepare_texts(texts, false);
473 self.embed_batch_internal(&prepared).await
474 }
475
476 #[instrument(skip(self, texts), fields(count = texts.len()))]
478 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
479 self.embed_batch_internal(texts).await
480 }
481
482 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
495 if texts.is_empty() {
496 return Ok(vec![]);
497 }
498
499 let pool_len = self.sessions.len();
500 let normalize = self.config.model.normalize_embeddings();
501 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
504
505 let mut batch_size = self.config.max_batch_size.max(1);
506 let use_gpu = self.use_gpu;
507
508 for attempt in 0_u32..=3 {
510 let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
511
512 let mut handles = Vec::with_capacity(batches.len());
514 for (i, batch_owned) in batches.into_iter().enumerate() {
515 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
516 let processor = Arc::clone(&self.processor);
517
518 let gpu_permit = if use_gpu {
523 Some(
524 std::sync::Arc::clone(&crate::GPU_INFERENCE_SEMAPHORE)
525 .acquire_owned()
526 .await
527 .map_err(|_| {
528 InferenceError::InferenceError(
529 "GPU inference semaphore unexpectedly closed".to_string(),
530 )
531 })?,
532 )
533 } else {
534 None
535 };
536
537 handles.push(tokio::task::spawn_blocking(move || {
538 let _gpu_permit = gpu_permit; let mut session_guard = session.lock();
540 Self::process_batch_blocking(
541 &batch_owned,
542 &mut session_guard,
543 &processor,
544 normalize,
545 )
546 }));
547 }
548
549 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
550 let mut oom: Option<InferenceError> = None;
551
552 for handle in handles {
553 match handle.await {
554 Err(panic_err) => {
555 return Err(InferenceError::InferenceError(format!(
556 "Inference task panicked: {panic_err}"
557 )));
558 }
559 Ok(Err(e)) => {
560 if attempt < 3 && Self::is_gpu_oom(&e) {
561 oom = Some(e);
563 break;
564 }
565 return Err(e);
566 }
567 Ok(Ok(batch_embs)) => {
568 all_embeddings.extend(batch_embs);
569 }
570 }
571 }
572
573 if oom.is_some() {
574 let next_batch = (batch_size / 2).max(1);
575 warn!(
576 "ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
577 attempt + 1,
578 batch_size,
579 next_batch,
580 );
581 batch_size = next_batch;
582 continue;
583 }
584
585 return Ok(all_embeddings);
586 }
587
588 Err(InferenceError::InferenceError(format!(
589 "ONNX inference failed: GPU/CPU allocator OOM persists after 3 \
590 batch-halving attempts (final batch_size={batch_size})"
591 )))
592 }
593
594 fn is_gpu_oom(err: &InferenceError) -> bool {
600 let msg = err.to_string();
601 msg.contains("BFCArena")
602 || msg.contains("Failed to allocate memory")
603 || msg.contains("CUDA_OUT_OF_MEMORY")
604 || msg.contains("CUDA out of memory")
605 || (msg.contains("allocate") && msg.contains("buffer of size"))
606 }
607
608 fn process_batch_blocking(
612 texts: &[String],
613 session: &mut Session,
614 processor: &BatchProcessor,
615 normalize: bool,
616 ) -> Result<Vec<Vec<f32>>> {
617 let prepared = processor.tokenize_batch(texts)?;
619 let batch_size = prepared.batch_size;
620 let seq_len = prepared.seq_len;
621
622 let attention_mask_flat = prepared.attention_mask.clone();
624
625 let input_ids_tensor =
627 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
628 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
629 let attention_mask_tensor =
630 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
631 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
632 let token_type_ids_tensor =
633 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
634 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
635
636 let outputs = session
638 .run(inputs![
639 "input_ids" => input_ids_tensor,
640 "attention_mask" => attention_mask_tensor,
641 "token_type_ids" => token_type_ids_tensor
642 ])
643 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
644
645 let (ort_shape, lhs_slice) = outputs[0]
649 .try_extract_tensor::<f32>()
650 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
651
652 if ort_shape.len() != 3 {
653 return Err(InferenceError::InferenceError(format!(
654 "Expected 3D last_hidden_state, got {} dims",
655 ort_shape.len()
656 )));
657 }
658 let hidden_size = ort_shape[2] as usize;
659
660 let mut embeddings = mean_pooling(
662 lhs_slice,
663 batch_size,
664 seq_len,
665 hidden_size,
666 &attention_mask_flat,
667 );
668
669 if normalize {
671 normalize_embeddings(&mut embeddings);
672 }
673
674 debug!(
675 "Generated {} embeddings of dimension {}",
676 embeddings.len(),
677 embeddings.first().map(|e| e.len()).unwrap_or(0)
678 );
679
680 Ok(embeddings)
681 }
682
683 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
685 let tokens_per_text =
687 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
688 let total_tokens = tokens_per_text * text_count as f64;
689 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
690 (total_tokens / tokens_per_second) * 1000.0
691 }
692}
693
694impl std::fmt::Debug for EmbeddingEngine {
695 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
696 f.debug_struct("EmbeddingEngine")
697 .field("model", &self.config.model)
698 .field("dimension", &self.dimension)
699 .field("max_batch_size", &self.config.max_batch_size)
700 .field("session_pool_size", &self.sessions.len())
701 .field("use_gpu", &self.use_gpu)
702 .finish()
703 }
704}
705
706pub struct EmbeddingEngineBuilder {
708 config: ModelConfig,
709}
710
711impl EmbeddingEngineBuilder {
712 pub fn new() -> Self {
714 Self {
715 config: ModelConfig::default(),
716 }
717 }
718
719 pub fn model(mut self, model: EmbeddingModel) -> Self {
721 self.config.model = model;
722 self
723 }
724
725 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
727 self.config.cache_dir = Some(dir.into());
728 self
729 }
730
731 pub fn max_batch_size(mut self, size: usize) -> Self {
733 self.config.max_batch_size = size;
734 self
735 }
736
737 pub fn use_gpu(mut self, enable: bool) -> Self {
739 self.config.use_gpu = enable;
740 self
741 }
742
743 pub fn num_threads(mut self, threads: usize) -> Self {
745 self.config.num_threads = Some(threads);
746 self
747 }
748
749 pub fn session_pool_size(mut self, size: usize) -> Self {
751 self.config.session_pool_size = size.max(1);
752 self
753 }
754
755 pub async fn build(self) -> Result<EmbeddingEngine> {
757 EmbeddingEngine::new(self.config).await
758 }
759}
760
761impl Default for EmbeddingEngineBuilder {
762 fn default() -> Self {
763 Self::new()
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770
771 #[test]
772 fn test_estimate_time() {
773 let config = ModelConfig::new(EmbeddingModel::MiniLM);
774 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
775 assert!(tokens_per_second > 0.0);
776 }
777
778 #[test]
779 fn test_builder() {
780 let builder = EmbeddingEngineBuilder::new()
781 .model(EmbeddingModel::BgeSmall)
782 .max_batch_size(64)
783 .use_gpu(false);
784
785 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
786 assert_eq!(builder.config.max_batch_size, 64);
787 assert!(!builder.config.use_gpu);
788 }
789
790 #[test]
796 fn test_model_cache_dir_with_hf_home() {
797 use std::sync::Mutex;
798 static ENV_LOCK: Mutex<()> = Mutex::new(());
799 let _guard = ENV_LOCK.lock().unwrap();
800
801 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
802 std::env::set_var("HF_HOME", &tmp);
803 let result = EmbeddingEngine::model_cache_dir("org/my-model");
804 std::env::remove_var("HF_HOME");
805
806 let path = result.unwrap();
807 assert!(
808 path.starts_with(&tmp),
809 "expected path under {tmp:?}, got {path:?}"
810 );
811 assert!(
812 path.to_str().unwrap().contains("org--my-model"),
813 "model_id separator not applied: {path:?}"
814 );
815 }
816
817 #[test]
818 fn test_model_cache_dir_contains_dakera_subdir() {
819 let path =
820 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
821 let s = path.to_str().unwrap();
822 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
823 assert!(
824 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
825 "expected transformed model id in path: {s}"
826 );
827 }
828
829 #[test]
830 fn test_model_cache_dir_creates_directory() {
831 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
832 assert!(dir.exists(), "model_cache_dir should create the directory");
833 }
834
835 #[test]
838 fn test_download_hf_file_returns_path_when_already_cached() {
839 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
840 std::fs::create_dir_all(&tmp).unwrap();
841 let file_path = tmp.join("config.json");
842 std::fs::write(&file_path, b"{}").unwrap();
843
844 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
845 assert!(result.is_ok());
846 assert_eq!(result.unwrap(), file_path);
847 }
848
849 #[test]
850 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
851 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
852 let onnx_dir = tmp.join("onnx");
853 std::fs::create_dir_all(&onnx_dir).unwrap();
854 let onnx_path = onnx_dir.join("model_quantized.onnx");
855 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
856
857 let result = EmbeddingEngine::download_hf_file(
859 "Xenova/all-MiniLM-L6-v2",
860 "onnx/model_quantized.onnx",
861 &tmp,
862 );
863 assert!(result.is_ok());
864 assert_eq!(result.unwrap(), onnx_path);
865 }
866
867 #[test]
870 fn test_builder_default_impl() {
871 let b1 = EmbeddingEngineBuilder::new();
872 let b2 = EmbeddingEngineBuilder::default();
873 assert_eq!(b1.config.model, b2.config.model);
874 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
875 }
876
877 #[test]
878 fn test_builder_model_field() {
879 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
880 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
881 }
882
883 #[test]
884 fn test_builder_cache_dir() {
885 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
886 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
887 }
888
889 #[test]
890 fn test_builder_max_batch_size() {
891 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
892 assert_eq!(builder.config.max_batch_size, 128);
893 }
894
895 #[test]
896 fn test_builder_use_gpu_true() {
897 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
898 assert!(builder.config.use_gpu);
899 }
900
901 #[test]
902 fn test_builder_use_gpu_false() {
903 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
904 assert!(!builder.config.use_gpu);
905 }
906
907 #[test]
908 fn test_builder_num_threads() {
909 let builder = EmbeddingEngineBuilder::new().num_threads(4);
910 assert_eq!(builder.config.num_threads, Some(4));
911 }
912
913 #[test]
914 fn test_builder_chain_all_fields() {
915 let builder = EmbeddingEngineBuilder::new()
916 .model(EmbeddingModel::BgeSmall)
917 .cache_dir("/cache")
918 .max_batch_size(16)
919 .use_gpu(false)
920 .num_threads(2);
921
922 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
923 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
924 assert_eq!(builder.config.max_batch_size, 16);
925 assert!(!builder.config.use_gpu);
926 assert_eq!(builder.config.num_threads, Some(2));
927 }
928
929 #[test]
932 fn test_estimate_time_zero_count() {
933 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
934 let estimate = (0.0 / tps) * 1000.0;
935 assert_eq!(estimate, 0.0);
936 }
937
938 #[test]
939 fn test_estimate_time_formula_cpu() {
940 let model = EmbeddingModel::MiniLM;
943 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
944 let total_tokens = tokens_per_text * 10.0;
945 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
946 assert!(
947 (estimate - 50.0).abs() < 1e-6,
948 "expected 50.0ms, got {estimate}"
949 );
950 }
951
952 #[test]
953 fn test_estimate_time_capped_at_max_seq_length() {
954 let model = EmbeddingModel::MiniLM;
955 let avg_len = 100_000;
956 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
957 assert_eq!(tokens_per_text, 256.0);
958 }
959
960 #[test]
963 fn test_model_config_new() {
964 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
965 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
966 assert_eq!(cfg.max_batch_size, 32);
967 assert!(!cfg.use_gpu);
968 assert!(cfg.cache_dir.is_none());
969 assert!(cfg.num_threads.is_none());
970 }
971
972 #[test]
973 fn test_model_config_default() {
974 let cfg = ModelConfig::default();
975 assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
976 assert_eq!(cfg.max_batch_size, 32);
977 assert!(!cfg.use_gpu);
978 }
979
980 #[test]
981 fn test_model_config_with_cache_dir() {
982 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
983 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
984 }
985
986 #[test]
987 fn test_model_config_with_max_batch_size() {
988 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
989 assert_eq!(cfg.max_batch_size, 64);
990 }
991
992 #[test]
993 fn test_model_config_with_gpu() {
994 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
995 assert!(cfg.use_gpu);
996 }
997
998 #[test]
999 fn test_model_config_with_num_threads() {
1000 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
1001 assert_eq!(cfg.num_threads, Some(8));
1002 }
1003
1004 #[test]
1005 fn test_model_config_chained_builder() {
1006 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
1007 .with_cache_dir("/cache")
1008 .with_max_batch_size(16)
1009 .with_gpu(false)
1010 .with_num_threads(4);
1011 assert_eq!(cfg.model, EmbeddingModel::E5Small);
1012 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
1013 assert_eq!(cfg.max_batch_size, 16);
1014 assert!(!cfg.use_gpu);
1015 assert_eq!(cfg.num_threads, Some(4));
1016 }
1017
1018 #[test]
1022 fn test_model_cache_dir_no_home_fallback() {
1023 use std::sync::Mutex;
1024 static ENV_LOCK: Mutex<()> = Mutex::new(());
1025 let _guard = ENV_LOCK.lock().unwrap();
1026
1027 let saved_home = std::env::var("HOME").ok();
1029 let saved_hf = std::env::var("HF_HOME").ok();
1030 unsafe {
1031 std::env::remove_var("HOME");
1032 std::env::remove_var("HF_HOME");
1033 }
1034
1035 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
1036
1037 if let Some(h) = saved_home {
1039 unsafe { std::env::set_var("HOME", h) };
1040 }
1041 if let Some(h) = saved_hf {
1042 unsafe { std::env::set_var("HF_HOME", h) };
1043 }
1044
1045 let path = result.unwrap();
1046 assert!(
1048 path.starts_with("/tmp"),
1049 "expected path under /tmp, got {path:?}"
1050 );
1051 }
1052
1053 #[test]
1054 fn test_model_cache_dir_deep_model_id() {
1055 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
1056 let s = path.to_str().unwrap();
1057 assert!(
1059 s.contains("org--sub--model-name-with-dashes"),
1060 "expected transformed path, got: {s}"
1061 );
1062 }
1063
1064 #[test]
1065 fn test_model_cache_dir_minilm_model_id() {
1066 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
1067 let s = path.to_str().unwrap();
1068 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
1069 }
1070
1071 #[test]
1072 fn test_model_cache_dir_bge_model_id() {
1073 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
1074 let s = path.to_str().unwrap();
1075 assert!(s.contains("BAAI--bge-small-en-v1.5"));
1076 }
1077
1078 #[test]
1079 fn test_model_cache_dir_e5_model_id() {
1080 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
1081 let s = path.to_str().unwrap();
1082 assert!(s.contains("intfloat--e5-small-v2"));
1083 }
1084
1085 #[test]
1088 fn test_download_hf_file_pytorch_bin_cached() {
1089 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
1090 std::fs::create_dir_all(&tmp).unwrap();
1091 let model_path = tmp.join("pytorch_model.bin");
1092 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
1093
1094 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
1095 assert!(result.is_ok());
1096 assert_eq!(result.unwrap(), model_path);
1097 }
1098
1099 #[test]
1100 fn test_download_hf_file_tokenizer_cached() {
1101 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
1102 std::fs::create_dir_all(&tmp).unwrap();
1103 let tok_path = tmp.join("tokenizer.json");
1104 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
1105
1106 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
1107 assert!(result.is_ok());
1108 assert_eq!(result.unwrap(), tok_path);
1109 }
1110
1111 #[test]
1112 fn test_download_hf_file_config_json_cached() {
1113 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
1114 std::fs::create_dir_all(&tmp).unwrap();
1115 let cfg_path = tmp.join("config.json");
1116 std::fs::write(&cfg_path, b"{}").unwrap();
1117
1118 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
1119 assert!(result.is_ok());
1120 assert_eq!(result.unwrap(), cfg_path);
1121 }
1122
1123 #[tokio::test]
1129 #[allow(clippy::await_holding_lock)]
1130 async fn test_new_fails_with_invalid_tokenizer_json() {
1131 use std::sync::Mutex;
1132 static ENV_LOCK: Mutex<()> = Mutex::new(());
1133 let _guard = ENV_LOCK.lock().unwrap();
1134
1135 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1137 let model_dir = tmp
1138 .join("dakera")
1139 .join("sentence-transformers--all-MiniLM-L6-v2");
1140 std::fs::create_dir_all(&model_dir).unwrap();
1141 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1143 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1145 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1146
1147 unsafe { std::env::set_var("HF_HOME", &tmp) };
1148
1149 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1150 let result = EmbeddingEngine::new(config).await;
1151
1152 unsafe { std::env::remove_var("HF_HOME") };
1153
1154 assert!(
1156 result.is_err(),
1157 "expected Err from new() with invalid tokenizer, got Ok"
1158 );
1159 }
1160
1161 #[test]
1164 fn test_builder_with_all_models() {
1165 for model in [
1166 EmbeddingModel::MiniLM,
1167 EmbeddingModel::BgeSmall,
1168 EmbeddingModel::E5Small,
1169 ] {
1170 let builder = EmbeddingEngineBuilder::new().model(model);
1171 assert_eq!(builder.config.model, model);
1172 }
1173 }
1174
1175 #[test]
1176 fn test_builder_max_batch_size_one() {
1177 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1178 assert_eq!(builder.config.max_batch_size, 1);
1179 }
1180
1181 #[test]
1182 fn test_builder_num_threads_zero() {
1183 let builder = EmbeddingEngineBuilder::new().num_threads(0);
1184 assert_eq!(builder.config.num_threads, Some(0));
1185 }
1186
1187 #[tokio::test]
1193 async fn test_engine_getters_when_model_cached() {
1194 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1195 match EmbeddingEngine::new(config).await {
1196 Ok(engine) => {
1197 assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1198 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1199 let _ = format!("{:?}", engine);
1202 let ms = engine.estimate_time_ms(10, 50);
1204 assert!(ms >= 0.0);
1205 }
1206 Err(_) => {
1207 }
1209 }
1210 }
1211
1212 #[tokio::test]
1215 async fn test_engine_embed_empty_batch_when_cached() {
1216 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1217 if let Ok(engine) = EmbeddingEngine::new(config).await {
1218 let result = engine.embed_raw(&[]).await;
1219 assert!(result.is_ok());
1220 assert!(result.unwrap().is_empty());
1221 }
1222 }
1223
1224 #[test]
1227 fn test_session_pool_default_is_4() {
1228 let config = ModelConfig::default();
1231 let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1232 .ok()
1233 .and_then(|v| v.parse::<usize>().ok())
1234 .filter(|&n| n >= 1)
1235 .unwrap_or(4);
1236 assert_eq!(config.session_pool_size, expected);
1237 }
1238
1239 #[test]
1240 fn test_session_pool_size_builder_roundtrip() {
1241 let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1242 assert_eq!(builder.config.session_pool_size, 8);
1243 }
1244
1245 #[test]
1246 fn test_session_pool_size_min_enforced() {
1247 let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1248 assert_eq!(
1249 builder.config.session_pool_size, 1,
1250 "pool size 0 must clamp to 1"
1251 );
1252 }
1253
1254 #[test]
1255 fn test_model_config_with_session_pool_size() {
1256 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1257 assert_eq!(cfg.session_pool_size, 2);
1258 }
1259
1260 #[tokio::test]
1262 async fn test_engine_pool_size_matches_config_when_cached() {
1263 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1264 if let Ok(engine) = EmbeddingEngine::new(config).await {
1265 assert_eq!(
1266 engine.pool_size(),
1267 2,
1268 "engine should hold exactly 2 sessions"
1269 );
1270 }
1271 }
1272
1273 #[test]
1277 fn test_round_robin_index_stays_in_bounds() {
1278 let pool_len = 4_usize;
1279 let counter = AtomicUsize::new(0);
1280 for expected_idx in 0..100_usize {
1281 let start = counter.fetch_add(1, Ordering::Relaxed);
1282 let slot = start % pool_len;
1283 assert!(slot < pool_len);
1284 assert_eq!(slot, expected_idx % pool_len);
1285 }
1286 }
1287
1288 #[test]
1290 fn test_round_robin_pool_size_one() {
1291 let pool_len = 1_usize;
1292 let counter = AtomicUsize::new(0);
1293 for _ in 0..50 {
1294 let start = counter.fetch_add(1, Ordering::Relaxed);
1295 assert_eq!(start % pool_len, 0);
1296 }
1297 }
1298
1299 #[tokio::test]
1301 async fn test_embed_empty_does_not_advance_pool_counter() {
1302 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1303 if let Ok(engine) = EmbeddingEngine::new(config).await {
1304 let result = engine.embed_raw(&[]).await;
1305 assert!(result.is_ok());
1306 assert!(result.unwrap().is_empty());
1307 assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1309 }
1310 }
1311
1312 #[allow(clippy::await_holding_lock)]
1316 #[tokio::test]
1317 async fn test_engine_use_gpu_defaults_to_false_when_not_configured() {
1318 use std::sync::Mutex;
1319 static ENV_LOCK: Mutex<()> = Mutex::new(());
1320 let _guard = ENV_LOCK.lock().unwrap();
1321
1322 unsafe { std::env::remove_var("DAKERA_USE_GPU") };
1324
1325 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(1);
1326 if let Ok(engine) = EmbeddingEngine::new(config).await {
1327 assert!(
1328 !engine.use_gpu,
1329 "use_gpu must be false when DAKERA_USE_GPU is unset"
1330 );
1331 }
1332 }
1333
1334 #[test]
1336 fn test_engine_use_gpu_resolved_from_env_var() {
1337 use std::sync::Mutex;
1338 static ENV_LOCK: Mutex<()> = Mutex::new(());
1339 let _guard = ENV_LOCK.lock().unwrap();
1340
1341 unsafe { std::env::set_var("DAKERA_USE_GPU", "1") };
1342 let resolved = std::env::var("DAKERA_USE_GPU")
1343 .map(|v| v == "1")
1344 .unwrap_or(false);
1345 unsafe { std::env::remove_var("DAKERA_USE_GPU") };
1346
1347 assert!(resolved, "DAKERA_USE_GPU=1 must resolve use_gpu=true");
1348 }
1349
1350 #[tokio::test]
1352 async fn test_engine_debug_includes_use_gpu() {
1353 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1354 if let Ok(engine) = EmbeddingEngine::new(config).await {
1355 let debug_str = format!("{:?}", engine);
1356 assert!(
1357 debug_str.contains("use_gpu"),
1358 "Debug output must include use_gpu field: {debug_str}"
1359 );
1360 }
1361 }
1362}