1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49pub struct EmbeddingEngine {
56 sessions: Vec<Arc<Mutex<Session>>>,
60 next_session: AtomicUsize,
62 processor: Arc<BatchProcessor>,
64 config: ModelConfig,
66 dimension: usize,
68 use_gpu: bool,
72}
73
74impl EmbeddingEngine {
75 #[instrument(skip_all, fields(model = %config.model))]
80 pub async fn new(config: ModelConfig) -> Result<Self> {
81 let use_gpu = std::env::var("DAKERA_USE_GPU")
84 .map(|v| v == "1")
85 .unwrap_or(config.use_gpu);
86 if use_gpu {
87 info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
88 }
89
90 info!(
91 "Initializing ONNX embedding engine with model: {}",
92 config.model
93 );
94
95 let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
97
98 info!("Loading tokenizer from {:?}", tokenizer_path);
100 let tokenizer = Tokenizer::from_file(&tokenizer_path)
101 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
102
103 info!("Loading ONNX model from {:?}", onnx_path);
106 let num_threads = config.num_threads.unwrap_or(4);
107 let pool_size = config.session_pool_size.max(1);
108 let onnx_path_clone = onnx_path.clone();
109
110 let sessions: Vec<Arc<Mutex<Session>>> =
111 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
112 (0..pool_size)
113 .map(|_| {
114 let builder = Session::builder()
115 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
116 .with_optimization_level(GraphOptimizationLevel::Level3)
117 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
118 .with_intra_threads(num_threads)
119 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
120
121 let mut builder = if use_gpu {
129 builder
130 .with_execution_providers(
131 [CUDAExecutionProvider::default().build()],
132 )
133 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
134 } else {
135 builder
136 .with_memory_pattern(false)
137 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
138 };
139
140 let s = builder
141 .commit_from_file(&onnx_path_clone)
142 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
143 Ok(Arc::new(Mutex::new(s)))
144 })
145 .collect()
146 })
147 .await
148 .map_err(|e| {
149 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
150 })??;
151
152 let dimension = config.model.dimension();
153 let processor = Arc::new(BatchProcessor::new(
154 tokenizer,
155 config.model,
156 config.max_batch_size,
157 ));
158
159 info!(
160 "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
161 config.model, dimension, num_threads, pool_size
162 );
163
164 Ok(Self {
165 sessions,
166 next_session: AtomicUsize::new(0),
167 processor,
168 config,
169 dimension,
170 use_gpu,
171 })
172 }
173
174 #[instrument(skip_all, fields(model = %config.model))]
180 async fn download_model_files(
181 config: &ModelConfig,
182 use_gpu: bool,
183 ) -> Result<(PathBuf, PathBuf)> {
184 let model_id = config.model.model_id();
185 let onnx_repo_id = config.model.onnx_repo_id();
186 let onnx_filename = if use_gpu {
187 config.model.onnx_filename_gpu()
188 } else {
189 config.model.onnx_filename()
190 };
191
192 info!(
193 "Resolving model files: tokenizer={}, onnx={}@{}",
194 model_id, onnx_filename, onnx_repo_id
195 );
196
197 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
198 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
199
200 let onnx_subdir = onnx_cache_dir.join("onnx");
202 std::fs::create_dir_all(&onnx_subdir)?;
203
204 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
205 let onnx_basename = Path::new(onnx_filename)
207 .file_name()
208 .and_then(|s| s.to_str())
209 .unwrap_or("model_quantized.onnx");
210 let local_onnx = onnx_subdir.join(onnx_basename);
211
212 let tokenizer_needs_download = !local_tokenizer.exists();
214
215 if use_gpu && local_onnx.exists() {
220 let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
221 if cached_size <= 500_000_000 {
222 warn!(
223 "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated by old \
224 download limit. Deleting for complete re-download.",
225 local_onnx, cached_size
226 );
227 let _ = std::fs::remove_file(&local_onnx);
228 }
229 }
230 let onnx_needs_download = !local_onnx.exists();
231
232 if tokenizer_needs_download || onnx_needs_download {
233 let model_id_owned = model_id.to_string();
234 let onnx_repo_id_owned = onnx_repo_id.to_string();
235 let onnx_filename_owned = onnx_filename.to_string();
236 let tokenizer_cache = tokenizer_cache_dir.clone();
237 let onnx_cache = onnx_cache_dir.clone();
238
239 tokio::task::spawn_blocking(move || {
240 if !tokenizer_cache.join("tokenizer.json").exists() {
241 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
242 .map_err(|e| {
243 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
244 })?;
245 }
246 if !onnx_cache.join(&onnx_filename_owned).exists() {
247 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
248 .map_err(|e| {
249 InferenceError::HubError(format!(
250 "Failed to download ONNX model: {}",
251 e
252 ))
253 })?;
254 }
255 Ok::<_, InferenceError>(())
256 })
257 .await
258 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
259 } else {
260 info!("All model files found in local cache");
261 }
262
263 let final_onnx = onnx_cache_dir.join(onnx_filename);
265
266 info!(
267 "Model files ready: tokenizer={:?}, onnx={:?}",
268 local_tokenizer, final_onnx
269 );
270 Ok((local_tokenizer, final_onnx))
271 }
272
273 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
275 let base = std::env::var("HF_HOME")
276 .map(PathBuf::from)
277 .unwrap_or_else(|_| {
278 let home = std::env::var("HOME").unwrap_or_else(|_| {
279 warn!("HOME environment variable not set, using /tmp for model cache");
280 "/tmp".to_string()
281 });
282 PathBuf::from(home).join(".cache").join("huggingface")
283 });
284 let dir = base.join("dakera").join(model_id.replace('/', "--"));
285 std::fs::create_dir_all(&dir)?;
286 Ok(dir)
287 }
288
289 pub fn download_hf_file_pub(
295 model_id: &str,
296 filename: &str,
297 cache_dir: &Path,
298 ) -> std::result::Result<PathBuf, String> {
299 Self::download_hf_file(model_id, filename, cache_dir)
300 }
301
302 fn download_hf_file(
303 model_id: &str,
304 filename: &str,
305 cache_dir: &Path,
306 ) -> std::result::Result<PathBuf, String> {
307 let file_path = cache_dir.join(filename);
309 if file_path.exists() {
310 info!("Cached: {}/{}", model_id, filename);
311 return Ok(file_path);
312 }
313
314 if let Some(parent) = file_path.parent() {
316 std::fs::create_dir_all(parent)
317 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
318 }
319
320 let url = format!(
321 "https://huggingface.co/{}/resolve/main/{}",
322 model_id, filename
323 );
324 info!("Downloading: {}", url);
325
326 let hf_token = std::env::var("HF_TOKEN")
328 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
329 .ok();
330 if hf_token.is_some() {
331 info!("Using HuggingFace auth token for download");
332 }
333
334 let agent = ureq::AgentBuilder::new()
336 .redirects(0)
337 .timeout(std::time::Duration::from_secs(300))
338 .build();
339
340 let mut current_url = url.clone();
341 let mut redirects = 0;
342 let max_redirects = 10;
343
344 let response = loop {
345 let mut req = agent.get(¤t_url);
346 if let Some(ref token) = hf_token {
347 req = req.set("Authorization", &format!("Bearer {}", token));
348 }
349 let resp = req.call();
350
351 let r = match resp {
352 Ok(r) => r,
353 Err(ureq::Error::Status(_status, r)) => r,
354 Err(e) => return Err(format!("{}: {}", filename, e)),
355 };
356
357 let status = r.status();
358 if (200..300).contains(&status) {
359 break r;
360 } else if (300..400).contains(&status) {
361 redirects += 1;
362 if redirects > max_redirects {
363 return Err(format!("{}: too many redirects", filename));
364 }
365 let location = r
366 .header("location")
367 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
368 .to_string();
369
370 current_url = if location.starts_with('/') {
372 let parsed = url::Url::parse(¤t_url)
373 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
374 let host = parsed.host_str().ok_or_else(|| {
375 format!("{}: redirect URL missing host: {}", filename, current_url)
376 })?;
377 format!("{}://{}{}", parsed.scheme(), host, location)
378 } else {
379 location
380 };
381 info!("Redirect {} → {}", redirects, current_url);
382 } else {
383 return Err(format!("{}: HTTP {}", filename, status));
384 }
385 };
386
387 let expected_bytes: Option<u64> = response
391 .header("x-linked-size")
392 .or_else(|| response.header("content-length"))
393 .and_then(|v| v.parse::<u64>().ok());
394
395 let mut bytes = Vec::new();
399 response
400 .into_reader()
401 .take(2_147_483_648)
402 .read_to_end(&mut bytes)
403 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
404
405 if let Some(expected) = expected_bytes {
409 let actual = bytes.len() as u64;
410 if actual < expected {
411 return Err(format!(
412 "{}: download incomplete — received {} of {} bytes. \
413 File may exceed 2 GiB or the connection was interrupted.",
414 filename, actual, expected
415 ));
416 }
417 }
418
419 std::fs::write(&file_path, &bytes)
420 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
421
422 info!("Downloaded {} ({} bytes)", filename, bytes.len());
423 Ok(file_path)
424 }
425
426 pub fn dimension(&self) -> usize {
428 self.dimension
429 }
430
431 pub fn model(&self) -> EmbeddingModel {
433 self.config.model
434 }
435
436 pub fn pool_size(&self) -> usize {
438 self.sessions.len()
439 }
440
441 #[instrument(skip(self, text), fields(text_len = text.len()))]
445 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
446 let texts = vec![text.to_string()];
447 let prepared = self.processor.prepare_texts(&texts, true);
448 let embeddings = self.embed_batch_internal(&prepared).await?;
449 embeddings.into_iter().next().ok_or_else(|| {
450 InferenceError::InferenceError("No embedding returned for query".to_string())
451 })
452 }
453
454 #[instrument(skip(self, texts), fields(count = texts.len()))]
458 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
459 let prepared = self.processor.prepare_texts(texts, true);
460 self.embed_batch_internal(&prepared).await
461 }
462
463 #[instrument(skip(self, text), fields(text_len = text.len()))]
467 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
468 let texts = vec![text.to_string()];
469 let prepared = self.processor.prepare_texts(&texts, false);
470 let embeddings = self.embed_batch_internal(&prepared).await?;
471 embeddings.into_iter().next().ok_or_else(|| {
472 InferenceError::InferenceError("No embedding returned for document".to_string())
473 })
474 }
475
476 #[instrument(skip(self, texts), fields(count = texts.len()))]
480 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
481 let prepared = self.processor.prepare_texts(texts, false);
482 self.embed_batch_internal(&prepared).await
483 }
484
485 #[instrument(skip(self, texts), fields(count = texts.len()))]
487 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
488 self.embed_batch_internal(texts).await
489 }
490
491 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
504 if texts.is_empty() {
505 return Ok(vec![]);
506 }
507
508 let pool_len = self.sessions.len();
509 let normalize = self.config.model.normalize_embeddings();
510 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
513
514 let mut batch_size = self.config.max_batch_size.max(1);
515 let use_gpu = self.use_gpu;
516
517 for attempt in 0_u32..=5 {
523 let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
524
525 let mut handles = Vec::with_capacity(batches.len());
527 for (i, batch_owned) in batches.into_iter().enumerate() {
528 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
529 let processor = Arc::clone(&self.processor);
530
531 let gpu_permit = if use_gpu {
536 Some(
537 std::sync::Arc::clone(&crate::GPU_INFERENCE_SEMAPHORE)
538 .acquire_owned()
539 .await
540 .map_err(|_| {
541 InferenceError::InferenceError(
542 "GPU inference semaphore unexpectedly closed".to_string(),
543 )
544 })?,
545 )
546 } else {
547 None
548 };
549
550 handles.push(tokio::task::spawn_blocking(move || {
551 let _gpu_permit = gpu_permit; let mut session_guard = session.lock();
553 Self::process_batch_blocking(
554 &batch_owned,
555 &mut session_guard,
556 &processor,
557 normalize,
558 )
559 }));
560 }
561
562 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
563 let mut oom: Option<InferenceError> = None;
564
565 for handle in handles {
566 match handle.await {
567 Err(panic_err) => {
568 return Err(InferenceError::InferenceError(format!(
569 "Inference task panicked: {panic_err}"
570 )));
571 }
572 Ok(Err(e)) => {
573 if attempt < 5 && Self::is_gpu_oom(&e) {
574 oom = Some(e);
576 break;
577 }
578 return Err(e);
579 }
580 Ok(Ok(batch_embs)) => {
581 all_embeddings.extend(batch_embs);
582 }
583 }
584 }
585
586 if oom.is_some() {
587 let next_batch = (batch_size / 2).max(1);
588 warn!(
589 "ONNX allocator OOM (attempt {}/5) — retrying with batch_size {} → {}",
590 attempt + 1,
591 batch_size,
592 next_batch,
593 );
594 batch_size = next_batch;
595 continue;
596 }
597
598 return Ok(all_embeddings);
599 }
600
601 Err(InferenceError::InferenceError(format!(
602 "ONNX inference failed: GPU/CPU allocator OOM persists after 5 \
603 batch-halving attempts (final batch_size={batch_size})"
604 )))
605 }
606
607 fn is_gpu_oom(err: &InferenceError) -> bool {
613 let msg = err.to_string();
614 msg.contains("BFCArena")
615 || msg.contains("Failed to allocate memory")
616 || msg.contains("CUDA_OUT_OF_MEMORY")
617 || msg.contains("CUDA out of memory")
618 || (msg.contains("allocate") && msg.contains("buffer of size"))
619 }
620
621 fn process_batch_blocking(
625 texts: &[String],
626 session: &mut Session,
627 processor: &BatchProcessor,
628 normalize: bool,
629 ) -> Result<Vec<Vec<f32>>> {
630 let prepared = processor.tokenize_batch(texts)?;
632 let batch_size = prepared.batch_size;
633 let seq_len = prepared.seq_len;
634
635 let attention_mask_flat = prepared.attention_mask.clone();
637
638 let input_ids_tensor =
640 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
641 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
642 let attention_mask_tensor =
643 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
644 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
645 let token_type_ids_tensor =
646 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
647 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
648
649 let outputs = session
651 .run(inputs![
652 "input_ids" => input_ids_tensor,
653 "attention_mask" => attention_mask_tensor,
654 "token_type_ids" => token_type_ids_tensor
655 ])
656 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
657
658 let (ort_shape, lhs_slice) = outputs[0]
662 .try_extract_tensor::<f32>()
663 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
664
665 if ort_shape.len() != 3 {
666 return Err(InferenceError::InferenceError(format!(
667 "Expected 3D last_hidden_state, got {} dims",
668 ort_shape.len()
669 )));
670 }
671 let hidden_size = ort_shape[2] as usize;
672
673 let mut embeddings = mean_pooling(
675 lhs_slice,
676 batch_size,
677 seq_len,
678 hidden_size,
679 &attention_mask_flat,
680 );
681
682 if normalize {
684 normalize_embeddings(&mut embeddings);
685 }
686
687 debug!(
688 "Generated {} embeddings of dimension {}",
689 embeddings.len(),
690 embeddings.first().map(|e| e.len()).unwrap_or(0)
691 );
692
693 Ok(embeddings)
694 }
695
696 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
698 let tokens_per_text =
700 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
701 let total_tokens = tokens_per_text * text_count as f64;
702 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
703 (total_tokens / tokens_per_second) * 1000.0
704 }
705}
706
707impl std::fmt::Debug for EmbeddingEngine {
708 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
709 f.debug_struct("EmbeddingEngine")
710 .field("model", &self.config.model)
711 .field("dimension", &self.dimension)
712 .field("max_batch_size", &self.config.max_batch_size)
713 .field("session_pool_size", &self.sessions.len())
714 .field("use_gpu", &self.use_gpu)
715 .finish()
716 }
717}
718
719pub struct EmbeddingEngineBuilder {
721 config: ModelConfig,
722}
723
724impl EmbeddingEngineBuilder {
725 pub fn new() -> Self {
727 Self {
728 config: ModelConfig::default(),
729 }
730 }
731
732 pub fn model(mut self, model: EmbeddingModel) -> Self {
734 self.config.model = model;
735 self
736 }
737
738 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
740 self.config.cache_dir = Some(dir.into());
741 self
742 }
743
744 pub fn max_batch_size(mut self, size: usize) -> Self {
746 self.config.max_batch_size = size;
747 self
748 }
749
750 pub fn use_gpu(mut self, enable: bool) -> Self {
752 self.config.use_gpu = enable;
753 self
754 }
755
756 pub fn num_threads(mut self, threads: usize) -> Self {
758 self.config.num_threads = Some(threads);
759 self
760 }
761
762 pub fn session_pool_size(mut self, size: usize) -> Self {
764 self.config.session_pool_size = size.max(1);
765 self
766 }
767
768 pub async fn build(self) -> Result<EmbeddingEngine> {
770 EmbeddingEngine::new(self.config).await
771 }
772}
773
774impl Default for EmbeddingEngineBuilder {
775 fn default() -> Self {
776 Self::new()
777 }
778}
779
780#[cfg(test)]
781mod tests {
782 use super::*;
783
784 #[test]
785 fn test_estimate_time() {
786 let config = ModelConfig::new(EmbeddingModel::MiniLM);
787 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
788 assert!(tokens_per_second > 0.0);
789 }
790
791 #[test]
792 fn test_builder() {
793 let builder = EmbeddingEngineBuilder::new()
794 .model(EmbeddingModel::BgeSmall)
795 .max_batch_size(64)
796 .use_gpu(false);
797
798 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
799 assert_eq!(builder.config.max_batch_size, 64);
800 assert!(!builder.config.use_gpu);
801 }
802
803 #[test]
809 fn test_model_cache_dir_with_hf_home() {
810 use std::sync::Mutex;
811 static ENV_LOCK: Mutex<()> = Mutex::new(());
812 let _guard = ENV_LOCK.lock().unwrap();
813
814 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
815 std::env::set_var("HF_HOME", &tmp);
816 let result = EmbeddingEngine::model_cache_dir("org/my-model");
817 std::env::remove_var("HF_HOME");
818
819 let path = result.unwrap();
820 assert!(
821 path.starts_with(&tmp),
822 "expected path under {tmp:?}, got {path:?}"
823 );
824 assert!(
825 path.to_str().unwrap().contains("org--my-model"),
826 "model_id separator not applied: {path:?}"
827 );
828 }
829
830 #[test]
831 fn test_model_cache_dir_contains_dakera_subdir() {
832 let path =
833 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
834 let s = path.to_str().unwrap();
835 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
836 assert!(
837 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
838 "expected transformed model id in path: {s}"
839 );
840 }
841
842 #[test]
843 fn test_model_cache_dir_creates_directory() {
844 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
845 assert!(dir.exists(), "model_cache_dir should create the directory");
846 }
847
848 #[test]
851 fn test_download_hf_file_returns_path_when_already_cached() {
852 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
853 std::fs::create_dir_all(&tmp).unwrap();
854 let file_path = tmp.join("config.json");
855 std::fs::write(&file_path, b"{}").unwrap();
856
857 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
858 assert!(result.is_ok());
859 assert_eq!(result.unwrap(), file_path);
860 }
861
862 #[test]
863 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
864 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
865 let onnx_dir = tmp.join("onnx");
866 std::fs::create_dir_all(&onnx_dir).unwrap();
867 let onnx_path = onnx_dir.join("model_quantized.onnx");
868 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
869
870 let result = EmbeddingEngine::download_hf_file(
872 "Xenova/all-MiniLM-L6-v2",
873 "onnx/model_quantized.onnx",
874 &tmp,
875 );
876 assert!(result.is_ok());
877 assert_eq!(result.unwrap(), onnx_path);
878 }
879
880 #[test]
883 fn test_builder_default_impl() {
884 let b1 = EmbeddingEngineBuilder::new();
885 let b2 = EmbeddingEngineBuilder::default();
886 assert_eq!(b1.config.model, b2.config.model);
887 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
888 }
889
890 #[test]
891 fn test_builder_model_field() {
892 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
893 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
894 }
895
896 #[test]
897 fn test_builder_cache_dir() {
898 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
899 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
900 }
901
902 #[test]
903 fn test_builder_max_batch_size() {
904 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
905 assert_eq!(builder.config.max_batch_size, 128);
906 }
907
908 #[test]
909 fn test_builder_use_gpu_true() {
910 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
911 assert!(builder.config.use_gpu);
912 }
913
914 #[test]
915 fn test_builder_use_gpu_false() {
916 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
917 assert!(!builder.config.use_gpu);
918 }
919
920 #[test]
921 fn test_builder_num_threads() {
922 let builder = EmbeddingEngineBuilder::new().num_threads(4);
923 assert_eq!(builder.config.num_threads, Some(4));
924 }
925
926 #[test]
927 fn test_builder_chain_all_fields() {
928 let builder = EmbeddingEngineBuilder::new()
929 .model(EmbeddingModel::BgeSmall)
930 .cache_dir("/cache")
931 .max_batch_size(16)
932 .use_gpu(false)
933 .num_threads(2);
934
935 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
936 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
937 assert_eq!(builder.config.max_batch_size, 16);
938 assert!(!builder.config.use_gpu);
939 assert_eq!(builder.config.num_threads, Some(2));
940 }
941
942 #[test]
945 fn test_estimate_time_zero_count() {
946 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
947 let estimate = (0.0 / tps) * 1000.0;
948 assert_eq!(estimate, 0.0);
949 }
950
951 #[test]
952 fn test_estimate_time_formula_cpu() {
953 let model = EmbeddingModel::MiniLM;
956 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
957 let total_tokens = tokens_per_text * 10.0;
958 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
959 assert!(
960 (estimate - 50.0).abs() < 1e-6,
961 "expected 50.0ms, got {estimate}"
962 );
963 }
964
965 #[test]
966 fn test_estimate_time_capped_at_max_seq_length() {
967 let model = EmbeddingModel::MiniLM;
968 let avg_len = 100_000;
969 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
970 assert_eq!(tokens_per_text, 256.0);
971 }
972
973 #[test]
976 fn test_model_config_new() {
977 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
978 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
979 assert_eq!(cfg.max_batch_size, 32);
980 assert!(!cfg.use_gpu);
981 assert!(cfg.cache_dir.is_none());
982 assert!(cfg.num_threads.is_none());
983 }
984
985 #[test]
986 fn test_model_config_default() {
987 let cfg = ModelConfig::default();
988 assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
989 assert_eq!(cfg.max_batch_size, 32);
990 assert!(!cfg.use_gpu);
991 }
992
993 #[test]
994 fn test_model_config_with_cache_dir() {
995 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
996 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
997 }
998
999 #[test]
1000 fn test_model_config_with_max_batch_size() {
1001 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
1002 assert_eq!(cfg.max_batch_size, 64);
1003 }
1004
1005 #[test]
1006 fn test_model_config_with_gpu() {
1007 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
1008 assert!(cfg.use_gpu);
1009 }
1010
1011 #[test]
1012 fn test_model_config_with_num_threads() {
1013 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
1014 assert_eq!(cfg.num_threads, Some(8));
1015 }
1016
1017 #[test]
1018 fn test_model_config_chained_builder() {
1019 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
1020 .with_cache_dir("/cache")
1021 .with_max_batch_size(16)
1022 .with_gpu(false)
1023 .with_num_threads(4);
1024 assert_eq!(cfg.model, EmbeddingModel::E5Small);
1025 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
1026 assert_eq!(cfg.max_batch_size, 16);
1027 assert!(!cfg.use_gpu);
1028 assert_eq!(cfg.num_threads, Some(4));
1029 }
1030
1031 #[test]
1035 fn test_model_cache_dir_no_home_fallback() {
1036 use std::sync::Mutex;
1037 static ENV_LOCK: Mutex<()> = Mutex::new(());
1038 let _guard = ENV_LOCK.lock().unwrap();
1039
1040 let saved_home = std::env::var("HOME").ok();
1042 let saved_hf = std::env::var("HF_HOME").ok();
1043 unsafe {
1044 std::env::remove_var("HOME");
1045 std::env::remove_var("HF_HOME");
1046 }
1047
1048 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
1049
1050 if let Some(h) = saved_home {
1052 unsafe { std::env::set_var("HOME", h) };
1053 }
1054 if let Some(h) = saved_hf {
1055 unsafe { std::env::set_var("HF_HOME", h) };
1056 }
1057
1058 let path = result.unwrap();
1059 assert!(
1061 path.starts_with("/tmp"),
1062 "expected path under /tmp, got {path:?}"
1063 );
1064 }
1065
1066 #[test]
1067 fn test_model_cache_dir_deep_model_id() {
1068 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
1069 let s = path.to_str().unwrap();
1070 assert!(
1072 s.contains("org--sub--model-name-with-dashes"),
1073 "expected transformed path, got: {s}"
1074 );
1075 }
1076
1077 #[test]
1078 fn test_model_cache_dir_minilm_model_id() {
1079 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
1080 let s = path.to_str().unwrap();
1081 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
1082 }
1083
1084 #[test]
1085 fn test_model_cache_dir_bge_model_id() {
1086 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
1087 let s = path.to_str().unwrap();
1088 assert!(s.contains("BAAI--bge-small-en-v1.5"));
1089 }
1090
1091 #[test]
1092 fn test_model_cache_dir_e5_model_id() {
1093 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
1094 let s = path.to_str().unwrap();
1095 assert!(s.contains("intfloat--e5-small-v2"));
1096 }
1097
1098 #[test]
1101 fn test_download_hf_file_pytorch_bin_cached() {
1102 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
1103 std::fs::create_dir_all(&tmp).unwrap();
1104 let model_path = tmp.join("pytorch_model.bin");
1105 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
1106
1107 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
1108 assert!(result.is_ok());
1109 assert_eq!(result.unwrap(), model_path);
1110 }
1111
1112 #[test]
1113 fn test_download_hf_file_tokenizer_cached() {
1114 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
1115 std::fs::create_dir_all(&tmp).unwrap();
1116 let tok_path = tmp.join("tokenizer.json");
1117 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
1118
1119 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
1120 assert!(result.is_ok());
1121 assert_eq!(result.unwrap(), tok_path);
1122 }
1123
1124 #[test]
1125 fn test_download_hf_file_config_json_cached() {
1126 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
1127 std::fs::create_dir_all(&tmp).unwrap();
1128 let cfg_path = tmp.join("config.json");
1129 std::fs::write(&cfg_path, b"{}").unwrap();
1130
1131 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
1132 assert!(result.is_ok());
1133 assert_eq!(result.unwrap(), cfg_path);
1134 }
1135
1136 #[tokio::test]
1142 #[allow(clippy::await_holding_lock)]
1143 async fn test_new_fails_with_invalid_tokenizer_json() {
1144 use std::sync::Mutex;
1145 static ENV_LOCK: Mutex<()> = Mutex::new(());
1146 let _guard = ENV_LOCK.lock().unwrap();
1147
1148 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1150 let model_dir = tmp
1151 .join("dakera")
1152 .join("sentence-transformers--all-MiniLM-L6-v2");
1153 std::fs::create_dir_all(&model_dir).unwrap();
1154 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1156 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1158 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1159
1160 unsafe { std::env::set_var("HF_HOME", &tmp) };
1161
1162 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1163 let result = EmbeddingEngine::new(config).await;
1164
1165 unsafe { std::env::remove_var("HF_HOME") };
1166
1167 assert!(
1169 result.is_err(),
1170 "expected Err from new() with invalid tokenizer, got Ok"
1171 );
1172 }
1173
1174 #[test]
1177 fn test_builder_with_all_models() {
1178 for model in [
1179 EmbeddingModel::MiniLM,
1180 EmbeddingModel::BgeSmall,
1181 EmbeddingModel::E5Small,
1182 ] {
1183 let builder = EmbeddingEngineBuilder::new().model(model);
1184 assert_eq!(builder.config.model, model);
1185 }
1186 }
1187
1188 #[test]
1189 fn test_builder_max_batch_size_one() {
1190 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1191 assert_eq!(builder.config.max_batch_size, 1);
1192 }
1193
1194 #[test]
1195 fn test_builder_num_threads_zero() {
1196 let builder = EmbeddingEngineBuilder::new().num_threads(0);
1197 assert_eq!(builder.config.num_threads, Some(0));
1198 }
1199
1200 #[tokio::test]
1206 async fn test_engine_getters_when_model_cached() {
1207 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1208 match EmbeddingEngine::new(config).await {
1209 Ok(engine) => {
1210 assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1211 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1212 let _ = format!("{:?}", engine);
1215 let ms = engine.estimate_time_ms(10, 50);
1217 assert!(ms >= 0.0);
1218 }
1219 Err(_) => {
1220 }
1222 }
1223 }
1224
1225 #[tokio::test]
1228 async fn test_engine_embed_empty_batch_when_cached() {
1229 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1230 if let Ok(engine) = EmbeddingEngine::new(config).await {
1231 let result = engine.embed_raw(&[]).await;
1232 assert!(result.is_ok());
1233 assert!(result.unwrap().is_empty());
1234 }
1235 }
1236
1237 #[test]
1240 fn test_session_pool_default_is_4() {
1241 let config = ModelConfig::default();
1244 let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1245 .ok()
1246 .and_then(|v| v.parse::<usize>().ok())
1247 .filter(|&n| n >= 1)
1248 .unwrap_or(4);
1249 assert_eq!(config.session_pool_size, expected);
1250 }
1251
1252 #[test]
1253 fn test_session_pool_size_builder_roundtrip() {
1254 let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1255 assert_eq!(builder.config.session_pool_size, 8);
1256 }
1257
1258 #[test]
1259 fn test_session_pool_size_min_enforced() {
1260 let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1261 assert_eq!(
1262 builder.config.session_pool_size, 1,
1263 "pool size 0 must clamp to 1"
1264 );
1265 }
1266
1267 #[test]
1268 fn test_model_config_with_session_pool_size() {
1269 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1270 assert_eq!(cfg.session_pool_size, 2);
1271 }
1272
1273 #[tokio::test]
1275 async fn test_engine_pool_size_matches_config_when_cached() {
1276 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1277 if let Ok(engine) = EmbeddingEngine::new(config).await {
1278 assert_eq!(
1279 engine.pool_size(),
1280 2,
1281 "engine should hold exactly 2 sessions"
1282 );
1283 }
1284 }
1285
1286 #[test]
1290 fn test_round_robin_index_stays_in_bounds() {
1291 let pool_len = 4_usize;
1292 let counter = AtomicUsize::new(0);
1293 for expected_idx in 0..100_usize {
1294 let start = counter.fetch_add(1, Ordering::Relaxed);
1295 let slot = start % pool_len;
1296 assert!(slot < pool_len);
1297 assert_eq!(slot, expected_idx % pool_len);
1298 }
1299 }
1300
1301 #[test]
1303 fn test_round_robin_pool_size_one() {
1304 let pool_len = 1_usize;
1305 let counter = AtomicUsize::new(0);
1306 for _ in 0..50 {
1307 let start = counter.fetch_add(1, Ordering::Relaxed);
1308 assert_eq!(start % pool_len, 0);
1309 }
1310 }
1311
1312 #[tokio::test]
1314 async fn test_embed_empty_does_not_advance_pool_counter() {
1315 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1316 if let Ok(engine) = EmbeddingEngine::new(config).await {
1317 let result = engine.embed_raw(&[]).await;
1318 assert!(result.is_ok());
1319 assert!(result.unwrap().is_empty());
1320 assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1322 }
1323 }
1324
1325 #[allow(clippy::await_holding_lock)]
1329 #[tokio::test]
1330 async fn test_engine_use_gpu_defaults_to_false_when_not_configured() {
1331 use std::sync::Mutex;
1332 static ENV_LOCK: Mutex<()> = Mutex::new(());
1333 let _guard = ENV_LOCK.lock().unwrap();
1334
1335 unsafe { std::env::remove_var("DAKERA_USE_GPU") };
1337
1338 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(1);
1339 if let Ok(engine) = EmbeddingEngine::new(config).await {
1340 assert!(
1341 !engine.use_gpu,
1342 "use_gpu must be false when DAKERA_USE_GPU is unset"
1343 );
1344 }
1345 }
1346
1347 #[test]
1349 fn test_engine_use_gpu_resolved_from_env_var() {
1350 use std::sync::Mutex;
1351 static ENV_LOCK: Mutex<()> = Mutex::new(());
1352 let _guard = ENV_LOCK.lock().unwrap();
1353
1354 unsafe { std::env::set_var("DAKERA_USE_GPU", "1") };
1355 let resolved = std::env::var("DAKERA_USE_GPU")
1356 .map(|v| v == "1")
1357 .unwrap_or(false);
1358 unsafe { std::env::remove_var("DAKERA_USE_GPU") };
1359
1360 assert!(resolved, "DAKERA_USE_GPU=1 must resolve use_gpu=true");
1361 }
1362
1363 #[tokio::test]
1365 async fn test_engine_debug_includes_use_gpu() {
1366 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1367 if let Ok(engine) = EmbeddingEngine::new(config).await {
1368 let debug_str = format!("{:?}", engine);
1369 assert!(
1370 debug_str.contains("use_gpu"),
1371 "Debug output must include use_gpu field: {debug_str}"
1372 );
1373 }
1374 }
1375}