Skip to main content

inference/
engine.rs

1//! Core embedding engine for generating vector embeddings from text.
2//!
3//! The `EmbeddingEngine` provides a high-level interface for:
4//! - Loading ONNX INT8 embedding models from HuggingFace Hub
5//! - Generating embeddings for single texts or batches
6//! - Automatic batching and parallel processing via ONNX Runtime
7//!
8//! # Example
9//!
10//! ```no_run
11//! use inference::{EmbeddingEngine, ModelConfig, EmbeddingModel};
12//!
13//! #[tokio::main]
14//! async fn main() {
15//!     let config = ModelConfig::new(EmbeddingModel::MiniLM);
16//!     let engine = EmbeddingEngine::new(config).await.unwrap();
17//!
18//!     // Embed a single query
19//!     let embedding = engine.embed_query("What is machine learning?").await.unwrap();
20//!     println!("Embedding dimension: {}", embedding.len());
21//!
22//!     // Embed multiple documents
23//!     let docs = vec![
24//!         "Machine learning is a subset of AI.".to_string(),
25//!         "Deep learning uses neural networks.".to_string(),
26//!     ];
27//!     let embeddings = engine.embed_documents(&docs).await.unwrap();
28//!     println!("Generated {} embeddings", embeddings.len());
29//! }
30//! ```
31
32use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49/// The main embedding engine for generating vector embeddings.
50///
51/// This struct is thread-safe and can be shared across async tasks.
52/// ORT sessions are mutex-guarded (run() takes &mut self) and held in a
53/// pool so concurrent callers can embed without head-of-line blocking.
54/// CPU-heavy inference is offloaded via `tokio::task::spawn_blocking`.
55pub struct EmbeddingEngine {
56    /// Pool of ONNX Runtime sessions — each guarded independently.
57    /// Concurrent callers round-robin across sessions via `next_session`,
58    /// eliminating Mutex head-of-line blocking for batch embedding workloads.
59    sessions: Vec<Arc<Mutex<Session>>>,
60    /// Round-robin counter: atomically incremented per batch, wraps via modulo.
61    next_session: AtomicUsize,
62    /// Batch processor for tokenization (Arc-wrapped for spawn_blocking)
63    processor: Arc<BatchProcessor>,
64    /// Model configuration
65    config: ModelConfig,
66    /// Embedding dimension
67    dimension: usize,
68}
69
70impl EmbeddingEngine {
71    /// Create a new embedding engine with the given configuration.
72    ///
73    /// Downloads the ONNX INT8 model from HuggingFace Hub if not cached.
74    #[instrument(skip_all, fields(model = %config.model))]
75    pub async fn new(config: ModelConfig) -> Result<Self> {
76        info!(
77            "Initializing ONNX embedding engine with model: {}",
78            config.model
79        );
80
81        // Download tokenizer and ONNX model files
82        let (tokenizer_path, onnx_path) = Self::download_model_files(&config).await?;
83
84        // Load tokenizer
85        info!("Loading tokenizer from {:?}", tokenizer_path);
86        let tokenizer = Tokenizer::from_file(&tokenizer_path)
87            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
88
89        // Build ONNX session pool — N independent sessions to serve concurrent callers.
90        // Each session has its own ORT context so pool members never block each other.
91        info!("Loading ONNX model from {:?}", onnx_path);
92        let num_threads = config.num_threads.unwrap_or(4);
93        let pool_size = config.session_pool_size.max(1);
94        let onnx_path_clone = onnx_path.clone();
95
96        let use_gpu = std::env::var("DAKERA_USE_GPU")
97            .map(|v| v == "1")
98            .unwrap_or(false);
99        if use_gpu {
100            info!("CUDA execution provider enabled (DAKERA_USE_GPU=1)");
101        }
102
103        let sessions: Vec<Arc<Mutex<Session>>> =
104            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
105                (0..pool_size)
106                    .map(|_| {
107                        let builder = Session::builder()
108                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
109                            .with_optimization_level(GraphOptimizationLevel::Level3)
110                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
111                            .with_intra_threads(num_threads)
112                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
113
114                        let mut builder = if use_gpu {
115                            builder
116                                .with_execution_providers(
117                                    [CUDAExecutionProvider::default().build()],
118                                )
119                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
120                        } else {
121                            builder
122                        };
123
124                        let s = builder
125                            .commit_from_file(&onnx_path_clone)
126                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
127                        Ok(Arc::new(Mutex::new(s)))
128                    })
129                    .collect()
130            })
131            .await
132            .map_err(|e| {
133                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
134            })??;
135
136        let dimension = config.model.dimension();
137        let processor = Arc::new(BatchProcessor::new(
138            tokenizer,
139            config.model,
140            config.max_batch_size,
141        ));
142
143        info!(
144            "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
145            config.model, dimension, num_threads, pool_size
146        );
147
148        Ok(Self {
149            sessions,
150            next_session: AtomicUsize::new(0),
151            processor,
152            config,
153            dimension,
154        })
155    }
156
157    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
158    ///
159    /// - `tokenizer.json` — from the original model repo (sentence-transformers, BAAI, intfloat)
160    /// - `onnx/model_quantized.onnx` — from the Xenova ONNX repo (INT8, pre-built)
161    #[instrument(skip_all, fields(model = %config.model))]
162    async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf)> {
163        let model_id = config.model.model_id();
164        let onnx_repo_id = config.model.onnx_repo_id();
165        let onnx_filename = config.model.onnx_filename();
166
167        info!(
168            "Resolving model files: tokenizer={}, onnx={}@{}",
169            model_id, onnx_filename, onnx_repo_id
170        );
171
172        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
173        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
174
175        // ONNX sub-directory mirrors the path within the repo (e.g. "onnx/")
176        let onnx_subdir = onnx_cache_dir.join("onnx");
177        std::fs::create_dir_all(&onnx_subdir)?;
178
179        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
180        // onnx_filename is "onnx/model_quantized.onnx" — basename is the last component
181        let onnx_basename = Path::new(onnx_filename)
182            .file_name()
183            .and_then(|s| s.to_str())
184            .unwrap_or("model_quantized.onnx");
185        let local_onnx = onnx_subdir.join(onnx_basename);
186
187        // Download missing files in a blocking thread
188        let tokenizer_needs_download = !local_tokenizer.exists();
189        let onnx_needs_download = !local_onnx.exists();
190
191        if tokenizer_needs_download || onnx_needs_download {
192            let model_id_owned = model_id.to_string();
193            let onnx_repo_id_owned = onnx_repo_id.to_string();
194            let onnx_filename_owned = onnx_filename.to_string();
195            let tokenizer_cache = tokenizer_cache_dir.clone();
196            let onnx_cache = onnx_cache_dir.clone();
197
198            tokio::task::spawn_blocking(move || {
199                if !tokenizer_cache.join("tokenizer.json").exists() {
200                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
201                        .map_err(|e| {
202                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
203                        })?;
204                }
205                if !onnx_cache.join(&onnx_filename_owned).exists() {
206                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
207                        .map_err(|e| {
208                            InferenceError::HubError(format!(
209                                "Failed to download ONNX model: {}",
210                                e
211                            ))
212                        })?;
213                }
214                Ok::<_, InferenceError>(())
215            })
216            .await
217            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
218        } else {
219            info!("All model files found in local cache");
220        }
221
222        // Re-derive paths (cache dir / onnx / basename)
223        let final_onnx = onnx_cache_dir.join(onnx_filename);
224
225        info!(
226            "Model files ready: tokenizer={:?}, onnx={:?}",
227            local_tokenizer, final_onnx
228        );
229        Ok((local_tokenizer, final_onnx))
230    }
231
232    /// Get or create the local model cache directory.
233    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
234        let base = std::env::var("HF_HOME")
235            .map(PathBuf::from)
236            .unwrap_or_else(|_| {
237                let home = std::env::var("HOME").unwrap_or_else(|_| {
238                    warn!("HOME environment variable not set, using /tmp for model cache");
239                    "/tmp".to_string()
240                });
241                PathBuf::from(home).join(".cache").join("huggingface")
242            });
243        let dir = base.join("dakera").join(model_id.replace('/', "--"));
244        std::fs::create_dir_all(&dir)?;
245        Ok(dir)
246    }
247
248    /// Download a single file from HuggingFace using ureq (sync, for spawn_blocking).
249    ///
250    /// Handles relative Location headers that ureq 2.x cannot resolve automatically.
251    ///
252    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
253    pub fn download_hf_file_pub(
254        model_id: &str,
255        filename: &str,
256        cache_dir: &Path,
257    ) -> std::result::Result<PathBuf, String> {
258        Self::download_hf_file(model_id, filename, cache_dir)
259    }
260
261    fn download_hf_file(
262        model_id: &str,
263        filename: &str,
264        cache_dir: &Path,
265    ) -> std::result::Result<PathBuf, String> {
266        // The file may be nested (e.g. "onnx/model_quantized.onnx")
267        let file_path = cache_dir.join(filename);
268        if file_path.exists() {
269            info!("Cached: {}/{}", model_id, filename);
270            return Ok(file_path);
271        }
272
273        // Ensure parent directory exists (for "onnx/model_quantized.onnx")
274        if let Some(parent) = file_path.parent() {
275            std::fs::create_dir_all(parent)
276                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
277        }
278
279        let url = format!(
280            "https://huggingface.co/{}/resolve/main/{}",
281            model_id, filename
282        );
283        info!("Downloading: {}", url);
284
285        // Read HuggingFace token from env (required for gated models).
286        let hf_token = std::env::var("HF_TOKEN")
287            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
288            .ok();
289        if hf_token.is_some() {
290            info!("Using HuggingFace auth token for download");
291        }
292
293        // Disable automatic redirects so we can resolve relative Location headers ourselves.
294        let agent = ureq::AgentBuilder::new()
295            .redirects(0)
296            .timeout(std::time::Duration::from_secs(300))
297            .build();
298
299        let mut current_url = url.clone();
300        let mut redirects = 0;
301        let max_redirects = 10;
302
303        let response = loop {
304            let mut req = agent.get(&current_url);
305            if let Some(ref token) = hf_token {
306                req = req.set("Authorization", &format!("Bearer {}", token));
307            }
308            let resp = req.call();
309
310            let r = match resp {
311                Ok(r) => r,
312                Err(ureq::Error::Status(_status, r)) => r,
313                Err(e) => return Err(format!("{}: {}", filename, e)),
314            };
315
316            let status = r.status();
317            if (200..300).contains(&status) {
318                break r;
319            } else if (300..400).contains(&status) {
320                redirects += 1;
321                if redirects > max_redirects {
322                    return Err(format!("{}: too many redirects", filename));
323                }
324                let location = r
325                    .header("location")
326                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
327                    .to_string();
328
329                // Resolve relative redirects against the current URL's origin
330                current_url = if location.starts_with('/') {
331                    let parsed = url::Url::parse(&current_url)
332                        .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
333                    let host = parsed.host_str().ok_or_else(|| {
334                        format!("{}: redirect URL missing host: {}", filename, current_url)
335                    })?;
336                    format!("{}://{}{}", parsed.scheme(), host, location)
337                } else {
338                    location
339                };
340                info!("Redirect {} → {}", redirects, current_url);
341            } else {
342                return Err(format!("{}: HTTP {}", filename, status));
343            }
344        };
345
346        let mut bytes = Vec::new();
347        response
348            .into_reader()
349            .take(500_000_000) // 500 MB safety limit
350            .read_to_end(&mut bytes)
351            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
352
353        std::fs::write(&file_path, &bytes)
354            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
355
356        info!("Downloaded {} ({} bytes)", filename, bytes.len());
357        Ok(file_path)
358    }
359
360    /// Get the embedding dimension for the loaded model.
361    pub fn dimension(&self) -> usize {
362        self.dimension
363    }
364
365    /// Get the model being used.
366    pub fn model(&self) -> EmbeddingModel {
367        self.config.model
368    }
369
370    /// Get the number of parallel ONNX sessions in the pool.
371    pub fn pool_size(&self) -> usize {
372        self.sessions.len()
373    }
374
375    /// Embed a single query text.
376    ///
377    /// For models like E5, this automatically applies the query prefix.
378    #[instrument(skip(self, text), fields(text_len = text.len()))]
379    pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
380        let texts = vec![text.to_string()];
381        let prepared = self.processor.prepare_texts(&texts, true);
382        let embeddings = self.embed_batch_internal(&prepared).await?;
383        embeddings.into_iter().next().ok_or_else(|| {
384            InferenceError::InferenceError("No embedding returned for query".to_string())
385        })
386    }
387
388    /// Embed multiple query texts.
389    ///
390    /// For models like E5, this automatically applies the query prefix.
391    #[instrument(skip(self, texts), fields(count = texts.len()))]
392    pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
393        let prepared = self.processor.prepare_texts(texts, true);
394        self.embed_batch_internal(&prepared).await
395    }
396
397    /// Embed a single document/passage.
398    ///
399    /// For models like E5, this automatically applies the document prefix.
400    #[instrument(skip(self, text), fields(text_len = text.len()))]
401    pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
402        let texts = vec![text.to_string()];
403        let prepared = self.processor.prepare_texts(&texts, false);
404        let embeddings = self.embed_batch_internal(&prepared).await?;
405        embeddings.into_iter().next().ok_or_else(|| {
406            InferenceError::InferenceError("No embedding returned for document".to_string())
407        })
408    }
409
410    /// Embed multiple documents/passages.
411    ///
412    /// For models like E5, this automatically applies the document prefix.
413    #[instrument(skip(self, texts), fields(count = texts.len()))]
414    pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
415        let prepared = self.processor.prepare_texts(texts, false);
416        self.embed_batch_internal(&prepared).await
417    }
418
419    /// Embed texts without any prefix (raw embedding).
420    #[instrument(skip(self, texts), fields(count = texts.len()))]
421    pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
422        self.embed_batch_internal(texts).await
423    }
424
425    /// Internal batch embedding implementation.
426    ///
427    /// Splits `texts` into sub-batches (≤ max_batch_size each) and distributes
428    /// them across the session pool via round-robin. All sub-batches are spawned
429    /// concurrently — sessions[i % pool_len] serializes only its own sub-batches,
430    /// so pool_len sub-batches run in true parallel, eliminating head-of-line
431    /// blocking when multiple HTTP handlers embed concurrently (DAK-5547).
432    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
433        if texts.is_empty() {
434            return Ok(vec![]);
435        }
436
437        let batches: Vec<Vec<String>> = self
438            .processor
439            .split_into_batches(texts)
440            .into_iter()
441            .map(|b| b.to_vec())
442            .collect();
443
444        let pool_len = self.sessions.len();
445        let normalize = self.config.model.normalize_embeddings();
446        // Round-robin starting index: each concurrent caller gets a different slot so
447        // concurrent requests don't all contend on sessions[0].
448        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
449
450        // Spawn all sub-batches concurrently; preserve insertion order for reassembly.
451        let mut handles = Vec::with_capacity(batches.len());
452        for (i, batch_owned) in batches.into_iter().enumerate() {
453            let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
454            let processor = Arc::clone(&self.processor);
455            handles.push(tokio::task::spawn_blocking(move || {
456                let mut session_guard = session.lock();
457                Self::process_batch_blocking(
458                    &batch_owned,
459                    &mut session_guard,
460                    &processor,
461                    normalize,
462                )
463            }));
464        }
465
466        let mut all_embeddings = Vec::with_capacity(texts.len());
467        for handle in handles {
468            let batch_embeddings = handle.await.map_err(|e| {
469                InferenceError::InferenceError(format!("Inference task panicked: {}", e))
470            })??;
471            all_embeddings.extend(batch_embeddings);
472        }
473
474        Ok(all_embeddings)
475    }
476
477    /// Process a single batch: tokenize → ORT session → mean pool → normalize.
478    ///
479    /// Designed to run inside `spawn_blocking` (takes no `&self`).
480    fn process_batch_blocking(
481        texts: &[String],
482        session: &mut Session,
483        processor: &BatchProcessor,
484        normalize: bool,
485    ) -> Result<Vec<Vec<f32>>> {
486        // Tokenize
487        let prepared = processor.tokenize_batch(texts)?;
488        let batch_size = prepared.batch_size;
489        let seq_len = prepared.seq_len;
490
491        // Keep a copy of attention_mask for mean_pooling (consumed by Tensor below)
492        let attention_mask_flat = prepared.attention_mask.clone();
493
494        // Build ORT tensors — from_array requires (shape, Vec<T>) in ort rc.12
495        let input_ids_tensor =
496            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
497                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
498        let attention_mask_tensor =
499            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
500                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
501        let token_type_ids_tensor =
502            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
503                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
504
505        // Run ONNX session
506        let outputs = session
507            .run(inputs![
508                "input_ids" => input_ids_tensor,
509                "attention_mask" => attention_mask_tensor,
510                "token_type_ids" => token_type_ids_tensor
511            ])
512            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
513
514        // Extract last_hidden_state: shape [batch, seq_len, hidden_size]
515        // ort rc.12: try_extract_tensor returns (&Shape, &[T])
516        // Shape derefs to [i64], so index directly.
517        let (ort_shape, lhs_slice) = outputs[0]
518            .try_extract_tensor::<f32>()
519            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
520
521        if ort_shape.len() != 3 {
522            return Err(InferenceError::InferenceError(format!(
523                "Expected 3D last_hidden_state, got {} dims",
524                ort_shape.len()
525            )));
526        }
527        let hidden_size = ort_shape[2] as usize;
528
529        // Apply mean pooling using the saved attention mask copy
530        let mut embeddings = mean_pooling(
531            lhs_slice,
532            batch_size,
533            seq_len,
534            hidden_size,
535            &attention_mask_flat,
536        );
537
538        // L2 normalize if configured
539        if normalize {
540            normalize_embeddings(&mut embeddings);
541        }
542
543        debug!(
544            "Generated {} embeddings of dimension {}",
545            embeddings.len(),
546            embeddings.first().map(|e| e.len()).unwrap_or(0)
547        );
548
549        Ok(embeddings)
550    }
551
552    /// Estimate the time to embed a batch of texts (in milliseconds).
553    pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
554        // Rough estimation based on model speed and text length (CPU path)
555        let tokens_per_text =
556            (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
557        let total_tokens = tokens_per_text * text_count as f64;
558        let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
559        (total_tokens / tokens_per_second) * 1000.0
560    }
561}
562
563impl std::fmt::Debug for EmbeddingEngine {
564    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
565        f.debug_struct("EmbeddingEngine")
566            .field("model", &self.config.model)
567            .field("dimension", &self.dimension)
568            .field("max_batch_size", &self.config.max_batch_size)
569            .field("session_pool_size", &self.sessions.len())
570            .finish()
571    }
572}
573
574/// Builder for creating an EmbeddingEngine with fluent API.
575pub struct EmbeddingEngineBuilder {
576    config: ModelConfig,
577}
578
579impl EmbeddingEngineBuilder {
580    /// Create a new builder with default configuration.
581    pub fn new() -> Self {
582        Self {
583            config: ModelConfig::default(),
584        }
585    }
586
587    /// Set the embedding model to use.
588    pub fn model(mut self, model: EmbeddingModel) -> Self {
589        self.config.model = model;
590        self
591    }
592
593    /// Set the cache directory for model files.
594    pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
595        self.config.cache_dir = Some(dir.into());
596        self
597    }
598
599    /// Set the maximum batch size.
600    pub fn max_batch_size(mut self, size: usize) -> Self {
601        self.config.max_batch_size = size;
602        self
603    }
604
605    /// Enable GPU acceleration (reserved for future use; ORT selects the execution provider).
606    pub fn use_gpu(mut self, enable: bool) -> Self {
607        self.config.use_gpu = enable;
608        self
609    }
610
611    /// Set the number of intra-op CPU threads for ORT inference.
612    pub fn num_threads(mut self, threads: usize) -> Self {
613        self.config.num_threads = Some(threads);
614        self
615    }
616
617    /// Set the number of parallel ONNX sessions in the pool.
618    pub fn session_pool_size(mut self, size: usize) -> Self {
619        self.config.session_pool_size = size.max(1);
620        self
621    }
622
623    /// Build the embedding engine.
624    pub async fn build(self) -> Result<EmbeddingEngine> {
625        EmbeddingEngine::new(self.config).await
626    }
627}
628
629impl Default for EmbeddingEngineBuilder {
630    fn default() -> Self {
631        Self::new()
632    }
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638
639    #[test]
640    fn test_estimate_time() {
641        let config = ModelConfig::new(EmbeddingModel::MiniLM);
642        let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
643        assert!(tokens_per_second > 0.0);
644    }
645
646    #[test]
647    fn test_builder() {
648        let builder = EmbeddingEngineBuilder::new()
649            .model(EmbeddingModel::BgeSmall)
650            .max_batch_size(64)
651            .use_gpu(false);
652
653        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
654        assert_eq!(builder.config.max_batch_size, 64);
655        assert!(!builder.config.use_gpu);
656    }
657
658    // ── model_cache_dir ──────────────────────────────────────────────────────
659
660    /// Ensure `model_cache_dir` respects `HF_HOME` when set.
661    ///
662    /// Uses a process-level mutex because `std::env::set_var` is not thread-safe.
663    #[test]
664    fn test_model_cache_dir_with_hf_home() {
665        use std::sync::Mutex;
666        static ENV_LOCK: Mutex<()> = Mutex::new(());
667        let _guard = ENV_LOCK.lock().unwrap();
668
669        let tmp = std::env::temp_dir().join("dakera_test_hf_home");
670        std::env::set_var("HF_HOME", &tmp);
671        let result = EmbeddingEngine::model_cache_dir("org/my-model");
672        std::env::remove_var("HF_HOME");
673
674        let path = result.unwrap();
675        assert!(
676            path.starts_with(&tmp),
677            "expected path under {tmp:?}, got {path:?}"
678        );
679        assert!(
680            path.to_str().unwrap().contains("org--my-model"),
681            "model_id separator not applied: {path:?}"
682        );
683    }
684
685    #[test]
686    fn test_model_cache_dir_contains_dakera_subdir() {
687        let path =
688            EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
689        let s = path.to_str().unwrap();
690        assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
691        assert!(
692            s.contains("sentence-transformers--all-MiniLM-L6-v2"),
693            "expected transformed model id in path: {s}"
694        );
695    }
696
697    #[test]
698    fn test_model_cache_dir_creates_directory() {
699        let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
700        assert!(dir.exists(), "model_cache_dir should create the directory");
701    }
702
703    // ── download_hf_file (cached early-return path) ──────────────────────────
704
705    #[test]
706    fn test_download_hf_file_returns_path_when_already_cached() {
707        let tmp = std::env::temp_dir().join("dakera_test_cached_file");
708        std::fs::create_dir_all(&tmp).unwrap();
709        let file_path = tmp.join("config.json");
710        std::fs::write(&file_path, b"{}").unwrap();
711
712        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
713        assert!(result.is_ok());
714        assert_eq!(result.unwrap(), file_path);
715    }
716
717    #[test]
718    fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
719        let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
720        let onnx_dir = tmp.join("onnx");
721        std::fs::create_dir_all(&onnx_dir).unwrap();
722        let onnx_path = onnx_dir.join("model_quantized.onnx");
723        std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
724
725        // Filename includes the subdirectory path
726        let result = EmbeddingEngine::download_hf_file(
727            "Xenova/all-MiniLM-L6-v2",
728            "onnx/model_quantized.onnx",
729            &tmp,
730        );
731        assert!(result.is_ok());
732        assert_eq!(result.unwrap(), onnx_path);
733    }
734
735    // ── EmbeddingEngineBuilder ───────────────────────────────────────────────
736
737    #[test]
738    fn test_builder_default_impl() {
739        let b1 = EmbeddingEngineBuilder::new();
740        let b2 = EmbeddingEngineBuilder::default();
741        assert_eq!(b1.config.model, b2.config.model);
742        assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
743    }
744
745    #[test]
746    fn test_builder_model_field() {
747        let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
748        assert_eq!(builder.config.model, EmbeddingModel::E5Small);
749    }
750
751    #[test]
752    fn test_builder_cache_dir() {
753        let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
754        assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
755    }
756
757    #[test]
758    fn test_builder_max_batch_size() {
759        let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
760        assert_eq!(builder.config.max_batch_size, 128);
761    }
762
763    #[test]
764    fn test_builder_use_gpu_true() {
765        let builder = EmbeddingEngineBuilder::new().use_gpu(true);
766        assert!(builder.config.use_gpu);
767    }
768
769    #[test]
770    fn test_builder_use_gpu_false() {
771        let builder = EmbeddingEngineBuilder::new().use_gpu(false);
772        assert!(!builder.config.use_gpu);
773    }
774
775    #[test]
776    fn test_builder_num_threads() {
777        let builder = EmbeddingEngineBuilder::new().num_threads(4);
778        assert_eq!(builder.config.num_threads, Some(4));
779    }
780
781    #[test]
782    fn test_builder_chain_all_fields() {
783        let builder = EmbeddingEngineBuilder::new()
784            .model(EmbeddingModel::BgeSmall)
785            .cache_dir("/cache")
786            .max_batch_size(16)
787            .use_gpu(false)
788            .num_threads(2);
789
790        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
791        assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
792        assert_eq!(builder.config.max_batch_size, 16);
793        assert!(!builder.config.use_gpu);
794        assert_eq!(builder.config.num_threads, Some(2));
795    }
796
797    // ── estimate_time_ms ─────────────────────────────────────────────────────
798
799    #[test]
800    fn test_estimate_time_zero_count() {
801        let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
802        let estimate = (0.0 / tps) * 1000.0;
803        assert_eq!(estimate, 0.0);
804    }
805
806    #[test]
807    fn test_estimate_time_formula_cpu() {
808        // texts=10, avg_len=100 → tokens_per_text = min(25, 256) = 25
809        // total_tokens = 250; tps = 5000; time = (250/5000)*1000 = 50ms
810        let model = EmbeddingModel::MiniLM;
811        let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
812        let total_tokens = tokens_per_text * 10.0;
813        let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
814        assert!(
815            (estimate - 50.0).abs() < 1e-6,
816            "expected 50.0ms, got {estimate}"
817        );
818    }
819
820    #[test]
821    fn test_estimate_time_capped_at_max_seq_length() {
822        let model = EmbeddingModel::MiniLM;
823        let avg_len = 100_000;
824        let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
825        assert_eq!(tokens_per_text, 256.0);
826    }
827
828    // ── ModelConfig API ───────────────────────────────────────────────────────
829
830    #[test]
831    fn test_model_config_new() {
832        let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
833        assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
834        assert_eq!(cfg.max_batch_size, 8);
835        assert!(!cfg.use_gpu);
836        assert!(cfg.cache_dir.is_none());
837        assert!(cfg.num_threads.is_none());
838    }
839
840    #[test]
841    fn test_model_config_default() {
842        let cfg = ModelConfig::default();
843        assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
844        assert_eq!(cfg.max_batch_size, 8);
845        assert!(!cfg.use_gpu);
846    }
847
848    #[test]
849    fn test_model_config_with_cache_dir() {
850        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
851        assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
852    }
853
854    #[test]
855    fn test_model_config_with_max_batch_size() {
856        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
857        assert_eq!(cfg.max_batch_size, 64);
858    }
859
860    #[test]
861    fn test_model_config_with_gpu() {
862        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
863        assert!(cfg.use_gpu);
864    }
865
866    #[test]
867    fn test_model_config_with_num_threads() {
868        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
869        assert_eq!(cfg.num_threads, Some(8));
870    }
871
872    #[test]
873    fn test_model_config_chained_builder() {
874        let cfg = ModelConfig::new(EmbeddingModel::E5Small)
875            .with_cache_dir("/cache")
876            .with_max_batch_size(16)
877            .with_gpu(false)
878            .with_num_threads(4);
879        assert_eq!(cfg.model, EmbeddingModel::E5Small);
880        assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
881        assert_eq!(cfg.max_batch_size, 16);
882        assert!(!cfg.use_gpu);
883        assert_eq!(cfg.num_threads, Some(4));
884    }
885
886    // ── model_cache_dir edge cases ────────────────────────────────────────────
887
888    /// Test `model_cache_dir` when HOME is not set — should fall back to /tmp.
889    #[test]
890    fn test_model_cache_dir_no_home_fallback() {
891        use std::sync::Mutex;
892        static ENV_LOCK: Mutex<()> = Mutex::new(());
893        let _guard = ENV_LOCK.lock().unwrap();
894
895        // Remove HOME and HF_HOME so we hit the /tmp fallback
896        let saved_home = std::env::var("HOME").ok();
897        let saved_hf = std::env::var("HF_HOME").ok();
898        unsafe {
899            std::env::remove_var("HOME");
900            std::env::remove_var("HF_HOME");
901        }
902
903        let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
904
905        // Restore env
906        if let Some(h) = saved_home {
907            unsafe { std::env::set_var("HOME", h) };
908        }
909        if let Some(h) = saved_hf {
910            unsafe { std::env::set_var("HF_HOME", h) };
911        }
912
913        let path = result.unwrap();
914        // Should be under /tmp since HOME was unset
915        assert!(
916            path.starts_with("/tmp"),
917            "expected path under /tmp, got {path:?}"
918        );
919    }
920
921    #[test]
922    fn test_model_cache_dir_deep_model_id() {
923        let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
924        let s = path.to_str().unwrap();
925        // All slashes replaced with double-dash
926        assert!(
927            s.contains("org--sub--model-name-with-dashes"),
928            "expected transformed path, got: {s}"
929        );
930    }
931
932    #[test]
933    fn test_model_cache_dir_minilm_model_id() {
934        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
935        let s = path.to_str().unwrap();
936        assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
937    }
938
939    #[test]
940    fn test_model_cache_dir_bge_model_id() {
941        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
942        let s = path.to_str().unwrap();
943        assert!(s.contains("BAAI--bge-small-en-v1.5"));
944    }
945
946    #[test]
947    fn test_model_cache_dir_e5_model_id() {
948        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
949        let s = path.to_str().unwrap();
950        assert!(s.contains("intfloat--e5-small-v2"));
951    }
952
953    // ── download_hf_file additional cache-hit variations ─────────────────────
954
955    #[test]
956    fn test_download_hf_file_pytorch_bin_cached() {
957        let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
958        std::fs::create_dir_all(&tmp).unwrap();
959        let model_path = tmp.join("pytorch_model.bin");
960        std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
961
962        let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
963        assert!(result.is_ok());
964        assert_eq!(result.unwrap(), model_path);
965    }
966
967    #[test]
968    fn test_download_hf_file_tokenizer_cached() {
969        let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
970        std::fs::create_dir_all(&tmp).unwrap();
971        let tok_path = tmp.join("tokenizer.json");
972        std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
973
974        let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
975        assert!(result.is_ok());
976        assert_eq!(result.unwrap(), tok_path);
977    }
978
979    #[test]
980    fn test_download_hf_file_config_json_cached() {
981        let tmp = std::env::temp_dir().join("dakera_test_config_cached");
982        std::fs::create_dir_all(&tmp).unwrap();
983        let cfg_path = tmp.join("config.json");
984        std::fs::write(&cfg_path, b"{}").unwrap();
985
986        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
987        assert!(result.is_ok());
988        assert_eq!(result.unwrap(), cfg_path);
989    }
990
991    // ── EmbeddingEngine::new() failure path via fake local cache ─────────────
992
993    /// Tests the code path through `download_model_files` (local Dakera cache hit)
994    /// and into `new()` — which then fails trying to load the tokenizer from a
995    /// fake file. No network access required; fake files are pre-seeded.
996    #[tokio::test]
997    #[allow(clippy::await_holding_lock)]
998    async fn test_new_fails_with_invalid_tokenizer_json() {
999        use std::sync::Mutex;
1000        static ENV_LOCK: Mutex<()> = Mutex::new(());
1001        let _guard = ENV_LOCK.lock().unwrap();
1002
1003        // Set up a fake Dakera model cache so download_model_files finds our files
1004        let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1005        let model_dir = tmp
1006            .join("dakera")
1007            .join("sentence-transformers--all-MiniLM-L6-v2");
1008        std::fs::create_dir_all(&model_dir).unwrap();
1009        // Valid-looking model weights placeholder (candle will fail on this, which is fine)
1010        std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1011        // Invalid tokenizer.json — will cause TokenizationError in new()
1012        std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1013        std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1014
1015        unsafe { std::env::set_var("HF_HOME", &tmp) };
1016
1017        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1018        let result = EmbeddingEngine::new(config).await;
1019
1020        unsafe { std::env::remove_var("HF_HOME") };
1021
1022        // Must fail — tokenizer.json is invalid JSON
1023        assert!(
1024            result.is_err(),
1025            "expected Err from new() with invalid tokenizer, got Ok"
1026        );
1027    }
1028
1029    // ── EmbeddingEngineBuilder additional coverage ────────────────────────────
1030
1031    #[test]
1032    fn test_builder_with_all_models() {
1033        for model in [
1034            EmbeddingModel::MiniLM,
1035            EmbeddingModel::BgeSmall,
1036            EmbeddingModel::E5Small,
1037        ] {
1038            let builder = EmbeddingEngineBuilder::new().model(model);
1039            assert_eq!(builder.config.model, model);
1040        }
1041    }
1042
1043    #[test]
1044    fn test_builder_max_batch_size_one() {
1045        let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1046        assert_eq!(builder.config.max_batch_size, 1);
1047    }
1048
1049    #[test]
1050    fn test_builder_num_threads_zero() {
1051        let builder = EmbeddingEngineBuilder::new().num_threads(0);
1052        assert_eq!(builder.config.num_threads, Some(0));
1053    }
1054
1055    // ── EmbeddingEngine::new() / getters when model is cached (best-effort) ──
1056
1057    /// If the embedding model is already cached on this machine, exercise the
1058    /// full `new()` path and test getters. On machines without a cached model
1059    /// the test passes silently — it is intentionally non-gating.
1060    #[tokio::test]
1061    async fn test_engine_getters_when_model_cached() {
1062        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1063        match EmbeddingEngine::new(config).await {
1064            Ok(engine) => {
1065                assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1066                assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1067                // Device should be CPU in test environments (device() removed in CE-3 ONNX migration)
1068                // Debug impl should not panic
1069                let _ = format!("{:?}", engine);
1070                // estimate_time_ms should return a non-negative value
1071                let ms = engine.estimate_time_ms(10, 50);
1072                assert!(ms >= 0.0);
1073            }
1074            Err(_) => {
1075                // Model not in cache — skip; CI runner may or may not have it
1076            }
1077        }
1078    }
1079
1080    /// When model is cached: embed an empty batch must return immediately with
1081    /// no embeddings (the `texts.is_empty()` fast-path in embed_batch_internal).
1082    #[tokio::test]
1083    async fn test_engine_embed_empty_batch_when_cached() {
1084        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1085        if let Ok(engine) = EmbeddingEngine::new(config).await {
1086            let result = engine.embed_raw(&[]).await;
1087            assert!(result.is_ok());
1088            assert!(result.unwrap().is_empty());
1089        }
1090    }
1091
1092    // ── Session pool (DAK-5547) ──────────────────────────────────────────────
1093
1094    #[test]
1095    fn test_session_pool_default_is_4() {
1096        // Default pool size is 4 (DAK-5746: pool=1 regressed LME ingest ~4×; OOM causes fixed).
1097        // DAKERA_ONNX_POOL_SIZE env var still allows override.
1098        let config = ModelConfig::default();
1099        let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1100            .ok()
1101            .and_then(|v| v.parse::<usize>().ok())
1102            .filter(|&n| n >= 1)
1103            .unwrap_or(4);
1104        assert_eq!(config.session_pool_size, expected);
1105    }
1106
1107    #[test]
1108    fn test_session_pool_size_builder_roundtrip() {
1109        let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1110        assert_eq!(builder.config.session_pool_size, 8);
1111    }
1112
1113    #[test]
1114    fn test_session_pool_size_min_enforced() {
1115        let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1116        assert_eq!(
1117            builder.config.session_pool_size, 1,
1118            "pool size 0 must clamp to 1"
1119        );
1120    }
1121
1122    #[test]
1123    fn test_model_config_with_session_pool_size() {
1124        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1125        assert_eq!(cfg.session_pool_size, 2);
1126    }
1127
1128    /// When model is cached: verify the session pool has the expected size.
1129    #[tokio::test]
1130    async fn test_engine_pool_size_matches_config_when_cached() {
1131        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1132        if let Ok(engine) = EmbeddingEngine::new(config).await {
1133            assert_eq!(
1134                engine.pool_size(),
1135                2,
1136                "engine should hold exactly 2 sessions"
1137            );
1138        }
1139    }
1140
1141    // ── next_session round-robin ──────────────────────────────────────────────
1142
1143    /// Round-robin counter distributes batches across pool slots without panic.
1144    #[test]
1145    fn test_round_robin_index_stays_in_bounds() {
1146        let pool_len = 4_usize;
1147        let counter = AtomicUsize::new(0);
1148        for expected_idx in 0..100_usize {
1149            let start = counter.fetch_add(1, Ordering::Relaxed);
1150            let slot = start % pool_len;
1151            assert!(slot < pool_len);
1152            assert_eq!(slot, expected_idx % pool_len);
1153        }
1154    }
1155
1156    /// Pool size 1 degrades to single-session behavior without panicking.
1157    #[test]
1158    fn test_round_robin_pool_size_one() {
1159        let pool_len = 1_usize;
1160        let counter = AtomicUsize::new(0);
1161        for _ in 0..50 {
1162            let start = counter.fetch_add(1, Ordering::Relaxed);
1163            assert_eq!(start % pool_len, 0);
1164        }
1165    }
1166
1167    /// When model is cached: empty batch short-circuits before touching the pool.
1168    #[tokio::test]
1169    async fn test_embed_empty_does_not_advance_pool_counter() {
1170        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1171        if let Ok(engine) = EmbeddingEngine::new(config).await {
1172            let result = engine.embed_raw(&[]).await;
1173            assert!(result.is_ok());
1174            assert!(result.unwrap().is_empty());
1175            // Empty batch returns before fetch_add — counter stays at 0.
1176            assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1177        }
1178    }
1179}