Skip to main content

cqs/embedder/
mod.rs

1//! Embedding generation with ort + tokenizers
2
3mod models;
4mod provider;
5
6pub use models::{EmbeddingConfig, ModelConfig, ModelInfo, DEFAULT_DIM, DEFAULT_MODEL_REPO};
7
8use provider::ort_err;
9pub(crate) use provider::{create_session, select_provider};
10
11use lru::LruCache;
12use ndarray::{Array2, Array3, Axis};
13use once_cell::sync::OnceCell;
14use ort::session::Session;
15use std::num::NonZeroUsize;
16use std::path::{Path, PathBuf};
17use std::sync::Mutex;
18use thiserror::Error;
19
20/// Retrieves the embedding model repository from the resolved ModelConfig.
21///
22/// Delegates to `ModelConfig::resolve(None, None)` which checks env var / defaults.
23pub fn model_repo() -> String {
24    ModelConfig::resolve(None, None).repo
25}
26
27// blake3 checksums — empty to skip validation (configurable models have different checksums)
28const MODEL_BLAKE3: &str = "";
29const TOKENIZER_BLAKE3: &str = "";
30
31#[derive(Error, Debug)]
32pub enum EmbedderError {
33    #[error("Model not found: {0}")]
34    ModelNotFound(String),
35    #[error("Tokenizer error: {0}")]
36    Tokenizer(String),
37    #[error("Inference failed: {0}")]
38    InferenceFailed(String),
39    #[error("Checksum mismatch for {path}: expected {expected}, got {actual}")]
40    ChecksumMismatch {
41        path: String,
42        expected: String,
43        actual: String,
44    },
45    #[error("Query cannot be empty")]
46    EmptyQuery,
47    #[error("HuggingFace Hub error: {0}")]
48    HfHub(String),
49}
50
51// `ort_err` is defined in `provider.rs` (pub(super)) and imported above.
52
53/// An L2-normalized embedding vector.
54///
55/// Dimension depends on the configured model (e.g., 1024 for BGE-large, 768 for E5-base).
56/// Can be compared using cosine similarity (dot product for normalized vectors).
57#[derive(Debug, Clone)]
58pub struct Embedding(Vec<f32>);
59
60/// Full embedding dimension -- re-exported from crate root
61pub use crate::EMBEDDING_DIM;
62
63/// Error returned when creating an embedding with invalid data (empty or non-finite)
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct EmbeddingDimensionError {
66    /// The actual dimension provided
67    pub actual: usize,
68    /// The expected minimum dimension
69    pub expected: usize,
70}
71
72impl std::fmt::Display for EmbeddingDimensionError {
73    /// Formats the embedding dimension mismatch error for display.
74    ///
75    /// This method implements the Display trait to produce a human-readable error message indicating a mismatch between expected and actual embedding dimensions.
76    ///
77    /// # Arguments
78    ///
79    /// * `f` - The formatter to write the error message to
80    ///
81    /// # Returns
82    ///
83    /// Returns `std::fmt::Result` indicating whether the formatting operation succeeded.
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(
86            f,
87            "Invalid embedding dimension: expected {}, got {}",
88            self.expected, self.actual
89        )
90    }
91}
92
93impl std::error::Error for EmbeddingDimensionError {}
94
95impl Embedding {
96    /// Create a new embedding from raw vector data (unchecked).
97    ///
98    /// Accepts any dimension — the Embedder validates consistency via `detected_dim`.
99    /// **Prefer `try_new()` for untrusted input** (external APIs, deserialized data).
100    /// Use `new()` only when the data is known-good (e.g., fresh from ONNX inference).
101    pub fn new(data: Vec<f32>) -> Self {
102        Self(data)
103    }
104
105    /// Create a new embedding with validation.
106    ///
107    /// Returns `Err` if the vector is empty or contains non-finite values.
108    /// Dimension is no longer validated here — the Embedder enforces consistency.
109    ///
110    /// # Example
111    /// ```
112    /// use cqs::embedder::Embedding;
113    ///
114    /// let valid = Embedding::try_new(vec![0.5; 768]);
115    /// assert!(valid.is_ok());
116    ///
117    /// let also_valid = Embedding::try_new(vec![0.5; 384]);
118    /// assert!(also_valid.is_ok());
119    ///
120    /// let empty = Embedding::try_new(vec![]);
121    /// assert!(empty.is_err());
122    /// ```
123    pub fn try_new(data: Vec<f32>) -> Result<Self, EmbeddingDimensionError> {
124        if data.is_empty() {
125            return Err(EmbeddingDimensionError {
126                actual: 0,
127                expected: 1, // at least 1 dimension required
128            });
129        }
130        if !data.iter().all(|v| v.is_finite()) {
131            return Err(EmbeddingDimensionError {
132                actual: data.len(),
133                expected: data.len(),
134            });
135        }
136        Ok(Self(data))
137    }
138
139    /// Get the embedding as a slice
140    pub fn as_slice(&self) -> &[f32] {
141        &self.0
142    }
143
144    /// Get a reference to the inner Vec (needed for some APIs like hnsw_rs)
145    pub fn as_vec(&self) -> &Vec<f32> {
146        &self.0
147    }
148
149    /// Consume the embedding and return the inner vector
150    pub fn into_inner(self) -> Vec<f32> {
151        self.0
152    }
153
154    /// Get the dimension of the embedding.
155    ///
156    /// Returns the number of dimensions (e.g., 1024 for BGE-large, 768 for E5-base).
157    pub fn len(&self) -> usize {
158        self.0.len()
159    }
160
161    /// Check if the embedding is empty
162    pub fn is_empty(&self) -> bool {
163        self.0.is_empty()
164    }
165}
166
167/// Hardware execution provider for inference
168#[derive(Debug, Clone, Copy)]
169pub enum ExecutionProvider {
170    /// NVIDIA CUDA (requires CUDA toolkit)
171    CUDA { device_id: i32 },
172    /// NVIDIA TensorRT (faster than CUDA, requires TensorRT)
173    TensorRT { device_id: i32 },
174    /// CPU fallback (always available)
175    CPU,
176}
177
178impl std::fmt::Display for ExecutionProvider {
179    /// Formats the ExecutionProvider variant as a human-readable string.
180    ///
181    /// # Arguments
182    /// * `f` - The formatter to write the formatted output to
183    ///
184    /// # Returns
185    /// A `std::fmt::Result` indicating whether the formatting operation succeeded
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        match self {
188            ExecutionProvider::CUDA { device_id } => write!(f, "CUDA (device {})", device_id),
189            ExecutionProvider::TensorRT { device_id } => {
190                write!(f, "TensorRT (device {})", device_id)
191            }
192            ExecutionProvider::CPU => write!(f, "CPU"),
193        }
194    }
195}
196
197/// Text embedding generator using a configurable model (default: BGE-large-en-v1.5)
198///
199/// Automatically downloads the model from HuggingFace Hub on first use.
200/// Detects GPU availability and uses CUDA/TensorRT when available.
201///
202/// # Example
203///
204/// ```no_run
205/// use cqs::Embedder;
206/// use cqs::embedder::ModelConfig;
207///
208/// let embedder = Embedder::new(ModelConfig::resolve(None, None))?;
209/// let embedding = embedder.embed_query("parse configuration file")?;
210/// println!("Embedding dimension: {}", embedding.len()); // 768
211/// # Ok::<(), anyhow::Error>(())
212/// ```
213pub struct Embedder {
214    /// Lazy-loaded ONNX session (expensive ~500ms init, needs Mutex for run()).
215    ///
216    /// Persists for the lifetime of the Embedder. In long-running processes,
217    /// this holds ~500MB of GPU/CPU memory. To release, call [`clear_session`]
218    /// or drop the Embedder instance and create a new one when needed.
219    session: Mutex<Option<Session>>,
220    /// Lazy-loaded tokenizer
221    tokenizer: OnceCell<tokenizers::Tokenizer>,
222    /// Lazy-loaded model paths (avoids HuggingFace API calls until actually embedding)
223    model_paths: OnceCell<(PathBuf, PathBuf)>,
224    provider: ExecutionProvider,
225    max_length: usize,
226    /// LRU cache for query embeddings (avoids re-computing same queries)
227    query_cache: Mutex<LruCache<String, Embedding>>,
228    /// Detected embedding dimension from the model. Set on first inference.
229    detected_dim: std::sync::OnceLock<usize>,
230    /// Model configuration (repo, paths, prefixes, dimensions)
231    model_config: ModelConfig,
232}
233
234/// Default query cache size (entries). Each entry is ~4KB (1024 floats + key).
235const DEFAULT_QUERY_CACHE_SIZE: usize = 32;
236
237impl Embedder {
238    /// Create a new embedder with lazy model loading.
239    ///
240    /// When `force_cpu` is false, automatically detects GPU and uses CUDA/TensorRT
241    /// when available, falling back to CPU if no GPU is found.
242    /// When `force_cpu` is true, always uses CPU -- use this for single-query
243    /// embedding where CPU is faster than GPU due to CUDA context setup overhead.
244    ///
245    /// Note: Model download and ONNX session are lazy-loaded on first
246    /// embedding request. This avoids HuggingFace API calls for commands
247    /// that don't need embeddings.
248    pub fn new(model_config: ModelConfig) -> Result<Self, EmbedderError> {
249        Self::new_with_provider(model_config, select_provider())
250    }
251
252    /// Create a CPU-only embedder with lazy model loading.
253    ///
254    /// Convenience wrapper for `new()` — use this for single-query embedding
255    /// where CPU is faster than GPU due to CUDA context setup overhead.
256    pub fn new_cpu(model_config: ModelConfig) -> Result<Self, EmbedderError> {
257        Self::new_with_provider(model_config, ExecutionProvider::CPU)
258    }
259
260    /// Shared constructor for both GPU-auto and CPU-only embedders.
261    fn new_with_provider(
262        model_config: ModelConfig,
263        provider: ExecutionProvider,
264    ) -> Result<Self, EmbedderError> {
265        let max_length = model_config.max_seq_length;
266
267        let query_cache = Mutex::new(LruCache::new(
268            NonZeroUsize::new(DEFAULT_QUERY_CACHE_SIZE)
269                .expect("DEFAULT_QUERY_CACHE_SIZE is non-zero"),
270        ));
271
272        Ok(Self {
273            session: Mutex::new(None),
274            tokenizer: OnceCell::new(),
275            model_paths: OnceCell::new(),
276            provider,
277            max_length,
278            query_cache,
279            detected_dim: std::sync::OnceLock::new(),
280            model_config,
281        })
282    }
283
284    /// Get the model configuration
285    pub fn model_config(&self) -> &ModelConfig {
286        &self.model_config
287    }
288
289    /// Get or initialize model paths (lazy download)
290    fn model_paths(&self) -> Result<&(PathBuf, PathBuf), EmbedderError> {
291        self.model_paths
292            .get_or_try_init(|| ensure_model(&self.model_config))
293    }
294
295    /// Get or initialize the ONNX session
296    fn session(&self) -> Result<std::sync::MutexGuard<'_, Option<Session>>, EmbedderError> {
297        let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
298        if guard.is_none() {
299            let _span = tracing::info_span!("embedder_session_init").entered();
300            let (model_path, _) = self.model_paths()?;
301            *guard = Some(create_session(model_path, self.provider)?);
302            tracing::info!("Embedder session initialized");
303        }
304        Ok(guard)
305    }
306
307    /// Get or initialize the tokenizer
308    fn tokenizer(&self) -> Result<&tokenizers::Tokenizer, EmbedderError> {
309        let (_, tokenizer_path) = self.model_paths()?;
310        self.tokenizer.get_or_try_init(|| {
311            tokenizers::Tokenizer::from_file(tokenizer_path)
312                .map_err(|e| EmbedderError::Tokenizer(e.to_string()))
313        })
314    }
315
316    /// Counts the number of tokens in the given text using the configured tokenizer.
317    ///
318    /// # Arguments
319    ///
320    /// * `text` - The text string to tokenize and count
321    ///
322    /// # Returns
323    ///
324    /// Returns `Ok(usize)` containing the number of tokens in the text, or `Err(EmbedderError)` if tokenization fails.
325    ///
326    /// # Errors
327    ///
328    /// Returns `EmbedderError::Tokenizer` if the tokenizer is unavailable or if encoding the text fails.
329    pub fn token_count(&self, text: &str) -> Result<usize, EmbedderError> {
330        let encoding = self
331            .tokenizer()?
332            .encode(text, false)
333            .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
334        Ok(encoding.get_ids().len())
335    }
336
337    /// Count tokens for multiple texts in a single batch.
338    ///
339    /// Uses `encode_batch` for potentially better throughput than individual
340    /// `token_count` calls when processing many texts.
341    pub fn token_counts_batch(&self, texts: &[&str]) -> Result<Vec<usize>, EmbedderError> {
342        if texts.is_empty() {
343            return Ok(vec![]);
344        }
345        let encodings = self
346            .tokenizer()?
347            .encode_batch(texts.to_vec(), false)
348            .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
349        Ok(encodings.iter().map(|e| e.get_ids().len()).collect())
350    }
351
352    /// Split text into overlapping windows of max_tokens with overlap tokens of context.
353    /// Returns Vec of (window_content, window_index).
354    /// If text fits in max_tokens, returns single window with index 0.
355    ///
356    /// # Panics
357    /// Panics if `overlap >= max_tokens / 2` as this creates exponential window count.
358    pub fn split_into_windows(
359        &self,
360        text: &str,
361        max_tokens: usize,
362        overlap: usize,
363    ) -> Result<Vec<(String, u32)>, EmbedderError> {
364        if max_tokens == 0 {
365            return Ok(vec![]);
366        }
367
368        // Validate overlap to prevent exponential window explosion.
369        // overlap >= max_tokens/2 means step <= max_tokens/2, causing O(2n/max_tokens) windows
370        // instead of O(n/max_tokens). With overlap >= max_tokens, step becomes 1 token = disaster.
371        if overlap >= max_tokens / 2 {
372            return Err(EmbedderError::Tokenizer(format!(
373                "overlap ({overlap}) must be less than max_tokens/2 ({})",
374                max_tokens / 2
375            )));
376        }
377
378        let tokenizer = self.tokenizer()?;
379        let encoding = tokenizer
380            .encode(text, false)
381            .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
382
383        let ids = encoding.get_ids();
384        if ids.len() <= max_tokens {
385            return Ok(vec![(text.to_string(), 0)]);
386        }
387
388        let mut windows = Vec::new();
389        // Step size: tokens per window minus overlap.
390        // The assertion above guarantees step > max_tokens/2, ensuring linear window count.
391        let step = max_tokens - overlap;
392        let mut start = 0;
393        let mut window_idx = 0u32;
394
395        while start < ids.len() {
396            let end = (start + max_tokens).min(ids.len());
397            let window_ids: Vec<u32> = ids[start..end].to_vec();
398
399            // Decode back to text
400            let window_text = tokenizer
401                .decode(&window_ids, true)
402                .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?;
403
404            windows.push((window_text, window_idx));
405            window_idx += 1;
406
407            if end >= ids.len() {
408                break;
409            }
410            start += step;
411        }
412
413        Ok(windows)
414    }
415
416    /// Embed documents (code chunks). Adds model-specific document prefix.
417    ///
418    /// Large inputs are processed in batches of 64 to cap GPU memory usage.
419    pub fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedderError> {
420        let _span = tracing::info_span!("embed_documents", count = texts.len()).entered();
421        let prefix = &self.model_config.doc_prefix;
422        const MAX_BATCH: usize = 64;
423        if texts.len() <= MAX_BATCH {
424            let prefixed: Vec<String> = texts.iter().map(|t| format!("{}{}", prefix, t)).collect();
425            return self.embed_batch(&prefixed);
426        }
427        let mut all = Vec::with_capacity(texts.len());
428        for chunk in texts.chunks(MAX_BATCH) {
429            let prefixed: Vec<String> = chunk.iter().map(|t| format!("{}{}", prefix, t)).collect();
430            all.extend(self.embed_batch(&prefixed)?);
431        }
432        Ok(all)
433    }
434
435    /// Embed a query. Adds "query: " prefix for E5. Uses LRU cache for repeated queries.
436    ///
437    /// # Concurrency Note
438    /// Intentionally releases lock during embedding computation (~100ms) to allow parallel queries.
439    /// This means two simultaneous queries for the same text may both compute embeddings, but this
440    /// is preferable to serializing all queries through a single lock. The duplicate work is rare
441    /// and the cache update is idempotent.
442    /// Maximum input bytes before truncation (RT-RES-5).
443    /// The tokenizer will further truncate to max_seq_length tokens, but this
444    /// prevents O(n) tokenization work on megabyte-sized inputs.
445    const MAX_QUERY_BYTES: usize = 32 * 1024;
446
447    pub fn embed_query(&self, text: &str) -> Result<Embedding, EmbedderError> {
448        let _span = tracing::info_span!("embed_query").entered();
449        let text = text.trim();
450        if text.is_empty() {
451            return Err(EmbedderError::EmptyQuery);
452        }
453        // RT-RES-5: Truncate oversized input before tokenization to bound CPU work.
454        let text = if text.len() > Self::MAX_QUERY_BYTES {
455            tracing::warn!(
456                len = text.len(),
457                max = Self::MAX_QUERY_BYTES,
458                "Query text truncated before embedding"
459            );
460            // Truncate at a char boundary
461            let mut end = Self::MAX_QUERY_BYTES;
462            while !text.is_char_boundary(end) && end > 0 {
463                end -= 1;
464            }
465            &text[..end]
466        } else {
467            text
468        };
469
470        // Check cache first (lock released after check to allow parallel computation)
471        {
472            let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
473                tracing::warn!("Query cache lock poisoned (prior panic), recovering");
474                poisoned.into_inner()
475            });
476            if let Some(cached) = cache.get(text) {
477                tracing::trace!(query = text, "Embedding cache hit");
478                return Ok(cached.clone());
479            }
480            tracing::trace!(query = text, "Embedding cache miss");
481        }
482
483        // Compute embedding (outside lock - allows parallel queries)
484        let prefixed = format!("{}{}", self.model_config.query_prefix, text);
485        let results = self.embed_batch(&[prefixed])?;
486        let base_embedding = results.into_iter().next().ok_or_else(|| {
487            EmbedderError::InferenceFailed("embed_batch returned empty result".to_string())
488        })?;
489
490        let embedding = base_embedding;
491
492        // Store in cache (idempotent - duplicate puts for same key are harmless)
493        {
494            let mut cache = self.query_cache.lock().unwrap_or_else(|poisoned| {
495                tracing::warn!("Query cache lock poisoned (prior panic), recovering");
496                poisoned.into_inner()
497            });
498            cache.put(text.to_string(), embedding.clone());
499            tracing::trace!(query = text, cache_len = cache.len(), "Embedding cached");
500        }
501
502        Ok(embedding)
503    }
504
505    /// Get the execution provider being used
506    pub fn provider(&self) -> ExecutionProvider {
507        self.provider
508    }
509
510    /// Clear the ONNX session to free memory (~500MB).
511    ///
512    /// The session will be lazily re-initialized on the next embedding request.
513    /// Use this in long-running processes during idle periods to reduce memory footprint.
514    ///
515    /// # Safety constraint
516    /// Must only be called during idle periods -- not while embedding is in progress.
517    /// Watch mode guarantees single-threaded access.
518    pub fn clear_session(&self) {
519        let mut guard = self.session.lock().unwrap_or_else(|p| p.into_inner());
520        *guard = None;
521        // Also clear query cache -- stale embeddings from old session would be wrong
522        // if model config changes before session is re-created.
523        let mut cache = self.query_cache.lock().unwrap_or_else(|p| p.into_inner());
524        cache.clear();
525        tracing::info!("Embedder session and query cache cleared");
526    }
527
528    /// Warm up the model with a dummy inference
529    pub fn warm(&self) -> Result<(), EmbedderError> {
530        let _ = self.embed_query("warmup")?;
531        Ok(())
532    }
533
534    /// Returns the embedding dimension detected from the model.
535    /// Falls back to the model config's declared dimension if no inference has been run yet.
536    pub fn embedding_dim(&self) -> usize {
537        let dim = *self.detected_dim.get().unwrap_or(&self.model_config.dim);
538        if dim == 0 {
539            EMBEDDING_DIM
540        } else {
541            dim
542        }
543    }
544
545    /// Generates embeddings for a batch of text inputs.
546    ///
547    /// This method tokenizes the input texts, prepares them as padded tensors suitable for the ONNX model, and runs inference to produce embedding vectors. Texts are padded to the maximum length within the batch (up to the model's configured maximum length).
548    ///
549    /// # Arguments
550    ///
551    /// * `texts` - A slice of strings to embed
552    ///
553    /// # Returns
554    ///
555    /// Returns a vector of embeddings, one per input text. Returns an error if tokenization fails or the embedding model cannot be run.
556    ///
557    /// # Errors
558    ///
559    /// Returns `EmbedderError::Tokenizer` if tokenization of the batch fails.
560    fn embed_batch(&self, texts: &[String]) -> Result<Vec<Embedding>, EmbedderError> {
561        use ort::value::Tensor;
562
563        let _span = tracing::info_span!("embed_batch", count = texts.len()).entered();
564
565        if texts.is_empty() {
566            return Ok(vec![]);
567        }
568
569        // Tokenize (lazy init tokenizer)
570        // PERF-36: `encode_batch` requires `Vec<EncodeInput>` (owned), so `texts.to_vec()` is
571        // unavoidable — the tokenizer API does not accept `&[impl AsRef<str>]`.
572        let encodings = {
573            let _tokenize = tracing::debug_span!("tokenize").entered();
574            self.tokenizer()?
575                .encode_batch(texts.to_vec(), true)
576                .map_err(|e| EmbedderError::Tokenizer(e.to_string()))?
577        };
578
579        // Prepare inputs - INT64 (i64) for ONNX model
580        let input_ids: Vec<Vec<i64>> = encodings
581            .iter()
582            .map(|e| e.get_ids().iter().map(|&id| id as i64).collect())
583            .collect();
584        let attention_mask: Vec<Vec<i64>> = encodings
585            .iter()
586            .map(|e| e.get_attention_mask().iter().map(|&m| m as i64).collect())
587            .collect();
588
589        // Pad to max length in batch
590        let max_len = input_ids
591            .iter()
592            .map(|v| v.len())
593            .max()
594            .unwrap_or(0)
595            .min(self.max_length);
596
597        // Create padded arrays
598        let input_ids_arr = pad_2d_i64(&input_ids, max_len, 0);
599        let attention_mask_arr = pad_2d_i64(&attention_mask, max_len, 0);
600        // token_type_ids: all zeros, same shape as input_ids
601        let token_type_ids_arr = Array2::<i64>::zeros((texts.len(), max_len));
602
603        // Create tensors
604        let input_ids_tensor = Tensor::from_array(input_ids_arr).map_err(ort_err)?;
605        let attention_mask_tensor = Tensor::from_array(attention_mask_arr).map_err(ort_err)?;
606        let token_type_ids_tensor = Tensor::from_array(token_type_ids_arr).map_err(ort_err)?;
607
608        // Run inference (lazy init session)
609        let mut guard = self.session()?;
610        let session = guard
611            .as_mut()
612            .expect("session() guarantees initialized after Ok return");
613        let _inference = tracing::debug_span!("inference", max_len).entered();
614        let outputs = session
615            .run(ort::inputs![
616                "input_ids" => input_ids_tensor,
617                "attention_mask" => attention_mask_tensor,
618                "token_type_ids" => token_type_ids_tensor,
619            ])
620            .map_err(ort_err)?;
621
622        // Get the last_hidden_state output: shape [batch, seq_len, 768]
623        let output = outputs.get("last_hidden_state").ok_or_else(|| {
624            EmbedderError::InferenceFailed(format!(
625                "ONNX model has no 'last_hidden_state' output. Available: {:?}",
626                outputs.keys().collect::<Vec<_>>()
627            ))
628        })?;
629        let (shape, data) = output.try_extract_tensor::<f32>().map_err(ort_err)?;
630
631        // Validate tensor shape: expect [batch_size, seq_len, 768]
632        let batch_size = texts.len();
633        let seq_len = max_len;
634        if shape.len() != 3 {
635            return Err(EmbedderError::InferenceFailed(format!(
636                "Unexpected tensor shape: expected 3 dimensions [batch, seq, dim], got {} dimensions",
637                shape.len()
638            )));
639        }
640        let embedding_dim = shape[2] as usize;
641        // Set or validate embedding dimension from model output
642        match self.detected_dim.get() {
643            Some(&expected) if expected != embedding_dim => {
644                return Err(EmbedderError::InferenceFailed(format!(
645                    "Embedding dimension changed: expected {expected}, got {embedding_dim}"
646                )));
647            }
648            None => {
649                let _ = self.detected_dim.set(embedding_dim);
650                tracing::info!(
651                    dim = embedding_dim,
652                    "Detected embedding dimension from model"
653                );
654            }
655            _ => {} // matches expected — OK
656        }
657        if shape[0] as usize != batch_size {
658            return Err(EmbedderError::InferenceFailed(format!(
659                "Tensor batch size mismatch: expected {}, got {}",
660                batch_size, shape[0]
661            )));
662        }
663        // Mean-pooling via ndarray (vectorized, SIMD-friendly)
664        let hidden = Array3::from_shape_vec((batch_size, seq_len, embedding_dim), data.to_vec())
665            .map_err(|e| EmbedderError::InferenceFailed(format!("tensor reshape failed: {e}")))?;
666
667        // Build mask: [batch, seq, 1] for broadcasting
668        let mask_2d = Array2::from_shape_fn((batch_size, seq_len), |(i, j)| {
669            attention_mask[i].get(j).copied().unwrap_or(0) as f32
670        });
671        let mask_3d = mask_2d.clone().insert_axis(Axis(2));
672
673        // Masked sum: (hidden * mask).sum(axis=1) / mask.sum(axis=1)
674        let masked = &hidden * &mask_3d;
675        let summed = masked.sum_axis(Axis(1)); // [batch, dim]
676        let counts = mask_2d.sum_axis(Axis(1)).insert_axis(Axis(1)); // [batch, 1]
677
678        let results = (0..batch_size)
679            .map(|i| {
680                let count = counts[[i, 0]];
681                let row = summed.row(i);
682                let pooled: Vec<f32> = if count > 0.0 {
683                    row.iter().map(|v| v / count).collect()
684                } else {
685                    tracing::warn!(batch_idx = i, "Zero attention mask — producing zero vector");
686                    vec![0.0f32; embedding_dim]
687                };
688                Embedding::new(normalize_l2(pooled))
689            })
690            .collect();
691
692        Ok(results)
693    }
694}
695
696/// Download model and tokenizer from HuggingFace Hub
697fn ensure_model(config: &ModelConfig) -> Result<(PathBuf, PathBuf), EmbedderError> {
698    // CQS_ONNX_DIR: bypass HF download, load from local directory.
699    // Directory must contain model.onnx and tokenizer.json.
700    if let Ok(dir) = std::env::var("CQS_ONNX_DIR") {
701        let dir = dunce::canonicalize(PathBuf::from(&dir)).unwrap_or_else(|_| PathBuf::from(dir));
702        let model_path = dir.join(&config.onnx_path);
703        let tokenizer_path = dir.join(&config.tokenizer_path);
704        if model_path.exists() && tokenizer_path.exists() {
705            tracing::info!(dir = %dir.display(), "Using local ONNX model directory");
706            return Ok((model_path, tokenizer_path));
707        }
708        // Try flat layout (model.onnx + tokenizer.json in same dir)
709        let flat_model = dir.join("model.onnx");
710        let flat_tok = dir.join("tokenizer.json");
711        if flat_model.exists() && flat_tok.exists() {
712            tracing::info!(dir = %dir.display(), "Using local ONNX model directory (flat)");
713            return Ok((flat_model, flat_tok));
714        }
715        tracing::warn!(dir = %dir.display(), "CQS_ONNX_DIR set but model files not found, falling back to HF download");
716    }
717
718    use hf_hub::api::sync::Api;
719
720    let api = Api::new().map_err(|e| EmbedderError::HfHub(e.to_string()))?;
721    let repo = api.model(config.repo.clone());
722
723    let model_path = repo
724        .get(&config.onnx_path)
725        .map_err(|e| EmbedderError::HfHub(e.to_string()))?;
726    let tokenizer_path = repo
727        .get(&config.tokenizer_path)
728        .map_err(|e| EmbedderError::HfHub(e.to_string()))?;
729
730    // Verify checksums (skip if already verified via marker file)
731    if !MODEL_BLAKE3.is_empty() || !TOKENIZER_BLAKE3.is_empty() {
732        let marker = model_path
733            .parent()
734            .unwrap_or(Path::new("."))
735            .join(".cqs_verified");
736        let expected_marker = format!("{}\n{}", MODEL_BLAKE3, TOKENIZER_BLAKE3);
737        let already_verified = std::fs::read_to_string(&marker)
738            .map(|s| s == expected_marker)
739            .unwrap_or(false);
740
741        if !already_verified {
742            if !MODEL_BLAKE3.is_empty() {
743                verify_checksum(&model_path, MODEL_BLAKE3)?;
744            }
745            if !TOKENIZER_BLAKE3.is_empty() {
746                verify_checksum(&tokenizer_path, TOKENIZER_BLAKE3)?;
747            }
748            // Write marker after successful verification
749            let _ = std::fs::write(&marker, &expected_marker);
750        }
751    }
752
753    Ok((model_path, tokenizer_path))
754}
755
756/// Verify file checksum using blake3
757fn verify_checksum(path: &Path, expected: &str) -> Result<(), EmbedderError> {
758    let mut file =
759        std::fs::File::open(path).map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
760    let mut hasher = blake3::Hasher::new();
761    std::io::copy(&mut file, &mut hasher)
762        .map_err(|e| EmbedderError::ModelNotFound(e.to_string()))?;
763    let actual = hasher.finalize().to_hex().to_string();
764
765    if actual != expected {
766        return Err(EmbedderError::ChecksumMismatch {
767            path: path.display().to_string(),
768            expected: expected.to_string(),
769            actual,
770        });
771    }
772    Ok(())
773}
774
775/// Pad 2D sequences to a fixed length
776pub(crate) fn pad_2d_i64(inputs: &[Vec<i64>], max_len: usize, pad_value: i64) -> Array2<i64> {
777    let batch_size = inputs.len();
778    let mut arr = Array2::from_elem((batch_size, max_len), pad_value);
779    for (i, seq) in inputs.iter().enumerate() {
780        for (j, &val) in seq.iter().take(max_len).enumerate() {
781            arr[[i, j]] = val;
782        }
783    }
784    arr
785}
786
787/// L2 normalize a vector (single-pass, in-place)
788fn normalize_l2(mut v: Vec<f32>) -> Vec<f32> {
789    let norm_sq: f32 = v.iter().fold(0.0, |acc, &x| acc + x * x);
790    if norm_sq > 0.0 {
791        let inv_norm = 1.0 / norm_sq.sqrt();
792        v.iter_mut().for_each(|x| *x *= inv_norm);
793    }
794    v
795}
796
797#[cfg(test)]
798mod tests {
799    use super::*;
800
801    // ===== Embedding tests =====
802
803    #[test]
804    fn test_embedding_new() {
805        let data = vec![0.5; EMBEDDING_DIM];
806        let emb = Embedding::new(data.clone());
807        assert_eq!(emb.as_slice(), &data);
808    }
809
810    #[test]
811    fn test_embedding_len() {
812        let emb = Embedding::new(vec![1.0; EMBEDDING_DIM]);
813        assert_eq!(emb.len(), EMBEDDING_DIM);
814    }
815
816    #[test]
817    fn test_embedding_is_empty() {
818        let empty = Embedding::new(vec![]);
819        assert!(empty.is_empty());
820
821        let non_empty = Embedding::new(vec![1.0; EMBEDDING_DIM]);
822        assert!(!non_empty.is_empty());
823    }
824
825    #[test]
826    fn test_embedding_into_inner() {
827        let data = vec![1.0; EMBEDDING_DIM];
828        let emb = Embedding::new(data.clone());
829        assert_eq!(emb.into_inner(), data);
830    }
831
832    #[test]
833    fn test_embedding_as_vec() {
834        let data = vec![1.0; EMBEDDING_DIM];
835        let emb = Embedding::new(data.clone());
836        assert_eq!(emb.as_vec(), &data);
837    }
838
839    // ===== Embedding::try_new tests (TC-33) =====
840
841    #[test]
842    fn tc33_try_new_empty_vec_errors() {
843        let result = Embedding::try_new(vec![]);
844        assert!(result.is_err());
845        let err = result.unwrap_err();
846        assert_eq!(err.actual, 0);
847        assert_eq!(err.expected, 1);
848    }
849
850    #[test]
851    fn tc33_try_new_nan_errors() {
852        let result = Embedding::try_new(vec![1.0, f32::NAN, 3.0]);
853        assert!(result.is_err(), "NaN should be rejected by try_new");
854    }
855
856    #[test]
857    fn tc33_try_new_inf_errors() {
858        let result = Embedding::try_new(vec![1.0, f32::INFINITY, 3.0]);
859        assert!(result.is_err(), "Infinity should be rejected by try_new");
860
861        let result = Embedding::try_new(vec![f32::NEG_INFINITY]);
862        assert!(result.is_err(), "Negative infinity should be rejected");
863    }
864
865    #[test]
866    fn tc33_try_new_valid_ok() {
867        let data = vec![0.1, 0.2, 0.3, 0.4, 0.5];
868        let result = Embedding::try_new(data.clone());
869        assert!(result.is_ok());
870        assert_eq!(result.unwrap().as_slice(), &data);
871    }
872
873    // ===== normalize_l2 tests =====
874
875    #[test]
876    fn test_normalize_l2_unit_vector() {
877        let v = normalize_l2(vec![1.0, 0.0, 0.0]);
878        assert!((v[0] - 1.0).abs() < 1e-6);
879        assert!((v[1] - 0.0).abs() < 1e-6);
880        assert!((v[2] - 0.0).abs() < 1e-6);
881    }
882
883    #[test]
884    fn test_normalize_l2_produces_unit_vector() {
885        let v = normalize_l2(vec![3.0, 4.0]);
886        // Should produce [0.6, 0.8] (3-4-5 triangle)
887        assert!((v[0] - 0.6).abs() < 1e-6);
888        assert!((v[1] - 0.8).abs() < 1e-6);
889
890        // Verify it's a unit vector (magnitude = 1)
891        let magnitude: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
892        assert!((magnitude - 1.0).abs() < 1e-6);
893    }
894
895    #[test]
896    fn test_normalize_l2_zero_vector() {
897        // Zero vector should remain zero (no division by zero)
898        let v = normalize_l2(vec![0.0, 0.0, 0.0]);
899        assert_eq!(v, vec![0.0, 0.0, 0.0]);
900    }
901
902    #[test]
903    fn test_normalize_l2_empty_vector() {
904        let v = normalize_l2(vec![]);
905        assert!(v.is_empty());
906    }
907
908    // ===== ExecutionProvider tests =====
909
910    #[test]
911    fn test_execution_provider_display() {
912        assert_eq!(format!("{}", ExecutionProvider::CPU), "CPU");
913        assert_eq!(
914            format!("{}", ExecutionProvider::CUDA { device_id: 0 }),
915            "CUDA (device 0)"
916        );
917        assert_eq!(
918            format!("{}", ExecutionProvider::TensorRT { device_id: 1 }),
919            "TensorRT (device 1)"
920        );
921    }
922
923    // ===== Constants tests =====
924
925    #[test]
926    fn test_model_dimensions() {
927        assert_eq!(EMBEDDING_DIM, 1024);
928    }
929
930    // ===== pad_2d_i64 tests =====
931
932    #[test]
933    fn test_pad_2d_i64_basic() {
934        let inputs = vec![vec![1, 2, 3], vec![4, 5]];
935        let result = pad_2d_i64(&inputs, 4, 0);
936        assert_eq!(result.shape(), &[2, 4]);
937        assert_eq!(result[[0, 0]], 1);
938        assert_eq!(result[[0, 1]], 2);
939        assert_eq!(result[[0, 2]], 3);
940        assert_eq!(result[[0, 3]], 0); // padded
941        assert_eq!(result[[1, 0]], 4);
942        assert_eq!(result[[1, 1]], 5);
943        assert_eq!(result[[1, 2]], 0); // padded
944        assert_eq!(result[[1, 3]], 0); // padded
945    }
946
947    #[test]
948    fn test_pad_2d_i64_truncates() {
949        let inputs = vec![vec![1, 2, 3, 4, 5]];
950        let result = pad_2d_i64(&inputs, 3, 0);
951        assert_eq!(result.shape(), &[1, 3]);
952        assert_eq!(result[[0, 0]], 1);
953        assert_eq!(result[[0, 1]], 2);
954        assert_eq!(result[[0, 2]], 3);
955        // 4 and 5 are truncated
956    }
957
958    #[test]
959    fn test_pad_2d_i64_empty_input() {
960        let inputs: Vec<Vec<i64>> = vec![];
961        let result = pad_2d_i64(&inputs, 5, 0);
962        assert_eq!(result.shape(), &[0, 5]);
963    }
964
965    #[test]
966    fn test_pad_2d_i64_custom_pad_value() {
967        let inputs = vec![vec![1]];
968        let result = pad_2d_i64(&inputs, 3, -1);
969        assert_eq!(result[[0, 0]], 1);
970        assert_eq!(result[[0, 1]], -1);
971        assert_eq!(result[[0, 2]], -1);
972    }
973
974    // ===== EmbedderError tests =====
975
976    #[test]
977    fn test_embedder_error_display() {
978        let err = EmbedderError::EmptyQuery;
979        assert_eq!(format!("{}", err), "Query cannot be empty");
980
981        let err = EmbedderError::ModelNotFound("model.onnx".to_string());
982        assert!(format!("{}", err).contains("model.onnx"));
983
984        let err = EmbedderError::Tokenizer("invalid token".to_string());
985        assert!(format!("{}", err).contains("invalid token"));
986
987        let err = EmbedderError::ChecksumMismatch {
988            path: "/path/to/file".to_string(),
989            expected: "abc123".to_string(),
990            actual: "def456".to_string(),
991        };
992        assert!(format!("{}", err).contains("abc123"));
993        assert!(format!("{}", err).contains("def456"));
994    }
995
996    #[test]
997    fn test_embedder_error_from_ort() {
998        // Test that ort::Error converts to EmbedderError::InferenceFailed
999        // We can't easily create an ort::Error, but we can verify the variant exists
1000        let err: EmbedderError = EmbedderError::InferenceFailed("test error".to_string());
1001        assert!(matches!(err, EmbedderError::InferenceFailed(_)));
1002    }
1003
1004    // ===== Property-based tests =====
1005
1006    mod proptests {
1007        use super::*;
1008        use proptest::prelude::*;
1009
1010        proptest! {
1011            /// Property: normalize_l2 produces unit vectors (magnitude ~= 1) or zero vectors
1012            #[test]
1013            fn prop_normalize_l2_unit_or_zero(v in prop::collection::vec(-1e6f32..1e6f32, 1..100)) {
1014                let normalized = normalize_l2(v.clone());
1015
1016                // Compute magnitude
1017                let magnitude: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
1018
1019                // Check: either zero vector (input was zero) or unit vector
1020                let input_is_zero = v.iter().all(|&x| x == 0.0);
1021                if input_is_zero {
1022                    prop_assert!(magnitude < 1e-6, "Zero input should give zero output");
1023                } else {
1024                    prop_assert!(
1025                        (magnitude - 1.0).abs() < 1e-4,
1026                        "Non-zero input should give unit vector, got magnitude {}",
1027                        magnitude
1028                    );
1029                }
1030            }
1031
1032            /// Property: normalize_l2 preserves vector direction (dot product with original > 0)
1033            #[test]
1034            fn prop_normalize_l2_preserves_direction(v in prop::collection::vec(1.0f32..100.0, 1..50)) {
1035                let normalized = normalize_l2(v.clone());
1036
1037                // Dot product with original should be positive (same direction)
1038                let dot: f32 = v.iter().zip(normalized.iter()).map(|(a, b)| a * b).sum();
1039                prop_assert!(dot > 0.0, "Direction should be preserved");
1040            }
1041
1042            /// Property: Embedding length is preserved through operations
1043            #[test]
1044            fn prop_embedding_length_preserved(use_model_dim in proptest::bool::ANY) {
1045                let _ = use_model_dim; // single dimension now
1046                let emb = Embedding::new(vec![0.5; EMBEDDING_DIM]);
1047                prop_assert_eq!(emb.len(), EMBEDDING_DIM);
1048                prop_assert_eq!(emb.as_slice().len(), EMBEDDING_DIM);
1049                prop_assert_eq!(emb.as_vec().len(), EMBEDDING_DIM);
1050            }
1051        }
1052    }
1053
1054    // ===== clear_session tests =====
1055
1056    #[test]
1057    #[ignore] // Requires model
1058    fn test_clear_session_and_reinit() {
1059        let embedder = Embedder::new(ModelConfig::e5_base()).unwrap();
1060        // Force session init by embedding something
1061        let _ = embedder.embed_query("test");
1062        // Clear and re-embed
1063        embedder.clear_session();
1064        let result = embedder.embed_query("test again");
1065        assert!(result.is_ok());
1066    }
1067
1068    #[test]
1069    fn test_clear_session_idempotent() {
1070        let embedder = Embedder::new_cpu(ModelConfig::e5_base()).unwrap();
1071        embedder.clear_session(); // clear before init -- should not panic
1072        embedder.clear_session(); // clear again -- should not panic
1073    }
1074
1075    // ===== Integration tests (require model) =====
1076
1077    mod integration {
1078        use super::*;
1079
1080        #[test]
1081        #[ignore] // Requires model - run with: cargo test --lib integration -- --ignored
1082        fn test_token_count_empty() {
1083            let embedder =
1084                Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
1085            let count = embedder.token_count("").expect("token_count failed");
1086            assert_eq!(count, 0);
1087        }
1088
1089        #[test]
1090        #[ignore]
1091        fn test_token_count_simple() {
1092            let embedder =
1093                Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
1094            let count = embedder
1095                .token_count("hello world")
1096                .expect("token_count failed");
1097            // E5-base-v2 tokenizer: "hello" and "world" are single tokens
1098            assert!(
1099                (2..=4).contains(&count),
1100                "Expected 2-4 tokens, got {}",
1101                count
1102            );
1103        }
1104
1105        #[test]
1106        #[ignore]
1107        fn test_token_count_code() {
1108            let embedder =
1109                Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
1110            let code = "fn main() { println!(\"Hello\"); }";
1111            let count = embedder.token_count(code).expect("token_count failed");
1112            // Code typically tokenizes to more tokens than words
1113            assert!(count > 5, "Expected >5 tokens for code, got {}", count);
1114        }
1115
1116        #[test]
1117        #[ignore]
1118        fn test_token_count_unicode() {
1119            let embedder =
1120                Embedder::new(ModelConfig::e5_base()).expect("Failed to create embedder");
1121            let text = "\u{3053}\u{3093}\u{306b}\u{3061}\u{306f}\u{4e16}\u{754c}"; // "Hello world" in Japanese
1122            let count = embedder.token_count(text).expect("token_count failed");
1123            // Unicode text may tokenize differently
1124            assert!(count > 0, "Expected >0 tokens for unicode, got {}", count);
1125        }
1126    }
1127
1128    // ===== TC-45: ensure_model / CQS_ONNX_DIR path tests =====
1129
1130    mod ensure_model_tests {
1131        use super::*;
1132        use std::sync::Mutex;
1133
1134        /// Mutex to serialize tests that manipulate CQS_ONNX_DIR env var.
1135        static ONNX_DIR_MUTEX: Mutex<()> = Mutex::new(());
1136
1137        fn test_model_config() -> ModelConfig {
1138            ModelConfig {
1139                name: "test".to_string(),
1140                repo: "test/model".to_string(),
1141                onnx_path: "onnx/model.onnx".to_string(),
1142                tokenizer_path: "tokenizer.json".to_string(),
1143                dim: 768,
1144                max_seq_length: 512,
1145                query_prefix: String::new(),
1146                doc_prefix: String::new(),
1147            }
1148        }
1149
1150        #[test]
1151        fn cqs_onnx_dir_structured_layout() {
1152            let _lock = ONNX_DIR_MUTEX.lock().unwrap();
1153            let dir = tempfile::TempDir::new().unwrap();
1154            let onnx_dir = dir.path().join("onnx");
1155            std::fs::create_dir_all(&onnx_dir).unwrap();
1156            std::fs::write(onnx_dir.join("model.onnx"), b"fake").unwrap();
1157            std::fs::write(dir.path().join("tokenizer.json"), b"fake").unwrap();
1158
1159            std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
1160            let result = ensure_model(&test_model_config());
1161            std::env::remove_var("CQS_ONNX_DIR");
1162
1163            let (model, tok) = result.unwrap();
1164            assert!(
1165                model.to_string_lossy().ends_with("model.onnx"),
1166                "Expected model path ending in model.onnx, got {:?}",
1167                model
1168            );
1169            assert!(
1170                tok.to_string_lossy().ends_with("tokenizer.json"),
1171                "Expected tokenizer path ending in tokenizer.json, got {:?}",
1172                tok
1173            );
1174        }
1175
1176        #[test]
1177        fn cqs_onnx_dir_flat_layout() {
1178            let _lock = ONNX_DIR_MUTEX.lock().unwrap();
1179            let dir = tempfile::TempDir::new().unwrap();
1180            std::fs::write(dir.path().join("model.onnx"), b"fake").unwrap();
1181            std::fs::write(dir.path().join("tokenizer.json"), b"fake").unwrap();
1182
1183            std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
1184            let result = ensure_model(&test_model_config());
1185            std::env::remove_var("CQS_ONNX_DIR");
1186
1187            let (model, tok) = result.unwrap();
1188            assert!(
1189                model.to_string_lossy().ends_with("model.onnx"),
1190                "Expected model path ending in model.onnx, got {:?}",
1191                model
1192            );
1193            assert!(
1194                tok.to_string_lossy().ends_with("tokenizer.json"),
1195                "Expected tokenizer path ending in tokenizer.json, got {:?}",
1196                tok
1197            );
1198        }
1199
1200        #[test]
1201        fn cqs_onnx_dir_missing_files_falls_through() {
1202            let _lock = ONNX_DIR_MUTEX.lock().unwrap();
1203            let dir = tempfile::TempDir::new().unwrap();
1204            // Empty dir -- neither structured nor flat layout
1205
1206            std::env::set_var("CQS_ONNX_DIR", dir.path().to_str().unwrap());
1207            let result = ensure_model(&test_model_config());
1208            std::env::remove_var("CQS_ONNX_DIR");
1209
1210            // Falls through to HF download -- which will fail in test env,
1211            // but the point is it didn't return the CQS_ONNX_DIR path
1212            assert!(
1213                result.is_err() || !result.as_ref().unwrap().0.starts_with(dir.path()),
1214                "Should not return paths from empty CQS_ONNX_DIR"
1215            );
1216        }
1217    }
1218}