Skip to main content

inference/
engine.rs

1//! Core embedding engine for generating vector embeddings from text.
2//!
3//! The `EmbeddingEngine` provides a high-level interface for:
4//! - Loading ONNX INT8 embedding models from HuggingFace Hub
5//! - Generating embeddings for single texts or batches
6//! - Automatic batching and parallel processing via ONNX Runtime
7//!
8//! # Example
9//!
10//! ```no_run
11//! use inference::{EmbeddingEngine, ModelConfig, EmbeddingModel};
12//!
13//! #[tokio::main]
14//! async fn main() {
15//!     let config = ModelConfig::new(EmbeddingModel::MiniLM);
16//!     let engine = EmbeddingEngine::new(config).await.unwrap();
17//!
18//!     // Embed a single query
19//!     let embedding = engine.embed_query("What is machine learning?").await.unwrap();
20//!     println!("Embedding dimension: {}", embedding.len());
21//!
22//!     // Embed multiple documents
23//!     let docs = vec![
24//!         "Machine learning is a subset of AI.".to_string(),
25//!         "Deep learning uses neural networks.".to_string(),
26//!     ];
27//!     let embeddings = engine.embed_documents(&docs).await.unwrap();
28//!     println!("Generated {} embeddings", embeddings.len());
29//! }
30//! ```
31
32use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47/// The main embedding engine for generating vector embeddings.
48///
49/// This struct is thread-safe and can be shared across async tasks.
50/// ORT sessions are mutex-guarded (run() takes &mut self) and held in a
51/// pool so concurrent callers can embed without head-of-line blocking.
52/// CPU-heavy inference is offloaded via `tokio::task::spawn_blocking`.
53pub struct EmbeddingEngine {
54    /// Pool of ONNX Runtime sessions — each guarded independently.
55    /// Concurrent callers round-robin across sessions via `next_session`,
56    /// eliminating Mutex head-of-line blocking for batch embedding workloads.
57    sessions: Vec<Arc<Mutex<Session>>>,
58    /// Round-robin counter: atomically incremented per batch, wraps via modulo.
59    next_session: AtomicUsize,
60    /// Batch processor for tokenization (Arc-wrapped for spawn_blocking)
61    processor: Arc<BatchProcessor>,
62    /// Model configuration
63    config: ModelConfig,
64    /// Embedding dimension
65    dimension: usize,
66}
67
68impl EmbeddingEngine {
69    /// Create a new embedding engine with the given configuration.
70    ///
71    /// Downloads the ONNX INT8 model from HuggingFace Hub if not cached.
72    #[instrument(skip_all, fields(model = %config.model))]
73    pub async fn new(config: ModelConfig) -> Result<Self> {
74        info!(
75            "Initializing ONNX embedding engine with model: {}",
76            config.model
77        );
78
79        // Download tokenizer and ONNX model files
80        let (tokenizer_path, onnx_path) = Self::download_model_files(&config).await?;
81
82        // Load tokenizer
83        info!("Loading tokenizer from {:?}", tokenizer_path);
84        let tokenizer = Tokenizer::from_file(&tokenizer_path)
85            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
86
87        // Build ONNX session pool — N independent sessions to serve concurrent callers.
88        // Each session has its own ORT context so pool members never block each other.
89        info!("Loading ONNX model from {:?}", onnx_path);
90        let num_threads = config.num_threads.unwrap_or(4);
91        let pool_size = config.session_pool_size.max(1);
92        let onnx_path_clone = onnx_path.clone();
93        let sessions: Vec<Arc<Mutex<Session>>> =
94            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
95                (0..pool_size)
96                    .map(|_| {
97                        let s = Session::builder()
98                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
99                            .with_optimization_level(GraphOptimizationLevel::Level3)
100                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
101                            .with_intra_threads(num_threads)
102                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
103                            .commit_from_file(&onnx_path_clone)
104                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
105                        Ok(Arc::new(Mutex::new(s)))
106                    })
107                    .collect()
108            })
109            .await
110            .map_err(|e| {
111                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
112            })??;
113
114        let dimension = config.model.dimension();
115        let processor = Arc::new(BatchProcessor::new(
116            tokenizer,
117            config.model,
118            config.max_batch_size,
119        ));
120
121        info!(
122            "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
123            config.model, dimension, num_threads, pool_size
124        );
125
126        Ok(Self {
127            sessions,
128            next_session: AtomicUsize::new(0),
129            processor,
130            config,
131            dimension,
132        })
133    }
134
135    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
136    ///
137    /// - `tokenizer.json` — from the original model repo (sentence-transformers, BAAI, intfloat)
138    /// - `onnx/model_quantized.onnx` — from the Xenova ONNX repo (INT8, pre-built)
139    #[instrument(skip_all, fields(model = %config.model))]
140    async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf)> {
141        let model_id = config.model.model_id();
142        let onnx_repo_id = config.model.onnx_repo_id();
143        let onnx_filename = config.model.onnx_filename();
144
145        info!(
146            "Resolving model files: tokenizer={}, onnx={}@{}",
147            model_id, onnx_filename, onnx_repo_id
148        );
149
150        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
151        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
152
153        // ONNX sub-directory mirrors the path within the repo (e.g. "onnx/")
154        let onnx_subdir = onnx_cache_dir.join("onnx");
155        std::fs::create_dir_all(&onnx_subdir)?;
156
157        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
158        // onnx_filename is "onnx/model_quantized.onnx" — basename is the last component
159        let onnx_basename = Path::new(onnx_filename)
160            .file_name()
161            .and_then(|s| s.to_str())
162            .unwrap_or("model_quantized.onnx");
163        let local_onnx = onnx_subdir.join(onnx_basename);
164
165        // Download missing files in a blocking thread
166        let tokenizer_needs_download = !local_tokenizer.exists();
167        let onnx_needs_download = !local_onnx.exists();
168
169        if tokenizer_needs_download || onnx_needs_download {
170            let model_id_owned = model_id.to_string();
171            let onnx_repo_id_owned = onnx_repo_id.to_string();
172            let onnx_filename_owned = onnx_filename.to_string();
173            let tokenizer_cache = tokenizer_cache_dir.clone();
174            let onnx_cache = onnx_cache_dir.clone();
175
176            tokio::task::spawn_blocking(move || {
177                if !tokenizer_cache.join("tokenizer.json").exists() {
178                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
179                        .map_err(|e| {
180                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
181                        })?;
182                }
183                if !onnx_cache.join(&onnx_filename_owned).exists() {
184                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
185                        .map_err(|e| {
186                            InferenceError::HubError(format!(
187                                "Failed to download ONNX model: {}",
188                                e
189                            ))
190                        })?;
191                }
192                Ok::<_, InferenceError>(())
193            })
194            .await
195            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
196        } else {
197            info!("All model files found in local cache");
198        }
199
200        // Re-derive paths (cache dir / onnx / basename)
201        let final_onnx = onnx_cache_dir.join(onnx_filename);
202
203        info!(
204            "Model files ready: tokenizer={:?}, onnx={:?}",
205            local_tokenizer, final_onnx
206        );
207        Ok((local_tokenizer, final_onnx))
208    }
209
210    /// Get or create the local model cache directory.
211    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
212        let base = std::env::var("HF_HOME")
213            .map(PathBuf::from)
214            .unwrap_or_else(|_| {
215                let home = std::env::var("HOME").unwrap_or_else(|_| {
216                    warn!("HOME environment variable not set, using /tmp for model cache");
217                    "/tmp".to_string()
218                });
219                PathBuf::from(home).join(".cache").join("huggingface")
220            });
221        let dir = base.join("dakera").join(model_id.replace('/', "--"));
222        std::fs::create_dir_all(&dir)?;
223        Ok(dir)
224    }
225
226    /// Download a single file from HuggingFace using ureq (sync, for spawn_blocking).
227    ///
228    /// Handles relative Location headers that ureq 2.x cannot resolve automatically.
229    ///
230    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
231    pub fn download_hf_file_pub(
232        model_id: &str,
233        filename: &str,
234        cache_dir: &Path,
235    ) -> std::result::Result<PathBuf, String> {
236        Self::download_hf_file(model_id, filename, cache_dir)
237    }
238
239    fn download_hf_file(
240        model_id: &str,
241        filename: &str,
242        cache_dir: &Path,
243    ) -> std::result::Result<PathBuf, String> {
244        // The file may be nested (e.g. "onnx/model_quantized.onnx")
245        let file_path = cache_dir.join(filename);
246        if file_path.exists() {
247            info!("Cached: {}/{}", model_id, filename);
248            return Ok(file_path);
249        }
250
251        // Ensure parent directory exists (for "onnx/model_quantized.onnx")
252        if let Some(parent) = file_path.parent() {
253            std::fs::create_dir_all(parent)
254                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
255        }
256
257        let url = format!(
258            "https://huggingface.co/{}/resolve/main/{}",
259            model_id, filename
260        );
261        info!("Downloading: {}", url);
262
263        // Read HuggingFace token from env (required for gated models).
264        let hf_token = std::env::var("HF_TOKEN")
265            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
266            .ok();
267        if hf_token.is_some() {
268            info!("Using HuggingFace auth token for download");
269        }
270
271        // Disable automatic redirects so we can resolve relative Location headers ourselves.
272        let agent = ureq::AgentBuilder::new()
273            .redirects(0)
274            .timeout(std::time::Duration::from_secs(300))
275            .build();
276
277        let mut current_url = url.clone();
278        let mut redirects = 0;
279        let max_redirects = 10;
280
281        let response = loop {
282            let mut req = agent.get(&current_url);
283            if let Some(ref token) = hf_token {
284                req = req.set("Authorization", &format!("Bearer {}", token));
285            }
286            let resp = req.call();
287
288            let r = match resp {
289                Ok(r) => r,
290                Err(ureq::Error::Status(_status, r)) => r,
291                Err(e) => return Err(format!("{}: {}", filename, e)),
292            };
293
294            let status = r.status();
295            if (200..300).contains(&status) {
296                break r;
297            } else if (300..400).contains(&status) {
298                redirects += 1;
299                if redirects > max_redirects {
300                    return Err(format!("{}: too many redirects", filename));
301                }
302                let location = r
303                    .header("location")
304                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
305                    .to_string();
306
307                // Resolve relative redirects against the current URL's origin
308                current_url = if location.starts_with('/') {
309                    let parsed = url::Url::parse(&current_url)
310                        .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
311                    let host = parsed.host_str().ok_or_else(|| {
312                        format!("{}: redirect URL missing host: {}", filename, current_url)
313                    })?;
314                    format!("{}://{}{}", parsed.scheme(), host, location)
315                } else {
316                    location
317                };
318                info!("Redirect {} → {}", redirects, current_url);
319            } else {
320                return Err(format!("{}: HTTP {}", filename, status));
321            }
322        };
323
324        let mut bytes = Vec::new();
325        response
326            .into_reader()
327            .take(500_000_000) // 500 MB safety limit
328            .read_to_end(&mut bytes)
329            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
330
331        std::fs::write(&file_path, &bytes)
332            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
333
334        info!("Downloaded {} ({} bytes)", filename, bytes.len());
335        Ok(file_path)
336    }
337
338    /// Get the embedding dimension for the loaded model.
339    pub fn dimension(&self) -> usize {
340        self.dimension
341    }
342
343    /// Get the model being used.
344    pub fn model(&self) -> EmbeddingModel {
345        self.config.model
346    }
347
348    /// Get the number of parallel ONNX sessions in the pool.
349    pub fn pool_size(&self) -> usize {
350        self.sessions.len()
351    }
352
353    /// Embed a single query text.
354    ///
355    /// For models like E5, this automatically applies the query prefix.
356    #[instrument(skip(self, text), fields(text_len = text.len()))]
357    pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
358        let texts = vec![text.to_string()];
359        let prepared = self.processor.prepare_texts(&texts, true);
360        let embeddings = self.embed_batch_internal(&prepared).await?;
361        embeddings.into_iter().next().ok_or_else(|| {
362            InferenceError::InferenceError("No embedding returned for query".to_string())
363        })
364    }
365
366    /// Embed multiple query texts.
367    ///
368    /// For models like E5, this automatically applies the query prefix.
369    #[instrument(skip(self, texts), fields(count = texts.len()))]
370    pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
371        let prepared = self.processor.prepare_texts(texts, true);
372        self.embed_batch_internal(&prepared).await
373    }
374
375    /// Embed a single document/passage.
376    ///
377    /// For models like E5, this automatically applies the document prefix.
378    #[instrument(skip(self, text), fields(text_len = text.len()))]
379    pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
380        let texts = vec![text.to_string()];
381        let prepared = self.processor.prepare_texts(&texts, false);
382        let embeddings = self.embed_batch_internal(&prepared).await?;
383        embeddings.into_iter().next().ok_or_else(|| {
384            InferenceError::InferenceError("No embedding returned for document".to_string())
385        })
386    }
387
388    /// Embed multiple documents/passages.
389    ///
390    /// For models like E5, this automatically applies the document prefix.
391    #[instrument(skip(self, texts), fields(count = texts.len()))]
392    pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
393        let prepared = self.processor.prepare_texts(texts, false);
394        self.embed_batch_internal(&prepared).await
395    }
396
397    /// Embed texts without any prefix (raw embedding).
398    #[instrument(skip(self, texts), fields(count = texts.len()))]
399    pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
400        self.embed_batch_internal(texts).await
401    }
402
403    /// Internal batch embedding implementation.
404    ///
405    /// Splits `texts` into sub-batches (≤ max_batch_size each) and distributes
406    /// them across the session pool via round-robin. All sub-batches are spawned
407    /// concurrently — sessions[i % pool_len] serializes only its own sub-batches,
408    /// so pool_len sub-batches run in true parallel, eliminating head-of-line
409    /// blocking when multiple HTTP handlers embed concurrently (DAK-5547).
410    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
411        if texts.is_empty() {
412            return Ok(vec![]);
413        }
414
415        let batches: Vec<Vec<String>> = self
416            .processor
417            .split_into_batches(texts)
418            .into_iter()
419            .map(|b| b.to_vec())
420            .collect();
421
422        let pool_len = self.sessions.len();
423        let normalize = self.config.model.normalize_embeddings();
424        // Round-robin starting index: each concurrent caller gets a different slot so
425        // concurrent requests don't all contend on sessions[0].
426        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
427
428        // Spawn all sub-batches concurrently; preserve insertion order for reassembly.
429        let mut handles = Vec::with_capacity(batches.len());
430        for (i, batch_owned) in batches.into_iter().enumerate() {
431            let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
432            let processor = Arc::clone(&self.processor);
433            handles.push(tokio::task::spawn_blocking(move || {
434                let mut session_guard = session.lock();
435                Self::process_batch_blocking(
436                    &batch_owned,
437                    &mut session_guard,
438                    &processor,
439                    normalize,
440                )
441            }));
442        }
443
444        let mut all_embeddings = Vec::with_capacity(texts.len());
445        for handle in handles {
446            let batch_embeddings = handle.await.map_err(|e| {
447                InferenceError::InferenceError(format!("Inference task panicked: {}", e))
448            })??;
449            all_embeddings.extend(batch_embeddings);
450        }
451
452        Ok(all_embeddings)
453    }
454
455    /// Process a single batch: tokenize → ORT session → mean pool → normalize.
456    ///
457    /// Designed to run inside `spawn_blocking` (takes no `&self`).
458    fn process_batch_blocking(
459        texts: &[String],
460        session: &mut Session,
461        processor: &BatchProcessor,
462        normalize: bool,
463    ) -> Result<Vec<Vec<f32>>> {
464        // Tokenize
465        let prepared = processor.tokenize_batch(texts)?;
466        let batch_size = prepared.batch_size;
467        let seq_len = prepared.seq_len;
468
469        // Keep a copy of attention_mask for mean_pooling (consumed by Tensor below)
470        let attention_mask_flat = prepared.attention_mask.clone();
471
472        // Build ORT tensors — from_array requires (shape, Vec<T>) in ort rc.12
473        let input_ids_tensor =
474            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
475                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
476        let attention_mask_tensor =
477            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
478                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
479        let token_type_ids_tensor =
480            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
481                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
482
483        // Run ONNX session
484        let outputs = session
485            .run(inputs![
486                "input_ids" => input_ids_tensor,
487                "attention_mask" => attention_mask_tensor,
488                "token_type_ids" => token_type_ids_tensor
489            ])
490            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
491
492        // Extract last_hidden_state: shape [batch, seq_len, hidden_size]
493        // ort rc.12: try_extract_tensor returns (&Shape, &[T])
494        // Shape derefs to [i64], so index directly.
495        let (ort_shape, lhs_slice) = outputs[0]
496            .try_extract_tensor::<f32>()
497            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
498
499        if ort_shape.len() != 3 {
500            return Err(InferenceError::InferenceError(format!(
501                "Expected 3D last_hidden_state, got {} dims",
502                ort_shape.len()
503            )));
504        }
505        let hidden_size = ort_shape[2] as usize;
506
507        // Apply mean pooling using the saved attention mask copy
508        let mut embeddings = mean_pooling(
509            lhs_slice,
510            batch_size,
511            seq_len,
512            hidden_size,
513            &attention_mask_flat,
514        );
515
516        // L2 normalize if configured
517        if normalize {
518            normalize_embeddings(&mut embeddings);
519        }
520
521        debug!(
522            "Generated {} embeddings of dimension {}",
523            embeddings.len(),
524            embeddings.first().map(|e| e.len()).unwrap_or(0)
525        );
526
527        Ok(embeddings)
528    }
529
530    /// Estimate the time to embed a batch of texts (in milliseconds).
531    pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
532        // Rough estimation based on model speed and text length (CPU path)
533        let tokens_per_text =
534            (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
535        let total_tokens = tokens_per_text * text_count as f64;
536        let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
537        (total_tokens / tokens_per_second) * 1000.0
538    }
539}
540
541impl std::fmt::Debug for EmbeddingEngine {
542    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543        f.debug_struct("EmbeddingEngine")
544            .field("model", &self.config.model)
545            .field("dimension", &self.dimension)
546            .field("max_batch_size", &self.config.max_batch_size)
547            .field("session_pool_size", &self.sessions.len())
548            .finish()
549    }
550}
551
552/// Builder for creating an EmbeddingEngine with fluent API.
553pub struct EmbeddingEngineBuilder {
554    config: ModelConfig,
555}
556
557impl EmbeddingEngineBuilder {
558    /// Create a new builder with default configuration.
559    pub fn new() -> Self {
560        Self {
561            config: ModelConfig::default(),
562        }
563    }
564
565    /// Set the embedding model to use.
566    pub fn model(mut self, model: EmbeddingModel) -> Self {
567        self.config.model = model;
568        self
569    }
570
571    /// Set the cache directory for model files.
572    pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
573        self.config.cache_dir = Some(dir.into());
574        self
575    }
576
577    /// Set the maximum batch size.
578    pub fn max_batch_size(mut self, size: usize) -> Self {
579        self.config.max_batch_size = size;
580        self
581    }
582
583    /// Enable GPU acceleration (reserved for future use; ORT selects the execution provider).
584    pub fn use_gpu(mut self, enable: bool) -> Self {
585        self.config.use_gpu = enable;
586        self
587    }
588
589    /// Set the number of intra-op CPU threads for ORT inference.
590    pub fn num_threads(mut self, threads: usize) -> Self {
591        self.config.num_threads = Some(threads);
592        self
593    }
594
595    /// Set the number of parallel ONNX sessions in the pool.
596    pub fn session_pool_size(mut self, size: usize) -> Self {
597        self.config.session_pool_size = size.max(1);
598        self
599    }
600
601    /// Build the embedding engine.
602    pub async fn build(self) -> Result<EmbeddingEngine> {
603        EmbeddingEngine::new(self.config).await
604    }
605}
606
607impl Default for EmbeddingEngineBuilder {
608    fn default() -> Self {
609        Self::new()
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_estimate_time() {
619        let config = ModelConfig::new(EmbeddingModel::MiniLM);
620        let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
621        assert!(tokens_per_second > 0.0);
622    }
623
624    #[test]
625    fn test_builder() {
626        let builder = EmbeddingEngineBuilder::new()
627            .model(EmbeddingModel::BgeSmall)
628            .max_batch_size(64)
629            .use_gpu(false);
630
631        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
632        assert_eq!(builder.config.max_batch_size, 64);
633        assert!(!builder.config.use_gpu);
634    }
635
636    // ── model_cache_dir ──────────────────────────────────────────────────────
637
638    /// Ensure `model_cache_dir` respects `HF_HOME` when set.
639    ///
640    /// Uses a process-level mutex because `std::env::set_var` is not thread-safe.
641    #[test]
642    fn test_model_cache_dir_with_hf_home() {
643        use std::sync::Mutex;
644        static ENV_LOCK: Mutex<()> = Mutex::new(());
645        let _guard = ENV_LOCK.lock().unwrap();
646
647        let tmp = std::env::temp_dir().join("dakera_test_hf_home");
648        std::env::set_var("HF_HOME", &tmp);
649        let result = EmbeddingEngine::model_cache_dir("org/my-model");
650        std::env::remove_var("HF_HOME");
651
652        let path = result.unwrap();
653        assert!(
654            path.starts_with(&tmp),
655            "expected path under {tmp:?}, got {path:?}"
656        );
657        assert!(
658            path.to_str().unwrap().contains("org--my-model"),
659            "model_id separator not applied: {path:?}"
660        );
661    }
662
663    #[test]
664    fn test_model_cache_dir_contains_dakera_subdir() {
665        let path =
666            EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
667        let s = path.to_str().unwrap();
668        assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
669        assert!(
670            s.contains("sentence-transformers--all-MiniLM-L6-v2"),
671            "expected transformed model id in path: {s}"
672        );
673    }
674
675    #[test]
676    fn test_model_cache_dir_creates_directory() {
677        let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
678        assert!(dir.exists(), "model_cache_dir should create the directory");
679    }
680
681    // ── download_hf_file (cached early-return path) ──────────────────────────
682
683    #[test]
684    fn test_download_hf_file_returns_path_when_already_cached() {
685        let tmp = std::env::temp_dir().join("dakera_test_cached_file");
686        std::fs::create_dir_all(&tmp).unwrap();
687        let file_path = tmp.join("config.json");
688        std::fs::write(&file_path, b"{}").unwrap();
689
690        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
691        assert!(result.is_ok());
692        assert_eq!(result.unwrap(), file_path);
693    }
694
695    #[test]
696    fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
697        let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
698        let onnx_dir = tmp.join("onnx");
699        std::fs::create_dir_all(&onnx_dir).unwrap();
700        let onnx_path = onnx_dir.join("model_quantized.onnx");
701        std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
702
703        // Filename includes the subdirectory path
704        let result = EmbeddingEngine::download_hf_file(
705            "Xenova/all-MiniLM-L6-v2",
706            "onnx/model_quantized.onnx",
707            &tmp,
708        );
709        assert!(result.is_ok());
710        assert_eq!(result.unwrap(), onnx_path);
711    }
712
713    // ── EmbeddingEngineBuilder ───────────────────────────────────────────────
714
715    #[test]
716    fn test_builder_default_impl() {
717        let b1 = EmbeddingEngineBuilder::new();
718        let b2 = EmbeddingEngineBuilder::default();
719        assert_eq!(b1.config.model, b2.config.model);
720        assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
721    }
722
723    #[test]
724    fn test_builder_model_field() {
725        let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
726        assert_eq!(builder.config.model, EmbeddingModel::E5Small);
727    }
728
729    #[test]
730    fn test_builder_cache_dir() {
731        let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
732        assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
733    }
734
735    #[test]
736    fn test_builder_max_batch_size() {
737        let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
738        assert_eq!(builder.config.max_batch_size, 128);
739    }
740
741    #[test]
742    fn test_builder_use_gpu_true() {
743        let builder = EmbeddingEngineBuilder::new().use_gpu(true);
744        assert!(builder.config.use_gpu);
745    }
746
747    #[test]
748    fn test_builder_use_gpu_false() {
749        let builder = EmbeddingEngineBuilder::new().use_gpu(false);
750        assert!(!builder.config.use_gpu);
751    }
752
753    #[test]
754    fn test_builder_num_threads() {
755        let builder = EmbeddingEngineBuilder::new().num_threads(4);
756        assert_eq!(builder.config.num_threads, Some(4));
757    }
758
759    #[test]
760    fn test_builder_chain_all_fields() {
761        let builder = EmbeddingEngineBuilder::new()
762            .model(EmbeddingModel::BgeSmall)
763            .cache_dir("/cache")
764            .max_batch_size(16)
765            .use_gpu(false)
766            .num_threads(2);
767
768        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
769        assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
770        assert_eq!(builder.config.max_batch_size, 16);
771        assert!(!builder.config.use_gpu);
772        assert_eq!(builder.config.num_threads, Some(2));
773    }
774
775    // ── estimate_time_ms ─────────────────────────────────────────────────────
776
777    #[test]
778    fn test_estimate_time_zero_count() {
779        let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
780        let estimate = (0.0 / tps) * 1000.0;
781        assert_eq!(estimate, 0.0);
782    }
783
784    #[test]
785    fn test_estimate_time_formula_cpu() {
786        // texts=10, avg_len=100 → tokens_per_text = min(25, 256) = 25
787        // total_tokens = 250; tps = 5000; time = (250/5000)*1000 = 50ms
788        let model = EmbeddingModel::MiniLM;
789        let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
790        let total_tokens = tokens_per_text * 10.0;
791        let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
792        assert!(
793            (estimate - 50.0).abs() < 1e-6,
794            "expected 50.0ms, got {estimate}"
795        );
796    }
797
798    #[test]
799    fn test_estimate_time_capped_at_max_seq_length() {
800        let model = EmbeddingModel::MiniLM;
801        let avg_len = 100_000;
802        let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
803        assert_eq!(tokens_per_text, 256.0);
804    }
805
806    // ── ModelConfig API ───────────────────────────────────────────────────────
807
808    #[test]
809    fn test_model_config_new() {
810        let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
811        assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
812        assert_eq!(cfg.max_batch_size, 8);
813        assert!(!cfg.use_gpu);
814        assert!(cfg.cache_dir.is_none());
815        assert!(cfg.num_threads.is_none());
816    }
817
818    #[test]
819    fn test_model_config_default() {
820        let cfg = ModelConfig::default();
821        assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
822        assert_eq!(cfg.max_batch_size, 8);
823        assert!(!cfg.use_gpu);
824    }
825
826    #[test]
827    fn test_model_config_with_cache_dir() {
828        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
829        assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
830    }
831
832    #[test]
833    fn test_model_config_with_max_batch_size() {
834        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
835        assert_eq!(cfg.max_batch_size, 64);
836    }
837
838    #[test]
839    fn test_model_config_with_gpu() {
840        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
841        assert!(cfg.use_gpu);
842    }
843
844    #[test]
845    fn test_model_config_with_num_threads() {
846        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
847        assert_eq!(cfg.num_threads, Some(8));
848    }
849
850    #[test]
851    fn test_model_config_chained_builder() {
852        let cfg = ModelConfig::new(EmbeddingModel::E5Small)
853            .with_cache_dir("/cache")
854            .with_max_batch_size(16)
855            .with_gpu(false)
856            .with_num_threads(4);
857        assert_eq!(cfg.model, EmbeddingModel::E5Small);
858        assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
859        assert_eq!(cfg.max_batch_size, 16);
860        assert!(!cfg.use_gpu);
861        assert_eq!(cfg.num_threads, Some(4));
862    }
863
864    // ── model_cache_dir edge cases ────────────────────────────────────────────
865
866    /// Test `model_cache_dir` when HOME is not set — should fall back to /tmp.
867    #[test]
868    fn test_model_cache_dir_no_home_fallback() {
869        use std::sync::Mutex;
870        static ENV_LOCK: Mutex<()> = Mutex::new(());
871        let _guard = ENV_LOCK.lock().unwrap();
872
873        // Remove HOME and HF_HOME so we hit the /tmp fallback
874        let saved_home = std::env::var("HOME").ok();
875        let saved_hf = std::env::var("HF_HOME").ok();
876        unsafe {
877            std::env::remove_var("HOME");
878            std::env::remove_var("HF_HOME");
879        }
880
881        let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
882
883        // Restore env
884        if let Some(h) = saved_home {
885            unsafe { std::env::set_var("HOME", h) };
886        }
887        if let Some(h) = saved_hf {
888            unsafe { std::env::set_var("HF_HOME", h) };
889        }
890
891        let path = result.unwrap();
892        // Should be under /tmp since HOME was unset
893        assert!(
894            path.starts_with("/tmp"),
895            "expected path under /tmp, got {path:?}"
896        );
897    }
898
899    #[test]
900    fn test_model_cache_dir_deep_model_id() {
901        let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
902        let s = path.to_str().unwrap();
903        // All slashes replaced with double-dash
904        assert!(
905            s.contains("org--sub--model-name-with-dashes"),
906            "expected transformed path, got: {s}"
907        );
908    }
909
910    #[test]
911    fn test_model_cache_dir_minilm_model_id() {
912        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
913        let s = path.to_str().unwrap();
914        assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
915    }
916
917    #[test]
918    fn test_model_cache_dir_bge_model_id() {
919        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
920        let s = path.to_str().unwrap();
921        assert!(s.contains("BAAI--bge-small-en-v1.5"));
922    }
923
924    #[test]
925    fn test_model_cache_dir_e5_model_id() {
926        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
927        let s = path.to_str().unwrap();
928        assert!(s.contains("intfloat--e5-small-v2"));
929    }
930
931    // ── download_hf_file additional cache-hit variations ─────────────────────
932
933    #[test]
934    fn test_download_hf_file_pytorch_bin_cached() {
935        let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
936        std::fs::create_dir_all(&tmp).unwrap();
937        let model_path = tmp.join("pytorch_model.bin");
938        std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
939
940        let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
941        assert!(result.is_ok());
942        assert_eq!(result.unwrap(), model_path);
943    }
944
945    #[test]
946    fn test_download_hf_file_tokenizer_cached() {
947        let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
948        std::fs::create_dir_all(&tmp).unwrap();
949        let tok_path = tmp.join("tokenizer.json");
950        std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
951
952        let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
953        assert!(result.is_ok());
954        assert_eq!(result.unwrap(), tok_path);
955    }
956
957    #[test]
958    fn test_download_hf_file_config_json_cached() {
959        let tmp = std::env::temp_dir().join("dakera_test_config_cached");
960        std::fs::create_dir_all(&tmp).unwrap();
961        let cfg_path = tmp.join("config.json");
962        std::fs::write(&cfg_path, b"{}").unwrap();
963
964        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
965        assert!(result.is_ok());
966        assert_eq!(result.unwrap(), cfg_path);
967    }
968
969    // ── EmbeddingEngine::new() failure path via fake local cache ─────────────
970
971    /// Tests the code path through `download_model_files` (local Dakera cache hit)
972    /// and into `new()` — which then fails trying to load the tokenizer from a
973    /// fake file. No network access required; fake files are pre-seeded.
974    #[tokio::test]
975    #[allow(clippy::await_holding_lock)]
976    async fn test_new_fails_with_invalid_tokenizer_json() {
977        use std::sync::Mutex;
978        static ENV_LOCK: Mutex<()> = Mutex::new(());
979        let _guard = ENV_LOCK.lock().unwrap();
980
981        // Set up a fake Dakera model cache so download_model_files finds our files
982        let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
983        let model_dir = tmp
984            .join("dakera")
985            .join("sentence-transformers--all-MiniLM-L6-v2");
986        std::fs::create_dir_all(&model_dir).unwrap();
987        // Valid-looking model weights placeholder (candle will fail on this, which is fine)
988        std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
989        // Invalid tokenizer.json — will cause TokenizationError in new()
990        std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
991        std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
992
993        unsafe { std::env::set_var("HF_HOME", &tmp) };
994
995        let config = ModelConfig::new(EmbeddingModel::MiniLM);
996        let result = EmbeddingEngine::new(config).await;
997
998        unsafe { std::env::remove_var("HF_HOME") };
999
1000        // Must fail — tokenizer.json is invalid JSON
1001        assert!(
1002            result.is_err(),
1003            "expected Err from new() with invalid tokenizer, got Ok"
1004        );
1005    }
1006
1007    // ── EmbeddingEngineBuilder additional coverage ────────────────────────────
1008
1009    #[test]
1010    fn test_builder_with_all_models() {
1011        for model in [
1012            EmbeddingModel::MiniLM,
1013            EmbeddingModel::BgeSmall,
1014            EmbeddingModel::E5Small,
1015        ] {
1016            let builder = EmbeddingEngineBuilder::new().model(model);
1017            assert_eq!(builder.config.model, model);
1018        }
1019    }
1020
1021    #[test]
1022    fn test_builder_max_batch_size_one() {
1023        let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1024        assert_eq!(builder.config.max_batch_size, 1);
1025    }
1026
1027    #[test]
1028    fn test_builder_num_threads_zero() {
1029        let builder = EmbeddingEngineBuilder::new().num_threads(0);
1030        assert_eq!(builder.config.num_threads, Some(0));
1031    }
1032
1033    // ── EmbeddingEngine::new() / getters when model is cached (best-effort) ──
1034
1035    /// If the embedding model is already cached on this machine, exercise the
1036    /// full `new()` path and test getters. On machines without a cached model
1037    /// the test passes silently — it is intentionally non-gating.
1038    #[tokio::test]
1039    async fn test_engine_getters_when_model_cached() {
1040        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1041        match EmbeddingEngine::new(config).await {
1042            Ok(engine) => {
1043                assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1044                assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1045                // Device should be CPU in test environments (device() removed in CE-3 ONNX migration)
1046                // Debug impl should not panic
1047                let _ = format!("{:?}", engine);
1048                // estimate_time_ms should return a non-negative value
1049                let ms = engine.estimate_time_ms(10, 50);
1050                assert!(ms >= 0.0);
1051            }
1052            Err(_) => {
1053                // Model not in cache — skip; CI runner may or may not have it
1054            }
1055        }
1056    }
1057
1058    /// When model is cached: embed an empty batch must return immediately with
1059    /// no embeddings (the `texts.is_empty()` fast-path in embed_batch_internal).
1060    #[tokio::test]
1061    async fn test_engine_embed_empty_batch_when_cached() {
1062        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1063        if let Ok(engine) = EmbeddingEngine::new(config).await {
1064            let result = engine.embed_raw(&[]).await;
1065            assert!(result.is_ok());
1066            assert!(result.unwrap().is_empty());
1067        }
1068    }
1069
1070    // ── Session pool (DAK-5547) ──────────────────────────────────────────────
1071
1072    #[test]
1073    fn test_session_pool_default_is_4() {
1074        // Default pool size is 4 (DAK-5746: pool=1 regressed LME ingest ~4×; OOM causes fixed).
1075        // DAKERA_ONNX_POOL_SIZE env var still allows override.
1076        let config = ModelConfig::default();
1077        let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1078            .ok()
1079            .and_then(|v| v.parse::<usize>().ok())
1080            .filter(|&n| n >= 1)
1081            .unwrap_or(4);
1082        assert_eq!(config.session_pool_size, expected);
1083    }
1084
1085    #[test]
1086    fn test_session_pool_size_builder_roundtrip() {
1087        let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1088        assert_eq!(builder.config.session_pool_size, 8);
1089    }
1090
1091    #[test]
1092    fn test_session_pool_size_min_enforced() {
1093        let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1094        assert_eq!(
1095            builder.config.session_pool_size, 1,
1096            "pool size 0 must clamp to 1"
1097        );
1098    }
1099
1100    #[test]
1101    fn test_model_config_with_session_pool_size() {
1102        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1103        assert_eq!(cfg.session_pool_size, 2);
1104    }
1105
1106    /// When model is cached: verify the session pool has the expected size.
1107    #[tokio::test]
1108    async fn test_engine_pool_size_matches_config_when_cached() {
1109        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1110        if let Ok(engine) = EmbeddingEngine::new(config).await {
1111            assert_eq!(
1112                engine.pool_size(),
1113                2,
1114                "engine should hold exactly 2 sessions"
1115            );
1116        }
1117    }
1118
1119    // ── next_session round-robin ──────────────────────────────────────────────
1120
1121    /// Round-robin counter distributes batches across pool slots without panic.
1122    #[test]
1123    fn test_round_robin_index_stays_in_bounds() {
1124        let pool_len = 4_usize;
1125        let counter = AtomicUsize::new(0);
1126        for expected_idx in 0..100_usize {
1127            let start = counter.fetch_add(1, Ordering::Relaxed);
1128            let slot = start % pool_len;
1129            assert!(slot < pool_len);
1130            assert_eq!(slot, expected_idx % pool_len);
1131        }
1132    }
1133
1134    /// Pool size 1 degrades to single-session behavior without panicking.
1135    #[test]
1136    fn test_round_robin_pool_size_one() {
1137        let pool_len = 1_usize;
1138        let counter = AtomicUsize::new(0);
1139        for _ in 0..50 {
1140            let start = counter.fetch_add(1, Ordering::Relaxed);
1141            assert_eq!(start % pool_len, 0);
1142        }
1143    }
1144
1145    /// When model is cached: empty batch short-circuits before touching the pool.
1146    #[tokio::test]
1147    async fn test_embed_empty_does_not_advance_pool_counter() {
1148        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1149        if let Ok(engine) = EmbeddingEngine::new(config).await {
1150            let result = engine.embed_raw(&[]).await;
1151            assert!(result.is_ok());
1152            assert!(result.unwrap().is_empty());
1153            // Empty batch returns before fetch_add — counter stays at 0.
1154            assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1155        }
1156    }
1157}