Skip to main content

inference/
engine.rs

1//! Core embedding engine for generating vector embeddings from text.
2//!
3//! The `EmbeddingEngine` provides a high-level interface for:
4//! - Loading ONNX INT8 embedding models from HuggingFace Hub
5//! - Generating embeddings for single texts or batches
6//! - Automatic batching and parallel processing via ONNX Runtime
7//!
8//! # Example
9//!
10//! ```no_run
11//! use inference::{EmbeddingEngine, ModelConfig, EmbeddingModel};
12//!
13//! #[tokio::main]
14//! async fn main() {
15//!     let config = ModelConfig::new(EmbeddingModel::MiniLM);
16//!     let engine = EmbeddingEngine::new(config).await.unwrap();
17//!
18//!     // Embed a single query
19//!     let embedding = engine.embed_query("What is machine learning?").await.unwrap();
20//!     println!("Embedding dimension: {}", embedding.len());
21//!
22//!     // Embed multiple documents
23//!     let docs = vec![
24//!         "Machine learning is a subset of AI.".to_string(),
25//!         "Deep learning uses neural networks.".to_string(),
26//!     ];
27//!     let embeddings = engine.embed_documents(&docs).await.unwrap();
28//!     println!("Generated {} embeddings", embeddings.len());
29//! }
30//! ```
31
32use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49/// The main embedding engine for generating vector embeddings.
50///
51/// This struct is thread-safe and can be shared across async tasks.
52/// ORT sessions are mutex-guarded (run() takes &mut self) and held in a
53/// pool so concurrent callers can embed without head-of-line blocking.
54/// CPU-heavy inference is offloaded via `tokio::task::spawn_blocking`.
55pub struct EmbeddingEngine {
56    /// Pool of ONNX Runtime sessions — each guarded independently.
57    /// Concurrent callers round-robin across sessions via `next_session`,
58    /// eliminating Mutex head-of-line blocking for batch embedding workloads.
59    sessions: Vec<Arc<Mutex<Session>>>,
60    /// Round-robin counter: atomically incremented per batch, wraps via modulo.
61    next_session: AtomicUsize,
62    /// Batch processor for tokenization (Arc-wrapped for spawn_blocking)
63    processor: Arc<BatchProcessor>,
64    /// Model configuration
65    config: ModelConfig,
66    /// Embedding dimension
67    dimension: usize,
68    /// Whether CUDA execution provider is active (resolved once at construction).
69    /// When true, each sub-batch acquires the global GPU_INFERENCE_SEMAPHORE before
70    /// spawning so at most one CUDA forward pass runs at a time, preventing VRAM exhaustion.
71    use_gpu: bool,
72}
73
74impl EmbeddingEngine {
75    /// Create a new embedding engine with the given configuration.
76    ///
77    /// Downloads the FP32 ONNX model (`model.onnx`) when CUDA EP is enabled, or the
78    /// INT8 quantized model (`model_quantized.onnx`) otherwise.
79    #[instrument(skip_all, fields(model = %config.model))]
80    pub async fn new(config: ModelConfig) -> Result<Self> {
81        // Resolve GPU mode before downloading model files — GPU requires the FP32 model.
82        // DAKERA_USE_GPU=1 overrides config.use_gpu (env var is the production knob).
83        let use_gpu = std::env::var("DAKERA_USE_GPU")
84            .map(|v| v == "1")
85            .unwrap_or(config.use_gpu);
86        if use_gpu {
87            info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
88        }
89
90        info!(
91            "Initializing ONNX embedding engine with model: {}",
92            config.model
93        );
94
95        // Download tokenizer and ONNX model files (FP32 for GPU, INT8 for CPU)
96        let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
97
98        // Load tokenizer
99        info!("Loading tokenizer from {:?}", tokenizer_path);
100        let tokenizer = Tokenizer::from_file(&tokenizer_path)
101            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
102
103        // Build ONNX session pool — N independent sessions to serve concurrent callers.
104        // Each session has its own ORT context so pool members never block each other.
105        info!("Loading ONNX model from {:?}", onnx_path);
106        let num_threads = config.num_threads.unwrap_or(4);
107        let pool_size = config.session_pool_size.max(1);
108        let onnx_path_clone = onnx_path.clone();
109
110        let sessions: Vec<Arc<Mutex<Session>>> =
111            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
112                (0..pool_size)
113                    .map(|_| {
114                        let builder = Session::builder()
115                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
116                            .with_optimization_level(GraphOptimizationLevel::Level3)
117                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
118                            .with_intra_threads(num_threads)
119                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
120
121                        let mut builder = if use_gpu {
122                            builder
123                                .with_execution_providers(
124                                    [CUDAExecutionProvider::default().build()],
125                                )
126                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
127                        } else {
128                            builder
129                        };
130
131                        let s = builder
132                            .commit_from_file(&onnx_path_clone)
133                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
134                        Ok(Arc::new(Mutex::new(s)))
135                    })
136                    .collect()
137            })
138            .await
139            .map_err(|e| {
140                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
141            })??;
142
143        let dimension = config.model.dimension();
144        let processor = Arc::new(BatchProcessor::new(
145            tokenizer,
146            config.model,
147            config.max_batch_size,
148        ));
149
150        info!(
151            "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
152            config.model, dimension, num_threads, pool_size
153        );
154
155        Ok(Self {
156            sessions,
157            next_session: AtomicUsize::new(0),
158            processor,
159            config,
160            dimension,
161            use_gpu,
162        })
163    }
164
165    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
166    ///
167    /// - `tokenizer.json` — from the original model repo (sentence-transformers, BAAI, intfloat)
168    /// - GPU: `onnx/model.onnx` (FP32) — no Memcpy round-trips, CUDA EP runs entirely on-device
169    /// - CPU: `onnx/model_quantized.onnx` (INT8) — fastest for CPU-only deployments
170    #[instrument(skip_all, fields(model = %config.model))]
171    async fn download_model_files(
172        config: &ModelConfig,
173        use_gpu: bool,
174    ) -> Result<(PathBuf, PathBuf)> {
175        let model_id = config.model.model_id();
176        let onnx_repo_id = config.model.onnx_repo_id();
177        let onnx_filename = if use_gpu {
178            config.model.onnx_filename_gpu()
179        } else {
180            config.model.onnx_filename()
181        };
182
183        info!(
184            "Resolving model files: tokenizer={}, onnx={}@{}",
185            model_id, onnx_filename, onnx_repo_id
186        );
187
188        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
189        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
190
191        // ONNX sub-directory mirrors the path within the repo (e.g. "onnx/")
192        let onnx_subdir = onnx_cache_dir.join("onnx");
193        std::fs::create_dir_all(&onnx_subdir)?;
194
195        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
196        // onnx_filename is e.g. "onnx/model_quantized.onnx" or "onnx/model.onnx" — use the basename
197        let onnx_basename = Path::new(onnx_filename)
198            .file_name()
199            .and_then(|s| s.to_str())
200            .unwrap_or("model_quantized.onnx");
201        let local_onnx = onnx_subdir.join(onnx_basename);
202
203        // Download missing files in a blocking thread
204        let tokenizer_needs_download = !local_tokenizer.exists();
205
206        // GPU FP32 model (BGE-large) is 1.27 GB. The old 500 MB download limit silently
207        // truncated it, leaving a corrupt ONNX file that ORT rejects with "Protobuf parsing
208        // failed" (DAK-5976). If the cached GPU model is ≤500 MB, it's the truncated artifact
209        // — delete it so download_hf_file fetches the complete file.
210        if use_gpu && local_onnx.exists() {
211            let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
212            if cached_size <= 500_000_000 {
213                warn!(
214                    "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated by old \
215                     download limit. Deleting for complete re-download.",
216                    local_onnx, cached_size
217                );
218                let _ = std::fs::remove_file(&local_onnx);
219            }
220        }
221        let onnx_needs_download = !local_onnx.exists();
222
223        if tokenizer_needs_download || onnx_needs_download {
224            let model_id_owned = model_id.to_string();
225            let onnx_repo_id_owned = onnx_repo_id.to_string();
226            let onnx_filename_owned = onnx_filename.to_string();
227            let tokenizer_cache = tokenizer_cache_dir.clone();
228            let onnx_cache = onnx_cache_dir.clone();
229
230            tokio::task::spawn_blocking(move || {
231                if !tokenizer_cache.join("tokenizer.json").exists() {
232                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
233                        .map_err(|e| {
234                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
235                        })?;
236                }
237                if !onnx_cache.join(&onnx_filename_owned).exists() {
238                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
239                        .map_err(|e| {
240                            InferenceError::HubError(format!(
241                                "Failed to download ONNX model: {}",
242                                e
243                            ))
244                        })?;
245                }
246                Ok::<_, InferenceError>(())
247            })
248            .await
249            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
250        } else {
251            info!("All model files found in local cache");
252        }
253
254        // Re-derive paths (cache dir / onnx / basename)
255        let final_onnx = onnx_cache_dir.join(onnx_filename);
256
257        info!(
258            "Model files ready: tokenizer={:?}, onnx={:?}",
259            local_tokenizer, final_onnx
260        );
261        Ok((local_tokenizer, final_onnx))
262    }
263
264    /// Get or create the local model cache directory.
265    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
266        let base = std::env::var("HF_HOME")
267            .map(PathBuf::from)
268            .unwrap_or_else(|_| {
269                let home = std::env::var("HOME").unwrap_or_else(|_| {
270                    warn!("HOME environment variable not set, using /tmp for model cache");
271                    "/tmp".to_string()
272                });
273                PathBuf::from(home).join(".cache").join("huggingface")
274            });
275        let dir = base.join("dakera").join(model_id.replace('/', "--"));
276        std::fs::create_dir_all(&dir)?;
277        Ok(dir)
278    }
279
280    /// Download a single file from HuggingFace using ureq (sync, for spawn_blocking).
281    ///
282    /// Handles relative Location headers that ureq 2.x cannot resolve automatically.
283    ///
284    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
285    pub fn download_hf_file_pub(
286        model_id: &str,
287        filename: &str,
288        cache_dir: &Path,
289    ) -> std::result::Result<PathBuf, String> {
290        Self::download_hf_file(model_id, filename, cache_dir)
291    }
292
293    fn download_hf_file(
294        model_id: &str,
295        filename: &str,
296        cache_dir: &Path,
297    ) -> std::result::Result<PathBuf, String> {
298        // The file may be nested (e.g. "onnx/model_quantized.onnx")
299        let file_path = cache_dir.join(filename);
300        if file_path.exists() {
301            info!("Cached: {}/{}", model_id, filename);
302            return Ok(file_path);
303        }
304
305        // Ensure parent directory exists (for "onnx/model_quantized.onnx")
306        if let Some(parent) = file_path.parent() {
307            std::fs::create_dir_all(parent)
308                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
309        }
310
311        let url = format!(
312            "https://huggingface.co/{}/resolve/main/{}",
313            model_id, filename
314        );
315        info!("Downloading: {}", url);
316
317        // Read HuggingFace token from env (required for gated models).
318        let hf_token = std::env::var("HF_TOKEN")
319            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
320            .ok();
321        if hf_token.is_some() {
322            info!("Using HuggingFace auth token for download");
323        }
324
325        // Disable automatic redirects so we can resolve relative Location headers ourselves.
326        let agent = ureq::AgentBuilder::new()
327            .redirects(0)
328            .timeout(std::time::Duration::from_secs(300))
329            .build();
330
331        let mut current_url = url.clone();
332        let mut redirects = 0;
333        let max_redirects = 10;
334
335        let response = loop {
336            let mut req = agent.get(&current_url);
337            if let Some(ref token) = hf_token {
338                req = req.set("Authorization", &format!("Bearer {}", token));
339            }
340            let resp = req.call();
341
342            let r = match resp {
343                Ok(r) => r,
344                Err(ureq::Error::Status(_status, r)) => r,
345                Err(e) => return Err(format!("{}: {}", filename, e)),
346            };
347
348            let status = r.status();
349            if (200..300).contains(&status) {
350                break r;
351            } else if (300..400).contains(&status) {
352                redirects += 1;
353                if redirects > max_redirects {
354                    return Err(format!("{}: too many redirects", filename));
355                }
356                let location = r
357                    .header("location")
358                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
359                    .to_string();
360
361                // Resolve relative redirects against the current URL's origin
362                current_url = if location.starts_with('/') {
363                    let parsed = url::Url::parse(&current_url)
364                        .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
365                    let host = parsed.host_str().ok_or_else(|| {
366                        format!("{}: redirect URL missing host: {}", filename, current_url)
367                    })?;
368                    format!("{}://{}{}", parsed.scheme(), host, location)
369                } else {
370                    location
371                };
372                info!("Redirect {} → {}", redirects, current_url);
373            } else {
374                return Err(format!("{}: HTTP {}", filename, status));
375            }
376        };
377
378        // Capture expected file size before consuming the response body.
379        // HuggingFace CDN reports LFS blob sizes via x-linked-size; standard
380        // HTTP servers use content-length. Either may be absent (chunked transfer).
381        let expected_bytes: Option<u64> = response
382            .header("x-linked-size")
383            .or_else(|| response.header("content-length"))
384            .and_then(|v| v.parse::<u64>().ok());
385
386        // 2 GiB limit — BGE-large FP32 is 1.27 GB; the old 500 MB limit silently
387        // truncated it, producing a corrupt ONNX that ORT rejects with
388        // "Protobuf parsing failed" (DAK-5976).
389        let mut bytes = Vec::new();
390        response
391            .into_reader()
392            .take(2_147_483_648)
393            .read_to_end(&mut bytes)
394            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
395
396        // Fail fast if the received payload is shorter than the declared size.
397        // This prevents writing a corrupt file that would pass the cache-hit check
398        // on the next boot but fail when ORT tries to parse it.
399        if let Some(expected) = expected_bytes {
400            let actual = bytes.len() as u64;
401            if actual < expected {
402                return Err(format!(
403                    "{}: download incomplete — received {} of {} bytes. \
404                     File may exceed 2 GiB or the connection was interrupted.",
405                    filename, actual, expected
406                ));
407            }
408        }
409
410        std::fs::write(&file_path, &bytes)
411            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
412
413        info!("Downloaded {} ({} bytes)", filename, bytes.len());
414        Ok(file_path)
415    }
416
417    /// Get the embedding dimension for the loaded model.
418    pub fn dimension(&self) -> usize {
419        self.dimension
420    }
421
422    /// Get the model being used.
423    pub fn model(&self) -> EmbeddingModel {
424        self.config.model
425    }
426
427    /// Get the number of parallel ONNX sessions in the pool.
428    pub fn pool_size(&self) -> usize {
429        self.sessions.len()
430    }
431
432    /// Embed a single query text.
433    ///
434    /// For models like E5, this automatically applies the query prefix.
435    #[instrument(skip(self, text), fields(text_len = text.len()))]
436    pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
437        let texts = vec![text.to_string()];
438        let prepared = self.processor.prepare_texts(&texts, true);
439        let embeddings = self.embed_batch_internal(&prepared).await?;
440        embeddings.into_iter().next().ok_or_else(|| {
441            InferenceError::InferenceError("No embedding returned for query".to_string())
442        })
443    }
444
445    /// Embed multiple query texts.
446    ///
447    /// For models like E5, this automatically applies the query prefix.
448    #[instrument(skip(self, texts), fields(count = texts.len()))]
449    pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
450        let prepared = self.processor.prepare_texts(texts, true);
451        self.embed_batch_internal(&prepared).await
452    }
453
454    /// Embed a single document/passage.
455    ///
456    /// For models like E5, this automatically applies the document prefix.
457    #[instrument(skip(self, text), fields(text_len = text.len()))]
458    pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
459        let texts = vec![text.to_string()];
460        let prepared = self.processor.prepare_texts(&texts, false);
461        let embeddings = self.embed_batch_internal(&prepared).await?;
462        embeddings.into_iter().next().ok_or_else(|| {
463            InferenceError::InferenceError("No embedding returned for document".to_string())
464        })
465    }
466
467    /// Embed multiple documents/passages.
468    ///
469    /// For models like E5, this automatically applies the document prefix.
470    #[instrument(skip(self, texts), fields(count = texts.len()))]
471    pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
472        let prepared = self.processor.prepare_texts(texts, false);
473        self.embed_batch_internal(&prepared).await
474    }
475
476    /// Embed texts without any prefix (raw embedding).
477    #[instrument(skip(self, texts), fields(count = texts.len()))]
478    pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
479        self.embed_batch_internal(texts).await
480    }
481
482    /// Internal batch embedding implementation.
483    ///
484    /// Splits `texts` into sub-batches (≤ `batch_size` each) and distributes them
485    /// across the session pool via round-robin.  On GPU BFCArena OOM the batch size
486    /// is halved and the full input is retried (up to 3 halvings; DAK-6002).  This
487    /// handles LME-style long documents whose attention tensor (batch × heads × seq²)
488    /// exceeds the VRAM headroom at the configured batch size but fits at batch/2.
489    ///
490    /// When GPU mode is active (`use_gpu=true`), each sub-batch acquires the global
491    /// `GPU_INFERENCE_SEMAPHORE` (1 permit) before spawning.  This serializes CUDA
492    /// forward passes so at most one batch occupies VRAM at a time.  The permit is
493    /// moved into the blocking task and dropped on completion, releasing the slot.
494    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
495        if texts.is_empty() {
496            return Ok(vec![]);
497        }
498
499        let pool_len = self.sessions.len();
500        let normalize = self.config.model.normalize_embeddings();
501        // Round-robin starting index: each concurrent caller gets a different slot so
502        // concurrent requests don't all contend on sessions[0].
503        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
504
505        let mut batch_size = self.config.max_batch_size.max(1);
506        let use_gpu = self.use_gpu;
507
508        // Up to 3 halvings: 32 → 16 → 8 → 4 → hard fail.
509        for attempt in 0_u32..=3 {
510            let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
511
512            // Spawn all sub-batches; on GPU serialize via semaphore, on CPU run concurrently.
513            let mut handles = Vec::with_capacity(batches.len());
514            for (i, batch_owned) in batches.into_iter().enumerate() {
515                let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
516                let processor = Arc::clone(&self.processor);
517
518                // GPU: acquire the global semaphore before spawning. The acquired permit is
519                // moved into the blocking closure and dropped when inference completes,
520                // releasing the slot for the next sub-batch. With 1 permit this guarantees
521                // at most one CUDA forward pass runs at a time across all ingest workers.
522                let gpu_permit = if use_gpu {
523                    Some(
524                        std::sync::Arc::clone(&crate::GPU_INFERENCE_SEMAPHORE)
525                            .acquire_owned()
526                            .await
527                            .map_err(|_| {
528                                InferenceError::InferenceError(
529                                    "GPU inference semaphore unexpectedly closed".to_string(),
530                                )
531                            })?,
532                    )
533                } else {
534                    None
535                };
536
537                handles.push(tokio::task::spawn_blocking(move || {
538                    let _gpu_permit = gpu_permit; // held until this task completes
539                    let mut session_guard = session.lock();
540                    Self::process_batch_blocking(
541                        &batch_owned,
542                        &mut session_guard,
543                        &processor,
544                        normalize,
545                    )
546                }));
547            }
548
549            let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
550            let mut oom: Option<InferenceError> = None;
551
552            for handle in handles {
553                match handle.await {
554                    Err(panic_err) => {
555                        return Err(InferenceError::InferenceError(format!(
556                            "Inference task panicked: {panic_err}"
557                        )));
558                    }
559                    Ok(Err(e)) => {
560                        if attempt < 3 && Self::is_gpu_oom(&e) {
561                            // First OOM wins; remaining in-flight tasks detach harmlessly.
562                            oom = Some(e);
563                            break;
564                        }
565                        return Err(e);
566                    }
567                    Ok(Ok(batch_embs)) => {
568                        all_embeddings.extend(batch_embs);
569                    }
570                }
571            }
572
573            if oom.is_some() {
574                let next_batch = (batch_size / 2).max(1);
575                warn!(
576                    "ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
577                    attempt + 1,
578                    batch_size,
579                    next_batch,
580                );
581                batch_size = next_batch;
582                continue;
583            }
584
585            return Ok(all_embeddings);
586        }
587
588        Err(InferenceError::InferenceError(format!(
589            "ONNX inference failed: GPU/CPU allocator OOM persists after 3 \
590             batch-halving attempts (final batch_size={batch_size})"
591        )))
592    }
593
594    /// Returns `true` when the error message indicates an allocator OOM that can be
595    /// recovered by reducing the inference batch size (DAK-6002).
596    ///
597    /// Covers CUDA BFCArena (`BFCArena`, `Failed to allocate memory … buffer of size`)
598    /// and CPU arena OOM patterns seen in ONNX Runtime 2.x.
599    fn is_gpu_oom(err: &InferenceError) -> bool {
600        let msg = err.to_string();
601        msg.contains("BFCArena")
602            || msg.contains("Failed to allocate memory")
603            || msg.contains("CUDA_OUT_OF_MEMORY")
604            || msg.contains("CUDA out of memory")
605            || (msg.contains("allocate") && msg.contains("buffer of size"))
606    }
607
608    /// Process a single batch: tokenize → ORT session → mean pool → normalize.
609    ///
610    /// Designed to run inside `spawn_blocking` (takes no `&self`).
611    fn process_batch_blocking(
612        texts: &[String],
613        session: &mut Session,
614        processor: &BatchProcessor,
615        normalize: bool,
616    ) -> Result<Vec<Vec<f32>>> {
617        // Tokenize
618        let prepared = processor.tokenize_batch(texts)?;
619        let batch_size = prepared.batch_size;
620        let seq_len = prepared.seq_len;
621
622        // Keep a copy of attention_mask for mean_pooling (consumed by Tensor below)
623        let attention_mask_flat = prepared.attention_mask.clone();
624
625        // Build ORT tensors — from_array requires (shape, Vec<T>) in ort rc.12
626        let input_ids_tensor =
627            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
628                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
629        let attention_mask_tensor =
630            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
631                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
632        let token_type_ids_tensor =
633            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
634                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
635
636        // Run ONNX session
637        let outputs = session
638            .run(inputs![
639                "input_ids" => input_ids_tensor,
640                "attention_mask" => attention_mask_tensor,
641                "token_type_ids" => token_type_ids_tensor
642            ])
643            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
644
645        // Extract last_hidden_state: shape [batch, seq_len, hidden_size]
646        // ort rc.12: try_extract_tensor returns (&Shape, &[T])
647        // Shape derefs to [i64], so index directly.
648        let (ort_shape, lhs_slice) = outputs[0]
649            .try_extract_tensor::<f32>()
650            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
651
652        if ort_shape.len() != 3 {
653            return Err(InferenceError::InferenceError(format!(
654                "Expected 3D last_hidden_state, got {} dims",
655                ort_shape.len()
656            )));
657        }
658        let hidden_size = ort_shape[2] as usize;
659
660        // Apply mean pooling using the saved attention mask copy
661        let mut embeddings = mean_pooling(
662            lhs_slice,
663            batch_size,
664            seq_len,
665            hidden_size,
666            &attention_mask_flat,
667        );
668
669        // L2 normalize if configured
670        if normalize {
671            normalize_embeddings(&mut embeddings);
672        }
673
674        debug!(
675            "Generated {} embeddings of dimension {}",
676            embeddings.len(),
677            embeddings.first().map(|e| e.len()).unwrap_or(0)
678        );
679
680        Ok(embeddings)
681    }
682
683    /// Estimate the time to embed a batch of texts (in milliseconds).
684    pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
685        // Rough estimation based on model speed and text length (CPU path)
686        let tokens_per_text =
687            (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
688        let total_tokens = tokens_per_text * text_count as f64;
689        let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
690        (total_tokens / tokens_per_second) * 1000.0
691    }
692}
693
694impl std::fmt::Debug for EmbeddingEngine {
695    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
696        f.debug_struct("EmbeddingEngine")
697            .field("model", &self.config.model)
698            .field("dimension", &self.dimension)
699            .field("max_batch_size", &self.config.max_batch_size)
700            .field("session_pool_size", &self.sessions.len())
701            .field("use_gpu", &self.use_gpu)
702            .finish()
703    }
704}
705
706/// Builder for creating an EmbeddingEngine with fluent API.
707pub struct EmbeddingEngineBuilder {
708    config: ModelConfig,
709}
710
711impl EmbeddingEngineBuilder {
712    /// Create a new builder with default configuration.
713    pub fn new() -> Self {
714        Self {
715            config: ModelConfig::default(),
716        }
717    }
718
719    /// Set the embedding model to use.
720    pub fn model(mut self, model: EmbeddingModel) -> Self {
721        self.config.model = model;
722        self
723    }
724
725    /// Set the cache directory for model files.
726    pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
727        self.config.cache_dir = Some(dir.into());
728        self
729    }
730
731    /// Set the maximum batch size.
732    pub fn max_batch_size(mut self, size: usize) -> Self {
733        self.config.max_batch_size = size;
734        self
735    }
736
737    /// Enable GPU acceleration (reserved for future use; ORT selects the execution provider).
738    pub fn use_gpu(mut self, enable: bool) -> Self {
739        self.config.use_gpu = enable;
740        self
741    }
742
743    /// Set the number of intra-op CPU threads for ORT inference.
744    pub fn num_threads(mut self, threads: usize) -> Self {
745        self.config.num_threads = Some(threads);
746        self
747    }
748
749    /// Set the number of parallel ONNX sessions in the pool.
750    pub fn session_pool_size(mut self, size: usize) -> Self {
751        self.config.session_pool_size = size.max(1);
752        self
753    }
754
755    /// Build the embedding engine.
756    pub async fn build(self) -> Result<EmbeddingEngine> {
757        EmbeddingEngine::new(self.config).await
758    }
759}
760
761impl Default for EmbeddingEngineBuilder {
762    fn default() -> Self {
763        Self::new()
764    }
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770
771    #[test]
772    fn test_estimate_time() {
773        let config = ModelConfig::new(EmbeddingModel::MiniLM);
774        let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
775        assert!(tokens_per_second > 0.0);
776    }
777
778    #[test]
779    fn test_builder() {
780        let builder = EmbeddingEngineBuilder::new()
781            .model(EmbeddingModel::BgeSmall)
782            .max_batch_size(64)
783            .use_gpu(false);
784
785        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
786        assert_eq!(builder.config.max_batch_size, 64);
787        assert!(!builder.config.use_gpu);
788    }
789
790    // ── model_cache_dir ──────────────────────────────────────────────────────
791
792    /// Ensure `model_cache_dir` respects `HF_HOME` when set.
793    ///
794    /// Uses a process-level mutex because `std::env::set_var` is not thread-safe.
795    #[test]
796    fn test_model_cache_dir_with_hf_home() {
797        use std::sync::Mutex;
798        static ENV_LOCK: Mutex<()> = Mutex::new(());
799        let _guard = ENV_LOCK.lock().unwrap();
800
801        let tmp = std::env::temp_dir().join("dakera_test_hf_home");
802        std::env::set_var("HF_HOME", &tmp);
803        let result = EmbeddingEngine::model_cache_dir("org/my-model");
804        std::env::remove_var("HF_HOME");
805
806        let path = result.unwrap();
807        assert!(
808            path.starts_with(&tmp),
809            "expected path under {tmp:?}, got {path:?}"
810        );
811        assert!(
812            path.to_str().unwrap().contains("org--my-model"),
813            "model_id separator not applied: {path:?}"
814        );
815    }
816
817    #[test]
818    fn test_model_cache_dir_contains_dakera_subdir() {
819        let path =
820            EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
821        let s = path.to_str().unwrap();
822        assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
823        assert!(
824            s.contains("sentence-transformers--all-MiniLM-L6-v2"),
825            "expected transformed model id in path: {s}"
826        );
827    }
828
829    #[test]
830    fn test_model_cache_dir_creates_directory() {
831        let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
832        assert!(dir.exists(), "model_cache_dir should create the directory");
833    }
834
835    // ── download_hf_file (cached early-return path) ──────────────────────────
836
837    #[test]
838    fn test_download_hf_file_returns_path_when_already_cached() {
839        let tmp = std::env::temp_dir().join("dakera_test_cached_file");
840        std::fs::create_dir_all(&tmp).unwrap();
841        let file_path = tmp.join("config.json");
842        std::fs::write(&file_path, b"{}").unwrap();
843
844        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
845        assert!(result.is_ok());
846        assert_eq!(result.unwrap(), file_path);
847    }
848
849    #[test]
850    fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
851        let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
852        let onnx_dir = tmp.join("onnx");
853        std::fs::create_dir_all(&onnx_dir).unwrap();
854        let onnx_path = onnx_dir.join("model_quantized.onnx");
855        std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
856
857        // Filename includes the subdirectory path
858        let result = EmbeddingEngine::download_hf_file(
859            "Xenova/all-MiniLM-L6-v2",
860            "onnx/model_quantized.onnx",
861            &tmp,
862        );
863        assert!(result.is_ok());
864        assert_eq!(result.unwrap(), onnx_path);
865    }
866
867    // ── EmbeddingEngineBuilder ───────────────────────────────────────────────
868
869    #[test]
870    fn test_builder_default_impl() {
871        let b1 = EmbeddingEngineBuilder::new();
872        let b2 = EmbeddingEngineBuilder::default();
873        assert_eq!(b1.config.model, b2.config.model);
874        assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
875    }
876
877    #[test]
878    fn test_builder_model_field() {
879        let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
880        assert_eq!(builder.config.model, EmbeddingModel::E5Small);
881    }
882
883    #[test]
884    fn test_builder_cache_dir() {
885        let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
886        assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
887    }
888
889    #[test]
890    fn test_builder_max_batch_size() {
891        let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
892        assert_eq!(builder.config.max_batch_size, 128);
893    }
894
895    #[test]
896    fn test_builder_use_gpu_true() {
897        let builder = EmbeddingEngineBuilder::new().use_gpu(true);
898        assert!(builder.config.use_gpu);
899    }
900
901    #[test]
902    fn test_builder_use_gpu_false() {
903        let builder = EmbeddingEngineBuilder::new().use_gpu(false);
904        assert!(!builder.config.use_gpu);
905    }
906
907    #[test]
908    fn test_builder_num_threads() {
909        let builder = EmbeddingEngineBuilder::new().num_threads(4);
910        assert_eq!(builder.config.num_threads, Some(4));
911    }
912
913    #[test]
914    fn test_builder_chain_all_fields() {
915        let builder = EmbeddingEngineBuilder::new()
916            .model(EmbeddingModel::BgeSmall)
917            .cache_dir("/cache")
918            .max_batch_size(16)
919            .use_gpu(false)
920            .num_threads(2);
921
922        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
923        assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
924        assert_eq!(builder.config.max_batch_size, 16);
925        assert!(!builder.config.use_gpu);
926        assert_eq!(builder.config.num_threads, Some(2));
927    }
928
929    // ── estimate_time_ms ─────────────────────────────────────────────────────
930
931    #[test]
932    fn test_estimate_time_zero_count() {
933        let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
934        let estimate = (0.0 / tps) * 1000.0;
935        assert_eq!(estimate, 0.0);
936    }
937
938    #[test]
939    fn test_estimate_time_formula_cpu() {
940        // texts=10, avg_len=100 → tokens_per_text = min(25, 256) = 25
941        // total_tokens = 250; tps = 5000; time = (250/5000)*1000 = 50ms
942        let model = EmbeddingModel::MiniLM;
943        let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
944        let total_tokens = tokens_per_text * 10.0;
945        let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
946        assert!(
947            (estimate - 50.0).abs() < 1e-6,
948            "expected 50.0ms, got {estimate}"
949        );
950    }
951
952    #[test]
953    fn test_estimate_time_capped_at_max_seq_length() {
954        let model = EmbeddingModel::MiniLM;
955        let avg_len = 100_000;
956        let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
957        assert_eq!(tokens_per_text, 256.0);
958    }
959
960    // ── ModelConfig API ───────────────────────────────────────────────────────
961
962    #[test]
963    fn test_model_config_new() {
964        let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
965        assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
966        assert_eq!(cfg.max_batch_size, 32);
967        assert!(!cfg.use_gpu);
968        assert!(cfg.cache_dir.is_none());
969        assert!(cfg.num_threads.is_none());
970    }
971
972    #[test]
973    fn test_model_config_default() {
974        let cfg = ModelConfig::default();
975        assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
976        assert_eq!(cfg.max_batch_size, 32);
977        assert!(!cfg.use_gpu);
978    }
979
980    #[test]
981    fn test_model_config_with_cache_dir() {
982        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
983        assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
984    }
985
986    #[test]
987    fn test_model_config_with_max_batch_size() {
988        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
989        assert_eq!(cfg.max_batch_size, 64);
990    }
991
992    #[test]
993    fn test_model_config_with_gpu() {
994        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
995        assert!(cfg.use_gpu);
996    }
997
998    #[test]
999    fn test_model_config_with_num_threads() {
1000        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
1001        assert_eq!(cfg.num_threads, Some(8));
1002    }
1003
1004    #[test]
1005    fn test_model_config_chained_builder() {
1006        let cfg = ModelConfig::new(EmbeddingModel::E5Small)
1007            .with_cache_dir("/cache")
1008            .with_max_batch_size(16)
1009            .with_gpu(false)
1010            .with_num_threads(4);
1011        assert_eq!(cfg.model, EmbeddingModel::E5Small);
1012        assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
1013        assert_eq!(cfg.max_batch_size, 16);
1014        assert!(!cfg.use_gpu);
1015        assert_eq!(cfg.num_threads, Some(4));
1016    }
1017
1018    // ── model_cache_dir edge cases ────────────────────────────────────────────
1019
1020    /// Test `model_cache_dir` when HOME is not set — should fall back to /tmp.
1021    #[test]
1022    fn test_model_cache_dir_no_home_fallback() {
1023        use std::sync::Mutex;
1024        static ENV_LOCK: Mutex<()> = Mutex::new(());
1025        let _guard = ENV_LOCK.lock().unwrap();
1026
1027        // Remove HOME and HF_HOME so we hit the /tmp fallback
1028        let saved_home = std::env::var("HOME").ok();
1029        let saved_hf = std::env::var("HF_HOME").ok();
1030        unsafe {
1031            std::env::remove_var("HOME");
1032            std::env::remove_var("HF_HOME");
1033        }
1034
1035        let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
1036
1037        // Restore env
1038        if let Some(h) = saved_home {
1039            unsafe { std::env::set_var("HOME", h) };
1040        }
1041        if let Some(h) = saved_hf {
1042            unsafe { std::env::set_var("HF_HOME", h) };
1043        }
1044
1045        let path = result.unwrap();
1046        // Should be under /tmp since HOME was unset
1047        assert!(
1048            path.starts_with("/tmp"),
1049            "expected path under /tmp, got {path:?}"
1050        );
1051    }
1052
1053    #[test]
1054    fn test_model_cache_dir_deep_model_id() {
1055        let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
1056        let s = path.to_str().unwrap();
1057        // All slashes replaced with double-dash
1058        assert!(
1059            s.contains("org--sub--model-name-with-dashes"),
1060            "expected transformed path, got: {s}"
1061        );
1062    }
1063
1064    #[test]
1065    fn test_model_cache_dir_minilm_model_id() {
1066        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
1067        let s = path.to_str().unwrap();
1068        assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
1069    }
1070
1071    #[test]
1072    fn test_model_cache_dir_bge_model_id() {
1073        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
1074        let s = path.to_str().unwrap();
1075        assert!(s.contains("BAAI--bge-small-en-v1.5"));
1076    }
1077
1078    #[test]
1079    fn test_model_cache_dir_e5_model_id() {
1080        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
1081        let s = path.to_str().unwrap();
1082        assert!(s.contains("intfloat--e5-small-v2"));
1083    }
1084
1085    // ── download_hf_file additional cache-hit variations ─────────────────────
1086
1087    #[test]
1088    fn test_download_hf_file_pytorch_bin_cached() {
1089        let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
1090        std::fs::create_dir_all(&tmp).unwrap();
1091        let model_path = tmp.join("pytorch_model.bin");
1092        std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
1093
1094        let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
1095        assert!(result.is_ok());
1096        assert_eq!(result.unwrap(), model_path);
1097    }
1098
1099    #[test]
1100    fn test_download_hf_file_tokenizer_cached() {
1101        let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
1102        std::fs::create_dir_all(&tmp).unwrap();
1103        let tok_path = tmp.join("tokenizer.json");
1104        std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
1105
1106        let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
1107        assert!(result.is_ok());
1108        assert_eq!(result.unwrap(), tok_path);
1109    }
1110
1111    #[test]
1112    fn test_download_hf_file_config_json_cached() {
1113        let tmp = std::env::temp_dir().join("dakera_test_config_cached");
1114        std::fs::create_dir_all(&tmp).unwrap();
1115        let cfg_path = tmp.join("config.json");
1116        std::fs::write(&cfg_path, b"{}").unwrap();
1117
1118        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
1119        assert!(result.is_ok());
1120        assert_eq!(result.unwrap(), cfg_path);
1121    }
1122
1123    // ── EmbeddingEngine::new() failure path via fake local cache ─────────────
1124
1125    /// Tests the code path through `download_model_files` (local Dakera cache hit)
1126    /// and into `new()` — which then fails trying to load the tokenizer from a
1127    /// fake file. No network access required; fake files are pre-seeded.
1128    #[tokio::test]
1129    #[allow(clippy::await_holding_lock)]
1130    async fn test_new_fails_with_invalid_tokenizer_json() {
1131        use std::sync::Mutex;
1132        static ENV_LOCK: Mutex<()> = Mutex::new(());
1133        let _guard = ENV_LOCK.lock().unwrap();
1134
1135        // Set up a fake Dakera model cache so download_model_files finds our files
1136        let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1137        let model_dir = tmp
1138            .join("dakera")
1139            .join("sentence-transformers--all-MiniLM-L6-v2");
1140        std::fs::create_dir_all(&model_dir).unwrap();
1141        // Valid-looking model weights placeholder (candle will fail on this, which is fine)
1142        std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1143        // Invalid tokenizer.json — will cause TokenizationError in new()
1144        std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1145        std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1146
1147        unsafe { std::env::set_var("HF_HOME", &tmp) };
1148
1149        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1150        let result = EmbeddingEngine::new(config).await;
1151
1152        unsafe { std::env::remove_var("HF_HOME") };
1153
1154        // Must fail — tokenizer.json is invalid JSON
1155        assert!(
1156            result.is_err(),
1157            "expected Err from new() with invalid tokenizer, got Ok"
1158        );
1159    }
1160
1161    // ── EmbeddingEngineBuilder additional coverage ────────────────────────────
1162
1163    #[test]
1164    fn test_builder_with_all_models() {
1165        for model in [
1166            EmbeddingModel::MiniLM,
1167            EmbeddingModel::BgeSmall,
1168            EmbeddingModel::E5Small,
1169        ] {
1170            let builder = EmbeddingEngineBuilder::new().model(model);
1171            assert_eq!(builder.config.model, model);
1172        }
1173    }
1174
1175    #[test]
1176    fn test_builder_max_batch_size_one() {
1177        let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1178        assert_eq!(builder.config.max_batch_size, 1);
1179    }
1180
1181    #[test]
1182    fn test_builder_num_threads_zero() {
1183        let builder = EmbeddingEngineBuilder::new().num_threads(0);
1184        assert_eq!(builder.config.num_threads, Some(0));
1185    }
1186
1187    // ── EmbeddingEngine::new() / getters when model is cached (best-effort) ──
1188
1189    /// If the embedding model is already cached on this machine, exercise the
1190    /// full `new()` path and test getters. On machines without a cached model
1191    /// the test passes silently — it is intentionally non-gating.
1192    #[tokio::test]
1193    async fn test_engine_getters_when_model_cached() {
1194        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1195        match EmbeddingEngine::new(config).await {
1196            Ok(engine) => {
1197                assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1198                assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1199                // Device should be CPU in test environments (device() removed in CE-3 ONNX migration)
1200                // Debug impl should not panic
1201                let _ = format!("{:?}", engine);
1202                // estimate_time_ms should return a non-negative value
1203                let ms = engine.estimate_time_ms(10, 50);
1204                assert!(ms >= 0.0);
1205            }
1206            Err(_) => {
1207                // Model not in cache — skip; CI runner may or may not have it
1208            }
1209        }
1210    }
1211
1212    /// When model is cached: embed an empty batch must return immediately with
1213    /// no embeddings (the `texts.is_empty()` fast-path in embed_batch_internal).
1214    #[tokio::test]
1215    async fn test_engine_embed_empty_batch_when_cached() {
1216        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1217        if let Ok(engine) = EmbeddingEngine::new(config).await {
1218            let result = engine.embed_raw(&[]).await;
1219            assert!(result.is_ok());
1220            assert!(result.unwrap().is_empty());
1221        }
1222    }
1223
1224    // ── Session pool (DAK-5547) ──────────────────────────────────────────────
1225
1226    #[test]
1227    fn test_session_pool_default_is_4() {
1228        // Default pool size is 4 (DAK-5746: pool=1 regressed LME ingest ~4×; OOM causes fixed).
1229        // DAKERA_ONNX_POOL_SIZE env var still allows override.
1230        let config = ModelConfig::default();
1231        let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1232            .ok()
1233            .and_then(|v| v.parse::<usize>().ok())
1234            .filter(|&n| n >= 1)
1235            .unwrap_or(4);
1236        assert_eq!(config.session_pool_size, expected);
1237    }
1238
1239    #[test]
1240    fn test_session_pool_size_builder_roundtrip() {
1241        let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1242        assert_eq!(builder.config.session_pool_size, 8);
1243    }
1244
1245    #[test]
1246    fn test_session_pool_size_min_enforced() {
1247        let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1248        assert_eq!(
1249            builder.config.session_pool_size, 1,
1250            "pool size 0 must clamp to 1"
1251        );
1252    }
1253
1254    #[test]
1255    fn test_model_config_with_session_pool_size() {
1256        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1257        assert_eq!(cfg.session_pool_size, 2);
1258    }
1259
1260    /// When model is cached: verify the session pool has the expected size.
1261    #[tokio::test]
1262    async fn test_engine_pool_size_matches_config_when_cached() {
1263        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1264        if let Ok(engine) = EmbeddingEngine::new(config).await {
1265            assert_eq!(
1266                engine.pool_size(),
1267                2,
1268                "engine should hold exactly 2 sessions"
1269            );
1270        }
1271    }
1272
1273    // ── next_session round-robin ──────────────────────────────────────────────
1274
1275    /// Round-robin counter distributes batches across pool slots without panic.
1276    #[test]
1277    fn test_round_robin_index_stays_in_bounds() {
1278        let pool_len = 4_usize;
1279        let counter = AtomicUsize::new(0);
1280        for expected_idx in 0..100_usize {
1281            let start = counter.fetch_add(1, Ordering::Relaxed);
1282            let slot = start % pool_len;
1283            assert!(slot < pool_len);
1284            assert_eq!(slot, expected_idx % pool_len);
1285        }
1286    }
1287
1288    /// Pool size 1 degrades to single-session behavior without panicking.
1289    #[test]
1290    fn test_round_robin_pool_size_one() {
1291        let pool_len = 1_usize;
1292        let counter = AtomicUsize::new(0);
1293        for _ in 0..50 {
1294            let start = counter.fetch_add(1, Ordering::Relaxed);
1295            assert_eq!(start % pool_len, 0);
1296        }
1297    }
1298
1299    /// When model is cached: empty batch short-circuits before touching the pool.
1300    #[tokio::test]
1301    async fn test_embed_empty_does_not_advance_pool_counter() {
1302        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1303        if let Ok(engine) = EmbeddingEngine::new(config).await {
1304            let result = engine.embed_raw(&[]).await;
1305            assert!(result.is_ok());
1306            assert!(result.unwrap().is_empty());
1307            // Empty batch returns before fetch_add — counter stays at 0.
1308            assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1309        }
1310    }
1311
1312    // ── GPU inference semaphore (DAK-6096) ────────────────────────────────────
1313
1314    /// Default config has use_gpu=false; the resolved flag stored in the engine must match.
1315    #[allow(clippy::await_holding_lock)]
1316    #[tokio::test]
1317    async fn test_engine_use_gpu_defaults_to_false_when_not_configured() {
1318        use std::sync::Mutex;
1319        static ENV_LOCK: Mutex<()> = Mutex::new(());
1320        let _guard = ENV_LOCK.lock().unwrap();
1321
1322        // Ensure DAKERA_USE_GPU is unset so we test the config.use_gpu=false path.
1323        unsafe { std::env::remove_var("DAKERA_USE_GPU") };
1324
1325        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(1);
1326        if let Ok(engine) = EmbeddingEngine::new(config).await {
1327            assert!(
1328                !engine.use_gpu,
1329                "use_gpu must be false when DAKERA_USE_GPU is unset"
1330            );
1331        }
1332    }
1333
1334    /// DAKERA_USE_GPU=1 env var must override config.use_gpu=false and set use_gpu=true.
1335    #[test]
1336    fn test_engine_use_gpu_resolved_from_env_var() {
1337        use std::sync::Mutex;
1338        static ENV_LOCK: Mutex<()> = Mutex::new(());
1339        let _guard = ENV_LOCK.lock().unwrap();
1340
1341        unsafe { std::env::set_var("DAKERA_USE_GPU", "1") };
1342        let resolved = std::env::var("DAKERA_USE_GPU")
1343            .map(|v| v == "1")
1344            .unwrap_or(false);
1345        unsafe { std::env::remove_var("DAKERA_USE_GPU") };
1346
1347        assert!(resolved, "DAKERA_USE_GPU=1 must resolve use_gpu=true");
1348    }
1349
1350    /// Debug output includes use_gpu field so operators can confirm GPU mode in logs.
1351    #[tokio::test]
1352    async fn test_engine_debug_includes_use_gpu() {
1353        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1354        if let Ok(engine) = EmbeddingEngine::new(config).await {
1355            let debug_str = format!("{:?}", engine);
1356            assert!(
1357                debug_str.contains("use_gpu"),
1358                "Debug output must include use_gpu field: {debug_str}"
1359            );
1360        }
1361    }
1362}