Skip to main content

inference/
engine.rs

1//! Core embedding engine for generating vector embeddings from text.
2//!
3//! The `EmbeddingEngine` provides a high-level interface for:
4//! - Loading ONNX INT8 embedding models from HuggingFace Hub
5//! - Generating embeddings for single texts or batches
6//! - Automatic batching and parallel processing via ONNX Runtime
7//!
8//! # Example
9//!
10//! ```no_run
11//! use inference::{EmbeddingEngine, ModelConfig, EmbeddingModel};
12//!
13//! #[tokio::main]
14//! async fn main() {
15//!     let config = ModelConfig::new(EmbeddingModel::MiniLM);
16//!     let engine = EmbeddingEngine::new(config).await.unwrap();
17//!
18//!     // Embed a single query
19//!     let embedding = engine.embed_query("What is machine learning?").await.unwrap();
20//!     println!("Embedding dimension: {}", embedding.len());
21//!
22//!     // Embed multiple documents
23//!     let docs = vec![
24//!         "Machine learning is a subset of AI.".to_string(),
25//!         "Deep learning uses neural networks.".to_string(),
26//!     ];
27//!     let embeddings = engine.embed_documents(&docs).await.unwrap();
28//!     println!("Generated {} embeddings", embeddings.len());
29//! }
30//! ```
31
32use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49/// The main embedding engine for generating vector embeddings.
50///
51/// This struct is thread-safe and can be shared across async tasks.
52/// ORT sessions are mutex-guarded (run() takes &mut self) and held in a
53/// pool so concurrent callers can embed without head-of-line blocking.
54/// CPU-heavy inference is offloaded via `tokio::task::spawn_blocking`.
55pub struct EmbeddingEngine {
56    /// Pool of ONNX Runtime sessions — each guarded independently.
57    /// Concurrent callers round-robin across sessions via `next_session`,
58    /// eliminating Mutex head-of-line blocking for batch embedding workloads.
59    sessions: Vec<Arc<Mutex<Session>>>,
60    /// Round-robin counter: atomically incremented per batch, wraps via modulo.
61    next_session: AtomicUsize,
62    /// Batch processor for tokenization (Arc-wrapped for spawn_blocking)
63    processor: Arc<BatchProcessor>,
64    /// Model configuration
65    config: ModelConfig,
66    /// Embedding dimension
67    dimension: usize,
68}
69
70impl EmbeddingEngine {
71    /// Create a new embedding engine with the given configuration.
72    ///
73    /// Downloads the FP32 ONNX model (`model.onnx`) when CUDA EP is enabled, or the
74    /// INT8 quantized model (`model_quantized.onnx`) otherwise.
75    #[instrument(skip_all, fields(model = %config.model))]
76    pub async fn new(config: ModelConfig) -> Result<Self> {
77        // Resolve GPU mode before downloading model files — GPU requires the FP32 model.
78        // DAKERA_USE_GPU=1 overrides config.use_gpu (env var is the production knob).
79        let use_gpu = std::env::var("DAKERA_USE_GPU")
80            .map(|v| v == "1")
81            .unwrap_or(config.use_gpu);
82        if use_gpu {
83            info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
84        }
85
86        info!(
87            "Initializing ONNX embedding engine with model: {}",
88            config.model
89        );
90
91        // Download tokenizer and ONNX model files (FP32 for GPU, INT8 for CPU)
92        let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
93
94        // Load tokenizer
95        info!("Loading tokenizer from {:?}", tokenizer_path);
96        let tokenizer = Tokenizer::from_file(&tokenizer_path)
97            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
98
99        // Build ONNX session pool — N independent sessions to serve concurrent callers.
100        // Each session has its own ORT context so pool members never block each other.
101        info!("Loading ONNX model from {:?}", onnx_path);
102        let num_threads = config.num_threads.unwrap_or(4);
103        let pool_size = config.session_pool_size.max(1);
104        let onnx_path_clone = onnx_path.clone();
105
106        let sessions: Vec<Arc<Mutex<Session>>> =
107            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
108                (0..pool_size)
109                    .map(|_| {
110                        let builder = Session::builder()
111                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
112                            .with_optimization_level(GraphOptimizationLevel::Level3)
113                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
114                            .with_intra_threads(num_threads)
115                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
116
117                        let mut builder = if use_gpu {
118                            builder
119                                .with_execution_providers(
120                                    [CUDAExecutionProvider::default().build()],
121                                )
122                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
123                        } else {
124                            builder
125                        };
126
127                        let s = builder
128                            .commit_from_file(&onnx_path_clone)
129                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
130                        Ok(Arc::new(Mutex::new(s)))
131                    })
132                    .collect()
133            })
134            .await
135            .map_err(|e| {
136                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
137            })??;
138
139        let dimension = config.model.dimension();
140        let processor = Arc::new(BatchProcessor::new(
141            tokenizer,
142            config.model,
143            config.max_batch_size,
144        ));
145
146        info!(
147            "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
148            config.model, dimension, num_threads, pool_size
149        );
150
151        Ok(Self {
152            sessions,
153            next_session: AtomicUsize::new(0),
154            processor,
155            config,
156            dimension,
157        })
158    }
159
160    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
161    ///
162    /// - `tokenizer.json` — from the original model repo (sentence-transformers, BAAI, intfloat)
163    /// - GPU: `onnx/model.onnx` (FP32) — no Memcpy round-trips, CUDA EP runs entirely on-device
164    /// - CPU: `onnx/model_quantized.onnx` (INT8) — fastest for CPU-only deployments
165    #[instrument(skip_all, fields(model = %config.model))]
166    async fn download_model_files(
167        config: &ModelConfig,
168        use_gpu: bool,
169    ) -> Result<(PathBuf, PathBuf)> {
170        let model_id = config.model.model_id();
171        let onnx_repo_id = config.model.onnx_repo_id();
172        let onnx_filename = if use_gpu {
173            config.model.onnx_filename_gpu()
174        } else {
175            config.model.onnx_filename()
176        };
177
178        info!(
179            "Resolving model files: tokenizer={}, onnx={}@{}",
180            model_id, onnx_filename, onnx_repo_id
181        );
182
183        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
184        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
185
186        // ONNX sub-directory mirrors the path within the repo (e.g. "onnx/")
187        let onnx_subdir = onnx_cache_dir.join("onnx");
188        std::fs::create_dir_all(&onnx_subdir)?;
189
190        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
191        // onnx_filename is e.g. "onnx/model_quantized.onnx" or "onnx/model.onnx" — use the basename
192        let onnx_basename = Path::new(onnx_filename)
193            .file_name()
194            .and_then(|s| s.to_str())
195            .unwrap_or("model_quantized.onnx");
196        let local_onnx = onnx_subdir.join(onnx_basename);
197
198        // Download missing files in a blocking thread
199        let tokenizer_needs_download = !local_tokenizer.exists();
200        let onnx_needs_download = !local_onnx.exists();
201
202        if tokenizer_needs_download || onnx_needs_download {
203            let model_id_owned = model_id.to_string();
204            let onnx_repo_id_owned = onnx_repo_id.to_string();
205            let onnx_filename_owned = onnx_filename.to_string();
206            let tokenizer_cache = tokenizer_cache_dir.clone();
207            let onnx_cache = onnx_cache_dir.clone();
208
209            tokio::task::spawn_blocking(move || {
210                if !tokenizer_cache.join("tokenizer.json").exists() {
211                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
212                        .map_err(|e| {
213                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
214                        })?;
215                }
216                if !onnx_cache.join(&onnx_filename_owned).exists() {
217                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
218                        .map_err(|e| {
219                            InferenceError::HubError(format!(
220                                "Failed to download ONNX model: {}",
221                                e
222                            ))
223                        })?;
224                }
225                Ok::<_, InferenceError>(())
226            })
227            .await
228            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
229        } else {
230            info!("All model files found in local cache");
231        }
232
233        // Re-derive paths (cache dir / onnx / basename)
234        let final_onnx = onnx_cache_dir.join(onnx_filename);
235
236        info!(
237            "Model files ready: tokenizer={:?}, onnx={:?}",
238            local_tokenizer, final_onnx
239        );
240        Ok((local_tokenizer, final_onnx))
241    }
242
243    /// Get or create the local model cache directory.
244    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
245        let base = std::env::var("HF_HOME")
246            .map(PathBuf::from)
247            .unwrap_or_else(|_| {
248                let home = std::env::var("HOME").unwrap_or_else(|_| {
249                    warn!("HOME environment variable not set, using /tmp for model cache");
250                    "/tmp".to_string()
251                });
252                PathBuf::from(home).join(".cache").join("huggingface")
253            });
254        let dir = base.join("dakera").join(model_id.replace('/', "--"));
255        std::fs::create_dir_all(&dir)?;
256        Ok(dir)
257    }
258
259    /// Download a single file from HuggingFace using ureq (sync, for spawn_blocking).
260    ///
261    /// Handles relative Location headers that ureq 2.x cannot resolve automatically.
262    ///
263    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
264    pub fn download_hf_file_pub(
265        model_id: &str,
266        filename: &str,
267        cache_dir: &Path,
268    ) -> std::result::Result<PathBuf, String> {
269        Self::download_hf_file(model_id, filename, cache_dir)
270    }
271
272    fn download_hf_file(
273        model_id: &str,
274        filename: &str,
275        cache_dir: &Path,
276    ) -> std::result::Result<PathBuf, String> {
277        // The file may be nested (e.g. "onnx/model_quantized.onnx")
278        let file_path = cache_dir.join(filename);
279        if file_path.exists() {
280            info!("Cached: {}/{}", model_id, filename);
281            return Ok(file_path);
282        }
283
284        // Ensure parent directory exists (for "onnx/model_quantized.onnx")
285        if let Some(parent) = file_path.parent() {
286            std::fs::create_dir_all(parent)
287                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
288        }
289
290        let url = format!(
291            "https://huggingface.co/{}/resolve/main/{}",
292            model_id, filename
293        );
294        info!("Downloading: {}", url);
295
296        // Read HuggingFace token from env (required for gated models).
297        let hf_token = std::env::var("HF_TOKEN")
298            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
299            .ok();
300        if hf_token.is_some() {
301            info!("Using HuggingFace auth token for download");
302        }
303
304        // Disable automatic redirects so we can resolve relative Location headers ourselves.
305        let agent = ureq::AgentBuilder::new()
306            .redirects(0)
307            .timeout(std::time::Duration::from_secs(300))
308            .build();
309
310        let mut current_url = url.clone();
311        let mut redirects = 0;
312        let max_redirects = 10;
313
314        let response = loop {
315            let mut req = agent.get(&current_url);
316            if let Some(ref token) = hf_token {
317                req = req.set("Authorization", &format!("Bearer {}", token));
318            }
319            let resp = req.call();
320
321            let r = match resp {
322                Ok(r) => r,
323                Err(ureq::Error::Status(_status, r)) => r,
324                Err(e) => return Err(format!("{}: {}", filename, e)),
325            };
326
327            let status = r.status();
328            if (200..300).contains(&status) {
329                break r;
330            } else if (300..400).contains(&status) {
331                redirects += 1;
332                if redirects > max_redirects {
333                    return Err(format!("{}: too many redirects", filename));
334                }
335                let location = r
336                    .header("location")
337                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
338                    .to_string();
339
340                // Resolve relative redirects against the current URL's origin
341                current_url = if location.starts_with('/') {
342                    let parsed = url::Url::parse(&current_url)
343                        .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
344                    let host = parsed.host_str().ok_or_else(|| {
345                        format!("{}: redirect URL missing host: {}", filename, current_url)
346                    })?;
347                    format!("{}://{}{}", parsed.scheme(), host, location)
348                } else {
349                    location
350                };
351                info!("Redirect {} → {}", redirects, current_url);
352            } else {
353                return Err(format!("{}: HTTP {}", filename, status));
354            }
355        };
356
357        let mut bytes = Vec::new();
358        response
359            .into_reader()
360            .take(500_000_000) // 500 MB safety limit
361            .read_to_end(&mut bytes)
362            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
363
364        std::fs::write(&file_path, &bytes)
365            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
366
367        info!("Downloaded {} ({} bytes)", filename, bytes.len());
368        Ok(file_path)
369    }
370
371    /// Get the embedding dimension for the loaded model.
372    pub fn dimension(&self) -> usize {
373        self.dimension
374    }
375
376    /// Get the model being used.
377    pub fn model(&self) -> EmbeddingModel {
378        self.config.model
379    }
380
381    /// Get the number of parallel ONNX sessions in the pool.
382    pub fn pool_size(&self) -> usize {
383        self.sessions.len()
384    }
385
386    /// Embed a single query text.
387    ///
388    /// For models like E5, this automatically applies the query prefix.
389    #[instrument(skip(self, text), fields(text_len = text.len()))]
390    pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
391        let texts = vec![text.to_string()];
392        let prepared = self.processor.prepare_texts(&texts, true);
393        let embeddings = self.embed_batch_internal(&prepared).await?;
394        embeddings.into_iter().next().ok_or_else(|| {
395            InferenceError::InferenceError("No embedding returned for query".to_string())
396        })
397    }
398
399    /// Embed multiple query texts.
400    ///
401    /// For models like E5, this automatically applies the query prefix.
402    #[instrument(skip(self, texts), fields(count = texts.len()))]
403    pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
404        let prepared = self.processor.prepare_texts(texts, true);
405        self.embed_batch_internal(&prepared).await
406    }
407
408    /// Embed a single document/passage.
409    ///
410    /// For models like E5, this automatically applies the document prefix.
411    #[instrument(skip(self, text), fields(text_len = text.len()))]
412    pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
413        let texts = vec![text.to_string()];
414        let prepared = self.processor.prepare_texts(&texts, false);
415        let embeddings = self.embed_batch_internal(&prepared).await?;
416        embeddings.into_iter().next().ok_or_else(|| {
417            InferenceError::InferenceError("No embedding returned for document".to_string())
418        })
419    }
420
421    /// Embed multiple documents/passages.
422    ///
423    /// For models like E5, this automatically applies the document prefix.
424    #[instrument(skip(self, texts), fields(count = texts.len()))]
425    pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
426        let prepared = self.processor.prepare_texts(texts, false);
427        self.embed_batch_internal(&prepared).await
428    }
429
430    /// Embed texts without any prefix (raw embedding).
431    #[instrument(skip(self, texts), fields(count = texts.len()))]
432    pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
433        self.embed_batch_internal(texts).await
434    }
435
436    /// Internal batch embedding implementation.
437    ///
438    /// Splits `texts` into sub-batches (≤ max_batch_size each) and distributes
439    /// them across the session pool via round-robin. All sub-batches are spawned
440    /// concurrently — sessions[i % pool_len] serializes only its own sub-batches,
441    /// so pool_len sub-batches run in true parallel, eliminating head-of-line
442    /// blocking when multiple HTTP handlers embed concurrently (DAK-5547).
443    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
444        if texts.is_empty() {
445            return Ok(vec![]);
446        }
447
448        let batches: Vec<Vec<String>> = self
449            .processor
450            .split_into_batches(texts)
451            .into_iter()
452            .map(|b| b.to_vec())
453            .collect();
454
455        let pool_len = self.sessions.len();
456        let normalize = self.config.model.normalize_embeddings();
457        // Round-robin starting index: each concurrent caller gets a different slot so
458        // concurrent requests don't all contend on sessions[0].
459        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
460
461        // Spawn all sub-batches concurrently; preserve insertion order for reassembly.
462        let mut handles = Vec::with_capacity(batches.len());
463        for (i, batch_owned) in batches.into_iter().enumerate() {
464            let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
465            let processor = Arc::clone(&self.processor);
466            handles.push(tokio::task::spawn_blocking(move || {
467                let mut session_guard = session.lock();
468                Self::process_batch_blocking(
469                    &batch_owned,
470                    &mut session_guard,
471                    &processor,
472                    normalize,
473                )
474            }));
475        }
476
477        let mut all_embeddings = Vec::with_capacity(texts.len());
478        for handle in handles {
479            let batch_embeddings = handle.await.map_err(|e| {
480                InferenceError::InferenceError(format!("Inference task panicked: {}", e))
481            })??;
482            all_embeddings.extend(batch_embeddings);
483        }
484
485        Ok(all_embeddings)
486    }
487
488    /// Process a single batch: tokenize → ORT session → mean pool → normalize.
489    ///
490    /// Designed to run inside `spawn_blocking` (takes no `&self`).
491    fn process_batch_blocking(
492        texts: &[String],
493        session: &mut Session,
494        processor: &BatchProcessor,
495        normalize: bool,
496    ) -> Result<Vec<Vec<f32>>> {
497        // Tokenize
498        let prepared = processor.tokenize_batch(texts)?;
499        let batch_size = prepared.batch_size;
500        let seq_len = prepared.seq_len;
501
502        // Keep a copy of attention_mask for mean_pooling (consumed by Tensor below)
503        let attention_mask_flat = prepared.attention_mask.clone();
504
505        // Build ORT tensors — from_array requires (shape, Vec<T>) in ort rc.12
506        let input_ids_tensor =
507            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
508                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
509        let attention_mask_tensor =
510            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
511                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
512        let token_type_ids_tensor =
513            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
514                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
515
516        // Run ONNX session
517        let outputs = session
518            .run(inputs![
519                "input_ids" => input_ids_tensor,
520                "attention_mask" => attention_mask_tensor,
521                "token_type_ids" => token_type_ids_tensor
522            ])
523            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
524
525        // Extract last_hidden_state: shape [batch, seq_len, hidden_size]
526        // ort rc.12: try_extract_tensor returns (&Shape, &[T])
527        // Shape derefs to [i64], so index directly.
528        let (ort_shape, lhs_slice) = outputs[0]
529            .try_extract_tensor::<f32>()
530            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
531
532        if ort_shape.len() != 3 {
533            return Err(InferenceError::InferenceError(format!(
534                "Expected 3D last_hidden_state, got {} dims",
535                ort_shape.len()
536            )));
537        }
538        let hidden_size = ort_shape[2] as usize;
539
540        // Apply mean pooling using the saved attention mask copy
541        let mut embeddings = mean_pooling(
542            lhs_slice,
543            batch_size,
544            seq_len,
545            hidden_size,
546            &attention_mask_flat,
547        );
548
549        // L2 normalize if configured
550        if normalize {
551            normalize_embeddings(&mut embeddings);
552        }
553
554        debug!(
555            "Generated {} embeddings of dimension {}",
556            embeddings.len(),
557            embeddings.first().map(|e| e.len()).unwrap_or(0)
558        );
559
560        Ok(embeddings)
561    }
562
563    /// Estimate the time to embed a batch of texts (in milliseconds).
564    pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
565        // Rough estimation based on model speed and text length (CPU path)
566        let tokens_per_text =
567            (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
568        let total_tokens = tokens_per_text * text_count as f64;
569        let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
570        (total_tokens / tokens_per_second) * 1000.0
571    }
572}
573
574impl std::fmt::Debug for EmbeddingEngine {
575    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
576        f.debug_struct("EmbeddingEngine")
577            .field("model", &self.config.model)
578            .field("dimension", &self.dimension)
579            .field("max_batch_size", &self.config.max_batch_size)
580            .field("session_pool_size", &self.sessions.len())
581            .finish()
582    }
583}
584
585/// Builder for creating an EmbeddingEngine with fluent API.
586pub struct EmbeddingEngineBuilder {
587    config: ModelConfig,
588}
589
590impl EmbeddingEngineBuilder {
591    /// Create a new builder with default configuration.
592    pub fn new() -> Self {
593        Self {
594            config: ModelConfig::default(),
595        }
596    }
597
598    /// Set the embedding model to use.
599    pub fn model(mut self, model: EmbeddingModel) -> Self {
600        self.config.model = model;
601        self
602    }
603
604    /// Set the cache directory for model files.
605    pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
606        self.config.cache_dir = Some(dir.into());
607        self
608    }
609
610    /// Set the maximum batch size.
611    pub fn max_batch_size(mut self, size: usize) -> Self {
612        self.config.max_batch_size = size;
613        self
614    }
615
616    /// Enable GPU acceleration (reserved for future use; ORT selects the execution provider).
617    pub fn use_gpu(mut self, enable: bool) -> Self {
618        self.config.use_gpu = enable;
619        self
620    }
621
622    /// Set the number of intra-op CPU threads for ORT inference.
623    pub fn num_threads(mut self, threads: usize) -> Self {
624        self.config.num_threads = Some(threads);
625        self
626    }
627
628    /// Set the number of parallel ONNX sessions in the pool.
629    pub fn session_pool_size(mut self, size: usize) -> Self {
630        self.config.session_pool_size = size.max(1);
631        self
632    }
633
634    /// Build the embedding engine.
635    pub async fn build(self) -> Result<EmbeddingEngine> {
636        EmbeddingEngine::new(self.config).await
637    }
638}
639
640impl Default for EmbeddingEngineBuilder {
641    fn default() -> Self {
642        Self::new()
643    }
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649
650    #[test]
651    fn test_estimate_time() {
652        let config = ModelConfig::new(EmbeddingModel::MiniLM);
653        let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
654        assert!(tokens_per_second > 0.0);
655    }
656
657    #[test]
658    fn test_builder() {
659        let builder = EmbeddingEngineBuilder::new()
660            .model(EmbeddingModel::BgeSmall)
661            .max_batch_size(64)
662            .use_gpu(false);
663
664        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
665        assert_eq!(builder.config.max_batch_size, 64);
666        assert!(!builder.config.use_gpu);
667    }
668
669    // ── model_cache_dir ──────────────────────────────────────────────────────
670
671    /// Ensure `model_cache_dir` respects `HF_HOME` when set.
672    ///
673    /// Uses a process-level mutex because `std::env::set_var` is not thread-safe.
674    #[test]
675    fn test_model_cache_dir_with_hf_home() {
676        use std::sync::Mutex;
677        static ENV_LOCK: Mutex<()> = Mutex::new(());
678        let _guard = ENV_LOCK.lock().unwrap();
679
680        let tmp = std::env::temp_dir().join("dakera_test_hf_home");
681        std::env::set_var("HF_HOME", &tmp);
682        let result = EmbeddingEngine::model_cache_dir("org/my-model");
683        std::env::remove_var("HF_HOME");
684
685        let path = result.unwrap();
686        assert!(
687            path.starts_with(&tmp),
688            "expected path under {tmp:?}, got {path:?}"
689        );
690        assert!(
691            path.to_str().unwrap().contains("org--my-model"),
692            "model_id separator not applied: {path:?}"
693        );
694    }
695
696    #[test]
697    fn test_model_cache_dir_contains_dakera_subdir() {
698        let path =
699            EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
700        let s = path.to_str().unwrap();
701        assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
702        assert!(
703            s.contains("sentence-transformers--all-MiniLM-L6-v2"),
704            "expected transformed model id in path: {s}"
705        );
706    }
707
708    #[test]
709    fn test_model_cache_dir_creates_directory() {
710        let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
711        assert!(dir.exists(), "model_cache_dir should create the directory");
712    }
713
714    // ── download_hf_file (cached early-return path) ──────────────────────────
715
716    #[test]
717    fn test_download_hf_file_returns_path_when_already_cached() {
718        let tmp = std::env::temp_dir().join("dakera_test_cached_file");
719        std::fs::create_dir_all(&tmp).unwrap();
720        let file_path = tmp.join("config.json");
721        std::fs::write(&file_path, b"{}").unwrap();
722
723        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
724        assert!(result.is_ok());
725        assert_eq!(result.unwrap(), file_path);
726    }
727
728    #[test]
729    fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
730        let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
731        let onnx_dir = tmp.join("onnx");
732        std::fs::create_dir_all(&onnx_dir).unwrap();
733        let onnx_path = onnx_dir.join("model_quantized.onnx");
734        std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
735
736        // Filename includes the subdirectory path
737        let result = EmbeddingEngine::download_hf_file(
738            "Xenova/all-MiniLM-L6-v2",
739            "onnx/model_quantized.onnx",
740            &tmp,
741        );
742        assert!(result.is_ok());
743        assert_eq!(result.unwrap(), onnx_path);
744    }
745
746    // ── EmbeddingEngineBuilder ───────────────────────────────────────────────
747
748    #[test]
749    fn test_builder_default_impl() {
750        let b1 = EmbeddingEngineBuilder::new();
751        let b2 = EmbeddingEngineBuilder::default();
752        assert_eq!(b1.config.model, b2.config.model);
753        assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
754    }
755
756    #[test]
757    fn test_builder_model_field() {
758        let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
759        assert_eq!(builder.config.model, EmbeddingModel::E5Small);
760    }
761
762    #[test]
763    fn test_builder_cache_dir() {
764        let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
765        assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
766    }
767
768    #[test]
769    fn test_builder_max_batch_size() {
770        let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
771        assert_eq!(builder.config.max_batch_size, 128);
772    }
773
774    #[test]
775    fn test_builder_use_gpu_true() {
776        let builder = EmbeddingEngineBuilder::new().use_gpu(true);
777        assert!(builder.config.use_gpu);
778    }
779
780    #[test]
781    fn test_builder_use_gpu_false() {
782        let builder = EmbeddingEngineBuilder::new().use_gpu(false);
783        assert!(!builder.config.use_gpu);
784    }
785
786    #[test]
787    fn test_builder_num_threads() {
788        let builder = EmbeddingEngineBuilder::new().num_threads(4);
789        assert_eq!(builder.config.num_threads, Some(4));
790    }
791
792    #[test]
793    fn test_builder_chain_all_fields() {
794        let builder = EmbeddingEngineBuilder::new()
795            .model(EmbeddingModel::BgeSmall)
796            .cache_dir("/cache")
797            .max_batch_size(16)
798            .use_gpu(false)
799            .num_threads(2);
800
801        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
802        assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
803        assert_eq!(builder.config.max_batch_size, 16);
804        assert!(!builder.config.use_gpu);
805        assert_eq!(builder.config.num_threads, Some(2));
806    }
807
808    // ── estimate_time_ms ─────────────────────────────────────────────────────
809
810    #[test]
811    fn test_estimate_time_zero_count() {
812        let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
813        let estimate = (0.0 / tps) * 1000.0;
814        assert_eq!(estimate, 0.0);
815    }
816
817    #[test]
818    fn test_estimate_time_formula_cpu() {
819        // texts=10, avg_len=100 → tokens_per_text = min(25, 256) = 25
820        // total_tokens = 250; tps = 5000; time = (250/5000)*1000 = 50ms
821        let model = EmbeddingModel::MiniLM;
822        let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
823        let total_tokens = tokens_per_text * 10.0;
824        let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
825        assert!(
826            (estimate - 50.0).abs() < 1e-6,
827            "expected 50.0ms, got {estimate}"
828        );
829    }
830
831    #[test]
832    fn test_estimate_time_capped_at_max_seq_length() {
833        let model = EmbeddingModel::MiniLM;
834        let avg_len = 100_000;
835        let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
836        assert_eq!(tokens_per_text, 256.0);
837    }
838
839    // ── ModelConfig API ───────────────────────────────────────────────────────
840
841    #[test]
842    fn test_model_config_new() {
843        let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
844        assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
845        assert_eq!(cfg.max_batch_size, 32);
846        assert!(!cfg.use_gpu);
847        assert!(cfg.cache_dir.is_none());
848        assert!(cfg.num_threads.is_none());
849    }
850
851    #[test]
852    fn test_model_config_default() {
853        let cfg = ModelConfig::default();
854        assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
855        assert_eq!(cfg.max_batch_size, 32);
856        assert!(!cfg.use_gpu);
857    }
858
859    #[test]
860    fn test_model_config_with_cache_dir() {
861        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
862        assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
863    }
864
865    #[test]
866    fn test_model_config_with_max_batch_size() {
867        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
868        assert_eq!(cfg.max_batch_size, 64);
869    }
870
871    #[test]
872    fn test_model_config_with_gpu() {
873        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
874        assert!(cfg.use_gpu);
875    }
876
877    #[test]
878    fn test_model_config_with_num_threads() {
879        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
880        assert_eq!(cfg.num_threads, Some(8));
881    }
882
883    #[test]
884    fn test_model_config_chained_builder() {
885        let cfg = ModelConfig::new(EmbeddingModel::E5Small)
886            .with_cache_dir("/cache")
887            .with_max_batch_size(16)
888            .with_gpu(false)
889            .with_num_threads(4);
890        assert_eq!(cfg.model, EmbeddingModel::E5Small);
891        assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
892        assert_eq!(cfg.max_batch_size, 16);
893        assert!(!cfg.use_gpu);
894        assert_eq!(cfg.num_threads, Some(4));
895    }
896
897    // ── model_cache_dir edge cases ────────────────────────────────────────────
898
899    /// Test `model_cache_dir` when HOME is not set — should fall back to /tmp.
900    #[test]
901    fn test_model_cache_dir_no_home_fallback() {
902        use std::sync::Mutex;
903        static ENV_LOCK: Mutex<()> = Mutex::new(());
904        let _guard = ENV_LOCK.lock().unwrap();
905
906        // Remove HOME and HF_HOME so we hit the /tmp fallback
907        let saved_home = std::env::var("HOME").ok();
908        let saved_hf = std::env::var("HF_HOME").ok();
909        unsafe {
910            std::env::remove_var("HOME");
911            std::env::remove_var("HF_HOME");
912        }
913
914        let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
915
916        // Restore env
917        if let Some(h) = saved_home {
918            unsafe { std::env::set_var("HOME", h) };
919        }
920        if let Some(h) = saved_hf {
921            unsafe { std::env::set_var("HF_HOME", h) };
922        }
923
924        let path = result.unwrap();
925        // Should be under /tmp since HOME was unset
926        assert!(
927            path.starts_with("/tmp"),
928            "expected path under /tmp, got {path:?}"
929        );
930    }
931
932    #[test]
933    fn test_model_cache_dir_deep_model_id() {
934        let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
935        let s = path.to_str().unwrap();
936        // All slashes replaced with double-dash
937        assert!(
938            s.contains("org--sub--model-name-with-dashes"),
939            "expected transformed path, got: {s}"
940        );
941    }
942
943    #[test]
944    fn test_model_cache_dir_minilm_model_id() {
945        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
946        let s = path.to_str().unwrap();
947        assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
948    }
949
950    #[test]
951    fn test_model_cache_dir_bge_model_id() {
952        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
953        let s = path.to_str().unwrap();
954        assert!(s.contains("BAAI--bge-small-en-v1.5"));
955    }
956
957    #[test]
958    fn test_model_cache_dir_e5_model_id() {
959        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
960        let s = path.to_str().unwrap();
961        assert!(s.contains("intfloat--e5-small-v2"));
962    }
963
964    // ── download_hf_file additional cache-hit variations ─────────────────────
965
966    #[test]
967    fn test_download_hf_file_pytorch_bin_cached() {
968        let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
969        std::fs::create_dir_all(&tmp).unwrap();
970        let model_path = tmp.join("pytorch_model.bin");
971        std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
972
973        let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
974        assert!(result.is_ok());
975        assert_eq!(result.unwrap(), model_path);
976    }
977
978    #[test]
979    fn test_download_hf_file_tokenizer_cached() {
980        let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
981        std::fs::create_dir_all(&tmp).unwrap();
982        let tok_path = tmp.join("tokenizer.json");
983        std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
984
985        let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
986        assert!(result.is_ok());
987        assert_eq!(result.unwrap(), tok_path);
988    }
989
990    #[test]
991    fn test_download_hf_file_config_json_cached() {
992        let tmp = std::env::temp_dir().join("dakera_test_config_cached");
993        std::fs::create_dir_all(&tmp).unwrap();
994        let cfg_path = tmp.join("config.json");
995        std::fs::write(&cfg_path, b"{}").unwrap();
996
997        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
998        assert!(result.is_ok());
999        assert_eq!(result.unwrap(), cfg_path);
1000    }
1001
1002    // ── EmbeddingEngine::new() failure path via fake local cache ─────────────
1003
1004    /// Tests the code path through `download_model_files` (local Dakera cache hit)
1005    /// and into `new()` — which then fails trying to load the tokenizer from a
1006    /// fake file. No network access required; fake files are pre-seeded.
1007    #[tokio::test]
1008    #[allow(clippy::await_holding_lock)]
1009    async fn test_new_fails_with_invalid_tokenizer_json() {
1010        use std::sync::Mutex;
1011        static ENV_LOCK: Mutex<()> = Mutex::new(());
1012        let _guard = ENV_LOCK.lock().unwrap();
1013
1014        // Set up a fake Dakera model cache so download_model_files finds our files
1015        let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1016        let model_dir = tmp
1017            .join("dakera")
1018            .join("sentence-transformers--all-MiniLM-L6-v2");
1019        std::fs::create_dir_all(&model_dir).unwrap();
1020        // Valid-looking model weights placeholder (candle will fail on this, which is fine)
1021        std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1022        // Invalid tokenizer.json — will cause TokenizationError in new()
1023        std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1024        std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1025
1026        unsafe { std::env::set_var("HF_HOME", &tmp) };
1027
1028        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1029        let result = EmbeddingEngine::new(config).await;
1030
1031        unsafe { std::env::remove_var("HF_HOME") };
1032
1033        // Must fail — tokenizer.json is invalid JSON
1034        assert!(
1035            result.is_err(),
1036            "expected Err from new() with invalid tokenizer, got Ok"
1037        );
1038    }
1039
1040    // ── EmbeddingEngineBuilder additional coverage ────────────────────────────
1041
1042    #[test]
1043    fn test_builder_with_all_models() {
1044        for model in [
1045            EmbeddingModel::MiniLM,
1046            EmbeddingModel::BgeSmall,
1047            EmbeddingModel::E5Small,
1048        ] {
1049            let builder = EmbeddingEngineBuilder::new().model(model);
1050            assert_eq!(builder.config.model, model);
1051        }
1052    }
1053
1054    #[test]
1055    fn test_builder_max_batch_size_one() {
1056        let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1057        assert_eq!(builder.config.max_batch_size, 1);
1058    }
1059
1060    #[test]
1061    fn test_builder_num_threads_zero() {
1062        let builder = EmbeddingEngineBuilder::new().num_threads(0);
1063        assert_eq!(builder.config.num_threads, Some(0));
1064    }
1065
1066    // ── EmbeddingEngine::new() / getters when model is cached (best-effort) ──
1067
1068    /// If the embedding model is already cached on this machine, exercise the
1069    /// full `new()` path and test getters. On machines without a cached model
1070    /// the test passes silently — it is intentionally non-gating.
1071    #[tokio::test]
1072    async fn test_engine_getters_when_model_cached() {
1073        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1074        match EmbeddingEngine::new(config).await {
1075            Ok(engine) => {
1076                assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1077                assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1078                // Device should be CPU in test environments (device() removed in CE-3 ONNX migration)
1079                // Debug impl should not panic
1080                let _ = format!("{:?}", engine);
1081                // estimate_time_ms should return a non-negative value
1082                let ms = engine.estimate_time_ms(10, 50);
1083                assert!(ms >= 0.0);
1084            }
1085            Err(_) => {
1086                // Model not in cache — skip; CI runner may or may not have it
1087            }
1088        }
1089    }
1090
1091    /// When model is cached: embed an empty batch must return immediately with
1092    /// no embeddings (the `texts.is_empty()` fast-path in embed_batch_internal).
1093    #[tokio::test]
1094    async fn test_engine_embed_empty_batch_when_cached() {
1095        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1096        if let Ok(engine) = EmbeddingEngine::new(config).await {
1097            let result = engine.embed_raw(&[]).await;
1098            assert!(result.is_ok());
1099            assert!(result.unwrap().is_empty());
1100        }
1101    }
1102
1103    // ── Session pool (DAK-5547) ──────────────────────────────────────────────
1104
1105    #[test]
1106    fn test_session_pool_default_is_4() {
1107        // Default pool size is 4 (DAK-5746: pool=1 regressed LME ingest ~4×; OOM causes fixed).
1108        // DAKERA_ONNX_POOL_SIZE env var still allows override.
1109        let config = ModelConfig::default();
1110        let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1111            .ok()
1112            .and_then(|v| v.parse::<usize>().ok())
1113            .filter(|&n| n >= 1)
1114            .unwrap_or(4);
1115        assert_eq!(config.session_pool_size, expected);
1116    }
1117
1118    #[test]
1119    fn test_session_pool_size_builder_roundtrip() {
1120        let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1121        assert_eq!(builder.config.session_pool_size, 8);
1122    }
1123
1124    #[test]
1125    fn test_session_pool_size_min_enforced() {
1126        let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1127        assert_eq!(
1128            builder.config.session_pool_size, 1,
1129            "pool size 0 must clamp to 1"
1130        );
1131    }
1132
1133    #[test]
1134    fn test_model_config_with_session_pool_size() {
1135        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1136        assert_eq!(cfg.session_pool_size, 2);
1137    }
1138
1139    /// When model is cached: verify the session pool has the expected size.
1140    #[tokio::test]
1141    async fn test_engine_pool_size_matches_config_when_cached() {
1142        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1143        if let Ok(engine) = EmbeddingEngine::new(config).await {
1144            assert_eq!(
1145                engine.pool_size(),
1146                2,
1147                "engine should hold exactly 2 sessions"
1148            );
1149        }
1150    }
1151
1152    // ── next_session round-robin ──────────────────────────────────────────────
1153
1154    /// Round-robin counter distributes batches across pool slots without panic.
1155    #[test]
1156    fn test_round_robin_index_stays_in_bounds() {
1157        let pool_len = 4_usize;
1158        let counter = AtomicUsize::new(0);
1159        for expected_idx in 0..100_usize {
1160            let start = counter.fetch_add(1, Ordering::Relaxed);
1161            let slot = start % pool_len;
1162            assert!(slot < pool_len);
1163            assert_eq!(slot, expected_idx % pool_len);
1164        }
1165    }
1166
1167    /// Pool size 1 degrades to single-session behavior without panicking.
1168    #[test]
1169    fn test_round_robin_pool_size_one() {
1170        let pool_len = 1_usize;
1171        let counter = AtomicUsize::new(0);
1172        for _ in 0..50 {
1173            let start = counter.fetch_add(1, Ordering::Relaxed);
1174            assert_eq!(start % pool_len, 0);
1175        }
1176    }
1177
1178    /// When model is cached: empty batch short-circuits before touching the pool.
1179    #[tokio::test]
1180    async fn test_embed_empty_does_not_advance_pool_counter() {
1181        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1182        if let Ok(engine) = EmbeddingEngine::new(config).await {
1183            let result = engine.embed_raw(&[]).await;
1184            assert!(result.is_ok());
1185            assert!(result.unwrap().is_empty());
1186            // Empty batch returns before fetch_add — counter stays at 0.
1187            assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1188        }
1189    }
1190}