Skip to main content

inference/
engine.rs

1//! Core embedding engine for generating vector embeddings from text.
2//!
3//! The `EmbeddingEngine` provides a high-level interface for:
4//! - Loading ONNX INT8 embedding models from HuggingFace Hub
5//! - Generating embeddings for single texts or batches
6//! - Automatic batching and parallel processing via ONNX Runtime
7//!
8//! # Example
9//!
10//! ```no_run
11//! use inference::{EmbeddingEngine, ModelConfig, EmbeddingModel};
12//!
13//! #[tokio::main]
14//! async fn main() {
15//!     let config = ModelConfig::new(EmbeddingModel::MiniLM);
16//!     let engine = EmbeddingEngine::new(config).await.unwrap();
17//!
18//!     // Embed a single query
19//!     let embedding = engine.embed_query("What is machine learning?").await.unwrap();
20//!     println!("Embedding dimension: {}", embedding.len());
21//!
22//!     // Embed multiple documents
23//!     let docs = vec![
24//!         "Machine learning is a subset of AI.".to_string(),
25//!         "Deep learning uses neural networks.".to_string(),
26//!     ];
27//!     let embeddings = engine.embed_documents(&docs).await.unwrap();
28//!     println!("Generated {} embeddings", embeddings.len());
29//! }
30//! ```
31
32use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49/// The main embedding engine for generating vector embeddings.
50///
51/// This struct is thread-safe and can be shared across async tasks.
52/// ORT sessions are mutex-guarded (run() takes &mut self) and held in a
53/// pool so concurrent callers can embed without head-of-line blocking.
54/// CPU-heavy inference is offloaded via `tokio::task::spawn_blocking`.
55pub struct EmbeddingEngine {
56    /// Pool of ONNX Runtime sessions — each guarded independently.
57    /// Concurrent callers round-robin across sessions via `next_session`,
58    /// eliminating Mutex head-of-line blocking for batch embedding workloads.
59    sessions: Vec<Arc<Mutex<Session>>>,
60    /// Round-robin counter: atomically incremented per batch, wraps via modulo.
61    next_session: AtomicUsize,
62    /// Batch processor for tokenization (Arc-wrapped for spawn_blocking)
63    processor: Arc<BatchProcessor>,
64    /// Model configuration
65    config: ModelConfig,
66    /// Embedding dimension
67    dimension: usize,
68    /// Whether CUDA execution provider is active (resolved once at construction).
69    /// When true, each sub-batch acquires the global GPU_INFERENCE_SEMAPHORE before
70    /// spawning so at most one CUDA forward pass runs at a time, preventing VRAM exhaustion.
71    use_gpu: bool,
72}
73
74impl EmbeddingEngine {
75    /// Create a new embedding engine with the given configuration.
76    ///
77    /// Downloads the FP32 ONNX model (`model.onnx`) when CUDA EP is enabled, or the
78    /// INT8 quantized model (`model_quantized.onnx`) otherwise.
79    #[instrument(skip_all, fields(model = %config.model))]
80    pub async fn new(config: ModelConfig) -> Result<Self> {
81        // Resolve GPU mode before downloading model files — GPU requires the FP32 model.
82        // DAKERA_USE_GPU=1 overrides config.use_gpu (env var is the production knob).
83        let use_gpu = std::env::var("DAKERA_USE_GPU")
84            .map(|v| v == "1")
85            .unwrap_or(config.use_gpu);
86        if use_gpu {
87            info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
88        }
89
90        info!(
91            "Initializing ONNX embedding engine with model: {}",
92            config.model
93        );
94
95        // Download tokenizer and ONNX model files (FP32 for GPU, INT8 for CPU)
96        let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
97
98        // Load tokenizer
99        info!("Loading tokenizer from {:?}", tokenizer_path);
100        let tokenizer = Tokenizer::from_file(&tokenizer_path)
101            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
102
103        // Build ONNX session pool — N independent sessions to serve concurrent callers.
104        // Each session has its own ORT context so pool members never block each other.
105        info!("Loading ONNX model from {:?}", onnx_path);
106        let num_threads = config.num_threads.unwrap_or(4);
107        let pool_size = config.session_pool_size.max(1);
108        let onnx_path_clone = onnx_path.clone();
109
110        let sessions: Vec<Arc<Mutex<Session>>> =
111            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
112                (0..pool_size)
113                    .map(|_| {
114                        let builder = Session::builder()
115                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
116                            .with_optimization_level(GraphOptimizationLevel::Level3)
117                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
118                            .with_intra_threads(num_threads)
119                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
120
121                        // DAK-6145: disable memory pattern pre-allocation on CPU.
122                        // With memory_pattern=true (default), ORT pre-allocates arena
123                        // capacity based on the max observed input shape.  On a
124                        // memory-constrained server this causes the BFCArena to over-
125                        // commit, leaving no headroom for the per-layer activations in
126                        // subsequent encoder layers.  Disabling it forces fresh
127                        // per-run allocation — ~5% slower but eliminates OOM crashes.
128                        let mut builder = if use_gpu {
129                            builder
130                                .with_execution_providers(
131                                    [CUDAExecutionProvider::default().build()],
132                                )
133                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
134                        } else {
135                            builder
136                                .with_memory_pattern(false)
137                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
138                        };
139
140                        let s = builder
141                            .commit_from_file(&onnx_path_clone)
142                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
143                        Ok(Arc::new(Mutex::new(s)))
144                    })
145                    .collect()
146            })
147            .await
148            .map_err(|e| {
149                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
150            })??;
151
152        let dimension = config.model.dimension();
153        let processor = Arc::new(BatchProcessor::new(
154            tokenizer,
155            config.model,
156            config.max_batch_size,
157        ));
158
159        info!(
160            "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
161            config.model, dimension, num_threads, pool_size
162        );
163
164        Ok(Self {
165            sessions,
166            next_session: AtomicUsize::new(0),
167            processor,
168            config,
169            dimension,
170            use_gpu,
171        })
172    }
173
174    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
175    ///
176    /// - `tokenizer.json` — from the original model repo (sentence-transformers, BAAI, intfloat)
177    /// - GPU: `onnx/model.onnx` (FP32) — no Memcpy round-trips, CUDA EP runs entirely on-device
178    /// - CPU: `onnx/model_quantized.onnx` (INT8) — fastest for CPU-only deployments
179    #[instrument(skip_all, fields(model = %config.model))]
180    async fn download_model_files(
181        config: &ModelConfig,
182        use_gpu: bool,
183    ) -> Result<(PathBuf, PathBuf)> {
184        let model_id = config.model.model_id();
185        let onnx_repo_id = config.model.onnx_repo_id();
186        let onnx_filename = if use_gpu {
187            config.model.onnx_filename_gpu()
188        } else {
189            config.model.onnx_filename()
190        };
191
192        info!(
193            "Resolving model files: tokenizer={}, onnx={}@{}",
194            model_id, onnx_filename, onnx_repo_id
195        );
196
197        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
198        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
199
200        // ONNX sub-directory mirrors the path within the repo (e.g. "onnx/")
201        let onnx_subdir = onnx_cache_dir.join("onnx");
202        std::fs::create_dir_all(&onnx_subdir)?;
203
204        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
205        // onnx_filename is e.g. "onnx/model_quantized.onnx" or "onnx/model.onnx" — use the basename
206        let onnx_basename = Path::new(onnx_filename)
207            .file_name()
208            .and_then(|s| s.to_str())
209            .unwrap_or("model_quantized.onnx");
210        let local_onnx = onnx_subdir.join(onnx_basename);
211
212        // Download missing files in a blocking thread
213        let tokenizer_needs_download = !local_tokenizer.exists();
214
215        // GPU FP32 model (BGE-large) is 1.27 GB. The old 500 MB download limit silently
216        // truncated it, leaving a corrupt ONNX file that ORT rejects with "Protobuf parsing
217        // failed" (DAK-5976). If the cached GPU model is ≤500 MB, it's the truncated artifact
218        // — delete it so download_hf_file fetches the complete file.
219        if use_gpu && local_onnx.exists() {
220            let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
221            if cached_size <= 500_000_000 {
222                warn!(
223                    "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated by old \
224                     download limit. Deleting for complete re-download.",
225                    local_onnx, cached_size
226                );
227                let _ = std::fs::remove_file(&local_onnx);
228            }
229        }
230        let onnx_needs_download = !local_onnx.exists();
231
232        if tokenizer_needs_download || onnx_needs_download {
233            let model_id_owned = model_id.to_string();
234            let onnx_repo_id_owned = onnx_repo_id.to_string();
235            let onnx_filename_owned = onnx_filename.to_string();
236            let tokenizer_cache = tokenizer_cache_dir.clone();
237            let onnx_cache = onnx_cache_dir.clone();
238
239            tokio::task::spawn_blocking(move || {
240                if !tokenizer_cache.join("tokenizer.json").exists() {
241                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
242                        .map_err(|e| {
243                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
244                        })?;
245                }
246                if !onnx_cache.join(&onnx_filename_owned).exists() {
247                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
248                        .map_err(|e| {
249                            InferenceError::HubError(format!(
250                                "Failed to download ONNX model: {}",
251                                e
252                            ))
253                        })?;
254                }
255                Ok::<_, InferenceError>(())
256            })
257            .await
258            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
259        } else {
260            info!("All model files found in local cache");
261        }
262
263        // Re-derive paths (cache dir / onnx / basename)
264        let final_onnx = onnx_cache_dir.join(onnx_filename);
265
266        info!(
267            "Model files ready: tokenizer={:?}, onnx={:?}",
268            local_tokenizer, final_onnx
269        );
270        Ok((local_tokenizer, final_onnx))
271    }
272
273    /// Get or create the local model cache directory.
274    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
275        let base = std::env::var("HF_HOME")
276            .map(PathBuf::from)
277            .unwrap_or_else(|_| {
278                let home = std::env::var("HOME").unwrap_or_else(|_| {
279                    warn!("HOME environment variable not set, using /tmp for model cache");
280                    "/tmp".to_string()
281                });
282                PathBuf::from(home).join(".cache").join("huggingface")
283            });
284        let dir = base.join("dakera").join(model_id.replace('/', "--"));
285        std::fs::create_dir_all(&dir)?;
286        Ok(dir)
287    }
288
289    /// Download a single file from HuggingFace using ureq (sync, for spawn_blocking).
290    ///
291    /// Handles relative Location headers that ureq 2.x cannot resolve automatically.
292    ///
293    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
294    pub fn download_hf_file_pub(
295        model_id: &str,
296        filename: &str,
297        cache_dir: &Path,
298    ) -> std::result::Result<PathBuf, String> {
299        Self::download_hf_file(model_id, filename, cache_dir)
300    }
301
302    fn download_hf_file(
303        model_id: &str,
304        filename: &str,
305        cache_dir: &Path,
306    ) -> std::result::Result<PathBuf, String> {
307        // The file may be nested (e.g. "onnx/model_quantized.onnx")
308        let file_path = cache_dir.join(filename);
309        if file_path.exists() {
310            info!("Cached: {}/{}", model_id, filename);
311            return Ok(file_path);
312        }
313
314        // Ensure parent directory exists (for "onnx/model_quantized.onnx")
315        if let Some(parent) = file_path.parent() {
316            std::fs::create_dir_all(parent)
317                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
318        }
319
320        let url = format!(
321            "https://huggingface.co/{}/resolve/main/{}",
322            model_id, filename
323        );
324        info!("Downloading: {}", url);
325
326        // Read HuggingFace token from env (required for gated models).
327        let hf_token = std::env::var("HF_TOKEN")
328            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
329            .ok();
330        if hf_token.is_some() {
331            info!("Using HuggingFace auth token for download");
332        }
333
334        // Disable automatic redirects so we can resolve relative Location headers ourselves.
335        let agent = ureq::AgentBuilder::new()
336            .redirects(0)
337            .timeout(std::time::Duration::from_secs(300))
338            .build();
339
340        let mut current_url = url.clone();
341        let mut redirects = 0;
342        let max_redirects = 10;
343
344        let response = loop {
345            let mut req = agent.get(&current_url);
346            if let Some(ref token) = hf_token {
347                req = req.set("Authorization", &format!("Bearer {}", token));
348            }
349            let resp = req.call();
350
351            let r = match resp {
352                Ok(r) => r,
353                Err(ureq::Error::Status(_status, r)) => r,
354                Err(e) => return Err(format!("{}: {}", filename, e)),
355            };
356
357            let status = r.status();
358            if (200..300).contains(&status) {
359                break r;
360            } else if (300..400).contains(&status) {
361                redirects += 1;
362                if redirects > max_redirects {
363                    return Err(format!("{}: too many redirects", filename));
364                }
365                let location = r
366                    .header("location")
367                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
368                    .to_string();
369
370                // Resolve relative redirects against the current URL's origin
371                current_url = if location.starts_with('/') {
372                    let parsed = url::Url::parse(&current_url)
373                        .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
374                    let host = parsed.host_str().ok_or_else(|| {
375                        format!("{}: redirect URL missing host: {}", filename, current_url)
376                    })?;
377                    format!("{}://{}{}", parsed.scheme(), host, location)
378                } else {
379                    location
380                };
381                info!("Redirect {} → {}", redirects, current_url);
382            } else {
383                return Err(format!("{}: HTTP {}", filename, status));
384            }
385        };
386
387        // Capture expected file size before consuming the response body.
388        // HuggingFace CDN reports LFS blob sizes via x-linked-size; standard
389        // HTTP servers use content-length. Either may be absent (chunked transfer).
390        let expected_bytes: Option<u64> = response
391            .header("x-linked-size")
392            .or_else(|| response.header("content-length"))
393            .and_then(|v| v.parse::<u64>().ok());
394
395        // 2 GiB limit — BGE-large FP32 is 1.27 GB; the old 500 MB limit silently
396        // truncated it, producing a corrupt ONNX that ORT rejects with
397        // "Protobuf parsing failed" (DAK-5976).
398        let mut bytes = Vec::new();
399        response
400            .into_reader()
401            .take(2_147_483_648)
402            .read_to_end(&mut bytes)
403            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
404
405        // Fail fast if the received payload is shorter than the declared size.
406        // This prevents writing a corrupt file that would pass the cache-hit check
407        // on the next boot but fail when ORT tries to parse it.
408        if let Some(expected) = expected_bytes {
409            let actual = bytes.len() as u64;
410            if actual < expected {
411                return Err(format!(
412                    "{}: download incomplete — received {} of {} bytes. \
413                     File may exceed 2 GiB or the connection was interrupted.",
414                    filename, actual, expected
415                ));
416            }
417        }
418
419        std::fs::write(&file_path, &bytes)
420            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
421
422        info!("Downloaded {} ({} bytes)", filename, bytes.len());
423        Ok(file_path)
424    }
425
426    /// Get the embedding dimension for the loaded model.
427    pub fn dimension(&self) -> usize {
428        self.dimension
429    }
430
431    /// Get the model being used.
432    pub fn model(&self) -> EmbeddingModel {
433        self.config.model
434    }
435
436    /// Get the number of parallel ONNX sessions in the pool.
437    pub fn pool_size(&self) -> usize {
438        self.sessions.len()
439    }
440
441    /// Embed a single query text.
442    ///
443    /// For models like E5, this automatically applies the query prefix.
444    #[instrument(skip(self, text), fields(text_len = text.len()))]
445    pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
446        let texts = vec![text.to_string()];
447        let prepared = self.processor.prepare_texts(&texts, true);
448        let embeddings = self.embed_batch_internal(&prepared).await?;
449        embeddings.into_iter().next().ok_or_else(|| {
450            InferenceError::InferenceError("No embedding returned for query".to_string())
451        })
452    }
453
454    /// Embed multiple query texts.
455    ///
456    /// For models like E5, this automatically applies the query prefix.
457    #[instrument(skip(self, texts), fields(count = texts.len()))]
458    pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
459        let prepared = self.processor.prepare_texts(texts, true);
460        self.embed_batch_internal(&prepared).await
461    }
462
463    /// Embed a single document/passage.
464    ///
465    /// For models like E5, this automatically applies the document prefix.
466    #[instrument(skip(self, text), fields(text_len = text.len()))]
467    pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
468        let texts = vec![text.to_string()];
469        let prepared = self.processor.prepare_texts(&texts, false);
470        let embeddings = self.embed_batch_internal(&prepared).await?;
471        embeddings.into_iter().next().ok_or_else(|| {
472            InferenceError::InferenceError("No embedding returned for document".to_string())
473        })
474    }
475
476    /// Embed multiple documents/passages.
477    ///
478    /// For models like E5, this automatically applies the document prefix.
479    #[instrument(skip(self, texts), fields(count = texts.len()))]
480    pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
481        let prepared = self.processor.prepare_texts(texts, false);
482        self.embed_batch_internal(&prepared).await
483    }
484
485    /// Embed texts without any prefix (raw embedding).
486    #[instrument(skip(self, texts), fields(count = texts.len()))]
487    pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
488        self.embed_batch_internal(texts).await
489    }
490
491    /// Internal batch embedding implementation.
492    ///
493    /// Splits `texts` into sub-batches (≤ `batch_size` each) and distributes them
494    /// across the session pool via round-robin.  On GPU BFCArena OOM the batch size
495    /// is halved and the full input is retried (up to 5 halvings; DAK-6002/DAK-6145).  This
496    /// handles LME-style long documents whose attention tensor (batch × heads × seq²)
497    /// exceeds the VRAM headroom at the configured batch size but fits at batch/2.
498    ///
499    /// When GPU mode is active (`use_gpu=true`), each sub-batch acquires the global
500    /// `GPU_INFERENCE_SEMAPHORE` (1 permit) before spawning.  This serializes CUDA
501    /// forward passes so at most one batch occupies VRAM at a time.  The permit is
502    /// moved into the blocking task and dropped on completion, releasing the slot.
503    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
504        if texts.is_empty() {
505            return Ok(vec![]);
506        }
507
508        let pool_len = self.sessions.len();
509        let normalize = self.config.model.normalize_embeddings();
510        // Round-robin starting index: each concurrent caller gets a different slot so
511        // concurrent requests don't all contend on sessions[0].
512        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
513
514        let mut batch_size = self.config.max_batch_size.max(1);
515        let use_gpu = self.use_gpu;
516
517        // DAK-6145: up to 5 halvings: 32 → 16 → 8 → 4 → 2 → 1 → hard fail.
518        // Previous depth of 3 (stopping at batch=4) was insufficient: at batch=4 with
519        // seq_len≈503 and BGE-Large (hidden=1024), the query/Add buffer is still 8MB
520        // and the BFCArena fails when concurrent sessions have exhausted system RAM.
521        // At batch=1 the same buffer is ~2MB and reliably allocates even under pressure.
522        for attempt in 0_u32..=5 {
523            let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
524
525            // Spawn all sub-batches; on GPU serialize via semaphore, on CPU run concurrently.
526            let mut handles = Vec::with_capacity(batches.len());
527            for (i, batch_owned) in batches.into_iter().enumerate() {
528                let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
529                let processor = Arc::clone(&self.processor);
530
531                // GPU: acquire the global semaphore before spawning. The acquired permit is
532                // moved into the blocking closure and dropped when inference completes,
533                // releasing the slot for the next sub-batch. With 1 permit this guarantees
534                // at most one CUDA forward pass runs at a time across all ingest workers.
535                let gpu_permit = if use_gpu {
536                    Some(
537                        std::sync::Arc::clone(&crate::GPU_INFERENCE_SEMAPHORE)
538                            .acquire_owned()
539                            .await
540                            .map_err(|_| {
541                                InferenceError::InferenceError(
542                                    "GPU inference semaphore unexpectedly closed".to_string(),
543                                )
544                            })?,
545                    )
546                } else {
547                    None
548                };
549
550                handles.push(tokio::task::spawn_blocking(move || {
551                    let _gpu_permit = gpu_permit; // held until this task completes
552                    let mut session_guard = session.lock();
553                    Self::process_batch_blocking(
554                        &batch_owned,
555                        &mut session_guard,
556                        &processor,
557                        normalize,
558                    )
559                }));
560            }
561
562            let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
563            let mut oom: Option<InferenceError> = None;
564
565            for handle in handles {
566                match handle.await {
567                    Err(panic_err) => {
568                        return Err(InferenceError::InferenceError(format!(
569                            "Inference task panicked: {panic_err}"
570                        )));
571                    }
572                    Ok(Err(e)) => {
573                        if attempt < 5 && Self::is_gpu_oom(&e) {
574                            // First OOM wins; remaining in-flight tasks detach harmlessly.
575                            oom = Some(e);
576                            break;
577                        }
578                        return Err(e);
579                    }
580                    Ok(Ok(batch_embs)) => {
581                        all_embeddings.extend(batch_embs);
582                    }
583                }
584            }
585
586            if oom.is_some() {
587                let next_batch = (batch_size / 2).max(1);
588                warn!(
589                    "ONNX allocator OOM (attempt {}/5) — retrying with batch_size {} → {}",
590                    attempt + 1,
591                    batch_size,
592                    next_batch,
593                );
594                batch_size = next_batch;
595                continue;
596            }
597
598            return Ok(all_embeddings);
599        }
600
601        Err(InferenceError::InferenceError(format!(
602            "ONNX inference failed: GPU/CPU allocator OOM persists after 5 \
603             batch-halving attempts (final batch_size={batch_size})"
604        )))
605    }
606
607    /// Returns `true` when the error message indicates an allocator OOM that can be
608    /// recovered by reducing the inference batch size (DAK-6002).
609    ///
610    /// Covers CUDA BFCArena (`BFCArena`, `Failed to allocate memory … buffer of size`)
611    /// and CPU arena OOM patterns seen in ONNX Runtime 2.x.
612    fn is_gpu_oom(err: &InferenceError) -> bool {
613        let msg = err.to_string();
614        msg.contains("BFCArena")
615            || msg.contains("Failed to allocate memory")
616            || msg.contains("CUDA_OUT_OF_MEMORY")
617            || msg.contains("CUDA out of memory")
618            || (msg.contains("allocate") && msg.contains("buffer of size"))
619    }
620
621    /// Process a single batch: tokenize → ORT session → mean pool → normalize.
622    ///
623    /// Designed to run inside `spawn_blocking` (takes no `&self`).
624    fn process_batch_blocking(
625        texts: &[String],
626        session: &mut Session,
627        processor: &BatchProcessor,
628        normalize: bool,
629    ) -> Result<Vec<Vec<f32>>> {
630        // Tokenize
631        let prepared = processor.tokenize_batch(texts)?;
632        let batch_size = prepared.batch_size;
633        let seq_len = prepared.seq_len;
634
635        // Keep a copy of attention_mask for mean_pooling (consumed by Tensor below)
636        let attention_mask_flat = prepared.attention_mask.clone();
637
638        // Build ORT tensors — from_array requires (shape, Vec<T>) in ort rc.12
639        let input_ids_tensor =
640            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
641                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
642        let attention_mask_tensor =
643            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
644                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
645        let token_type_ids_tensor =
646            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
647                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
648
649        // Run ONNX session
650        let outputs = session
651            .run(inputs![
652                "input_ids" => input_ids_tensor,
653                "attention_mask" => attention_mask_tensor,
654                "token_type_ids" => token_type_ids_tensor
655            ])
656            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
657
658        // Extract last_hidden_state: shape [batch, seq_len, hidden_size]
659        // ort rc.12: try_extract_tensor returns (&Shape, &[T])
660        // Shape derefs to [i64], so index directly.
661        let (ort_shape, lhs_slice) = outputs[0]
662            .try_extract_tensor::<f32>()
663            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
664
665        if ort_shape.len() != 3 {
666            return Err(InferenceError::InferenceError(format!(
667                "Expected 3D last_hidden_state, got {} dims",
668                ort_shape.len()
669            )));
670        }
671        let hidden_size = ort_shape[2] as usize;
672
673        // Apply mean pooling using the saved attention mask copy
674        let mut embeddings = mean_pooling(
675            lhs_slice,
676            batch_size,
677            seq_len,
678            hidden_size,
679            &attention_mask_flat,
680        );
681
682        // L2 normalize if configured
683        if normalize {
684            normalize_embeddings(&mut embeddings);
685        }
686
687        debug!(
688            "Generated {} embeddings of dimension {}",
689            embeddings.len(),
690            embeddings.first().map(|e| e.len()).unwrap_or(0)
691        );
692
693        Ok(embeddings)
694    }
695
696    /// Estimate the time to embed a batch of texts (in milliseconds).
697    pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
698        // Rough estimation based on model speed and text length (CPU path)
699        let tokens_per_text =
700            (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
701        let total_tokens = tokens_per_text * text_count as f64;
702        let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
703        (total_tokens / tokens_per_second) * 1000.0
704    }
705}
706
707impl std::fmt::Debug for EmbeddingEngine {
708    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
709        f.debug_struct("EmbeddingEngine")
710            .field("model", &self.config.model)
711            .field("dimension", &self.dimension)
712            .field("max_batch_size", &self.config.max_batch_size)
713            .field("session_pool_size", &self.sessions.len())
714            .field("use_gpu", &self.use_gpu)
715            .finish()
716    }
717}
718
719/// Builder for creating an EmbeddingEngine with fluent API.
720pub struct EmbeddingEngineBuilder {
721    config: ModelConfig,
722}
723
724impl EmbeddingEngineBuilder {
725    /// Create a new builder with default configuration.
726    pub fn new() -> Self {
727        Self {
728            config: ModelConfig::default(),
729        }
730    }
731
732    /// Set the embedding model to use.
733    pub fn model(mut self, model: EmbeddingModel) -> Self {
734        self.config.model = model;
735        self
736    }
737
738    /// Set the cache directory for model files.
739    pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
740        self.config.cache_dir = Some(dir.into());
741        self
742    }
743
744    /// Set the maximum batch size.
745    pub fn max_batch_size(mut self, size: usize) -> Self {
746        self.config.max_batch_size = size;
747        self
748    }
749
750    /// Enable GPU acceleration (reserved for future use; ORT selects the execution provider).
751    pub fn use_gpu(mut self, enable: bool) -> Self {
752        self.config.use_gpu = enable;
753        self
754    }
755
756    /// Set the number of intra-op CPU threads for ORT inference.
757    pub fn num_threads(mut self, threads: usize) -> Self {
758        self.config.num_threads = Some(threads);
759        self
760    }
761
762    /// Set the number of parallel ONNX sessions in the pool.
763    pub fn session_pool_size(mut self, size: usize) -> Self {
764        self.config.session_pool_size = size.max(1);
765        self
766    }
767
768    /// Build the embedding engine.
769    pub async fn build(self) -> Result<EmbeddingEngine> {
770        EmbeddingEngine::new(self.config).await
771    }
772}
773
774impl Default for EmbeddingEngineBuilder {
775    fn default() -> Self {
776        Self::new()
777    }
778}
779
780#[cfg(test)]
781mod tests {
782    use super::*;
783
784    #[test]
785    fn test_estimate_time() {
786        let config = ModelConfig::new(EmbeddingModel::MiniLM);
787        let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
788        assert!(tokens_per_second > 0.0);
789    }
790
791    #[test]
792    fn test_builder() {
793        let builder = EmbeddingEngineBuilder::new()
794            .model(EmbeddingModel::BgeSmall)
795            .max_batch_size(64)
796            .use_gpu(false);
797
798        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
799        assert_eq!(builder.config.max_batch_size, 64);
800        assert!(!builder.config.use_gpu);
801    }
802
803    // ── model_cache_dir ──────────────────────────────────────────────────────
804
805    /// Ensure `model_cache_dir` respects `HF_HOME` when set.
806    ///
807    /// Uses a process-level mutex because `std::env::set_var` is not thread-safe.
808    #[test]
809    fn test_model_cache_dir_with_hf_home() {
810        use std::sync::Mutex;
811        static ENV_LOCK: Mutex<()> = Mutex::new(());
812        let _guard = ENV_LOCK.lock().unwrap();
813
814        let tmp = std::env::temp_dir().join("dakera_test_hf_home");
815        std::env::set_var("HF_HOME", &tmp);
816        let result = EmbeddingEngine::model_cache_dir("org/my-model");
817        std::env::remove_var("HF_HOME");
818
819        let path = result.unwrap();
820        assert!(
821            path.starts_with(&tmp),
822            "expected path under {tmp:?}, got {path:?}"
823        );
824        assert!(
825            path.to_str().unwrap().contains("org--my-model"),
826            "model_id separator not applied: {path:?}"
827        );
828    }
829
830    #[test]
831    fn test_model_cache_dir_contains_dakera_subdir() {
832        let path =
833            EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
834        let s = path.to_str().unwrap();
835        assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
836        assert!(
837            s.contains("sentence-transformers--all-MiniLM-L6-v2"),
838            "expected transformed model id in path: {s}"
839        );
840    }
841
842    #[test]
843    fn test_model_cache_dir_creates_directory() {
844        let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
845        assert!(dir.exists(), "model_cache_dir should create the directory");
846    }
847
848    // ── download_hf_file (cached early-return path) ──────────────────────────
849
850    #[test]
851    fn test_download_hf_file_returns_path_when_already_cached() {
852        let tmp = std::env::temp_dir().join("dakera_test_cached_file");
853        std::fs::create_dir_all(&tmp).unwrap();
854        let file_path = tmp.join("config.json");
855        std::fs::write(&file_path, b"{}").unwrap();
856
857        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
858        assert!(result.is_ok());
859        assert_eq!(result.unwrap(), file_path);
860    }
861
862    #[test]
863    fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
864        let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
865        let onnx_dir = tmp.join("onnx");
866        std::fs::create_dir_all(&onnx_dir).unwrap();
867        let onnx_path = onnx_dir.join("model_quantized.onnx");
868        std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
869
870        // Filename includes the subdirectory path
871        let result = EmbeddingEngine::download_hf_file(
872            "Xenova/all-MiniLM-L6-v2",
873            "onnx/model_quantized.onnx",
874            &tmp,
875        );
876        assert!(result.is_ok());
877        assert_eq!(result.unwrap(), onnx_path);
878    }
879
880    // ── EmbeddingEngineBuilder ───────────────────────────────────────────────
881
882    #[test]
883    fn test_builder_default_impl() {
884        let b1 = EmbeddingEngineBuilder::new();
885        let b2 = EmbeddingEngineBuilder::default();
886        assert_eq!(b1.config.model, b2.config.model);
887        assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
888    }
889
890    #[test]
891    fn test_builder_model_field() {
892        let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
893        assert_eq!(builder.config.model, EmbeddingModel::E5Small);
894    }
895
896    #[test]
897    fn test_builder_cache_dir() {
898        let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
899        assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
900    }
901
902    #[test]
903    fn test_builder_max_batch_size() {
904        let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
905        assert_eq!(builder.config.max_batch_size, 128);
906    }
907
908    #[test]
909    fn test_builder_use_gpu_true() {
910        let builder = EmbeddingEngineBuilder::new().use_gpu(true);
911        assert!(builder.config.use_gpu);
912    }
913
914    #[test]
915    fn test_builder_use_gpu_false() {
916        let builder = EmbeddingEngineBuilder::new().use_gpu(false);
917        assert!(!builder.config.use_gpu);
918    }
919
920    #[test]
921    fn test_builder_num_threads() {
922        let builder = EmbeddingEngineBuilder::new().num_threads(4);
923        assert_eq!(builder.config.num_threads, Some(4));
924    }
925
926    #[test]
927    fn test_builder_chain_all_fields() {
928        let builder = EmbeddingEngineBuilder::new()
929            .model(EmbeddingModel::BgeSmall)
930            .cache_dir("/cache")
931            .max_batch_size(16)
932            .use_gpu(false)
933            .num_threads(2);
934
935        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
936        assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
937        assert_eq!(builder.config.max_batch_size, 16);
938        assert!(!builder.config.use_gpu);
939        assert_eq!(builder.config.num_threads, Some(2));
940    }
941
942    // ── estimate_time_ms ─────────────────────────────────────────────────────
943
944    #[test]
945    fn test_estimate_time_zero_count() {
946        let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
947        let estimate = (0.0 / tps) * 1000.0;
948        assert_eq!(estimate, 0.0);
949    }
950
951    #[test]
952    fn test_estimate_time_formula_cpu() {
953        // texts=10, avg_len=100 → tokens_per_text = min(25, 256) = 25
954        // total_tokens = 250; tps = 5000; time = (250/5000)*1000 = 50ms
955        let model = EmbeddingModel::MiniLM;
956        let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
957        let total_tokens = tokens_per_text * 10.0;
958        let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
959        assert!(
960            (estimate - 50.0).abs() < 1e-6,
961            "expected 50.0ms, got {estimate}"
962        );
963    }
964
965    #[test]
966    fn test_estimate_time_capped_at_max_seq_length() {
967        let model = EmbeddingModel::MiniLM;
968        let avg_len = 100_000;
969        let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
970        assert_eq!(tokens_per_text, 256.0);
971    }
972
973    // ── ModelConfig API ───────────────────────────────────────────────────────
974
975    #[test]
976    fn test_model_config_new() {
977        let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
978        assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
979        assert_eq!(cfg.max_batch_size, 32);
980        assert!(!cfg.use_gpu);
981        assert!(cfg.cache_dir.is_none());
982        assert!(cfg.num_threads.is_none());
983    }
984
985    #[test]
986    fn test_model_config_default() {
987        let cfg = ModelConfig::default();
988        assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
989        assert_eq!(cfg.max_batch_size, 32);
990        assert!(!cfg.use_gpu);
991    }
992
993    #[test]
994    fn test_model_config_with_cache_dir() {
995        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
996        assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
997    }
998
999    #[test]
1000    fn test_model_config_with_max_batch_size() {
1001        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
1002        assert_eq!(cfg.max_batch_size, 64);
1003    }
1004
1005    #[test]
1006    fn test_model_config_with_gpu() {
1007        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
1008        assert!(cfg.use_gpu);
1009    }
1010
1011    #[test]
1012    fn test_model_config_with_num_threads() {
1013        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
1014        assert_eq!(cfg.num_threads, Some(8));
1015    }
1016
1017    #[test]
1018    fn test_model_config_chained_builder() {
1019        let cfg = ModelConfig::new(EmbeddingModel::E5Small)
1020            .with_cache_dir("/cache")
1021            .with_max_batch_size(16)
1022            .with_gpu(false)
1023            .with_num_threads(4);
1024        assert_eq!(cfg.model, EmbeddingModel::E5Small);
1025        assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
1026        assert_eq!(cfg.max_batch_size, 16);
1027        assert!(!cfg.use_gpu);
1028        assert_eq!(cfg.num_threads, Some(4));
1029    }
1030
1031    // ── model_cache_dir edge cases ────────────────────────────────────────────
1032
1033    /// Test `model_cache_dir` when HOME is not set — should fall back to /tmp.
1034    #[test]
1035    fn test_model_cache_dir_no_home_fallback() {
1036        use std::sync::Mutex;
1037        static ENV_LOCK: Mutex<()> = Mutex::new(());
1038        let _guard = ENV_LOCK.lock().unwrap();
1039
1040        // Remove HOME and HF_HOME so we hit the /tmp fallback
1041        let saved_home = std::env::var("HOME").ok();
1042        let saved_hf = std::env::var("HF_HOME").ok();
1043        unsafe {
1044            std::env::remove_var("HOME");
1045            std::env::remove_var("HF_HOME");
1046        }
1047
1048        let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
1049
1050        // Restore env
1051        if let Some(h) = saved_home {
1052            unsafe { std::env::set_var("HOME", h) };
1053        }
1054        if let Some(h) = saved_hf {
1055            unsafe { std::env::set_var("HF_HOME", h) };
1056        }
1057
1058        let path = result.unwrap();
1059        // Should be under /tmp since HOME was unset
1060        assert!(
1061            path.starts_with("/tmp"),
1062            "expected path under /tmp, got {path:?}"
1063        );
1064    }
1065
1066    #[test]
1067    fn test_model_cache_dir_deep_model_id() {
1068        let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
1069        let s = path.to_str().unwrap();
1070        // All slashes replaced with double-dash
1071        assert!(
1072            s.contains("org--sub--model-name-with-dashes"),
1073            "expected transformed path, got: {s}"
1074        );
1075    }
1076
1077    #[test]
1078    fn test_model_cache_dir_minilm_model_id() {
1079        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
1080        let s = path.to_str().unwrap();
1081        assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
1082    }
1083
1084    #[test]
1085    fn test_model_cache_dir_bge_model_id() {
1086        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
1087        let s = path.to_str().unwrap();
1088        assert!(s.contains("BAAI--bge-small-en-v1.5"));
1089    }
1090
1091    #[test]
1092    fn test_model_cache_dir_e5_model_id() {
1093        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
1094        let s = path.to_str().unwrap();
1095        assert!(s.contains("intfloat--e5-small-v2"));
1096    }
1097
1098    // ── download_hf_file additional cache-hit variations ─────────────────────
1099
1100    #[test]
1101    fn test_download_hf_file_pytorch_bin_cached() {
1102        let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
1103        std::fs::create_dir_all(&tmp).unwrap();
1104        let model_path = tmp.join("pytorch_model.bin");
1105        std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
1106
1107        let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
1108        assert!(result.is_ok());
1109        assert_eq!(result.unwrap(), model_path);
1110    }
1111
1112    #[test]
1113    fn test_download_hf_file_tokenizer_cached() {
1114        let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
1115        std::fs::create_dir_all(&tmp).unwrap();
1116        let tok_path = tmp.join("tokenizer.json");
1117        std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
1118
1119        let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
1120        assert!(result.is_ok());
1121        assert_eq!(result.unwrap(), tok_path);
1122    }
1123
1124    #[test]
1125    fn test_download_hf_file_config_json_cached() {
1126        let tmp = std::env::temp_dir().join("dakera_test_config_cached");
1127        std::fs::create_dir_all(&tmp).unwrap();
1128        let cfg_path = tmp.join("config.json");
1129        std::fs::write(&cfg_path, b"{}").unwrap();
1130
1131        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
1132        assert!(result.is_ok());
1133        assert_eq!(result.unwrap(), cfg_path);
1134    }
1135
1136    // ── EmbeddingEngine::new() failure path via fake local cache ─────────────
1137
1138    /// Tests the code path through `download_model_files` (local Dakera cache hit)
1139    /// and into `new()` — which then fails trying to load the tokenizer from a
1140    /// fake file. No network access required; fake files are pre-seeded.
1141    #[tokio::test]
1142    #[allow(clippy::await_holding_lock)]
1143    async fn test_new_fails_with_invalid_tokenizer_json() {
1144        use std::sync::Mutex;
1145        static ENV_LOCK: Mutex<()> = Mutex::new(());
1146        let _guard = ENV_LOCK.lock().unwrap();
1147
1148        // Set up a fake Dakera model cache so download_model_files finds our files
1149        let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1150        let model_dir = tmp
1151            .join("dakera")
1152            .join("sentence-transformers--all-MiniLM-L6-v2");
1153        std::fs::create_dir_all(&model_dir).unwrap();
1154        // Valid-looking model weights placeholder (candle will fail on this, which is fine)
1155        std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1156        // Invalid tokenizer.json — will cause TokenizationError in new()
1157        std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1158        std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1159
1160        unsafe { std::env::set_var("HF_HOME", &tmp) };
1161
1162        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1163        let result = EmbeddingEngine::new(config).await;
1164
1165        unsafe { std::env::remove_var("HF_HOME") };
1166
1167        // Must fail — tokenizer.json is invalid JSON
1168        assert!(
1169            result.is_err(),
1170            "expected Err from new() with invalid tokenizer, got Ok"
1171        );
1172    }
1173
1174    // ── EmbeddingEngineBuilder additional coverage ────────────────────────────
1175
1176    #[test]
1177    fn test_builder_with_all_models() {
1178        for model in [
1179            EmbeddingModel::MiniLM,
1180            EmbeddingModel::BgeSmall,
1181            EmbeddingModel::E5Small,
1182        ] {
1183            let builder = EmbeddingEngineBuilder::new().model(model);
1184            assert_eq!(builder.config.model, model);
1185        }
1186    }
1187
1188    #[test]
1189    fn test_builder_max_batch_size_one() {
1190        let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1191        assert_eq!(builder.config.max_batch_size, 1);
1192    }
1193
1194    #[test]
1195    fn test_builder_num_threads_zero() {
1196        let builder = EmbeddingEngineBuilder::new().num_threads(0);
1197        assert_eq!(builder.config.num_threads, Some(0));
1198    }
1199
1200    // ── EmbeddingEngine::new() / getters when model is cached (best-effort) ──
1201
1202    /// If the embedding model is already cached on this machine, exercise the
1203    /// full `new()` path and test getters. On machines without a cached model
1204    /// the test passes silently — it is intentionally non-gating.
1205    #[tokio::test]
1206    async fn test_engine_getters_when_model_cached() {
1207        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1208        match EmbeddingEngine::new(config).await {
1209            Ok(engine) => {
1210                assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1211                assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1212                // Device should be CPU in test environments (device() removed in CE-3 ONNX migration)
1213                // Debug impl should not panic
1214                let _ = format!("{:?}", engine);
1215                // estimate_time_ms should return a non-negative value
1216                let ms = engine.estimate_time_ms(10, 50);
1217                assert!(ms >= 0.0);
1218            }
1219            Err(_) => {
1220                // Model not in cache — skip; CI runner may or may not have it
1221            }
1222        }
1223    }
1224
1225    /// When model is cached: embed an empty batch must return immediately with
1226    /// no embeddings (the `texts.is_empty()` fast-path in embed_batch_internal).
1227    #[tokio::test]
1228    async fn test_engine_embed_empty_batch_when_cached() {
1229        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1230        if let Ok(engine) = EmbeddingEngine::new(config).await {
1231            let result = engine.embed_raw(&[]).await;
1232            assert!(result.is_ok());
1233            assert!(result.unwrap().is_empty());
1234        }
1235    }
1236
1237    // ── Session pool (DAK-5547) ──────────────────────────────────────────────
1238
1239    #[test]
1240    fn test_session_pool_default_is_4() {
1241        // Default pool size is 4 (DAK-5746: pool=1 regressed LME ingest ~4×; OOM causes fixed).
1242        // DAKERA_ONNX_POOL_SIZE env var still allows override.
1243        let config = ModelConfig::default();
1244        let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1245            .ok()
1246            .and_then(|v| v.parse::<usize>().ok())
1247            .filter(|&n| n >= 1)
1248            .unwrap_or(4);
1249        assert_eq!(config.session_pool_size, expected);
1250    }
1251
1252    #[test]
1253    fn test_session_pool_size_builder_roundtrip() {
1254        let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1255        assert_eq!(builder.config.session_pool_size, 8);
1256    }
1257
1258    #[test]
1259    fn test_session_pool_size_min_enforced() {
1260        let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1261        assert_eq!(
1262            builder.config.session_pool_size, 1,
1263            "pool size 0 must clamp to 1"
1264        );
1265    }
1266
1267    #[test]
1268    fn test_model_config_with_session_pool_size() {
1269        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1270        assert_eq!(cfg.session_pool_size, 2);
1271    }
1272
1273    /// When model is cached: verify the session pool has the expected size.
1274    #[tokio::test]
1275    async fn test_engine_pool_size_matches_config_when_cached() {
1276        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1277        if let Ok(engine) = EmbeddingEngine::new(config).await {
1278            assert_eq!(
1279                engine.pool_size(),
1280                2,
1281                "engine should hold exactly 2 sessions"
1282            );
1283        }
1284    }
1285
1286    // ── next_session round-robin ──────────────────────────────────────────────
1287
1288    /// Round-robin counter distributes batches across pool slots without panic.
1289    #[test]
1290    fn test_round_robin_index_stays_in_bounds() {
1291        let pool_len = 4_usize;
1292        let counter = AtomicUsize::new(0);
1293        for expected_idx in 0..100_usize {
1294            let start = counter.fetch_add(1, Ordering::Relaxed);
1295            let slot = start % pool_len;
1296            assert!(slot < pool_len);
1297            assert_eq!(slot, expected_idx % pool_len);
1298        }
1299    }
1300
1301    /// Pool size 1 degrades to single-session behavior without panicking.
1302    #[test]
1303    fn test_round_robin_pool_size_one() {
1304        let pool_len = 1_usize;
1305        let counter = AtomicUsize::new(0);
1306        for _ in 0..50 {
1307            let start = counter.fetch_add(1, Ordering::Relaxed);
1308            assert_eq!(start % pool_len, 0);
1309        }
1310    }
1311
1312    /// When model is cached: empty batch short-circuits before touching the pool.
1313    #[tokio::test]
1314    async fn test_embed_empty_does_not_advance_pool_counter() {
1315        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1316        if let Ok(engine) = EmbeddingEngine::new(config).await {
1317            let result = engine.embed_raw(&[]).await;
1318            assert!(result.is_ok());
1319            assert!(result.unwrap().is_empty());
1320            // Empty batch returns before fetch_add — counter stays at 0.
1321            assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1322        }
1323    }
1324
1325    // ── GPU inference semaphore (DAK-6096) ────────────────────────────────────
1326
1327    /// Default config has use_gpu=false; the resolved flag stored in the engine must match.
1328    #[allow(clippy::await_holding_lock)]
1329    #[tokio::test]
1330    async fn test_engine_use_gpu_defaults_to_false_when_not_configured() {
1331        use std::sync::Mutex;
1332        static ENV_LOCK: Mutex<()> = Mutex::new(());
1333        let _guard = ENV_LOCK.lock().unwrap();
1334
1335        // Ensure DAKERA_USE_GPU is unset so we test the config.use_gpu=false path.
1336        unsafe { std::env::remove_var("DAKERA_USE_GPU") };
1337
1338        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(1);
1339        if let Ok(engine) = EmbeddingEngine::new(config).await {
1340            assert!(
1341                !engine.use_gpu,
1342                "use_gpu must be false when DAKERA_USE_GPU is unset"
1343            );
1344        }
1345    }
1346
1347    /// DAKERA_USE_GPU=1 env var must override config.use_gpu=false and set use_gpu=true.
1348    #[test]
1349    fn test_engine_use_gpu_resolved_from_env_var() {
1350        use std::sync::Mutex;
1351        static ENV_LOCK: Mutex<()> = Mutex::new(());
1352        let _guard = ENV_LOCK.lock().unwrap();
1353
1354        unsafe { std::env::set_var("DAKERA_USE_GPU", "1") };
1355        let resolved = std::env::var("DAKERA_USE_GPU")
1356            .map(|v| v == "1")
1357            .unwrap_or(false);
1358        unsafe { std::env::remove_var("DAKERA_USE_GPU") };
1359
1360        assert!(resolved, "DAKERA_USE_GPU=1 must resolve use_gpu=true");
1361    }
1362
1363    /// Debug output includes use_gpu field so operators can confirm GPU mode in logs.
1364    #[tokio::test]
1365    async fn test_engine_debug_includes_use_gpu() {
1366        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1367        if let Ok(engine) = EmbeddingEngine::new(config).await {
1368            let debug_str = format!("{:?}", engine);
1369            assert!(
1370                debug_str.contains("use_gpu"),
1371                "Debug output must include use_gpu field: {debug_str}"
1372            );
1373        }
1374    }
1375}