1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49pub struct EmbeddingEngine {
56 sessions: Vec<Arc<Mutex<Session>>>,
60 next_session: AtomicUsize,
62 processor: Arc<BatchProcessor>,
64 config: ModelConfig,
66 dimension: usize,
68}
69
70impl EmbeddingEngine {
71 #[instrument(skip_all, fields(model = %config.model))]
76 pub async fn new(config: ModelConfig) -> Result<Self> {
77 let use_gpu = std::env::var("DAKERA_USE_GPU")
80 .map(|v| v == "1")
81 .unwrap_or(config.use_gpu);
82 if use_gpu {
83 info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
84 }
85
86 info!(
87 "Initializing ONNX embedding engine with model: {}",
88 config.model
89 );
90
91 let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
93
94 info!("Loading tokenizer from {:?}", tokenizer_path);
96 let tokenizer = Tokenizer::from_file(&tokenizer_path)
97 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
98
99 info!("Loading ONNX model from {:?}", onnx_path);
102 let num_threads = config.num_threads.unwrap_or(4);
103 let pool_size = config.session_pool_size.max(1);
104 let onnx_path_clone = onnx_path.clone();
105
106 let sessions: Vec<Arc<Mutex<Session>>> =
107 tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
108 (0..pool_size)
109 .map(|_| {
110 let builder = Session::builder()
111 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
112 .with_optimization_level(GraphOptimizationLevel::Level3)
113 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
114 .with_intra_threads(num_threads)
115 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
116
117 let mut builder = if use_gpu {
118 builder
119 .with_execution_providers(
120 [CUDAExecutionProvider::default().build()],
121 )
122 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
123 } else {
124 builder
125 };
126
127 let s = builder
128 .commit_from_file(&onnx_path_clone)
129 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
130 Ok(Arc::new(Mutex::new(s)))
131 })
132 .collect()
133 })
134 .await
135 .map_err(|e| {
136 InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
137 })??;
138
139 let dimension = config.model.dimension();
140 let processor = Arc::new(BatchProcessor::new(
141 tokenizer,
142 config.model,
143 config.max_batch_size,
144 ));
145
146 info!(
147 "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
148 config.model, dimension, num_threads, pool_size
149 );
150
151 Ok(Self {
152 sessions,
153 next_session: AtomicUsize::new(0),
154 processor,
155 config,
156 dimension,
157 })
158 }
159
160 #[instrument(skip_all, fields(model = %config.model))]
166 async fn download_model_files(
167 config: &ModelConfig,
168 use_gpu: bool,
169 ) -> Result<(PathBuf, PathBuf)> {
170 let model_id = config.model.model_id();
171 let onnx_repo_id = config.model.onnx_repo_id();
172 let onnx_filename = if use_gpu {
173 config.model.onnx_filename_gpu()
174 } else {
175 config.model.onnx_filename()
176 };
177
178 info!(
179 "Resolving model files: tokenizer={}, onnx={}@{}",
180 model_id, onnx_filename, onnx_repo_id
181 );
182
183 let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
184 let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
185
186 let onnx_subdir = onnx_cache_dir.join("onnx");
188 std::fs::create_dir_all(&onnx_subdir)?;
189
190 let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
191 let onnx_basename = Path::new(onnx_filename)
193 .file_name()
194 .and_then(|s| s.to_str())
195 .unwrap_or("model_quantized.onnx");
196 let local_onnx = onnx_subdir.join(onnx_basename);
197
198 let tokenizer_needs_download = !local_tokenizer.exists();
200
201 if use_gpu && local_onnx.exists() {
206 let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
207 if cached_size <= 500_000_000 {
208 warn!(
209 "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated by old \
210 download limit. Deleting for complete re-download.",
211 local_onnx, cached_size
212 );
213 let _ = std::fs::remove_file(&local_onnx);
214 }
215 }
216 let onnx_needs_download = !local_onnx.exists();
217
218 if tokenizer_needs_download || onnx_needs_download {
219 let model_id_owned = model_id.to_string();
220 let onnx_repo_id_owned = onnx_repo_id.to_string();
221 let onnx_filename_owned = onnx_filename.to_string();
222 let tokenizer_cache = tokenizer_cache_dir.clone();
223 let onnx_cache = onnx_cache_dir.clone();
224
225 tokio::task::spawn_blocking(move || {
226 if !tokenizer_cache.join("tokenizer.json").exists() {
227 Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
228 .map_err(|e| {
229 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
230 })?;
231 }
232 if !onnx_cache.join(&onnx_filename_owned).exists() {
233 Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
234 .map_err(|e| {
235 InferenceError::HubError(format!(
236 "Failed to download ONNX model: {}",
237 e
238 ))
239 })?;
240 }
241 Ok::<_, InferenceError>(())
242 })
243 .await
244 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
245 } else {
246 info!("All model files found in local cache");
247 }
248
249 let final_onnx = onnx_cache_dir.join(onnx_filename);
251
252 info!(
253 "Model files ready: tokenizer={:?}, onnx={:?}",
254 local_tokenizer, final_onnx
255 );
256 Ok((local_tokenizer, final_onnx))
257 }
258
259 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
261 let base = std::env::var("HF_HOME")
262 .map(PathBuf::from)
263 .unwrap_or_else(|_| {
264 let home = std::env::var("HOME").unwrap_or_else(|_| {
265 warn!("HOME environment variable not set, using /tmp for model cache");
266 "/tmp".to_string()
267 });
268 PathBuf::from(home).join(".cache").join("huggingface")
269 });
270 let dir = base.join("dakera").join(model_id.replace('/', "--"));
271 std::fs::create_dir_all(&dir)?;
272 Ok(dir)
273 }
274
275 pub fn download_hf_file_pub(
281 model_id: &str,
282 filename: &str,
283 cache_dir: &Path,
284 ) -> std::result::Result<PathBuf, String> {
285 Self::download_hf_file(model_id, filename, cache_dir)
286 }
287
288 fn download_hf_file(
289 model_id: &str,
290 filename: &str,
291 cache_dir: &Path,
292 ) -> std::result::Result<PathBuf, String> {
293 let file_path = cache_dir.join(filename);
295 if file_path.exists() {
296 info!("Cached: {}/{}", model_id, filename);
297 return Ok(file_path);
298 }
299
300 if let Some(parent) = file_path.parent() {
302 std::fs::create_dir_all(parent)
303 .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
304 }
305
306 let url = format!(
307 "https://huggingface.co/{}/resolve/main/{}",
308 model_id, filename
309 );
310 info!("Downloading: {}", url);
311
312 let hf_token = std::env::var("HF_TOKEN")
314 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
315 .ok();
316 if hf_token.is_some() {
317 info!("Using HuggingFace auth token for download");
318 }
319
320 let agent = ureq::AgentBuilder::new()
322 .redirects(0)
323 .timeout(std::time::Duration::from_secs(300))
324 .build();
325
326 let mut current_url = url.clone();
327 let mut redirects = 0;
328 let max_redirects = 10;
329
330 let response = loop {
331 let mut req = agent.get(¤t_url);
332 if let Some(ref token) = hf_token {
333 req = req.set("Authorization", &format!("Bearer {}", token));
334 }
335 let resp = req.call();
336
337 let r = match resp {
338 Ok(r) => r,
339 Err(ureq::Error::Status(_status, r)) => r,
340 Err(e) => return Err(format!("{}: {}", filename, e)),
341 };
342
343 let status = r.status();
344 if (200..300).contains(&status) {
345 break r;
346 } else if (300..400).contains(&status) {
347 redirects += 1;
348 if redirects > max_redirects {
349 return Err(format!("{}: too many redirects", filename));
350 }
351 let location = r
352 .header("location")
353 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
354 .to_string();
355
356 current_url = if location.starts_with('/') {
358 let parsed = url::Url::parse(¤t_url)
359 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
360 let host = parsed.host_str().ok_or_else(|| {
361 format!("{}: redirect URL missing host: {}", filename, current_url)
362 })?;
363 format!("{}://{}{}", parsed.scheme(), host, location)
364 } else {
365 location
366 };
367 info!("Redirect {} → {}", redirects, current_url);
368 } else {
369 return Err(format!("{}: HTTP {}", filename, status));
370 }
371 };
372
373 let expected_bytes: Option<u64> = response
377 .header("x-linked-size")
378 .or_else(|| response.header("content-length"))
379 .and_then(|v| v.parse::<u64>().ok());
380
381 let mut bytes = Vec::new();
385 response
386 .into_reader()
387 .take(2_147_483_648)
388 .read_to_end(&mut bytes)
389 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
390
391 if let Some(expected) = expected_bytes {
395 let actual = bytes.len() as u64;
396 if actual < expected {
397 return Err(format!(
398 "{}: download incomplete — received {} of {} bytes. \
399 File may exceed 2 GiB or the connection was interrupted.",
400 filename, actual, expected
401 ));
402 }
403 }
404
405 std::fs::write(&file_path, &bytes)
406 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
407
408 info!("Downloaded {} ({} bytes)", filename, bytes.len());
409 Ok(file_path)
410 }
411
412 pub fn dimension(&self) -> usize {
414 self.dimension
415 }
416
417 pub fn model(&self) -> EmbeddingModel {
419 self.config.model
420 }
421
422 pub fn pool_size(&self) -> usize {
424 self.sessions.len()
425 }
426
427 #[instrument(skip(self, text), fields(text_len = text.len()))]
431 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
432 let texts = vec![text.to_string()];
433 let prepared = self.processor.prepare_texts(&texts, true);
434 let embeddings = self.embed_batch_internal(&prepared).await?;
435 embeddings.into_iter().next().ok_or_else(|| {
436 InferenceError::InferenceError("No embedding returned for query".to_string())
437 })
438 }
439
440 #[instrument(skip(self, texts), fields(count = texts.len()))]
444 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
445 let prepared = self.processor.prepare_texts(texts, true);
446 self.embed_batch_internal(&prepared).await
447 }
448
449 #[instrument(skip(self, text), fields(text_len = text.len()))]
453 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
454 let texts = vec![text.to_string()];
455 let prepared = self.processor.prepare_texts(&texts, false);
456 let embeddings = self.embed_batch_internal(&prepared).await?;
457 embeddings.into_iter().next().ok_or_else(|| {
458 InferenceError::InferenceError("No embedding returned for document".to_string())
459 })
460 }
461
462 #[instrument(skip(self, texts), fields(count = texts.len()))]
466 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
467 let prepared = self.processor.prepare_texts(texts, false);
468 self.embed_batch_internal(&prepared).await
469 }
470
471 #[instrument(skip(self, texts), fields(count = texts.len()))]
473 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
474 self.embed_batch_internal(texts).await
475 }
476
477 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
485 if texts.is_empty() {
486 return Ok(vec![]);
487 }
488
489 let batches: Vec<Vec<String>> = self
490 .processor
491 .split_into_batches(texts)
492 .into_iter()
493 .map(|b| b.to_vec())
494 .collect();
495
496 let pool_len = self.sessions.len();
497 let normalize = self.config.model.normalize_embeddings();
498 let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
501
502 let mut handles = Vec::with_capacity(batches.len());
504 for (i, batch_owned) in batches.into_iter().enumerate() {
505 let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
506 let processor = Arc::clone(&self.processor);
507 handles.push(tokio::task::spawn_blocking(move || {
508 let mut session_guard = session.lock();
509 Self::process_batch_blocking(
510 &batch_owned,
511 &mut session_guard,
512 &processor,
513 normalize,
514 )
515 }));
516 }
517
518 let mut all_embeddings = Vec::with_capacity(texts.len());
519 for handle in handles {
520 let batch_embeddings = handle.await.map_err(|e| {
521 InferenceError::InferenceError(format!("Inference task panicked: {}", e))
522 })??;
523 all_embeddings.extend(batch_embeddings);
524 }
525
526 Ok(all_embeddings)
527 }
528
529 fn process_batch_blocking(
533 texts: &[String],
534 session: &mut Session,
535 processor: &BatchProcessor,
536 normalize: bool,
537 ) -> Result<Vec<Vec<f32>>> {
538 let prepared = processor.tokenize_batch(texts)?;
540 let batch_size = prepared.batch_size;
541 let seq_len = prepared.seq_len;
542
543 let attention_mask_flat = prepared.attention_mask.clone();
545
546 let input_ids_tensor =
548 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
549 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
550 let attention_mask_tensor =
551 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
552 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
553 let token_type_ids_tensor =
554 Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
555 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
556
557 let outputs = session
559 .run(inputs![
560 "input_ids" => input_ids_tensor,
561 "attention_mask" => attention_mask_tensor,
562 "token_type_ids" => token_type_ids_tensor
563 ])
564 .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
565
566 let (ort_shape, lhs_slice) = outputs[0]
570 .try_extract_tensor::<f32>()
571 .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
572
573 if ort_shape.len() != 3 {
574 return Err(InferenceError::InferenceError(format!(
575 "Expected 3D last_hidden_state, got {} dims",
576 ort_shape.len()
577 )));
578 }
579 let hidden_size = ort_shape[2] as usize;
580
581 let mut embeddings = mean_pooling(
583 lhs_slice,
584 batch_size,
585 seq_len,
586 hidden_size,
587 &attention_mask_flat,
588 );
589
590 if normalize {
592 normalize_embeddings(&mut embeddings);
593 }
594
595 debug!(
596 "Generated {} embeddings of dimension {}",
597 embeddings.len(),
598 embeddings.first().map(|e| e.len()).unwrap_or(0)
599 );
600
601 Ok(embeddings)
602 }
603
604 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
606 let tokens_per_text =
608 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
609 let total_tokens = tokens_per_text * text_count as f64;
610 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
611 (total_tokens / tokens_per_second) * 1000.0
612 }
613}
614
615impl std::fmt::Debug for EmbeddingEngine {
616 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617 f.debug_struct("EmbeddingEngine")
618 .field("model", &self.config.model)
619 .field("dimension", &self.dimension)
620 .field("max_batch_size", &self.config.max_batch_size)
621 .field("session_pool_size", &self.sessions.len())
622 .finish()
623 }
624}
625
626pub struct EmbeddingEngineBuilder {
628 config: ModelConfig,
629}
630
631impl EmbeddingEngineBuilder {
632 pub fn new() -> Self {
634 Self {
635 config: ModelConfig::default(),
636 }
637 }
638
639 pub fn model(mut self, model: EmbeddingModel) -> Self {
641 self.config.model = model;
642 self
643 }
644
645 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
647 self.config.cache_dir = Some(dir.into());
648 self
649 }
650
651 pub fn max_batch_size(mut self, size: usize) -> Self {
653 self.config.max_batch_size = size;
654 self
655 }
656
657 pub fn use_gpu(mut self, enable: bool) -> Self {
659 self.config.use_gpu = enable;
660 self
661 }
662
663 pub fn num_threads(mut self, threads: usize) -> Self {
665 self.config.num_threads = Some(threads);
666 self
667 }
668
669 pub fn session_pool_size(mut self, size: usize) -> Self {
671 self.config.session_pool_size = size.max(1);
672 self
673 }
674
675 pub async fn build(self) -> Result<EmbeddingEngine> {
677 EmbeddingEngine::new(self.config).await
678 }
679}
680
681impl Default for EmbeddingEngineBuilder {
682 fn default() -> Self {
683 Self::new()
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690
691 #[test]
692 fn test_estimate_time() {
693 let config = ModelConfig::new(EmbeddingModel::MiniLM);
694 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
695 assert!(tokens_per_second > 0.0);
696 }
697
698 #[test]
699 fn test_builder() {
700 let builder = EmbeddingEngineBuilder::new()
701 .model(EmbeddingModel::BgeSmall)
702 .max_batch_size(64)
703 .use_gpu(false);
704
705 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
706 assert_eq!(builder.config.max_batch_size, 64);
707 assert!(!builder.config.use_gpu);
708 }
709
710 #[test]
716 fn test_model_cache_dir_with_hf_home() {
717 use std::sync::Mutex;
718 static ENV_LOCK: Mutex<()> = Mutex::new(());
719 let _guard = ENV_LOCK.lock().unwrap();
720
721 let tmp = std::env::temp_dir().join("dakera_test_hf_home");
722 std::env::set_var("HF_HOME", &tmp);
723 let result = EmbeddingEngine::model_cache_dir("org/my-model");
724 std::env::remove_var("HF_HOME");
725
726 let path = result.unwrap();
727 assert!(
728 path.starts_with(&tmp),
729 "expected path under {tmp:?}, got {path:?}"
730 );
731 assert!(
732 path.to_str().unwrap().contains("org--my-model"),
733 "model_id separator not applied: {path:?}"
734 );
735 }
736
737 #[test]
738 fn test_model_cache_dir_contains_dakera_subdir() {
739 let path =
740 EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
741 let s = path.to_str().unwrap();
742 assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
743 assert!(
744 s.contains("sentence-transformers--all-MiniLM-L6-v2"),
745 "expected transformed model id in path: {s}"
746 );
747 }
748
749 #[test]
750 fn test_model_cache_dir_creates_directory() {
751 let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
752 assert!(dir.exists(), "model_cache_dir should create the directory");
753 }
754
755 #[test]
758 fn test_download_hf_file_returns_path_when_already_cached() {
759 let tmp = std::env::temp_dir().join("dakera_test_cached_file");
760 std::fs::create_dir_all(&tmp).unwrap();
761 let file_path = tmp.join("config.json");
762 std::fs::write(&file_path, b"{}").unwrap();
763
764 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
765 assert!(result.is_ok());
766 assert_eq!(result.unwrap(), file_path);
767 }
768
769 #[test]
770 fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
771 let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
772 let onnx_dir = tmp.join("onnx");
773 std::fs::create_dir_all(&onnx_dir).unwrap();
774 let onnx_path = onnx_dir.join("model_quantized.onnx");
775 std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
776
777 let result = EmbeddingEngine::download_hf_file(
779 "Xenova/all-MiniLM-L6-v2",
780 "onnx/model_quantized.onnx",
781 &tmp,
782 );
783 assert!(result.is_ok());
784 assert_eq!(result.unwrap(), onnx_path);
785 }
786
787 #[test]
790 fn test_builder_default_impl() {
791 let b1 = EmbeddingEngineBuilder::new();
792 let b2 = EmbeddingEngineBuilder::default();
793 assert_eq!(b1.config.model, b2.config.model);
794 assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
795 }
796
797 #[test]
798 fn test_builder_model_field() {
799 let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
800 assert_eq!(builder.config.model, EmbeddingModel::E5Small);
801 }
802
803 #[test]
804 fn test_builder_cache_dir() {
805 let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
806 assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
807 }
808
809 #[test]
810 fn test_builder_max_batch_size() {
811 let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
812 assert_eq!(builder.config.max_batch_size, 128);
813 }
814
815 #[test]
816 fn test_builder_use_gpu_true() {
817 let builder = EmbeddingEngineBuilder::new().use_gpu(true);
818 assert!(builder.config.use_gpu);
819 }
820
821 #[test]
822 fn test_builder_use_gpu_false() {
823 let builder = EmbeddingEngineBuilder::new().use_gpu(false);
824 assert!(!builder.config.use_gpu);
825 }
826
827 #[test]
828 fn test_builder_num_threads() {
829 let builder = EmbeddingEngineBuilder::new().num_threads(4);
830 assert_eq!(builder.config.num_threads, Some(4));
831 }
832
833 #[test]
834 fn test_builder_chain_all_fields() {
835 let builder = EmbeddingEngineBuilder::new()
836 .model(EmbeddingModel::BgeSmall)
837 .cache_dir("/cache")
838 .max_batch_size(16)
839 .use_gpu(false)
840 .num_threads(2);
841
842 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
843 assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
844 assert_eq!(builder.config.max_batch_size, 16);
845 assert!(!builder.config.use_gpu);
846 assert_eq!(builder.config.num_threads, Some(2));
847 }
848
849 #[test]
852 fn test_estimate_time_zero_count() {
853 let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
854 let estimate = (0.0 / tps) * 1000.0;
855 assert_eq!(estimate, 0.0);
856 }
857
858 #[test]
859 fn test_estimate_time_formula_cpu() {
860 let model = EmbeddingModel::MiniLM;
863 let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
864 let total_tokens = tokens_per_text * 10.0;
865 let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
866 assert!(
867 (estimate - 50.0).abs() < 1e-6,
868 "expected 50.0ms, got {estimate}"
869 );
870 }
871
872 #[test]
873 fn test_estimate_time_capped_at_max_seq_length() {
874 let model = EmbeddingModel::MiniLM;
875 let avg_len = 100_000;
876 let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
877 assert_eq!(tokens_per_text, 256.0);
878 }
879
880 #[test]
883 fn test_model_config_new() {
884 let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
885 assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
886 assert_eq!(cfg.max_batch_size, 32);
887 assert!(!cfg.use_gpu);
888 assert!(cfg.cache_dir.is_none());
889 assert!(cfg.num_threads.is_none());
890 }
891
892 #[test]
893 fn test_model_config_default() {
894 let cfg = ModelConfig::default();
895 assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
896 assert_eq!(cfg.max_batch_size, 32);
897 assert!(!cfg.use_gpu);
898 }
899
900 #[test]
901 fn test_model_config_with_cache_dir() {
902 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
903 assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
904 }
905
906 #[test]
907 fn test_model_config_with_max_batch_size() {
908 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
909 assert_eq!(cfg.max_batch_size, 64);
910 }
911
912 #[test]
913 fn test_model_config_with_gpu() {
914 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
915 assert!(cfg.use_gpu);
916 }
917
918 #[test]
919 fn test_model_config_with_num_threads() {
920 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
921 assert_eq!(cfg.num_threads, Some(8));
922 }
923
924 #[test]
925 fn test_model_config_chained_builder() {
926 let cfg = ModelConfig::new(EmbeddingModel::E5Small)
927 .with_cache_dir("/cache")
928 .with_max_batch_size(16)
929 .with_gpu(false)
930 .with_num_threads(4);
931 assert_eq!(cfg.model, EmbeddingModel::E5Small);
932 assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
933 assert_eq!(cfg.max_batch_size, 16);
934 assert!(!cfg.use_gpu);
935 assert_eq!(cfg.num_threads, Some(4));
936 }
937
938 #[test]
942 fn test_model_cache_dir_no_home_fallback() {
943 use std::sync::Mutex;
944 static ENV_LOCK: Mutex<()> = Mutex::new(());
945 let _guard = ENV_LOCK.lock().unwrap();
946
947 let saved_home = std::env::var("HOME").ok();
949 let saved_hf = std::env::var("HF_HOME").ok();
950 unsafe {
951 std::env::remove_var("HOME");
952 std::env::remove_var("HF_HOME");
953 }
954
955 let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
956
957 if let Some(h) = saved_home {
959 unsafe { std::env::set_var("HOME", h) };
960 }
961 if let Some(h) = saved_hf {
962 unsafe { std::env::set_var("HF_HOME", h) };
963 }
964
965 let path = result.unwrap();
966 assert!(
968 path.starts_with("/tmp"),
969 "expected path under /tmp, got {path:?}"
970 );
971 }
972
973 #[test]
974 fn test_model_cache_dir_deep_model_id() {
975 let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
976 let s = path.to_str().unwrap();
977 assert!(
979 s.contains("org--sub--model-name-with-dashes"),
980 "expected transformed path, got: {s}"
981 );
982 }
983
984 #[test]
985 fn test_model_cache_dir_minilm_model_id() {
986 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
987 let s = path.to_str().unwrap();
988 assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
989 }
990
991 #[test]
992 fn test_model_cache_dir_bge_model_id() {
993 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
994 let s = path.to_str().unwrap();
995 assert!(s.contains("BAAI--bge-small-en-v1.5"));
996 }
997
998 #[test]
999 fn test_model_cache_dir_e5_model_id() {
1000 let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
1001 let s = path.to_str().unwrap();
1002 assert!(s.contains("intfloat--e5-small-v2"));
1003 }
1004
1005 #[test]
1008 fn test_download_hf_file_pytorch_bin_cached() {
1009 let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
1010 std::fs::create_dir_all(&tmp).unwrap();
1011 let model_path = tmp.join("pytorch_model.bin");
1012 std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
1013
1014 let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
1015 assert!(result.is_ok());
1016 assert_eq!(result.unwrap(), model_path);
1017 }
1018
1019 #[test]
1020 fn test_download_hf_file_tokenizer_cached() {
1021 let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
1022 std::fs::create_dir_all(&tmp).unwrap();
1023 let tok_path = tmp.join("tokenizer.json");
1024 std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
1025
1026 let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
1027 assert!(result.is_ok());
1028 assert_eq!(result.unwrap(), tok_path);
1029 }
1030
1031 #[test]
1032 fn test_download_hf_file_config_json_cached() {
1033 let tmp = std::env::temp_dir().join("dakera_test_config_cached");
1034 std::fs::create_dir_all(&tmp).unwrap();
1035 let cfg_path = tmp.join("config.json");
1036 std::fs::write(&cfg_path, b"{}").unwrap();
1037
1038 let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
1039 assert!(result.is_ok());
1040 assert_eq!(result.unwrap(), cfg_path);
1041 }
1042
1043 #[tokio::test]
1049 #[allow(clippy::await_holding_lock)]
1050 async fn test_new_fails_with_invalid_tokenizer_json() {
1051 use std::sync::Mutex;
1052 static ENV_LOCK: Mutex<()> = Mutex::new(());
1053 let _guard = ENV_LOCK.lock().unwrap();
1054
1055 let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1057 let model_dir = tmp
1058 .join("dakera")
1059 .join("sentence-transformers--all-MiniLM-L6-v2");
1060 std::fs::create_dir_all(&model_dir).unwrap();
1061 std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1063 std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1065 std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1066
1067 unsafe { std::env::set_var("HF_HOME", &tmp) };
1068
1069 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1070 let result = EmbeddingEngine::new(config).await;
1071
1072 unsafe { std::env::remove_var("HF_HOME") };
1073
1074 assert!(
1076 result.is_err(),
1077 "expected Err from new() with invalid tokenizer, got Ok"
1078 );
1079 }
1080
1081 #[test]
1084 fn test_builder_with_all_models() {
1085 for model in [
1086 EmbeddingModel::MiniLM,
1087 EmbeddingModel::BgeSmall,
1088 EmbeddingModel::E5Small,
1089 ] {
1090 let builder = EmbeddingEngineBuilder::new().model(model);
1091 assert_eq!(builder.config.model, model);
1092 }
1093 }
1094
1095 #[test]
1096 fn test_builder_max_batch_size_one() {
1097 let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1098 assert_eq!(builder.config.max_batch_size, 1);
1099 }
1100
1101 #[test]
1102 fn test_builder_num_threads_zero() {
1103 let builder = EmbeddingEngineBuilder::new().num_threads(0);
1104 assert_eq!(builder.config.num_threads, Some(0));
1105 }
1106
1107 #[tokio::test]
1113 async fn test_engine_getters_when_model_cached() {
1114 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1115 match EmbeddingEngine::new(config).await {
1116 Ok(engine) => {
1117 assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1118 assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1119 let _ = format!("{:?}", engine);
1122 let ms = engine.estimate_time_ms(10, 50);
1124 assert!(ms >= 0.0);
1125 }
1126 Err(_) => {
1127 }
1129 }
1130 }
1131
1132 #[tokio::test]
1135 async fn test_engine_embed_empty_batch_when_cached() {
1136 let config = ModelConfig::new(EmbeddingModel::MiniLM);
1137 if let Ok(engine) = EmbeddingEngine::new(config).await {
1138 let result = engine.embed_raw(&[]).await;
1139 assert!(result.is_ok());
1140 assert!(result.unwrap().is_empty());
1141 }
1142 }
1143
1144 #[test]
1147 fn test_session_pool_default_is_4() {
1148 let config = ModelConfig::default();
1151 let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1152 .ok()
1153 .and_then(|v| v.parse::<usize>().ok())
1154 .filter(|&n| n >= 1)
1155 .unwrap_or(4);
1156 assert_eq!(config.session_pool_size, expected);
1157 }
1158
1159 #[test]
1160 fn test_session_pool_size_builder_roundtrip() {
1161 let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1162 assert_eq!(builder.config.session_pool_size, 8);
1163 }
1164
1165 #[test]
1166 fn test_session_pool_size_min_enforced() {
1167 let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1168 assert_eq!(
1169 builder.config.session_pool_size, 1,
1170 "pool size 0 must clamp to 1"
1171 );
1172 }
1173
1174 #[test]
1175 fn test_model_config_with_session_pool_size() {
1176 let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1177 assert_eq!(cfg.session_pool_size, 2);
1178 }
1179
1180 #[tokio::test]
1182 async fn test_engine_pool_size_matches_config_when_cached() {
1183 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1184 if let Ok(engine) = EmbeddingEngine::new(config).await {
1185 assert_eq!(
1186 engine.pool_size(),
1187 2,
1188 "engine should hold exactly 2 sessions"
1189 );
1190 }
1191 }
1192
1193 #[test]
1197 fn test_round_robin_index_stays_in_bounds() {
1198 let pool_len = 4_usize;
1199 let counter = AtomicUsize::new(0);
1200 for expected_idx in 0..100_usize {
1201 let start = counter.fetch_add(1, Ordering::Relaxed);
1202 let slot = start % pool_len;
1203 assert!(slot < pool_len);
1204 assert_eq!(slot, expected_idx % pool_len);
1205 }
1206 }
1207
1208 #[test]
1210 fn test_round_robin_pool_size_one() {
1211 let pool_len = 1_usize;
1212 let counter = AtomicUsize::new(0);
1213 for _ in 0..50 {
1214 let start = counter.fetch_add(1, Ordering::Relaxed);
1215 assert_eq!(start % pool_len, 0);
1216 }
1217 }
1218
1219 #[tokio::test]
1221 async fn test_embed_empty_does_not_advance_pool_counter() {
1222 let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1223 if let Ok(engine) = EmbeddingEngine::new(config).await {
1224 let result = engine.embed_raw(&[]).await;
1225 assert!(result.is_ok());
1226 assert!(result.unwrap().is_empty());
1227 assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1229 }
1230 }
1231}