Skip to main content

inference/
engine.rs

1//! Core embedding engine for generating vector embeddings from text.
2//!
3//! The `EmbeddingEngine` provides a high-level interface for:
4//! - Loading ONNX INT8 embedding models from HuggingFace Hub
5//! - Generating embeddings for single texts or batches
6//! - Automatic batching and parallel processing via ONNX Runtime
7//!
8//! # Example
9//!
10//! ```no_run
11//! use inference::{EmbeddingEngine, ModelConfig, EmbeddingModel};
12//!
13//! #[tokio::main]
14//! async fn main() {
15//!     let config = ModelConfig::new(EmbeddingModel::MiniLM);
16//!     let engine = EmbeddingEngine::new(config).await.unwrap();
17//!
18//!     // Embed a single query
19//!     let embedding = engine.embed_query("What is machine learning?").await.unwrap();
20//!     println!("Embedding dimension: {}", embedding.len());
21//!
22//!     // Embed multiple documents
23//!     let docs = vec![
24//!         "Machine learning is a subset of AI.".to_string(),
25//!         "Deep learning uses neural networks.".to_string(),
26//!     ];
27//!     let embeddings = engine.embed_documents(&docs).await.unwrap();
28//!     println!("Generated {} embeddings", embeddings.len());
29//! }
30//! ```
31
32use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use ort::inputs;
36use ort::session::builder::GraphOptimizationLevel;
37use ort::session::Session;
38use ort::value::Tensor;
39use parking_lot::Mutex;
40use std::io::Read;
41use std::path::{Path, PathBuf};
42use std::sync::atomic::{AtomicUsize, Ordering};
43use std::sync::Arc;
44use tokenizers::Tokenizer;
45use tracing::{debug, info, instrument, warn};
46
47use ort::execution_providers::CUDAExecutionProvider;
48
49/// The main embedding engine for generating vector embeddings.
50///
51/// This struct is thread-safe and can be shared across async tasks.
52/// ORT sessions are mutex-guarded (run() takes &mut self) and held in a
53/// pool so concurrent callers can embed without head-of-line blocking.
54/// CPU-heavy inference is offloaded via `tokio::task::spawn_blocking`.
55pub struct EmbeddingEngine {
56    /// Pool of ONNX Runtime sessions — each guarded independently.
57    /// Concurrent callers round-robin across sessions via `next_session`,
58    /// eliminating Mutex head-of-line blocking for batch embedding workloads.
59    sessions: Vec<Arc<Mutex<Session>>>,
60    /// Round-robin counter: atomically incremented per batch, wraps via modulo.
61    next_session: AtomicUsize,
62    /// Batch processor for tokenization (Arc-wrapped for spawn_blocking)
63    processor: Arc<BatchProcessor>,
64    /// Model configuration
65    config: ModelConfig,
66    /// Embedding dimension
67    dimension: usize,
68}
69
70impl EmbeddingEngine {
71    /// Create a new embedding engine with the given configuration.
72    ///
73    /// Downloads the FP32 ONNX model (`model.onnx`) when CUDA EP is enabled, or the
74    /// INT8 quantized model (`model_quantized.onnx`) otherwise.
75    #[instrument(skip_all, fields(model = %config.model))]
76    pub async fn new(config: ModelConfig) -> Result<Self> {
77        // Resolve GPU mode before downloading model files — GPU requires the FP32 model.
78        // DAKERA_USE_GPU=1 overrides config.use_gpu (env var is the production knob).
79        let use_gpu = std::env::var("DAKERA_USE_GPU")
80            .map(|v| v == "1")
81            .unwrap_or(config.use_gpu);
82        if use_gpu {
83            info!("CUDA execution provider enabled — using FP32 model (DAKERA_USE_GPU=1)");
84        }
85
86        info!(
87            "Initializing ONNX embedding engine with model: {}",
88            config.model
89        );
90
91        // Download tokenizer and ONNX model files (FP32 for GPU, INT8 for CPU)
92        let (tokenizer_path, onnx_path) = Self::download_model_files(&config, use_gpu).await?;
93
94        // Load tokenizer
95        info!("Loading tokenizer from {:?}", tokenizer_path);
96        let tokenizer = Tokenizer::from_file(&tokenizer_path)
97            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
98
99        // Build ONNX session pool — N independent sessions to serve concurrent callers.
100        // Each session has its own ORT context so pool members never block each other.
101        info!("Loading ONNX model from {:?}", onnx_path);
102        let num_threads = config.num_threads.unwrap_or(4);
103        let pool_size = config.session_pool_size.max(1);
104        let onnx_path_clone = onnx_path.clone();
105
106        let sessions: Vec<Arc<Mutex<Session>>> =
107            tokio::task::spawn_blocking(move || -> Result<Vec<Arc<Mutex<Session>>>> {
108                (0..pool_size)
109                    .map(|_| {
110                        let builder = Session::builder()
111                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
112                            .with_optimization_level(GraphOptimizationLevel::Level3)
113                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
114                            .with_intra_threads(num_threads)
115                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
116
117                        let mut builder = if use_gpu {
118                            builder
119                                .with_execution_providers(
120                                    [CUDAExecutionProvider::default().build()],
121                                )
122                                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
123                        } else {
124                            builder
125                        };
126
127                        let s = builder
128                            .commit_from_file(&onnx_path_clone)
129                            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
130                        Ok(Arc::new(Mutex::new(s)))
131                    })
132                    .collect()
133            })
134            .await
135            .map_err(|e| {
136                InferenceError::ModelLoadError(format!("Session pool init panicked: {}", e))
137            })??;
138
139        let dimension = config.model.dimension();
140        let processor = Arc::new(BatchProcessor::new(
141            tokenizer,
142            config.model,
143            config.max_batch_size,
144        ));
145
146        info!(
147            "ONNX embedding engine ready: model={}, dimension={}, threads={}, pool={}",
148            config.model, dimension, num_threads, pool_size
149        );
150
151        Ok(Self {
152            sessions,
153            next_session: AtomicUsize::new(0),
154            processor,
155            config,
156            dimension,
157        })
158    }
159
160    /// Resolve tokenizer and ONNX model files, downloading from HuggingFace if needed.
161    ///
162    /// - `tokenizer.json` — from the original model repo (sentence-transformers, BAAI, intfloat)
163    /// - GPU: `onnx/model.onnx` (FP32) — no Memcpy round-trips, CUDA EP runs entirely on-device
164    /// - CPU: `onnx/model_quantized.onnx` (INT8) — fastest for CPU-only deployments
165    #[instrument(skip_all, fields(model = %config.model))]
166    async fn download_model_files(
167        config: &ModelConfig,
168        use_gpu: bool,
169    ) -> Result<(PathBuf, PathBuf)> {
170        let model_id = config.model.model_id();
171        let onnx_repo_id = config.model.onnx_repo_id();
172        let onnx_filename = if use_gpu {
173            config.model.onnx_filename_gpu()
174        } else {
175            config.model.onnx_filename()
176        };
177
178        info!(
179            "Resolving model files: tokenizer={}, onnx={}@{}",
180            model_id, onnx_filename, onnx_repo_id
181        );
182
183        let tokenizer_cache_dir = Self::model_cache_dir(model_id)?;
184        let onnx_cache_dir = Self::model_cache_dir(onnx_repo_id)?;
185
186        // ONNX sub-directory mirrors the path within the repo (e.g. "onnx/")
187        let onnx_subdir = onnx_cache_dir.join("onnx");
188        std::fs::create_dir_all(&onnx_subdir)?;
189
190        let local_tokenizer = tokenizer_cache_dir.join("tokenizer.json");
191        // onnx_filename is e.g. "onnx/model_quantized.onnx" or "onnx/model.onnx" — use the basename
192        let onnx_basename = Path::new(onnx_filename)
193            .file_name()
194            .and_then(|s| s.to_str())
195            .unwrap_or("model_quantized.onnx");
196        let local_onnx = onnx_subdir.join(onnx_basename);
197
198        // Download missing files in a blocking thread
199        let tokenizer_needs_download = !local_tokenizer.exists();
200
201        // GPU FP32 model (BGE-large) is 1.27 GB. The old 500 MB download limit silently
202        // truncated it, leaving a corrupt ONNX file that ORT rejects with "Protobuf parsing
203        // failed" (DAK-5976). If the cached GPU model is ≤500 MB, it's the truncated artifact
204        // — delete it so download_hf_file fetches the complete file.
205        if use_gpu && local_onnx.exists() {
206            let cached_size = local_onnx.metadata().map(|m| m.len()).unwrap_or(0);
207            if cached_size <= 500_000_000 {
208                warn!(
209                    "Cached GPU ONNX at {:?} is {} bytes (≤500 MB) — likely truncated by old \
210                     download limit. Deleting for complete re-download.",
211                    local_onnx, cached_size
212                );
213                let _ = std::fs::remove_file(&local_onnx);
214            }
215        }
216        let onnx_needs_download = !local_onnx.exists();
217
218        if tokenizer_needs_download || onnx_needs_download {
219            let model_id_owned = model_id.to_string();
220            let onnx_repo_id_owned = onnx_repo_id.to_string();
221            let onnx_filename_owned = onnx_filename.to_string();
222            let tokenizer_cache = tokenizer_cache_dir.clone();
223            let onnx_cache = onnx_cache_dir.clone();
224
225            tokio::task::spawn_blocking(move || {
226                if !tokenizer_cache.join("tokenizer.json").exists() {
227                    Self::download_hf_file(&model_id_owned, "tokenizer.json", &tokenizer_cache)
228                        .map_err(|e| {
229                            InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
230                        })?;
231                }
232                if !onnx_cache.join(&onnx_filename_owned).exists() {
233                    Self::download_hf_file(&onnx_repo_id_owned, &onnx_filename_owned, &onnx_cache)
234                        .map_err(|e| {
235                            InferenceError::HubError(format!(
236                                "Failed to download ONNX model: {}",
237                                e
238                            ))
239                        })?;
240                }
241                Ok::<_, InferenceError>(())
242            })
243            .await
244            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
245        } else {
246            info!("All model files found in local cache");
247        }
248
249        // Re-derive paths (cache dir / onnx / basename)
250        let final_onnx = onnx_cache_dir.join(onnx_filename);
251
252        info!(
253            "Model files ready: tokenizer={:?}, onnx={:?}",
254            local_tokenizer, final_onnx
255        );
256        Ok((local_tokenizer, final_onnx))
257    }
258
259    /// Get or create the local model cache directory.
260    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
261        let base = std::env::var("HF_HOME")
262            .map(PathBuf::from)
263            .unwrap_or_else(|_| {
264                let home = std::env::var("HOME").unwrap_or_else(|_| {
265                    warn!("HOME environment variable not set, using /tmp for model cache");
266                    "/tmp".to_string()
267                });
268                PathBuf::from(home).join(".cache").join("huggingface")
269            });
270        let dir = base.join("dakera").join(model_id.replace('/', "--"));
271        std::fs::create_dir_all(&dir)?;
272        Ok(dir)
273    }
274
275    /// Download a single file from HuggingFace using ureq (sync, for spawn_blocking).
276    ///
277    /// Handles relative Location headers that ureq 2.x cannot resolve automatically.
278    ///
279    /// Public alias for use by other inference modules (e.g. GLiNER NER engine).
280    pub fn download_hf_file_pub(
281        model_id: &str,
282        filename: &str,
283        cache_dir: &Path,
284    ) -> std::result::Result<PathBuf, String> {
285        Self::download_hf_file(model_id, filename, cache_dir)
286    }
287
288    fn download_hf_file(
289        model_id: &str,
290        filename: &str,
291        cache_dir: &Path,
292    ) -> std::result::Result<PathBuf, String> {
293        // The file may be nested (e.g. "onnx/model_quantized.onnx")
294        let file_path = cache_dir.join(filename);
295        if file_path.exists() {
296            info!("Cached: {}/{}", model_id, filename);
297            return Ok(file_path);
298        }
299
300        // Ensure parent directory exists (for "onnx/model_quantized.onnx")
301        if let Some(parent) = file_path.parent() {
302            std::fs::create_dir_all(parent)
303                .map_err(|e| format!("Failed to create directory {:?}: {}", parent, e))?;
304        }
305
306        let url = format!(
307            "https://huggingface.co/{}/resolve/main/{}",
308            model_id, filename
309        );
310        info!("Downloading: {}", url);
311
312        // Read HuggingFace token from env (required for gated models).
313        let hf_token = std::env::var("HF_TOKEN")
314            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
315            .ok();
316        if hf_token.is_some() {
317            info!("Using HuggingFace auth token for download");
318        }
319
320        // Disable automatic redirects so we can resolve relative Location headers ourselves.
321        let agent = ureq::AgentBuilder::new()
322            .redirects(0)
323            .timeout(std::time::Duration::from_secs(300))
324            .build();
325
326        let mut current_url = url.clone();
327        let mut redirects = 0;
328        let max_redirects = 10;
329
330        let response = loop {
331            let mut req = agent.get(&current_url);
332            if let Some(ref token) = hf_token {
333                req = req.set("Authorization", &format!("Bearer {}", token));
334            }
335            let resp = req.call();
336
337            let r = match resp {
338                Ok(r) => r,
339                Err(ureq::Error::Status(_status, r)) => r,
340                Err(e) => return Err(format!("{}: {}", filename, e)),
341            };
342
343            let status = r.status();
344            if (200..300).contains(&status) {
345                break r;
346            } else if (300..400).contains(&status) {
347                redirects += 1;
348                if redirects > max_redirects {
349                    return Err(format!("{}: too many redirects", filename));
350                }
351                let location = r
352                    .header("location")
353                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
354                    .to_string();
355
356                // Resolve relative redirects against the current URL's origin
357                current_url = if location.starts_with('/') {
358                    let parsed = url::Url::parse(&current_url)
359                        .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
360                    let host = parsed.host_str().ok_or_else(|| {
361                        format!("{}: redirect URL missing host: {}", filename, current_url)
362                    })?;
363                    format!("{}://{}{}", parsed.scheme(), host, location)
364                } else {
365                    location
366                };
367                info!("Redirect {} → {}", redirects, current_url);
368            } else {
369                return Err(format!("{}: HTTP {}", filename, status));
370            }
371        };
372
373        // Capture expected file size before consuming the response body.
374        // HuggingFace CDN reports LFS blob sizes via x-linked-size; standard
375        // HTTP servers use content-length. Either may be absent (chunked transfer).
376        let expected_bytes: Option<u64> = response
377            .header("x-linked-size")
378            .or_else(|| response.header("content-length"))
379            .and_then(|v| v.parse::<u64>().ok());
380
381        // 2 GiB limit — BGE-large FP32 is 1.27 GB; the old 500 MB limit silently
382        // truncated it, producing a corrupt ONNX that ORT rejects with
383        // "Protobuf parsing failed" (DAK-5976).
384        let mut bytes = Vec::new();
385        response
386            .into_reader()
387            .take(2_147_483_648)
388            .read_to_end(&mut bytes)
389            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
390
391        // Fail fast if the received payload is shorter than the declared size.
392        // This prevents writing a corrupt file that would pass the cache-hit check
393        // on the next boot but fail when ORT tries to parse it.
394        if let Some(expected) = expected_bytes {
395            let actual = bytes.len() as u64;
396            if actual < expected {
397                return Err(format!(
398                    "{}: download incomplete — received {} of {} bytes. \
399                     File may exceed 2 GiB or the connection was interrupted.",
400                    filename, actual, expected
401                ));
402            }
403        }
404
405        std::fs::write(&file_path, &bytes)
406            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
407
408        info!("Downloaded {} ({} bytes)", filename, bytes.len());
409        Ok(file_path)
410    }
411
412    /// Get the embedding dimension for the loaded model.
413    pub fn dimension(&self) -> usize {
414        self.dimension
415    }
416
417    /// Get the model being used.
418    pub fn model(&self) -> EmbeddingModel {
419        self.config.model
420    }
421
422    /// Get the number of parallel ONNX sessions in the pool.
423    pub fn pool_size(&self) -> usize {
424        self.sessions.len()
425    }
426
427    /// Embed a single query text.
428    ///
429    /// For models like E5, this automatically applies the query prefix.
430    #[instrument(skip(self, text), fields(text_len = text.len()))]
431    pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
432        let texts = vec![text.to_string()];
433        let prepared = self.processor.prepare_texts(&texts, true);
434        let embeddings = self.embed_batch_internal(&prepared).await?;
435        embeddings.into_iter().next().ok_or_else(|| {
436            InferenceError::InferenceError("No embedding returned for query".to_string())
437        })
438    }
439
440    /// Embed multiple query texts.
441    ///
442    /// For models like E5, this automatically applies the query prefix.
443    #[instrument(skip(self, texts), fields(count = texts.len()))]
444    pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
445        let prepared = self.processor.prepare_texts(texts, true);
446        self.embed_batch_internal(&prepared).await
447    }
448
449    /// Embed a single document/passage.
450    ///
451    /// For models like E5, this automatically applies the document prefix.
452    #[instrument(skip(self, text), fields(text_len = text.len()))]
453    pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
454        let texts = vec![text.to_string()];
455        let prepared = self.processor.prepare_texts(&texts, false);
456        let embeddings = self.embed_batch_internal(&prepared).await?;
457        embeddings.into_iter().next().ok_or_else(|| {
458            InferenceError::InferenceError("No embedding returned for document".to_string())
459        })
460    }
461
462    /// Embed multiple documents/passages.
463    ///
464    /// For models like E5, this automatically applies the document prefix.
465    #[instrument(skip(self, texts), fields(count = texts.len()))]
466    pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
467        let prepared = self.processor.prepare_texts(texts, false);
468        self.embed_batch_internal(&prepared).await
469    }
470
471    /// Embed texts without any prefix (raw embedding).
472    #[instrument(skip(self, texts), fields(count = texts.len()))]
473    pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
474        self.embed_batch_internal(texts).await
475    }
476
477    /// Internal batch embedding implementation.
478    ///
479    /// Splits `texts` into sub-batches (≤ `batch_size` each) and distributes them
480    /// across the session pool via round-robin.  On GPU BFCArena OOM the batch size
481    /// is halved and the full input is retried (up to 3 halvings; DAK-6002).  This
482    /// handles LME-style long documents whose attention tensor (batch × heads × seq²)
483    /// exceeds the VRAM headroom at the configured batch size but fits at batch/2.
484    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
485        if texts.is_empty() {
486            return Ok(vec![]);
487        }
488
489        let pool_len = self.sessions.len();
490        let normalize = self.config.model.normalize_embeddings();
491        // Round-robin starting index: each concurrent caller gets a different slot so
492        // concurrent requests don't all contend on sessions[0].
493        let start_idx = self.next_session.fetch_add(1, Ordering::Relaxed);
494
495        let mut batch_size = self.config.max_batch_size.max(1);
496
497        // Up to 3 halvings: 32 → 16 → 8 → 4 → hard fail.
498        for attempt in 0_u32..=3 {
499            let batches: Vec<Vec<String>> = texts.chunks(batch_size).map(|b| b.to_vec()).collect();
500
501            // Spawn all sub-batches concurrently; preserve insertion order for reassembly.
502            let mut handles = Vec::with_capacity(batches.len());
503            for (i, batch_owned) in batches.into_iter().enumerate() {
504                let session = Arc::clone(&self.sessions[(start_idx + i) % pool_len]);
505                let processor = Arc::clone(&self.processor);
506                handles.push(tokio::task::spawn_blocking(move || {
507                    let mut session_guard = session.lock();
508                    Self::process_batch_blocking(
509                        &batch_owned,
510                        &mut session_guard,
511                        &processor,
512                        normalize,
513                    )
514                }));
515            }
516
517            let mut all_embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
518            let mut oom: Option<InferenceError> = None;
519
520            for handle in handles {
521                match handle.await {
522                    Err(panic_err) => {
523                        return Err(InferenceError::InferenceError(format!(
524                            "Inference task panicked: {panic_err}"
525                        )));
526                    }
527                    Ok(Err(e)) => {
528                        if attempt < 3 && Self::is_gpu_oom(&e) {
529                            // First OOM wins; remaining in-flight tasks detach harmlessly.
530                            oom = Some(e);
531                            break;
532                        }
533                        return Err(e);
534                    }
535                    Ok(Ok(batch_embs)) => {
536                        all_embeddings.extend(batch_embs);
537                    }
538                }
539            }
540
541            if oom.is_some() {
542                let next_batch = (batch_size / 2).max(1);
543                warn!(
544                    "ONNX allocator OOM (attempt {}/3) — retrying with batch_size {} → {}",
545                    attempt + 1,
546                    batch_size,
547                    next_batch,
548                );
549                batch_size = next_batch;
550                continue;
551            }
552
553            return Ok(all_embeddings);
554        }
555
556        Err(InferenceError::InferenceError(format!(
557            "ONNX inference failed: GPU/CPU allocator OOM persists after 3 \
558             batch-halving attempts (final batch_size={batch_size})"
559        )))
560    }
561
562    /// Returns `true` when the error message indicates an allocator OOM that can be
563    /// recovered by reducing the inference batch size (DAK-6002).
564    ///
565    /// Covers CUDA BFCArena (`BFCArena`, `Failed to allocate memory … buffer of size`)
566    /// and CPU arena OOM patterns seen in ONNX Runtime 2.x.
567    fn is_gpu_oom(err: &InferenceError) -> bool {
568        let msg = err.to_string();
569        msg.contains("BFCArena")
570            || msg.contains("Failed to allocate memory")
571            || msg.contains("CUDA_OUT_OF_MEMORY")
572            || msg.contains("CUDA out of memory")
573            || (msg.contains("allocate") && msg.contains("buffer of size"))
574    }
575
576    /// Process a single batch: tokenize → ORT session → mean pool → normalize.
577    ///
578    /// Designed to run inside `spawn_blocking` (takes no `&self`).
579    fn process_batch_blocking(
580        texts: &[String],
581        session: &mut Session,
582        processor: &BatchProcessor,
583        normalize: bool,
584    ) -> Result<Vec<Vec<f32>>> {
585        // Tokenize
586        let prepared = processor.tokenize_batch(texts)?;
587        let batch_size = prepared.batch_size;
588        let seq_len = prepared.seq_len;
589
590        // Keep a copy of attention_mask for mean_pooling (consumed by Tensor below)
591        let attention_mask_flat = prepared.attention_mask.clone();
592
593        // Build ORT tensors — from_array requires (shape, Vec<T>) in ort rc.12
594        let input_ids_tensor =
595            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.input_ids))
596                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
597        let attention_mask_tensor =
598            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.attention_mask))
599                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
600        let token_type_ids_tensor =
601            Tensor::<i64>::from_array(([batch_size, seq_len], prepared.token_type_ids))
602                .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
603
604        // Run ONNX session
605        let outputs = session
606            .run(inputs![
607                "input_ids" => input_ids_tensor,
608                "attention_mask" => attention_mask_tensor,
609                "token_type_ids" => token_type_ids_tensor
610            ])
611            .map_err(|e: ort::Error| InferenceError::InferenceError(e.to_string()))?;
612
613        // Extract last_hidden_state: shape [batch, seq_len, hidden_size]
614        // ort rc.12: try_extract_tensor returns (&Shape, &[T])
615        // Shape derefs to [i64], so index directly.
616        let (ort_shape, lhs_slice) = outputs[0]
617            .try_extract_tensor::<f32>()
618            .map_err(|e| InferenceError::InferenceError(e.to_string()))?;
619
620        if ort_shape.len() != 3 {
621            return Err(InferenceError::InferenceError(format!(
622                "Expected 3D last_hidden_state, got {} dims",
623                ort_shape.len()
624            )));
625        }
626        let hidden_size = ort_shape[2] as usize;
627
628        // Apply mean pooling using the saved attention mask copy
629        let mut embeddings = mean_pooling(
630            lhs_slice,
631            batch_size,
632            seq_len,
633            hidden_size,
634            &attention_mask_flat,
635        );
636
637        // L2 normalize if configured
638        if normalize {
639            normalize_embeddings(&mut embeddings);
640        }
641
642        debug!(
643            "Generated {} embeddings of dimension {}",
644            embeddings.len(),
645            embeddings.first().map(|e| e.len()).unwrap_or(0)
646        );
647
648        Ok(embeddings)
649    }
650
651    /// Estimate the time to embed a batch of texts (in milliseconds).
652    pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
653        // Rough estimation based on model speed and text length (CPU path)
654        let tokens_per_text =
655            (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
656        let total_tokens = tokens_per_text * text_count as f64;
657        let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
658        (total_tokens / tokens_per_second) * 1000.0
659    }
660}
661
662impl std::fmt::Debug for EmbeddingEngine {
663    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664        f.debug_struct("EmbeddingEngine")
665            .field("model", &self.config.model)
666            .field("dimension", &self.dimension)
667            .field("max_batch_size", &self.config.max_batch_size)
668            .field("session_pool_size", &self.sessions.len())
669            .finish()
670    }
671}
672
673/// Builder for creating an EmbeddingEngine with fluent API.
674pub struct EmbeddingEngineBuilder {
675    config: ModelConfig,
676}
677
678impl EmbeddingEngineBuilder {
679    /// Create a new builder with default configuration.
680    pub fn new() -> Self {
681        Self {
682            config: ModelConfig::default(),
683        }
684    }
685
686    /// Set the embedding model to use.
687    pub fn model(mut self, model: EmbeddingModel) -> Self {
688        self.config.model = model;
689        self
690    }
691
692    /// Set the cache directory for model files.
693    pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
694        self.config.cache_dir = Some(dir.into());
695        self
696    }
697
698    /// Set the maximum batch size.
699    pub fn max_batch_size(mut self, size: usize) -> Self {
700        self.config.max_batch_size = size;
701        self
702    }
703
704    /// Enable GPU acceleration (reserved for future use; ORT selects the execution provider).
705    pub fn use_gpu(mut self, enable: bool) -> Self {
706        self.config.use_gpu = enable;
707        self
708    }
709
710    /// Set the number of intra-op CPU threads for ORT inference.
711    pub fn num_threads(mut self, threads: usize) -> Self {
712        self.config.num_threads = Some(threads);
713        self
714    }
715
716    /// Set the number of parallel ONNX sessions in the pool.
717    pub fn session_pool_size(mut self, size: usize) -> Self {
718        self.config.session_pool_size = size.max(1);
719        self
720    }
721
722    /// Build the embedding engine.
723    pub async fn build(self) -> Result<EmbeddingEngine> {
724        EmbeddingEngine::new(self.config).await
725    }
726}
727
728impl Default for EmbeddingEngineBuilder {
729    fn default() -> Self {
730        Self::new()
731    }
732}
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737
738    #[test]
739    fn test_estimate_time() {
740        let config = ModelConfig::new(EmbeddingModel::MiniLM);
741        let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
742        assert!(tokens_per_second > 0.0);
743    }
744
745    #[test]
746    fn test_builder() {
747        let builder = EmbeddingEngineBuilder::new()
748            .model(EmbeddingModel::BgeSmall)
749            .max_batch_size(64)
750            .use_gpu(false);
751
752        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
753        assert_eq!(builder.config.max_batch_size, 64);
754        assert!(!builder.config.use_gpu);
755    }
756
757    // ── model_cache_dir ──────────────────────────────────────────────────────
758
759    /// Ensure `model_cache_dir` respects `HF_HOME` when set.
760    ///
761    /// Uses a process-level mutex because `std::env::set_var` is not thread-safe.
762    #[test]
763    fn test_model_cache_dir_with_hf_home() {
764        use std::sync::Mutex;
765        static ENV_LOCK: Mutex<()> = Mutex::new(());
766        let _guard = ENV_LOCK.lock().unwrap();
767
768        let tmp = std::env::temp_dir().join("dakera_test_hf_home");
769        std::env::set_var("HF_HOME", &tmp);
770        let result = EmbeddingEngine::model_cache_dir("org/my-model");
771        std::env::remove_var("HF_HOME");
772
773        let path = result.unwrap();
774        assert!(
775            path.starts_with(&tmp),
776            "expected path under {tmp:?}, got {path:?}"
777        );
778        assert!(
779            path.to_str().unwrap().contains("org--my-model"),
780            "model_id separator not applied: {path:?}"
781        );
782    }
783
784    #[test]
785    fn test_model_cache_dir_contains_dakera_subdir() {
786        let path =
787            EmbeddingEngine::model_cache_dir("sentence-transformers/all-MiniLM-L6-v2").unwrap();
788        let s = path.to_str().unwrap();
789        assert!(s.contains("dakera"), "expected 'dakera' in path: {s}");
790        assert!(
791            s.contains("sentence-transformers--all-MiniLM-L6-v2"),
792            "expected transformed model id in path: {s}"
793        );
794    }
795
796    #[test]
797    fn test_model_cache_dir_creates_directory() {
798        let dir = EmbeddingEngine::model_cache_dir("test/cache-dir-creation-probe").unwrap();
799        assert!(dir.exists(), "model_cache_dir should create the directory");
800    }
801
802    // ── download_hf_file (cached early-return path) ──────────────────────────
803
804    #[test]
805    fn test_download_hf_file_returns_path_when_already_cached() {
806        let tmp = std::env::temp_dir().join("dakera_test_cached_file");
807        std::fs::create_dir_all(&tmp).unwrap();
808        let file_path = tmp.join("config.json");
809        std::fs::write(&file_path, b"{}").unwrap();
810
811        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
812        assert!(result.is_ok());
813        assert_eq!(result.unwrap(), file_path);
814    }
815
816    #[test]
817    fn test_download_hf_file_returns_correct_path_for_cached_onnx() {
818        let tmp = std::env::temp_dir().join("dakera_test_cached_onnx");
819        let onnx_dir = tmp.join("onnx");
820        std::fs::create_dir_all(&onnx_dir).unwrap();
821        let onnx_path = onnx_dir.join("model_quantized.onnx");
822        std::fs::write(&onnx_path, b"fake_onnx_model").unwrap();
823
824        // Filename includes the subdirectory path
825        let result = EmbeddingEngine::download_hf_file(
826            "Xenova/all-MiniLM-L6-v2",
827            "onnx/model_quantized.onnx",
828            &tmp,
829        );
830        assert!(result.is_ok());
831        assert_eq!(result.unwrap(), onnx_path);
832    }
833
834    // ── EmbeddingEngineBuilder ───────────────────────────────────────────────
835
836    #[test]
837    fn test_builder_default_impl() {
838        let b1 = EmbeddingEngineBuilder::new();
839        let b2 = EmbeddingEngineBuilder::default();
840        assert_eq!(b1.config.model, b2.config.model);
841        assert_eq!(b1.config.max_batch_size, b2.config.max_batch_size);
842    }
843
844    #[test]
845    fn test_builder_model_field() {
846        let builder = EmbeddingEngineBuilder::new().model(EmbeddingModel::E5Small);
847        assert_eq!(builder.config.model, EmbeddingModel::E5Small);
848    }
849
850    #[test]
851    fn test_builder_cache_dir() {
852        let builder = EmbeddingEngineBuilder::new().cache_dir("/tmp/my-models");
853        assert_eq!(builder.config.cache_dir, Some("/tmp/my-models".to_string()));
854    }
855
856    #[test]
857    fn test_builder_max_batch_size() {
858        let builder = EmbeddingEngineBuilder::new().max_batch_size(128);
859        assert_eq!(builder.config.max_batch_size, 128);
860    }
861
862    #[test]
863    fn test_builder_use_gpu_true() {
864        let builder = EmbeddingEngineBuilder::new().use_gpu(true);
865        assert!(builder.config.use_gpu);
866    }
867
868    #[test]
869    fn test_builder_use_gpu_false() {
870        let builder = EmbeddingEngineBuilder::new().use_gpu(false);
871        assert!(!builder.config.use_gpu);
872    }
873
874    #[test]
875    fn test_builder_num_threads() {
876        let builder = EmbeddingEngineBuilder::new().num_threads(4);
877        assert_eq!(builder.config.num_threads, Some(4));
878    }
879
880    #[test]
881    fn test_builder_chain_all_fields() {
882        let builder = EmbeddingEngineBuilder::new()
883            .model(EmbeddingModel::BgeSmall)
884            .cache_dir("/cache")
885            .max_batch_size(16)
886            .use_gpu(false)
887            .num_threads(2);
888
889        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
890        assert_eq!(builder.config.cache_dir, Some("/cache".to_string()));
891        assert_eq!(builder.config.max_batch_size, 16);
892        assert!(!builder.config.use_gpu);
893        assert_eq!(builder.config.num_threads, Some(2));
894    }
895
896    // ── estimate_time_ms ─────────────────────────────────────────────────────
897
898    #[test]
899    fn test_estimate_time_zero_count() {
900        let tps = EmbeddingModel::MiniLM.tokens_per_second_cpu() as f64;
901        let estimate = (0.0 / tps) * 1000.0;
902        assert_eq!(estimate, 0.0);
903    }
904
905    #[test]
906    fn test_estimate_time_formula_cpu() {
907        // texts=10, avg_len=100 → tokens_per_text = min(25, 256) = 25
908        // total_tokens = 250; tps = 5000; time = (250/5000)*1000 = 50ms
909        let model = EmbeddingModel::MiniLM;
910        let tokens_per_text = (100.0f64 / 4.0).min(model.max_seq_length() as f64);
911        let total_tokens = tokens_per_text * 10.0;
912        let estimate = (total_tokens / model.tokens_per_second_cpu() as f64) * 1000.0;
913        assert!(
914            (estimate - 50.0).abs() < 1e-6,
915            "expected 50.0ms, got {estimate}"
916        );
917    }
918
919    #[test]
920    fn test_estimate_time_capped_at_max_seq_length() {
921        let model = EmbeddingModel::MiniLM;
922        let avg_len = 100_000;
923        let tokens_per_text = (avg_len as f64 / 4.0).min(model.max_seq_length() as f64);
924        assert_eq!(tokens_per_text, 256.0);
925    }
926
927    // ── ModelConfig API ───────────────────────────────────────────────────────
928
929    #[test]
930    fn test_model_config_new() {
931        let cfg = ModelConfig::new(EmbeddingModel::BgeSmall);
932        assert_eq!(cfg.model, EmbeddingModel::BgeSmall);
933        assert_eq!(cfg.max_batch_size, 32);
934        assert!(!cfg.use_gpu);
935        assert!(cfg.cache_dir.is_none());
936        assert!(cfg.num_threads.is_none());
937    }
938
939    #[test]
940    fn test_model_config_default() {
941        let cfg = ModelConfig::default();
942        assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
943        assert_eq!(cfg.max_batch_size, 32);
944        assert!(!cfg.use_gpu);
945    }
946
947    #[test]
948    fn test_model_config_with_cache_dir() {
949        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_cache_dir("/tmp/models");
950        assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
951    }
952
953    #[test]
954    fn test_model_config_with_max_batch_size() {
955        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(64);
956        assert_eq!(cfg.max_batch_size, 64);
957    }
958
959    #[test]
960    fn test_model_config_with_gpu() {
961        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
962        assert!(cfg.use_gpu);
963    }
964
965    #[test]
966    fn test_model_config_with_num_threads() {
967        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
968        assert_eq!(cfg.num_threads, Some(8));
969    }
970
971    #[test]
972    fn test_model_config_chained_builder() {
973        let cfg = ModelConfig::new(EmbeddingModel::E5Small)
974            .with_cache_dir("/cache")
975            .with_max_batch_size(16)
976            .with_gpu(false)
977            .with_num_threads(4);
978        assert_eq!(cfg.model, EmbeddingModel::E5Small);
979        assert_eq!(cfg.cache_dir, Some("/cache".to_string()));
980        assert_eq!(cfg.max_batch_size, 16);
981        assert!(!cfg.use_gpu);
982        assert_eq!(cfg.num_threads, Some(4));
983    }
984
985    // ── model_cache_dir edge cases ────────────────────────────────────────────
986
987    /// Test `model_cache_dir` when HOME is not set — should fall back to /tmp.
988    #[test]
989    fn test_model_cache_dir_no_home_fallback() {
990        use std::sync::Mutex;
991        static ENV_LOCK: Mutex<()> = Mutex::new(());
992        let _guard = ENV_LOCK.lock().unwrap();
993
994        // Remove HOME and HF_HOME so we hit the /tmp fallback
995        let saved_home = std::env::var("HOME").ok();
996        let saved_hf = std::env::var("HF_HOME").ok();
997        unsafe {
998            std::env::remove_var("HOME");
999            std::env::remove_var("HF_HOME");
1000        }
1001
1002        let result = EmbeddingEngine::model_cache_dir("test/fallback-model");
1003
1004        // Restore env
1005        if let Some(h) = saved_home {
1006            unsafe { std::env::set_var("HOME", h) };
1007        }
1008        if let Some(h) = saved_hf {
1009            unsafe { std::env::set_var("HF_HOME", h) };
1010        }
1011
1012        let path = result.unwrap();
1013        // Should be under /tmp since HOME was unset
1014        assert!(
1015            path.starts_with("/tmp"),
1016            "expected path under /tmp, got {path:?}"
1017        );
1018    }
1019
1020    #[test]
1021    fn test_model_cache_dir_deep_model_id() {
1022        let path = EmbeddingEngine::model_cache_dir("org/sub/model-name-with-dashes").unwrap();
1023        let s = path.to_str().unwrap();
1024        // All slashes replaced with double-dash
1025        assert!(
1026            s.contains("org--sub--model-name-with-dashes"),
1027            "expected transformed path, got: {s}"
1028        );
1029    }
1030
1031    #[test]
1032    fn test_model_cache_dir_minilm_model_id() {
1033        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::MiniLM.model_id()).unwrap();
1034        let s = path.to_str().unwrap();
1035        assert!(s.contains("sentence-transformers--all-MiniLM-L6-v2"));
1036    }
1037
1038    #[test]
1039    fn test_model_cache_dir_bge_model_id() {
1040        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::BgeSmall.model_id()).unwrap();
1041        let s = path.to_str().unwrap();
1042        assert!(s.contains("BAAI--bge-small-en-v1.5"));
1043    }
1044
1045    #[test]
1046    fn test_model_cache_dir_e5_model_id() {
1047        let path = EmbeddingEngine::model_cache_dir(EmbeddingModel::E5Small.model_id()).unwrap();
1048        let s = path.to_str().unwrap();
1049        assert!(s.contains("intfloat--e5-small-v2"));
1050    }
1051
1052    // ── download_hf_file additional cache-hit variations ─────────────────────
1053
1054    #[test]
1055    fn test_download_hf_file_pytorch_bin_cached() {
1056        let tmp = std::env::temp_dir().join("dakera_test_pytorch_bin");
1057        std::fs::create_dir_all(&tmp).unwrap();
1058        let model_path = tmp.join("pytorch_model.bin");
1059        std::fs::write(&model_path, b"fake_pytorch_weights").unwrap();
1060
1061        let result = EmbeddingEngine::download_hf_file("test/model", "pytorch_model.bin", &tmp);
1062        assert!(result.is_ok());
1063        assert_eq!(result.unwrap(), model_path);
1064    }
1065
1066    #[test]
1067    fn test_download_hf_file_tokenizer_cached() {
1068        let tmp = std::env::temp_dir().join("dakera_test_tokenizer_cached");
1069        std::fs::create_dir_all(&tmp).unwrap();
1070        let tok_path = tmp.join("tokenizer.json");
1071        std::fs::write(&tok_path, br#"{"version":"1.0"}"#).unwrap();
1072
1073        let result = EmbeddingEngine::download_hf_file("test/model", "tokenizer.json", &tmp);
1074        assert!(result.is_ok());
1075        assert_eq!(result.unwrap(), tok_path);
1076    }
1077
1078    #[test]
1079    fn test_download_hf_file_config_json_cached() {
1080        let tmp = std::env::temp_dir().join("dakera_test_config_cached");
1081        std::fs::create_dir_all(&tmp).unwrap();
1082        let cfg_path = tmp.join("config.json");
1083        std::fs::write(&cfg_path, b"{}").unwrap();
1084
1085        let result = EmbeddingEngine::download_hf_file("test/model", "config.json", &tmp);
1086        assert!(result.is_ok());
1087        assert_eq!(result.unwrap(), cfg_path);
1088    }
1089
1090    // ── EmbeddingEngine::new() failure path via fake local cache ─────────────
1091
1092    /// Tests the code path through `download_model_files` (local Dakera cache hit)
1093    /// and into `new()` — which then fails trying to load the tokenizer from a
1094    /// fake file. No network access required; fake files are pre-seeded.
1095    #[tokio::test]
1096    #[allow(clippy::await_holding_lock)]
1097    async fn test_new_fails_with_invalid_tokenizer_json() {
1098        use std::sync::Mutex;
1099        static ENV_LOCK: Mutex<()> = Mutex::new(());
1100        let _guard = ENV_LOCK.lock().unwrap();
1101
1102        // Set up a fake Dakera model cache so download_model_files finds our files
1103        let tmp = std::env::temp_dir().join("dakera_test_engine_new_fail_tok");
1104        let model_dir = tmp
1105            .join("dakera")
1106            .join("sentence-transformers--all-MiniLM-L6-v2");
1107        std::fs::create_dir_all(&model_dir).unwrap();
1108        // Valid-looking model weights placeholder (candle will fail on this, which is fine)
1109        std::fs::write(model_dir.join("model.safetensors"), b"not_real_weights").unwrap();
1110        // Invalid tokenizer.json — will cause TokenizationError in new()
1111        std::fs::write(model_dir.join("tokenizer.json"), b"NOT_VALID_JSON").unwrap();
1112        std::fs::write(model_dir.join("config.json"), b"{}").unwrap();
1113
1114        unsafe { std::env::set_var("HF_HOME", &tmp) };
1115
1116        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1117        let result = EmbeddingEngine::new(config).await;
1118
1119        unsafe { std::env::remove_var("HF_HOME") };
1120
1121        // Must fail — tokenizer.json is invalid JSON
1122        assert!(
1123            result.is_err(),
1124            "expected Err from new() with invalid tokenizer, got Ok"
1125        );
1126    }
1127
1128    // ── EmbeddingEngineBuilder additional coverage ────────────────────────────
1129
1130    #[test]
1131    fn test_builder_with_all_models() {
1132        for model in [
1133            EmbeddingModel::MiniLM,
1134            EmbeddingModel::BgeSmall,
1135            EmbeddingModel::E5Small,
1136        ] {
1137            let builder = EmbeddingEngineBuilder::new().model(model);
1138            assert_eq!(builder.config.model, model);
1139        }
1140    }
1141
1142    #[test]
1143    fn test_builder_max_batch_size_one() {
1144        let builder = EmbeddingEngineBuilder::new().max_batch_size(1);
1145        assert_eq!(builder.config.max_batch_size, 1);
1146    }
1147
1148    #[test]
1149    fn test_builder_num_threads_zero() {
1150        let builder = EmbeddingEngineBuilder::new().num_threads(0);
1151        assert_eq!(builder.config.num_threads, Some(0));
1152    }
1153
1154    // ── EmbeddingEngine::new() / getters when model is cached (best-effort) ──
1155
1156    /// If the embedding model is already cached on this machine, exercise the
1157    /// full `new()` path and test getters. On machines without a cached model
1158    /// the test passes silently — it is intentionally non-gating.
1159    #[tokio::test]
1160    async fn test_engine_getters_when_model_cached() {
1161        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1162        match EmbeddingEngine::new(config).await {
1163            Ok(engine) => {
1164                assert_eq!(engine.dimension(), EmbeddingModel::MiniLM.dimension());
1165                assert_eq!(engine.model(), EmbeddingModel::MiniLM);
1166                // Device should be CPU in test environments (device() removed in CE-3 ONNX migration)
1167                // Debug impl should not panic
1168                let _ = format!("{:?}", engine);
1169                // estimate_time_ms should return a non-negative value
1170                let ms = engine.estimate_time_ms(10, 50);
1171                assert!(ms >= 0.0);
1172            }
1173            Err(_) => {
1174                // Model not in cache — skip; CI runner may or may not have it
1175            }
1176        }
1177    }
1178
1179    /// When model is cached: embed an empty batch must return immediately with
1180    /// no embeddings (the `texts.is_empty()` fast-path in embed_batch_internal).
1181    #[tokio::test]
1182    async fn test_engine_embed_empty_batch_when_cached() {
1183        let config = ModelConfig::new(EmbeddingModel::MiniLM);
1184        if let Ok(engine) = EmbeddingEngine::new(config).await {
1185            let result = engine.embed_raw(&[]).await;
1186            assert!(result.is_ok());
1187            assert!(result.unwrap().is_empty());
1188        }
1189    }
1190
1191    // ── Session pool (DAK-5547) ──────────────────────────────────────────────
1192
1193    #[test]
1194    fn test_session_pool_default_is_4() {
1195        // Default pool size is 4 (DAK-5746: pool=1 regressed LME ingest ~4×; OOM causes fixed).
1196        // DAKERA_ONNX_POOL_SIZE env var still allows override.
1197        let config = ModelConfig::default();
1198        let expected = std::env::var("DAKERA_ONNX_POOL_SIZE")
1199            .ok()
1200            .and_then(|v| v.parse::<usize>().ok())
1201            .filter(|&n| n >= 1)
1202            .unwrap_or(4);
1203        assert_eq!(config.session_pool_size, expected);
1204    }
1205
1206    #[test]
1207    fn test_session_pool_size_builder_roundtrip() {
1208        let builder = EmbeddingEngineBuilder::new().session_pool_size(8);
1209        assert_eq!(builder.config.session_pool_size, 8);
1210    }
1211
1212    #[test]
1213    fn test_session_pool_size_min_enforced() {
1214        let builder = EmbeddingEngineBuilder::new().session_pool_size(0);
1215        assert_eq!(
1216            builder.config.session_pool_size, 1,
1217            "pool size 0 must clamp to 1"
1218        );
1219    }
1220
1221    #[test]
1222    fn test_model_config_with_session_pool_size() {
1223        let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1224        assert_eq!(cfg.session_pool_size, 2);
1225    }
1226
1227    /// When model is cached: verify the session pool has the expected size.
1228    #[tokio::test]
1229    async fn test_engine_pool_size_matches_config_when_cached() {
1230        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1231        if let Ok(engine) = EmbeddingEngine::new(config).await {
1232            assert_eq!(
1233                engine.pool_size(),
1234                2,
1235                "engine should hold exactly 2 sessions"
1236            );
1237        }
1238    }
1239
1240    // ── next_session round-robin ──────────────────────────────────────────────
1241
1242    /// Round-robin counter distributes batches across pool slots without panic.
1243    #[test]
1244    fn test_round_robin_index_stays_in_bounds() {
1245        let pool_len = 4_usize;
1246        let counter = AtomicUsize::new(0);
1247        for expected_idx in 0..100_usize {
1248            let start = counter.fetch_add(1, Ordering::Relaxed);
1249            let slot = start % pool_len;
1250            assert!(slot < pool_len);
1251            assert_eq!(slot, expected_idx % pool_len);
1252        }
1253    }
1254
1255    /// Pool size 1 degrades to single-session behavior without panicking.
1256    #[test]
1257    fn test_round_robin_pool_size_one() {
1258        let pool_len = 1_usize;
1259        let counter = AtomicUsize::new(0);
1260        for _ in 0..50 {
1261            let start = counter.fetch_add(1, Ordering::Relaxed);
1262            assert_eq!(start % pool_len, 0);
1263        }
1264    }
1265
1266    /// When model is cached: empty batch short-circuits before touching the pool.
1267    #[tokio::test]
1268    async fn test_embed_empty_does_not_advance_pool_counter() {
1269        let config = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(2);
1270        if let Ok(engine) = EmbeddingEngine::new(config).await {
1271            let result = engine.embed_raw(&[]).await;
1272            assert!(result.is_ok());
1273            assert!(result.unwrap().is_empty());
1274            // Empty batch returns before fetch_add — counter stays at 0.
1275            assert_eq!(engine.next_session.load(Ordering::Relaxed), 0);
1276        }
1277    }
1278}