1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49pub struct EmbeddingEngine {
56 sessions: Vec<Arc<Mutex<Session>>>,
60 next_session: AtomicUsize,
62 processor: Arc<BatchProcessor>,
64 config: ModelConfig,
66 dimension: usize,
68}
69
70impl EmbeddingEngine {
71 #[instrument(skip_all, fields(model = %config.model))]
76 pub async fn new(config: ModelConfig) -> Result<Self> {
77 let use_gpu = std::env::var("DAKERA_USE_GPU")
80 .map(|v| v == "1")
81 .unwrap_or(config.use_gpu);
82 if use_gpu {
83 info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
84 }
85
86 info!(
87 "Initializing ONNX embedding engine with model: {}",
88 config.model
89 );
90
91 let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
93
94 info!("Loading tokenizer from {:?}", tokenizer_path);
96 let tokenizer = Tokenizer::from_file(&tokenizer_path)
97 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
98
99 info!("Loading ONNX model from {:?}", onnx_path);
102 let num_threads = config.num_threads.unwrap_or(4);
103 let pool_size = config.session_pool_size.max(1);
104 let onnx_path_clone = onnx_path.clone();
105
106 let sessions: Vec<Arc<Mutex<Session>>> =
107 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
108 (0..pool_size)
109 .map(|_| {
110 let builder = Session::builder()
111 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
112 .with_optimization_level(GraphOptimizationLevel::Level3)
113 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
114 .with_intra_threads(num_threads)
115 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
116
117 let mut builder = if use_gpu {
118 builder
119 .with_execution_providers(
120 [CUDAExecutionProvider::default().build()],
121 )
122 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
123 } else {
124 builder
125 };
126
127 let s = builder
128 .commit_from_file(&onnx_path_clone)
129 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
130 Ok(Arc::new(Mutex::new(s)))
131 })
132 .collect()
133 })
134 .await
135 .map_err(|e| {
136 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
137 })??;
138
139 let dimension = config.model.dimension();
140 let processor = Arc::new(BatchProcessor::new(
141 tokenizer,
142 config.model,
143 config.max_batch_size,
144 ));
145
146 info!(
147 "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
148 config.model, dimension, num_threads, pool_size
149 );
150
151 Ok(Self {
152 sessions,
153 next_session: AtomicUsize::new(0),
154 processor,
155 config,
156 dimension,
157 })
158 }
159
160 #[instrument(skip_all, fields(model = %config.model))]
166 async fn download_model_files(
167 config: &ModelConfig,
168 use_gpu: bool,
169 ) -> Result<(PathBuf, PathBuf)> {
170 let model_id = config.model.model_id();
171 let onnx_repo_id = config.model.onnx_repo_id();
172 let onnx_filename = if use_gpu {
173 config.model.onnx_filename_gpu()
174 } else {
175 config.model.onnx_filename()
176 };
177
178 info!(
179 "Resolving model files: tokenizer={}, onnx={}@{}",
180 model_id, onnx_filename, onnx_repo_id
181 );
182
183 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
184 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
185
186 let onnx_subdir = onnx_cache_dir.join("onnx");
188 std::fs::create_dir_all(&onnx_subdir)?;
189
190 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
191 let onnx_basename = Path::new(onnx_filename)
193 .file_name()
194 .and_then(|s| s.to_str())
195 .unwrap_or("model_quantized.onnx");
196 let local_onnx = onnx_subdir.join(onnx_basename);
197
198 let tokenizer_needs_download = !local_tokenizer.exists();
200
201 if use_gpu && local_onnx.exists() {
206 let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
207 if cached_size <= 500_000_000 {
208 warn!(
209 "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated by old \
210 download limit. Deleting for complete re-download.",
211 local_onnx, cached_size
212 );
213 let _ = std::fs::remove_file(&local_onnx);
214 }
215 }
216 let onnx_needs_download = !local_onnx.exists();
217
218 if tokenizer_needs_download || onnx_needs_download {
219 let model_id_owned = model_id.to_string();
220 let onnx_repo_id_owned = onnx_repo_id.to_string();
221 let onnx_filename_owned = onnx_filename.to_string();
222 let tokenizer_cache = tokenizer_cache_dir.clone();
223 let onnx_cache = onnx_cache_dir.clone();
224
225 tokio::task::spawn_blocking(move || {
226 if !tokenizer_cache.join("tokenizer.json").exists() {
227 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
228 .map_err(|e| {
229 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
230 })?;
231 }
232 if !onnx_cache.join(&onnx_filename_owned).exists() {
233 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
234 .map_err(|e| {
235 InferenceError::HubError(format!(
236 "Failed to download ONNX model: {}",
237 e
238 ))
239 })?;
240 }
241 Ok::<_, InferenceError>(())
242 })
243 .await
244 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
245 } else {
246 info!("All model files found in local cache");
247 }
248
249 let final_onnx = onnx_cache_dir.join(onnx_filename);
251
252 info!(
253 "Model files ready: tokenizer={:?}, onnx={:?}",
254 local_tokenizer, final_onnx
255 );
256 Ok((local_tokenizer, final_onnx))
257 }
258
259 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
261 let base = std::env::var("HF_HOME")
262 .map(PathBuf::from)
263 .unwrap_or_else(|_| {
264 let home = std::env::var("HOME").unwrap_or_else(|_| {
265 warn!("HOME environment variable not set, using /tmp for model cache");
266 "/tmp".to_string()
267 });
268 PathBuf::from(home).join(".cache").join("huggingface")
269 });
270 let dir = base.join("dakera").join(model_id.replace('/', "--"));
271 std::fs::create_dir_all(&dir)?;
272 Ok(dir)
273 }
274
275 pub fn download_hf_file_pub(
281 model_id: &str,
282 filename: &str,
283 cache_dir: &Path,
284 ) -> std::result::Result<PathBuf, String> {
285 Self::download_hf_file(model_id, filename, cache_dir)
286 }
287
288 fn download_hf_file(
289 model_id: &str,
290 filename: &str,
291 cache_dir: &Path,
292 ) -> std::result::Result<PathBuf, String> {
293 let file_path = cache_dir.join(filename);
295 if file_path.exists() {
296 info!("Cached: {}/{}", model_id, filename);
297 return Ok(file_path);
298 }
299
300 if let Some(parent) = file_path.parent() {
302 std::fs::create_dir_all(parent)
303 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
304 }
305
306 let url = format!(
307 "https://huggingface.co/{}/resolve/main/{}",
308 model_id, filename
309 );
310 info!("Downloading: {}", url);
311
312 let hf_token = std::env::var("HF_TOKEN")
314 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
315 .ok();
316 if hf_token.is_some() {
317 info!("Using HuggingFace auth token for download");
318 }
319
320 let agent = ureq::AgentBuilder::new()
322 .redirects(0)
323 .timeout(std::time::Duration::from_secs(300))
324 .build();
325
326 let mut current_url = url.clone();
327 let mut redirects = 0;
328 let max_redirects = 10;
329
330 let response = loop {
331 let mut req = agent.get(¤t_url);
332 if let Some(ref token) = hf_token {
333 req = req.set("Authorization", &format!("Bearer {}", token));
334 }
335 let resp = req.call();
336
337 let r = match resp {
338 Ok(r) => r,
339 Err(ureq::Error::Status(_status, r)) => r,
340 Err(e) => return Err(format!("{}: {}", filename, e)),
341 };
342
343 let status = r.status();
344 if (200..300).contains(&status) {
345 break r;
346 } else if (300..400).contains(&status) {
347 redirects += 1;
348 if redirects > max_redirects {
349 return Err(format!("{}: too many redirects", filename));
350 }
351 let location = r
352 .header("location")
353 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
354 .to_string();
355
356 current_url = if location.starts_with('/') {
358 let parsed = url::Url::parse(¤t_url)
359 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
360 let host = parsed.host_str().ok_or_else(|| {
361 format!("{}: redirect URL missing host: {}", filename, current_url)
362 })?;
363 format!("{}://{}{}", parsed.scheme(), host, location)
364 } else {
365 location
366 };
367 info!("Redirect {} → {}", redirects, current_url);
368 } else {
369 return Err(format!("{}: HTTP {}", filename, status));
370 }
371 };
372
373 let expected_bytes: Option<u64> = response
377 .header("x-linked-size")
378 .or_else(|| response.header("content-length"))
379 .and_then(|v| v.parse::<u64>().ok());
380
381 let mut bytes = Vec::new();
385 response
386 .into_reader()
387 .take(2_147_483_648)
388 .read_to_end(&mut bytes)
389 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
390
391 if let Some(expected) = expected_bytes {
395 let actual = bytes.len() as u64;
396 if actual < expected {
397 return Err(format!(
398 "{}: download incomplete — received {} of {} bytes. \
399 File may exceed 2 GiB or the connection was interrupted.",
400 filename, actual, expected
401 ));
402 }
403 }
404
405 std::fs::write(&file_path, &bytes)
406 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
407
408 info!("Downloaded {} ({} bytes)", filename, bytes.len());
409 Ok(file_path)
410 }
411
412 pub fn dimension(&self) -> usize {
414 self.dimension
415 }
416
417 pub fn model(&self) -> EmbeddingModel {
419 self.config.model
420 }
421
422 pub fn pool_size(&self) -> usize {
424 self.sessions.len()
425 }
426
427 #[instrument(skip(self, text), fields(text_len = text.len()))]
431 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
432 let texts = vec![text.to_string()];
433 let prepared = self.processor.prepare_texts(&texts, true);
434 let embeddings = self.embed_batch_internal(&prepared).await?;
435 embeddings.into_iter().next().ok_or_else(|| {
436 InferenceError::InferenceError("No embedding returned for query".to_string())
437 })
438 }
439
440 #[instrument(skip(self, texts), fields(count = texts.len()))]
444 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
445 let prepared = self.processor.prepare_texts(texts, true);
446 self.embed_batch_internal(&prepared).await
447 }
448
449 #[instrument(skip(self, text), fields(text_len = text.len()))]
453 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
454 let texts = vec![text.to_string()];
455 let prepared = self.processor.prepare_texts(&texts, false);
456 let embeddings = self.embed_batch_internal(&prepared).await?;
457 embeddings.into_iter().next().ok_or_else(|| {
458 InferenceError::InferenceError("No embedding returned for document".to_string())
459 })
460 }
461
462 #[instrument(skip(self, texts), fields(count = texts.len()))]
466 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
467 let prepared = self.processor.prepare_texts(texts, false);
468 self.embed_batch_internal(&prepared).await
469 }
470
471 #[instrument(skip(self, texts), fields(count = texts.len()))]
473 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
474 self.embed_batch_internal(texts).await
475 }
476
477 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
485 if texts.is_empty() {
486 return Ok(vec![]);
487 }
488
489 let pool_len = self.sessions.len();
490 let normalize = self.config.model.normalize_embeddings();
491 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
494
495 let mut batch_size = self.config.max_batch_size.max(1);
496
497 for attempt in 0_u32..=3 {
499 let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
500
501 let mut handles = Vec::with_capacity(batches.len());
503 for (i, batch_owned) in batches.into_iter().enumerate() {
504 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
505 let processor = Arc::clone(&self.processor);
506 handles.push(tokio::task::spawn_blocking(move || {
507 let mut session_guard = session.lock();
508 Self::process_batch_blocking(
509 &batch_owned,
510 &mut session_guard,
511 &processor,
512 normalize,
513 )
514 }));
515 }
516
517 let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
518 let mut oom: Option<InferenceError> = None;
519
520 for handle in handles {
521 match handle.await {
522 Err(panic_err) => {
523 return Err(InferenceError::InferenceError(format!(
524 "Inference task panicked: {panic_err}"
525 )));
526 }
527 Ok(Err(e)) => {
528 if attempt < 3 && Self::is_gpu_oom(&e) {
529 oom = Some(e);
531 break;
532 }
533 return Err(e);
534 }
535 Ok(Ok(batch_embs)) => {
536 all_embeddings.extend(batch_embs);
537 }
538 }
539 }
540
541 if oom.is_some() {
542 let next_batch = (batch_size / 2).max(1);
543 warn!(
544 "ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
545 attempt + 1,
546 batch_size,
547 next_batch,
548 );
549 batch_size = next_batch;
550 continue;
551 }
552
553 return Ok(all_embeddings);
554 }
555
556 Err(InferenceError::InferenceError(format!(
557 "ONNX inference failed: GPU/CPU allocator OOM persists after 3 \
558 batch-halving attempts (final batch_size={batch_size})"
559 )))
560 }
561
562 fn is_gpu_oom(err: &InferenceError) -> bool {
568 let msg = err.to_string();
569 msg.contains("BFCArena")
570 || msg.contains("Failed to allocate memory")
571 || msg.contains("CUDA_OUT_OF_MEMORY")
572 || msg.contains("CUDA out of memory")
573 || (msg.contains("allocate") && msg.contains("buffer of size"))
574 }
575
576 fn process_batch_blocking(
580 texts: &[String],
581 session: &mut Session,
582 processor: &BatchProcessor,
583 normalize: bool,
584 ) -> Result<Vec<Vec<f32>>> {
585 let prepared = processor.tokenize_batch(texts)?;
587 let batch_size = prepared.batch_size;
588 let seq_len = prepared.seq_len;
589
590 let attention_mask_flat = prepared.attention_mask.clone();
592
593 let input_ids_tensor =
595 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
596 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
597 let attention_mask_tensor =
598 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
599 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
600 let token_type_ids_tensor =
601 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
602 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
603
604 let outputs = session
606 .run(inputs![
607 "input_ids" => input_ids_tensor,
608 "attention_mask" => attention_mask_tensor,
609 "token_type_ids" => token_type_ids_tensor
610 ])
611 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
612
613 let (ort_shape, lhs_slice) = outputs[0]
617 .try_extract_tensor::<f32>()
618 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
619
620 if ort_shape.len() != 3 {
621 return Err(InferenceError::InferenceError(format!(
622 "Expected 3D last_hidden_state, got {} dims",
623 ort_shape.len()
624 )));
625 }
626 let hidden_size = ort_shape[2] as usize;
627
628 let mut embeddings = mean_pooling(
630 lhs_slice,
631 batch_size,
632 seq_len,
633 hidden_size,
634 &attention_mask_flat,
635 );
636
637 if normalize {
639 normalize_embeddings(&mut embeddings);
640 }
641
642 debug!(
643 "Generated {} embeddings of dimension {}",
644 embeddings.len(),
645 embeddings.first().map(|e| e.len()).unwrap_or(0)
646 );
647
648 Ok(embeddings)
649 }
650
651 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
653 let tokens_per_text =
655 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
656 let total_tokens = tokens_per_text * text_count as f64;
657 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
658 (total_tokens / tokens_per_second) * 1000.0
659 }
660}
661
662impl std::fmt::Debug for EmbeddingEngine {
663 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664 f.debug_struct("EmbeddingEngine")
665 .field("model", &self.config.model)
666 .field("dimension", &self.dimension)
667 .field("max_batch_size", &self.config.max_batch_size)
668 .field("session_pool_size", &self.sessions.len())
669 .finish()
670 }
671}
672
673pub struct EmbeddingEngineBuilder {
675 config: ModelConfig,
676}
677
678impl EmbeddingEngineBuilder {
679 pub fn new() -> Self {
681 Self {
682 config: ModelConfig::default(),
683 }
684 }
685
686 pub fn model(mut self, model: EmbeddingModel) -> Self {
688 self.config.model = model;
689 self
690 }
691
692 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
694 self.config.cache_dir = Some(dir.into());
695 self
696 }
697
698 pub fn max_batch_size(mut self, size: usize) -> Self {
700 self.config.max_batch_size = size;
701 self
702 }
703
704 pub fn use_gpu(mut self, enable: bool) -> Self {
706 self.config.use_gpu = enable;
707 self
708 }
709
710 pub fn num_threads(mut self, threads: usize) -> Self {
712 self.config.num_threads = Some(threads);
713 self
714 }
715
716 pub fn session_pool_size(mut self, size: usize) -> Self {
718 self.config.session_pool_size = size.max(1);
719 self
720 }
721
722 pub async fn build(self) -> Result<EmbeddingEngine> {
724 EmbeddingEngine::new(self.config).await
725 }
726}
727
728impl Default for EmbeddingEngineBuilder {
729 fn default() -> Self {
730 Self::new()
731 }
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737
738 #[test]
739 fn test_estimate_time() {
740 let config = ModelConfig::new(EmbeddingModel::MiniLM);
741 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
742 assert!(tokens_per_second > 0.0);
743 }
744
745 #[test]
746 fn test_builder() {
747 let builder = EmbeddingEngineBuilder::new()
748 .model(EmbeddingModel::BgeSmall)
749 .max_batch_size(64)
750 .use_gpu(false);
751
752 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
753 assert_eq!(builder.config.max_batch_size, 64);
754 assert!(!builder.config.use_gpu);
755 }
756
757 #[test]
763 fn test_model_cache_dir_with_hf_home() {
764 use std::sync::Mutex;
765 static ENV_LOCK: Mutex<()> = Mutex::new(());
766 let _guard = ENV_LOCK.lock().unwrap();
767
768 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
769 std::env::set_var("HF_HOME", &tmp);
770 let result = EmbeddingEngine::model_cache_dir("org/my-model");
771 std::env::remove_var("HF_HOME");
772
773 let path = result.unwrap();
774 assert!(
775 path.starts_with(&tmp),
776 "expected path under {tmp:?}, got {path:?}"
777 );
778 assert!(
779 path.to_str().unwrap().contains("org--my-model"),
780 "model_id separator not applied: {path:?}"
781 );
782 }
783
784 #[test]
785 fn test_model_cache_dir_contains_dakera_subdir() {
786 let path =
787 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
788 let s = path.to_str().unwrap();
789 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
790 assert!(
791 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
792 "expected transformed model id in path: {s}"
793 );
794 }
795
796 #[test]
797 fn test_model_cache_dir_creates_directory() {
798 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
799 assert!(dir.exists(), "model_cache_dir should create the directory");
800 }
801
802 #[test]
805 fn test_download_hf_file_returns_path_when_already_cached() {
806 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
807 std::fs::create_dir_all(&tmp).unwrap();
808 let file_path = tmp.join("config.json");
809 std::fs::write(&file_path, b"{}").unwrap();
810
811 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
812 assert!(result.is_ok());
813 assert_eq!(result.unwrap(), file_path);
814 }
815
816 #[test]
817 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
818 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
819 let onnx_dir = tmp.join("onnx");
820 std::fs::create_dir_all(&onnx_dir).unwrap();
821 let onnx_path = onnx_dir.join("model_quantized.onnx");
822 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
823
824 let result = EmbeddingEngine::download_hf_file(
826 "Xenova/all-MiniLM-L6-v2",
827 "onnx/model_quantized.onnx",
828 &tmp,
829 );
830 assert!(result.is_ok());
831 assert_eq!(result.unwrap(), onnx_path);
832 }
833
834 #[test]
837 fn test_builder_default_impl() {
838 let b1 = EmbeddingEngineBuilder::new();
839 let b2 = EmbeddingEngineBuilder::default();
840 assert_eq!(b1.config.model, b2.config.model);
841 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
842 }
843
844 #[test]
845 fn test_builder_model_field() {
846 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
847 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
848 }
849
850 #[test]
851 fn test_builder_cache_dir() {
852 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
853 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
854 }
855
856 #[test]
857 fn test_builder_max_batch_size() {
858 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
859 assert_eq!(builder.config.max_batch_size, 128);
860 }
861
862 #[test]
863 fn test_builder_use_gpu_true() {
864 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
865 assert!(builder.config.use_gpu);
866 }
867
868 #[test]
869 fn test_builder_use_gpu_false() {
870 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
871 assert!(!builder.config.use_gpu);
872 }
873
874 #[test]
875 fn test_builder_num_threads() {
876 let builder = EmbeddingEngineBuilder::new().num_threads(4);
877 assert_eq!(builder.config.num_threads, Some(4));
878 }
879
880 #[test]
881 fn test_builder_chain_all_fields() {
882 let builder = EmbeddingEngineBuilder::new()
883 .model(EmbeddingModel::BgeSmall)
884 .cache_dir("/cache")
885 .max_batch_size(16)
886 .use_gpu(false)
887 .num_threads(2);
888
889 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
890 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
891 assert_eq!(builder.config.max_batch_size, 16);
892 assert!(!builder.config.use_gpu);
893 assert_eq!(builder.config.num_threads, Some(2));
894 }
895
896 #[test]
899 fn test_estimate_time_zero_count() {
900 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
901 let estimate = (0.0 / tps) * 1000.0;
902 assert_eq!(estimate, 0.0);
903 }
904
905 #[test]
906 fn test_estimate_time_formula_cpu() {
907 let model = EmbeddingModel::MiniLM;
910 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
911 let total_tokens = tokens_per_text * 10.0;
912 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
913 assert!(
914 (estimate - 50.0).abs() < 1e-6,
915 "expected 50.0ms, got {estimate}"
916 );
917 }
918
919 #[test]
920 fn test_estimate_time_capped_at_max_seq_length() {
921 let model = EmbeddingModel::MiniLM;
922 let avg_len = 100_000;
923 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
924 assert_eq!(tokens_per_text, 256.0);
925 }
926
927 #[test]
930 fn test_model_config_new() {
931 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
932 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
933 assert_eq!(cfg.max_batch_size, 32);
934 assert!(!cfg.use_gpu);
935 assert!(cfg.cache_dir.is_none());
936 assert!(cfg.num_threads.is_none());
937 }
938
939 #[test]
940 fn test_model_config_default() {
941 let cfg = ModelConfig::default();
942 assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
943 assert_eq!(cfg.max_batch_size, 32);
944 assert!(!cfg.use_gpu);
945 }
946
947 #[test]
948 fn test_model_config_with_cache_dir() {
949 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
950 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
951 }
952
953 #[test]
954 fn test_model_config_with_max_batch_size() {
955 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
956 assert_eq!(cfg.max_batch_size, 64);
957 }
958
959 #[test]
960 fn test_model_config_with_gpu() {
961 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
962 assert!(cfg.use_gpu);
963 }
964
965 #[test]
966 fn test_model_config_with_num_threads() {
967 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
968 assert_eq!(cfg.num_threads, Some(8));
969 }
970
971 #[test]
972 fn test_model_config_chained_builder() {
973 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
974 .with_cache_dir("/cache")
975 .with_max_batch_size(16)
976 .with_gpu(false)
977 .with_num_threads(4);
978 assert_eq!(cfg.model, EmbeddingModel::E5Small);
979 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
980 assert_eq!(cfg.max_batch_size, 16);
981 assert!(!cfg.use_gpu);
982 assert_eq!(cfg.num_threads, Some(4));
983 }
984
985 #[test]
989 fn test_model_cache_dir_no_home_fallback() {
990 use std::sync::Mutex;
991 static ENV_LOCK: Mutex<()> = Mutex::new(());
992 let _guard = ENV_LOCK.lock().unwrap();
993
994 let saved_home = std::env::var("HOME").ok();
996 let saved_hf = std::env::var("HF_HOME").ok();
997 unsafe {
998 std::env::remove_var("HOME");
999 std::env::remove_var("HF_HOME");
1000 }
1001
1002 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
1003
1004 if let Some(h) = saved_home {
1006 unsafe { std::env::set_var("HOME", h) };
1007 }
1008 if let Some(h) = saved_hf {
1009 unsafe { std::env::set_var("HF_HOME", h) };
1010 }
1011
1012 let path = result.unwrap();
1013 assert!(
1015 path.starts_with("/tmp"),
1016 "expected path under /tmp, got {path:?}"
1017 );
1018 }
1019
1020 #[test]
1021 fn test_model_cache_dir_deep_model_id() {
1022 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
1023 let s = path.to_str().unwrap();
1024 assert!(
1026 s.contains("org--sub--model-name-with-dashes"),
1027 "expected transformed path, got: {s}"
1028 );
1029 }
1030
1031 #[test]
1032 fn test_model_cache_dir_minilm_model_id() {
1033 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
1034 let s = path.to_str().unwrap();
1035 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
1036 }
1037
1038 #[test]
1039 fn test_model_cache_dir_bge_model_id() {
1040 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
1041 let s = path.to_str().unwrap();
1042 assert!(s.contains("BAAI--bge-small-en-v1.5"));
1043 }
1044
1045 #[test]
1046 fn test_model_cache_dir_e5_model_id() {
1047 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
1048 let s = path.to_str().unwrap();
1049 assert!(s.contains("intfloat--e5-small-v2"));
1050 }
1051
1052 #[test]
1055 fn test_download_hf_file_pytorch_bin_cached() {
1056 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
1057 std::fs::create_dir_all(&tmp).unwrap();
1058 let model_path = tmp.join("pytorch_model.bin");
1059 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
1060
1061 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
1062 assert!(result.is_ok());
1063 assert_eq!(result.unwrap(), model_path);
1064 }
1065
1066 #[test]
1067 fn test_download_hf_file_tokenizer_cached() {
1068 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
1069 std::fs::create_dir_all(&tmp).unwrap();
1070 let tok_path = tmp.join("tokenizer.json");
1071 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
1072
1073 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
1074 assert!(result.is_ok());
1075 assert_eq!(result.unwrap(), tok_path);
1076 }
1077
1078 #[test]
1079 fn test_download_hf_file_config_json_cached() {
1080 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
1081 std::fs::create_dir_all(&tmp).unwrap();
1082 let cfg_path = tmp.join("config.json");
1083 std::fs::write(&cfg_path, b"{}").unwrap();
1084
1085 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
1086 assert!(result.is_ok());
1087 assert_eq!(result.unwrap(), cfg_path);
1088 }
1089
1090 #[tokio::test]
1096 #[allow(clippy::await_holding_lock)]
1097 async fn test_new_fails_with_invalid_tokenizer_json() {
1098 use std::sync::Mutex;
1099 static ENV_LOCK: Mutex<()> = Mutex::new(());
1100 let _guard = ENV_LOCK.lock().unwrap();
1101
1102 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1104 let model_dir = tmp
1105 .join("dakera")
1106 .join("sentence-transformers--all-MiniLM-L6-v2");
1107 std::fs::create_dir_all(&model_dir).unwrap();
1108 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1110 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1112 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1113
1114 unsafe { std::env::set_var("HF_HOME", &tmp) };
1115
1116 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1117 let result = EmbeddingEngine::new(config).await;
1118
1119 unsafe { std::env::remove_var("HF_HOME") };
1120
1121 assert!(
1123 result.is_err(),
1124 "expected Err from new() with invalid tokenizer, got Ok"
1125 );
1126 }
1127
1128 #[test]
1131 fn test_builder_with_all_models() {
1132 for model in [
1133 EmbeddingModel::MiniLM,
1134 EmbeddingModel::BgeSmall,
1135 EmbeddingModel::E5Small,
1136 ] {
1137 let builder = EmbeddingEngineBuilder::new().model(model);
1138 assert_eq!(builder.config.model, model);
1139 }
1140 }
1141
1142 #[test]
1143 fn test_builder_max_batch_size_one() {
1144 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1145 assert_eq!(builder.config.max_batch_size, 1);
1146 }
1147
1148 #[test]
1149 fn test_builder_num_threads_zero() {
1150 let builder = EmbeddingEngineBuilder::new().num_threads(0);
1151 assert_eq!(builder.config.num_threads, Some(0));
1152 }
1153
1154 #[tokio::test]
1160 async fn test_engine_getters_when_model_cached() {
1161 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1162 match EmbeddingEngine::new(config).await {
1163 Ok(engine) => {
1164 assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1165 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1166 let _ = format!("{:?}", engine);
1169 let ms = engine.estimate_time_ms(10, 50);
1171 assert!(ms >= 0.0);
1172 }
1173 Err(_) => {
1174 }
1176 }
1177 }
1178
1179 #[tokio::test]
1182 async fn test_engine_embed_empty_batch_when_cached() {
1183 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1184 if let Ok(engine) = EmbeddingEngine::new(config).await {
1185 let result = engine.embed_raw(&[]).await;
1186 assert!(result.is_ok());
1187 assert!(result.unwrap().is_empty());
1188 }
1189 }
1190
1191 #[test]
1194 fn test_session_pool_default_is_4() {
1195 let config = ModelConfig::default();
1198 let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1199 .ok()
1200 .and_then(|v| v.parse::<usize>().ok())
1201 .filter(|&n| n >= 1)
1202 .unwrap_or(4);
1203 assert_eq!(config.session_pool_size, expected);
1204 }
1205
1206 #[test]
1207 fn test_session_pool_size_builder_roundtrip() {
1208 let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1209 assert_eq!(builder.config.session_pool_size, 8);
1210 }
1211
1212 #[test]
1213 fn test_session_pool_size_min_enforced() {
1214 let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1215 assert_eq!(
1216 builder.config.session_pool_size, 1,
1217 "pool size 0 must clamp to 1"
1218 );
1219 }
1220
1221 #[test]
1222 fn test_model_config_with_session_pool_size() {
1223 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1224 assert_eq!(cfg.session_pool_size, 2);
1225 }
1226
1227 #[tokio::test]
1229 async fn test_engine_pool_size_matches_config_when_cached() {
1230 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1231 if let Ok(engine) = EmbeddingEngine::new(config).await {
1232 assert_eq!(
1233 engine.pool_size(),
1234 2,
1235 "engine should hold exactly 2 sessions"
1236 );
1237 }
1238 }
1239
1240 #[test]
1244 fn test_round_robin_index_stays_in_bounds() {
1245 let pool_len = 4_usize;
1246 let counter = AtomicUsize::new(0);
1247 for expected_idx in 0..100_usize {
1248 let start = counter.fetch_add(1, Ordering::Relaxed);
1249 let slot = start % pool_len;
1250 assert!(slot < pool_len);
1251 assert_eq!(slot, expected_idx % pool_len);
1252 }
1253 }
1254
1255 #[test]
1257 fn test_round_robin_pool_size_one() {
1258 let pool_len = 1_usize;
1259 let counter = AtomicUsize::new(0);
1260 for _ in 0..50 {
1261 let start = counter.fetch_add(1, Ordering::Relaxed);
1262 assert_eq!(start % pool_len, 0);
1263 }
1264 }
1265
1266 #[tokio::test]
1268 async fn test_embed_empty_does_not_advance_pool_counter() {
1269 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1270 if let Ok(engine) = EmbeddingEngine::new(config).await {
1271 let result = engine.embed_raw(&[]).await;
1272 assert!(result.is_ok());
1273 assert!(result.unwrap().is_empty());
1274 assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1276 }
1277 }
1278}