Skip to main content

inference/
engine.rs

1//! Core embedding engine for generating vector embeddings from text.
2//!
3//! The `EmbeddingEngine` provides a high-level interface for:
4//! - Loading ONNX INT8 embedding models from HuggingFace Hub
5//! - Generating embeddings for single texts or batches
6//! - Automatic batching and parallel processing via ONNX Runtime
7//!
8//! # Example
9//!
10//! ```no_run
11//! use inference::{EmbeddingEngine, ModelConfig, EmbeddingModel};
12//!
13//! #[tokio::main]
14//! async fn main() {
15//!     let config = ModelConfig::new(EmbeddingModel::MiniLM);
16//!     let engine = EmbeddingEngine::new(config).await.unwrap();
17//!
18//!     // Embed a single query
19//!     let embedding = engine.embed_query("What is machine learning?").await.unwrap();
20//!     println!("Embedding dimension: {}", embedding.len());
21//!
22//!     // Embed multiple documents
23//!     let docs = vec![
24//!         "Machine learning is a subset of AI.".to_string(),
25//!         "Deep learning uses neural networks.".to_string(),
26//!     ];
27//!     let embeddings = engine.embed_documents(&docs).await.unwrap();
28//!     println!("Generated {} embeddings", embeddings.len());
29//! }
30//! ```
31
32use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::Arc;
43use tokenizers::Tokenizer;
44use tracing::{debug, info, instrument, warn};
45
46/// The main embedding engine for generating vector embeddings.
47///
48/// This struct is thread-safe and can be shared across async tasks.
49/// The ORT session is wrapped in `Arc<Mutex<_>>` because `Session::run`
50/// requires `&mut self`. CPU-heavy inference is offloaded to blocking
51/// threads via `tokio::task::spawn_blocking`.
52pub struct EmbeddingEngine {
53    /// The loaded ONNX Runtime session (mutex-guarded because run() takes &mut self)
54    session: Arc<Mutex<Session>>,
55    /// Batch processor for tokenization (Arc-wrapped for spawn_blocking)
56    processor: Arc<BatchProcessor>,
57    /// Model configuration
58    config: ModelConfig,
59    /// Embedding dimension
60    dimension: usize,
61}
62
63impl EmbeddingEngine {
64    /// Create a new embedding engine with the given configuration.
65    ///
66    /// Downloads the ONNX INT8 model from HuggingFace Hub if not cached.
67    #[instrument(skip_all, fields(model = %config.model))]
68    pub async fn new(config: ModelConfig) -> Result<Self> {
69        info!(
70            "Initializing ONNX embedding engine with model: {}",
71            config.model
72        );
73
74        // Download tokenizer and ONNX model files
75        let (tokenizer_path, onnx_path) = Self::download_model_files(&config).await?;
76
77        // Load tokenizer
78        info!("Loading tokenizer from {:?}", tokenizer_path);
79        let tokenizer = Tokenizer::from_file(&tokenizer_path)
80            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
81
82        // Build ORT session
83        info!("Loading ONNX model from {:?}", onnx_path);
84        let num_threads = config.num_threads.unwrap_or(1);
85        let session = Session::builder()
86            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
87            .with_optimization_level(GraphOptimizationLevel::Level3)
88            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
89            .with_intra_threads(num_threads)
90            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
91            .commit_from_file(&onnx_path)
92            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
93
94        let dimension = config.model.dimension();
95        let processor = Arc::new(BatchProcessor::new(
96            tokenizer,
97            config.model,
98            config.max_batch_size,
99        ));
100
101        info!(
102            "ONNX embedding engine ready: model={}, dimension={}, threads={}",
103            config.model, dimension, num_threads
104        );
105
106        Ok(Self {
107            session: Arc::new(Mutex::new(session)),
108            processor,
109            config,
110            dimension,
111        })
112    }
113
114    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
115    ///
116    /// - `tokenizer.json` — from the original model repo (sentence-transformers, BAAI, intfloat)
117    /// - `onnx/model_quantized.onnx` — from the Xenova ONNX repo (INT8, pre-built)
118    #[instrument(skip_all, fields(model = %config.model))]
119    async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf)> {
120        let model_id = config.model.model_id();
121        let onnx_repo_id = config.model.onnx_repo_id();
122        let onnx_filename = config.model.onnx_filename();
123
124        info!(
125            "Resolving model files: tokenizer={}, onnx={}@{}",
126            model_id, onnx_filename, onnx_repo_id
127        );
128
129        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
130        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
131
132        // ONNX sub-directory mirrors the path within the repo (e.g. "onnx/")
133        let onnx_subdir = onnx_cache_dir.join("onnx");
134        std::fs::create_dir_all(&onnx_subdir)?;
135
136        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
137        // onnx_filename is "onnx/model_quantized.onnx" — basename is the last component
138        let onnx_basename = Path::new(onnx_filename)
139            .file_name()
140            .and_then(|s| s.to_str())
141            .unwrap_or("model_quantized.onnx");
142        let local_onnx = onnx_subdir.join(onnx_basename);
143
144        // Download missing files in a blocking thread
145        let tokenizer_needs_download = !local_tokenizer.exists();
146        let onnx_needs_download = !local_onnx.exists();
147
148        if tokenizer_needs_download || onnx_needs_download {
149            let model_id_owned = model_id.to_string();
150            let onnx_repo_id_owned = onnx_repo_id.to_string();
151            let onnx_filename_owned = onnx_filename.to_string();
152            let tokenizer_cache = tokenizer_cache_dir.clone();
153            let onnx_cache = onnx_cache_dir.clone();
154
155            tokio::task::spawn_blocking(move || {
156                if !tokenizer_cache.join("tokenizer.json").exists() {
157                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
158                        .map_err(|e| {
159                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
160                        })?;
161                }
162                if !onnx_cache.join(&onnx_filename_owned).exists() {
163                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
164                        .map_err(|e| {
165                            InferenceError::HubError(format!(
166                                "Failed to download ONNX model: {}",
167                                e
168                            ))
169                        })?;
170                }
171                Ok::<_, InferenceError>(())
172            })
173            .await
174            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
175        } else {
176            info!("All model files found in local cache");
177        }
178
179        // Re-derive paths (cache dir / onnx / basename)
180        let final_onnx = onnx_cache_dir.join(onnx_filename);
181
182        info!(
183            "Model files ready: tokenizer={:?}, onnx={:?}",
184            local_tokenizer, final_onnx
185        );
186        Ok((local_tokenizer, final_onnx))
187    }
188
189    /// Get or create the local model cache directory.
190    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
191        let base = std::env::var("HF_HOME")
192            .map(PathBuf::from)
193            .unwrap_or_else(|_| {
194                let home = std::env::var("HOME").unwrap_or_else(|_| {
195                    warn!("HOME environment variable not set, using /tmp for model cache");
196                    "/tmp".to_string()
197                });
198                PathBuf::from(home).join(".cache").join("huggingface")
199            });
200        let dir = base.join("dakera").join(model_id.replace('/', "--"));
201        std::fs::create_dir_all(&dir)?;
202        Ok(dir)
203    }
204
205    /// Download a single file from HuggingFace using ureq (sync, for spawn_blocking).
206    ///
207    /// Handles relative Location headers that ureq 2.x cannot resolve automatically.
208    ///
209    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
210    pub fn download_hf_file_pub(
211        model_id: &str,
212        filename: &str,
213        cache_dir: &Path,
214    ) -> std::result::Result<PathBuf, String> {
215        Self::download_hf_file(model_id, filename, cache_dir)
216    }
217
218    fn download_hf_file(
219        model_id: &str,
220        filename: &str,
221        cache_dir: &Path,
222    ) -> std::result::Result<PathBuf, String> {
223        // The file may be nested (e.g. "onnx/model_quantized.onnx")
224        let file_path = cache_dir.join(filename);
225        if file_path.exists() {
226            info!("Cached: {}/{}", model_id, filename);
227            return Ok(file_path);
228        }
229
230        // Ensure parent directory exists (for "onnx/model_quantized.onnx")
231        if let Some(parent) = file_path.parent() {
232            std::fs::create_dir_all(parent)
233                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
234        }
235
236        let url = format!(
237            "https://huggingface.co/{}/resolve/main/{}",
238            model_id, filename
239        );
240        info!("Downloading: {}", url);
241
242        // Disable automatic redirects so we can resolve relative Location headers ourselves.
243        let agent = ureq::AgentBuilder::new()
244            .redirects(0)
245            .timeout(std::time::Duration::from_secs(300))
246            .build();
247
248        let mut current_url = url.clone();
249        let mut redirects = 0;
250        let max_redirects = 10;
251
252        let response = loop {
253            let resp = agent.get(&current_url).call();
254
255            let r = match resp {
256                Ok(r) => r,
257                Err(ureq::Error::Status(_status, r)) => r,
258                Err(e) => return Err(format!("{}: {}", filename, e)),
259            };
260
261            let status = r.status();
262            if (200..300).contains(&status) {
263                break r;
264            } else if (300..400).contains(&status) {
265                redirects += 1;
266                if redirects > max_redirects {
267                    return Err(format!("{}: too many redirects", filename));
268                }
269                let location = r
270                    .header("location")
271                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
272                    .to_string();
273
274                // Resolve relative redirects against the current URL's origin
275                current_url = if location.starts_with('/') {
276                    let parsed = url::Url::parse(&current_url)
277                        .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
278                    let host = parsed.host_str().ok_or_else(|| {
279                        format!("{}: redirect URL missing host: {}", filename, current_url)
280                    })?;
281                    format!("{}://{}{}", parsed.scheme(), host, location)
282                } else {
283                    location
284                };
285                info!("Redirect {} → {}", redirects, current_url);
286            } else {
287                return Err(format!("{}: HTTP {}", filename, status));
288            }
289        };
290
291        let mut bytes = Vec::new();
292        response
293            .into_reader()
294            .take(500_000_000) // 500 MB safety limit
295            .read_to_end(&mut bytes)
296            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
297
298        std::fs::write(&file_path, &bytes)
299            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
300
301        info!("Downloaded {} ({} bytes)", filename, bytes.len());
302        Ok(file_path)
303    }
304
305    /// Get the embedding dimension for the loaded model.
306    pub fn dimension(&self) -> usize {
307        self.dimension
308    }
309
310    /// Get the model being used.
311    pub fn model(&self) -> EmbeddingModel {
312        self.config.model
313    }
314
315    /// Embed a single query text.
316    ///
317    /// For models like E5, this automatically applies the query prefix.
318    #[instrument(skip(self, text), fields(text_len = text.len()))]
319    pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
320        let texts = vec![text.to_string()];
321        let prepared = self.processor.prepare_texts(&texts, true);
322        let embeddings = self.embed_batch_internal(&prepared).await?;
323        embeddings.into_iter().next().ok_or_else(|| {
324            InferenceError::InferenceError("No embedding returned for query".to_string())
325        })
326    }
327
328    /// Embed multiple query texts.
329    ///
330    /// For models like E5, this automatically applies the query prefix.
331    #[instrument(skip(self, texts), fields(count = texts.len()))]
332    pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
333        let prepared = self.processor.prepare_texts(texts, true);
334        self.embed_batch_internal(&prepared).await
335    }
336
337    /// Embed a single document/passage.
338    ///
339    /// For models like E5, this automatically applies the document prefix.
340    #[instrument(skip(self, text), fields(text_len = text.len()))]
341    pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
342        let texts = vec![text.to_string()];
343        let prepared = self.processor.prepare_texts(&texts, false);
344        let embeddings = self.embed_batch_internal(&prepared).await?;
345        embeddings.into_iter().next().ok_or_else(|| {
346            InferenceError::InferenceError("No embedding returned for document".to_string())
347        })
348    }
349
350    /// Embed multiple documents/passages.
351    ///
352    /// For models like E5, this automatically applies the document prefix.
353    #[instrument(skip(self, texts), fields(count = texts.len()))]
354    pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
355        let prepared = self.processor.prepare_texts(texts, false);
356        self.embed_batch_internal(&prepared).await
357    }
358
359    /// Embed texts without any prefix (raw embedding).
360    #[instrument(skip(self, texts), fields(count = texts.len()))]
361    pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
362        self.embed_batch_internal(texts).await
363    }
364
365    /// Internal batch embedding implementation.
366    ///
367    /// Splits into batches then offloads each batch to a blocking thread via
368    /// `spawn_blocking` so ONNX inference does not block the Tokio runtime.
369    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
370        if texts.is_empty() {
371            return Ok(vec![]);
372        }
373
374        let batches = self.processor.split_into_batches(texts);
375        let mut all_embeddings = Vec::with_capacity(texts.len());
376
377        for batch in batches {
378            let batch_owned: Vec<String> = batch.to_vec();
379            let session = Arc::clone(&self.session);
380            let processor = Arc::clone(&self.processor);
381            let normalize = self.config.model.normalize_embeddings();
382
383            let batch_embeddings = tokio::task::spawn_blocking(move || {
384                let mut session_guard = session.lock();
385                Self::process_batch_blocking(
386                    &batch_owned,
387                    &mut session_guard,
388                    &processor,
389                    normalize,
390                )
391            })
392            .await
393            .map_err(|e| {
394                InferenceError::InferenceError(format!("Inference task panicked: {}", e))
395            })??;
396
397            all_embeddings.extend(batch_embeddings);
398        }
399
400        Ok(all_embeddings)
401    }
402
403    /// Process a single batch: tokenize → ORT session → mean pool → normalize.
404    ///
405    /// Designed to run inside `spawn_blocking` (takes no `&self`).
406    fn process_batch_blocking(
407        texts: &[String],
408        session: &mut Session,
409        processor: &BatchProcessor,
410        normalize: bool,
411    ) -> Result<Vec<Vec<f32>>> {
412        // Tokenize
413        let prepared = processor.tokenize_batch(texts)?;
414        let batch_size = prepared.batch_size;
415        let seq_len = prepared.seq_len;
416
417        // Keep a copy of attention_mask for mean_pooling (consumed by Tensor below)
418        let attention_mask_flat = prepared.attention_mask.clone();
419
420        // Build ORT tensors — from_array requires (shape, Vec<T>) in ort rc.12
421        let input_ids_tensor =
422            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
423                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
424        let attention_mask_tensor =
425            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
426                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
427        let token_type_ids_tensor =
428            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
429                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
430
431        // Run ONNX session
432        let outputs = session
433            .run(inputs![
434                "input_ids" => input_ids_tensor,
435                "attention_mask" => attention_mask_tensor,
436                "token_type_ids" => token_type_ids_tensor
437            ])
438            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
439
440        // Extract last_hidden_state: shape [batch, seq_len, hidden_size]
441        // ort rc.12: try_extract_tensor returns (&Shape, &[T])
442        // Shape derefs to [i64], so index directly.
443        let (ort_shape, lhs_slice) = outputs[0]
444            .try_extract_tensor::<f32>()
445            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
446
447        if ort_shape.len() != 3 {
448            return Err(InferenceError::InferenceError(format!(
449                "Expected 3D last_hidden_state, got {} dims",
450                ort_shape.len()
451            )));
452        }
453        let hidden_size = ort_shape[2] as usize;
454
455        // Apply mean pooling using the saved attention mask copy
456        let mut embeddings = mean_pooling(
457            lhs_slice,
458            batch_size,
459            seq_len,
460            hidden_size,
461            &attention_mask_flat,
462        );
463
464        // L2 normalize if configured
465        if normalize {
466            normalize_embeddings(&mut embeddings);
467        }
468
469        debug!(
470            "Generated {} embeddings of dimension {}",
471            embeddings.len(),
472            embeddings.first().map(|e| e.len()).unwrap_or(0)
473        );
474
475        Ok(embeddings)
476    }
477
478    /// Estimate the time to embed a batch of texts (in milliseconds).
479    pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
480        // Rough estimation based on model speed and text length (CPU path)
481        let tokens_per_text =
482            (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
483        let total_tokens = tokens_per_text * text_count as f64;
484        let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
485        (total_tokens / tokens_per_second) * 1000.0
486    }
487}
488
489impl std::fmt::Debug for EmbeddingEngine {
490    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491        f.debug_struct("EmbeddingEngine")
492            .field("model", &self.config.model)
493            .field("dimension", &self.dimension)
494            .field("max_batch_size", &self.config.max_batch_size)
495            .finish()
496    }
497}
498
499/// Builder for creating an EmbeddingEngine with fluent API.
500pub struct EmbeddingEngineBuilder {
501    config: ModelConfig,
502}
503
504impl EmbeddingEngineBuilder {
505    /// Create a new builder with default configuration.
506    pub fn new() -> Self {
507        Self {
508            config: ModelConfig::default(),
509        }
510    }
511
512    /// Set the embedding model to use.
513    pub fn model(mut self, model: EmbeddingModel) -> Self {
514        self.config.model = model;
515        self
516    }
517
518    /// Set the cache directory for model files.
519    pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
520        self.config.cache_dir = Some(dir.into());
521        self
522    }
523
524    /// Set the maximum batch size.
525    pub fn max_batch_size(mut self, size: usize) -> Self {
526        self.config.max_batch_size = size;
527        self
528    }
529
530    /// Enable GPU acceleration (reserved for future use; ORT selects the execution provider).
531    pub fn use_gpu(mut self, enable: bool) -> Self {
532        self.config.use_gpu = enable;
533        self
534    }
535
536    /// Set the number of intra-op CPU threads for ORT inference.
537    pub fn num_threads(mut self, threads: usize) -> Self {
538        self.config.num_threads = Some(threads);
539        self
540    }
541
542    /// Build the embedding engine.
543    pub async fn build(self) -> Result<EmbeddingEngine> {
544        EmbeddingEngine::new(self.config).await
545    }
546}
547
548impl Default for EmbeddingEngineBuilder {
549    fn default() -> Self {
550        Self::new()
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_estimate_time() {
560        let config = ModelConfig::new(EmbeddingModel::MiniLM);
561        let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
562        assert!(tokens_per_second > 0.0);
563    }
564
565    #[test]
566    fn test_builder() {
567        let builder = EmbeddingEngineBuilder::new()
568            .model(EmbeddingModel::BgeSmall)
569            .max_batch_size(64)
570            .use_gpu(false);
571
572        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
573        assert_eq!(builder.config.max_batch_size, 64);
574        assert!(!builder.config.use_gpu);
575    }
576
577    // ── model_cache_dir ──────────────────────────────────────────────────────
578
579    /// Ensure `model_cache_dir` respects `HF_HOME` when set.
580    ///
581    /// Uses a process-level mutex because `std::env::set_var` is not thread-safe.
582    #[test]
583    fn test_model_cache_dir_with_hf_home() {
584        use std::sync::Mutex;
585        static ENV_LOCK: Mutex<()> = Mutex::new(());
586        let _guard = ENV_LOCK.lock().unwrap();
587
588        let tmp = std::env::temp_dir().join("dakera_test_hf_home");
589        std::env::set_var("HF_HOME", &tmp);
590        let result = EmbeddingEngine::model_cache_dir("org/my-model");
591        std::env::remove_var("HF_HOME");
592
593        let path = result.unwrap();
594        assert!(
595            path.starts_with(&tmp),
596            "expected path under {tmp:?}, got {path:?}"
597        );
598        assert!(
599            path.to_str().unwrap().contains("org--my-model"),
600            "model_id separator not applied: {path:?}"
601        );
602    }
603
604    #[test]
605    fn test_model_cache_dir_contains_dakera_subdir() {
606        let path =
607            EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
608        let s = path.to_str().unwrap();
609        assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
610        assert!(
611            s.contains("sentence-transformers--all-MiniLM-L6-v2"),
612            "expected transformed model id in path: {s}"
613        );
614    }
615
616    #[test]
617    fn test_model_cache_dir_creates_directory() {
618        let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
619        assert!(dir.exists(), "model_cache_dir should create the directory");
620    }
621
622    // ── download_hf_file (cached early-return path) ──────────────────────────
623
624    #[test]
625    fn test_download_hf_file_returns_path_when_already_cached() {
626        let tmp = std::env::temp_dir().join("dakera_test_cached_file");
627        std::fs::create_dir_all(&tmp).unwrap();
628        let file_path = tmp.join("config.json");
629        std::fs::write(&file_path, b"{}").unwrap();
630
631        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
632        assert!(result.is_ok());
633        assert_eq!(result.unwrap(), file_path);
634    }
635
636    #[test]
637    fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
638        let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
639        let onnx_dir = tmp.join("onnx");
640        std::fs::create_dir_all(&onnx_dir).unwrap();
641        let onnx_path = onnx_dir.join("model_quantized.onnx");
642        std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
643
644        // Filename includes the subdirectory path
645        let result = EmbeddingEngine::download_hf_file(
646            "Xenova/all-MiniLM-L6-v2",
647            "onnx/model_quantized.onnx",
648            &tmp,
649        );
650        assert!(result.is_ok());
651        assert_eq!(result.unwrap(), onnx_path);
652    }
653
654    // ── EmbeddingEngineBuilder ───────────────────────────────────────────────
655
656    #[test]
657    fn test_builder_default_impl() {
658        let b1 = EmbeddingEngineBuilder::new();
659        let b2 = EmbeddingEngineBuilder::default();
660        assert_eq!(b1.config.model, b2.config.model);
661        assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
662    }
663
664    #[test]
665    fn test_builder_model_field() {
666        let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
667        assert_eq!(builder.config.model, EmbeddingModel::E5Small);
668    }
669
670    #[test]
671    fn test_builder_cache_dir() {
672        let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
673        assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
674    }
675
676    #[test]
677    fn test_builder_max_batch_size() {
678        let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
679        assert_eq!(builder.config.max_batch_size, 128);
680    }
681
682    #[test]
683    fn test_builder_use_gpu_true() {
684        let builder = EmbeddingEngineBuilder::new().use_gpu(true);
685        assert!(builder.config.use_gpu);
686    }
687
688    #[test]
689    fn test_builder_use_gpu_false() {
690        let builder = EmbeddingEngineBuilder::new().use_gpu(false);
691        assert!(!builder.config.use_gpu);
692    }
693
694    #[test]
695    fn test_builder_num_threads() {
696        let builder = EmbeddingEngineBuilder::new().num_threads(4);
697        assert_eq!(builder.config.num_threads, Some(4));
698    }
699
700    #[test]
701    fn test_builder_chain_all_fields() {
702        let builder = EmbeddingEngineBuilder::new()
703            .model(EmbeddingModel::BgeSmall)
704            .cache_dir("/cache")
705            .max_batch_size(16)
706            .use_gpu(false)
707            .num_threads(2);
708
709        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
710        assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
711        assert_eq!(builder.config.max_batch_size, 16);
712        assert!(!builder.config.use_gpu);
713        assert_eq!(builder.config.num_threads, Some(2));
714    }
715
716    // ── estimate_time_ms ─────────────────────────────────────────────────────
717
718    #[test]
719    fn test_estimate_time_zero_count() {
720        let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
721        let estimate = (0.0 / tps) * 1000.0;
722        assert_eq!(estimate, 0.0);
723    }
724
725    #[test]
726    fn test_estimate_time_formula_cpu() {
727        // texts=10, avg_len=100 → tokens_per_text = min(25, 256) = 25
728        // total_tokens = 250; tps = 5000; time = (250/5000)*1000 = 50ms
729        let model = EmbeddingModel::MiniLM;
730        let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
731        let total_tokens = tokens_per_text * 10.0;
732        let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
733        assert!(
734            (estimate - 50.0).abs() < 1e-6,
735            "expected 50.0ms, got {estimate}"
736        );
737    }
738
739    #[test]
740    fn test_estimate_time_capped_at_max_seq_length() {
741        let model = EmbeddingModel::MiniLM;
742        let avg_len = 100_000;
743        let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
744        assert_eq!(tokens_per_text, 256.0);
745    }
746
747    // ── ModelConfig API ───────────────────────────────────────────────────────
748
749    #[test]
750    fn test_model_config_new() {
751        let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
752        assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
753        assert_eq!(cfg.max_batch_size, 32);
754        assert!(!cfg.use_gpu);
755        assert!(cfg.cache_dir.is_none());
756        assert!(cfg.num_threads.is_none());
757    }
758
759    #[test]
760    fn test_model_config_default() {
761        let cfg = ModelConfig::default();
762        assert_eq!(cfg.model, EmbeddingModel::MiniLM);
763        assert_eq!(cfg.max_batch_size, 32);
764        assert!(!cfg.use_gpu);
765    }
766
767    #[test]
768    fn test_model_config_with_cache_dir() {
769        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
770        assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
771    }
772
773    #[test]
774    fn test_model_config_with_max_batch_size() {
775        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
776        assert_eq!(cfg.max_batch_size, 64);
777    }
778
779    #[test]
780    fn test_model_config_with_gpu() {
781        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
782        assert!(cfg.use_gpu);
783    }
784
785    #[test]
786    fn test_model_config_with_num_threads() {
787        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
788        assert_eq!(cfg.num_threads, Some(8));
789    }
790
791    #[test]
792    fn test_model_config_chained_builder() {
793        let cfg = ModelConfig::new(EmbeddingModel::E5Small)
794            .with_cache_dir("/cache")
795            .with_max_batch_size(16)
796            .with_gpu(false)
797            .with_num_threads(4);
798        assert_eq!(cfg.model, EmbeddingModel::E5Small);
799        assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
800        assert_eq!(cfg.max_batch_size, 16);
801        assert!(!cfg.use_gpu);
802        assert_eq!(cfg.num_threads, Some(4));
803    }
804
805    // ── model_cache_dir edge cases ────────────────────────────────────────────
806
807    /// Test `model_cache_dir` when HOME is not set — should fall back to /tmp.
808    #[test]
809    fn test_model_cache_dir_no_home_fallback() {
810        use std::sync::Mutex;
811        static ENV_LOCK: Mutex<()> = Mutex::new(());
812        let _guard = ENV_LOCK.lock().unwrap();
813
814        // Remove HOME and HF_HOME so we hit the /tmp fallback
815        let saved_home = std::env::var("HOME").ok();
816        let saved_hf = std::env::var("HF_HOME").ok();
817        unsafe {
818            std::env::remove_var("HOME");
819            std::env::remove_var("HF_HOME");
820        }
821
822        let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
823
824        // Restore env
825        if let Some(h) = saved_home {
826            unsafe { std::env::set_var("HOME", h) };
827        }
828        if let Some(h) = saved_hf {
829            unsafe { std::env::set_var("HF_HOME", h) };
830        }
831
832        let path = result.unwrap();
833        // Should be under /tmp since HOME was unset
834        assert!(
835            path.starts_with("/tmp"),
836            "expected path under /tmp, got {path:?}"
837        );
838    }
839
840    #[test]
841    fn test_model_cache_dir_deep_model_id() {
842        let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
843        let s = path.to_str().unwrap();
844        // All slashes replaced with double-dash
845        assert!(
846            s.contains("org--sub--model-name-with-dashes"),
847            "expected transformed path, got: {s}"
848        );
849    }
850
851    #[test]
852    fn test_model_cache_dir_minilm_model_id() {
853        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
854        let s = path.to_str().unwrap();
855        assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
856    }
857
858    #[test]
859    fn test_model_cache_dir_bge_model_id() {
860        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
861        let s = path.to_str().unwrap();
862        assert!(s.contains("BAAI--bge-small-en-v1.5"));
863    }
864
865    #[test]
866    fn test_model_cache_dir_e5_model_id() {
867        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
868        let s = path.to_str().unwrap();
869        assert!(s.contains("intfloat--e5-small-v2"));
870    }
871
872    // ── download_hf_file additional cache-hit variations ─────────────────────
873
874    #[test]
875    fn test_download_hf_file_pytorch_bin_cached() {
876        let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
877        std::fs::create_dir_all(&tmp).unwrap();
878        let model_path = tmp.join("pytorch_model.bin");
879        std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
880
881        let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
882        assert!(result.is_ok());
883        assert_eq!(result.unwrap(), model_path);
884    }
885
886    #[test]
887    fn test_download_hf_file_tokenizer_cached() {
888        let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
889        std::fs::create_dir_all(&tmp).unwrap();
890        let tok_path = tmp.join("tokenizer.json");
891        std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
892
893        let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
894        assert!(result.is_ok());
895        assert_eq!(result.unwrap(), tok_path);
896    }
897
898    #[test]
899    fn test_download_hf_file_config_json_cached() {
900        let tmp = std::env::temp_dir().join("dakera_test_config_cached");
901        std::fs::create_dir_all(&tmp).unwrap();
902        let cfg_path = tmp.join("config.json");
903        std::fs::write(&cfg_path, b"{}").unwrap();
904
905        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
906        assert!(result.is_ok());
907        assert_eq!(result.unwrap(), cfg_path);
908    }
909
910    // ── EmbeddingEngine::new() failure path via fake local cache ─────────────
911
912    /// Tests the code path through `download_model_files` (local Dakera cache hit)
913    /// and into `new()` — which then fails trying to load the tokenizer from a
914    /// fake file. No network access required; fake files are pre-seeded.
915    #[tokio::test]
916    async fn test_new_fails_with_invalid_tokenizer_json() {
917        use std::sync::Mutex;
918        static ENV_LOCK: Mutex<()> = Mutex::new(());
919        let _guard = ENV_LOCK.lock().unwrap();
920
921        // Set up a fake Dakera model cache so download_model_files finds our files
922        let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
923        let model_dir = tmp
924            .join("dakera")
925            .join("sentence-transformers--all-MiniLM-L6-v2");
926        std::fs::create_dir_all(&model_dir).unwrap();
927        // Valid-looking model weights placeholder (candle will fail on this, which is fine)
928        std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
929        // Invalid tokenizer.json — will cause TokenizationError in new()
930        std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
931        std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
932
933        unsafe { std::env::set_var("HF_HOME", &tmp) };
934
935        let config = ModelConfig::new(EmbeddingModel::MiniLM);
936        let result = EmbeddingEngine::new(config).await;
937
938        unsafe { std::env::remove_var("HF_HOME") };
939
940        // Must fail — tokenizer.json is invalid JSON
941        assert!(
942            result.is_err(),
943            "expected Err from new() with invalid tokenizer, got Ok"
944        );
945    }
946
947    // ── EmbeddingEngineBuilder additional coverage ────────────────────────────
948
949    #[test]
950    fn test_builder_with_all_models() {
951        for model in [
952            EmbeddingModel::MiniLM,
953            EmbeddingModel::BgeSmall,
954            EmbeddingModel::E5Small,
955        ] {
956            let builder = EmbeddingEngineBuilder::new().model(model);
957            assert_eq!(builder.config.model, model);
958        }
959    }
960
961    #[test]
962    fn test_builder_max_batch_size_one() {
963        let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
964        assert_eq!(builder.config.max_batch_size, 1);
965    }
966
967    #[test]
968    fn test_builder_num_threads_zero() {
969        let builder = EmbeddingEngineBuilder::new().num_threads(0);
970        assert_eq!(builder.config.num_threads, Some(0));
971    }
972
973    // ── EmbeddingEngine::new() / getters when model is cached (best-effort) ──
974
975    /// If the embedding model is already cached on this machine, exercise the
976    /// full `new()` path and test getters. On machines without a cached model
977    /// the test passes silently — it is intentionally non-gating.
978    #[tokio::test]
979    async fn test_engine_getters_when_model_cached() {
980        let config = ModelConfig::new(EmbeddingModel::MiniLM);
981        match EmbeddingEngine::new(config).await {
982            Ok(engine) => {
983                assert_eq!(engine.dimension(), 384);
984                assert_eq!(engine.model(), EmbeddingModel::MiniLM);
985                // Device should be CPU in test environments (device() removed in CE-3 ONNX migration)
986                // Debug impl should not panic
987                let _ = format!("{:?}", engine);
988                // estimate_time_ms should return a non-negative value
989                let ms = engine.estimate_time_ms(10, 50);
990                assert!(ms >= 0.0);
991            }
992            Err(_) => {
993                // Model not in cache — skip; CI runner may or may not have it
994            }
995        }
996    }
997
998    /// When model is cached: embed an empty batch must return immediately with
999    /// no embeddings (the `texts.is_empty()` fast-path in embed_batch_internal).
1000    #[tokio::test]
1001    async fn test_engine_embed_empty_batch_when_cached() {
1002        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1003        match EmbeddingEngine::new(config).await {
1004            Ok(engine) => {
1005                let result = engine.embed_raw(&[]).await;
1006                assert!(result.is_ok());
1007                assert!(result.unwrap().is_empty());
1008            }
1009            Err(_) => {}
1010        }
1011    }
1012}