Skip to main content

pulsedb/embedding/
onnx.rs

1//! ONNX-based embedding generation.
2//!
3//! This module provides embedding generation using ONNX Runtime.
4//! It requires the `builtin-embeddings` feature to be enabled.
5//!
6//! # Supported Models
7//!
8//! - **all-MiniLM-L6-v2** (384 dimensions) - Default, fast and compact
9//! - **bge-base-en-v1.5** (768 dimensions) - Higher quality, larger
10//!
11//! # Example
12//!
13//! ```rust,no_run
14//! use pulsedb::embedding::onnx::OnnxEmbedding;
15//! use pulsedb::embedding::EmbeddingService;
16//!
17//! # fn main() -> pulsedb::Result<()> {
18//! let service = OnnxEmbedding::new(None)?;  // Use default model
19//! let embedding = service.embed("Hello, world!")?;
20//! assert_eq!(embedding.len(), 384);
21//! # Ok(())
22//! # }
23//! ```
24//!
25//! # Architecture
26//!
27//! The embedding pipeline mirrors what runs inside services like Ollama
28//! or OpenAI's embedding endpoint, but executed locally:
29//!
30//! ```text
31//! Text → Tokenize → ONNX Inference → Mean Pool → L2 Normalize → Embedding
32//! ```
33//!
34//! # Performance Notes
35//!
36//! - Embedding generation is CPU-intensive
37//! - Use `embed_batch()` for multiple texts (more efficient due to batched inference)
38//! - Consider using `spawn_blocking` when called from async context
39
40use std::path::{Path, PathBuf};
41use std::sync::Mutex;
42
43use ndarray::Array2;
44use ort::session::builder::GraphOptimizationLevel;
45use ort::session::Session;
46use tokenizers::Tokenizer;
47use tracing::{debug, info};
48
49use crate::embedding::EmbeddingService;
50use crate::error::{PulseDBError, Result};
51use crate::types::Embedding;
52
53// ---------------------------------------------------------------------------
54// Model configuration constants
55// ---------------------------------------------------------------------------
56
57/// Default model: all-MiniLM-L6-v2 (384 dimensions, 256 max tokens)
58const DEFAULT_MODEL_NAME: &str = "all-MiniLM-L6-v2";
59const DEFAULT_DIMENSION: usize = 384;
60const DEFAULT_MAX_LENGTH: usize = 256;
61
62/// Alternative model: bge-base-en-v1.5 (768 dimensions, 512 max tokens)
63const BGE_MODEL_NAME: &str = "bge-base-en-v1.5";
64const BGE_MAX_LENGTH: usize = 512;
65
66/// File names expected in each model directory
67const MODEL_FILENAME: &str = "model.onnx";
68const TOKENIZER_FILENAME: &str = "tokenizer.json";
69
70// ---------------------------------------------------------------------------
71// OnnxEmbedding struct
72// ---------------------------------------------------------------------------
73
74/// ONNX-based embedding service.
75///
76/// Generates embeddings locally using an ONNX model via ONNX Runtime.
77/// The model and tokenizer are loaded eagerly at construction time for
78/// fail-fast behavior — if the model files are missing, you'll get an
79/// error at `PulseDB::open()`, not at the first `record_experience()`.
80///
81/// # Thread Safety
82///
83/// `OnnxEmbedding` is `Send + Sync`. ONNX Runtime's `Session` handles
84/// internal synchronization for concurrent inference requests.
85pub struct OnnxEmbedding {
86    /// ONNX Runtime session (the loaded model, ready for inference).
87    /// Wrapped in Mutex because `Session::run()` requires `&mut self`,
88    /// but our `EmbeddingService` trait uses `&self` for concurrent access.
89    session: Mutex<Session>,
90
91    /// HuggingFace tokenizer (converts text to token IDs).
92    /// Tokenizer is immutable after loading so no Mutex needed.
93    tokenizer: Tokenizer,
94
95    /// Embedding dimension produced by this model (e.g., 384 or 768).
96    dimension: usize,
97
98    /// Maximum sequence length the model accepts.
99    max_length: usize,
100}
101
102impl OnnxEmbedding {
103    /// Creates a new ONNX embedding service with the default model (all-MiniLM-L6-v2, 384d).
104    ///
105    /// # Arguments
106    ///
107    /// * `model_path` - Optional path to a model directory containing `model.onnx`
108    ///   and `tokenizer.json`. If `None`, looks in the default cache directory
109    ///   (`~/.cache/pulsedb/models/all-MiniLM-L6-v2/`).
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if model files are not found or cannot be loaded.
114    ///
115    /// # Example
116    ///
117    /// ```rust,no_run
118    /// use pulsedb::embedding::onnx::OnnxEmbedding;
119    ///
120    /// # fn main() -> pulsedb::Result<()> {
121    /// // Use default model from cache
122    /// let service = OnnxEmbedding::new(None)?;
123    ///
124    /// // Use custom model directory
125    /// let service = OnnxEmbedding::new(Some("./models/my-model".into()))?;
126    /// # Ok(())
127    /// # }
128    /// ```
129    pub fn new(model_path: Option<PathBuf>) -> Result<Self> {
130        Self::with_dimension(model_path, DEFAULT_DIMENSION)
131    }
132
133    /// Creates an ONNX embedding service with a specific dimension.
134    ///
135    /// The dimension determines which default model to use:
136    /// - `384` → all-MiniLM-L6-v2 (max 256 tokens)
137    /// - `768` → bge-base-en-v1.5 (max 512 tokens)
138    /// - Other → requires `model_path` to be provided
139    ///
140    /// # Arguments
141    ///
142    /// * `model_path` - Optional path to a model directory
143    /// * `dimension` - Expected embedding dimension
144    pub fn with_dimension(model_path: Option<PathBuf>, dimension: usize) -> Result<Self> {
145        let max_length = match dimension {
146            DEFAULT_DIMENSION => DEFAULT_MAX_LENGTH,
147            768 => BGE_MAX_LENGTH,
148            _ => DEFAULT_MAX_LENGTH,
149        };
150
151        let model_dir = resolve_model_dir(model_path.as_deref(), dimension)?;
152
153        info!(
154            model_dir = %model_dir.display(),
155            dimension,
156            max_length,
157            "Loading ONNX embedding model"
158        );
159
160        Self::load_from_dir(&model_dir, dimension, max_length)
161    }
162
163    /// Downloads the default model files to the cache directory.
164    ///
165    /// Downloads `model.onnx` and `tokenizer.json` from HuggingFace Hub
166    /// to `~/.cache/pulsedb/models/{model_name}/`.
167    ///
168    /// # Arguments
169    ///
170    /// * `dimension` - Which model to download:
171    ///   - `384` → all-MiniLM-L6-v2
172    ///   - `768` → bge-base-en-v1.5
173    ///
174    /// # Returns
175    ///
176    /// The path to the model directory.
177    pub fn download_default_model(dimension: usize) -> Result<PathBuf> {
178        let (model_name, model_url, tokenizer_url) = match dimension {
179            DEFAULT_DIMENSION => (
180                DEFAULT_MODEL_NAME,
181                "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx",
182                "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json",
183            ),
184            768 => (
185                BGE_MODEL_NAME,
186                "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx",
187                "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json",
188            ),
189            _ => {
190                return Err(PulseDBError::embedding(format!(
191                    "No default model for dimension {dimension}. \
192                     Supported: 384 (all-MiniLM-L6-v2), 768 (bge-base-en-v1.5)"
193                )));
194            }
195        };
196
197        let cache_dir = default_cache_dir(model_name);
198
199        // Create directory
200        std::fs::create_dir_all(&cache_dir).map_err(|e| {
201            PulseDBError::embedding(format!(
202                "Failed to create model cache directory {}: {e}",
203                cache_dir.display()
204            ))
205        })?;
206
207        // Acquire exclusive file lock to prevent concurrent download races.
208        // Multiple threads/processes may call this simultaneously on first run.
209        let lock_path = cache_dir.join(".download.lock");
210        let lock_file = std::fs::File::create(&lock_path)
211            .map_err(|e| PulseDBError::embedding(format!("Failed to create download lock: {e}")))?;
212        use fs2::FileExt;
213        lock_file.lock_exclusive().map_err(|e| {
214            PulseDBError::embedding(format!("Failed to acquire download lock: {e}"))
215        })?;
216
217        let model_path = cache_dir.join(MODEL_FILENAME);
218        let tokenizer_path = cache_dir.join(TOKENIZER_FILENAME);
219
220        // Double-check after acquiring lock — another process may have downloaded while we waited
221        if model_path.exists() && tokenizer_path.exists() {
222            info!(dir = %cache_dir.display(), "Model files already downloaded by another process");
223            return Ok(cache_dir);
224        }
225
226        // Download model if not present
227        if !model_path.exists() {
228            info!(url = model_url, dest = %model_path.display(), "Downloading ONNX model");
229            download_file(model_url, &model_path)?;
230        }
231
232        // Download tokenizer if not present
233        if !tokenizer_path.exists() {
234            info!(url = tokenizer_url, dest = %tokenizer_path.display(), "Downloading tokenizer");
235            download_file(tokenizer_url, &tokenizer_path)?;
236        }
237
238        info!(dir = %cache_dir.display(), "Model files ready");
239        Ok(cache_dir)
240    }
241
242    /// Loads the model and tokenizer from a directory.
243    fn load_from_dir(model_dir: &Path, dimension: usize, max_length: usize) -> Result<Self> {
244        let model_path = model_dir.join(MODEL_FILENAME);
245        let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
246
247        // Validate files exist
248        if !model_path.exists() {
249            return Err(PulseDBError::embedding(format!(
250                "Model file not found: {}. \
251                 Download with OnnxEmbedding::download_default_model({dimension}) \
252                 or provide a directory containing '{MODEL_FILENAME}'",
253                model_path.display()
254            )));
255        }
256        if !tokenizer_path.exists() {
257            return Err(PulseDBError::embedding(format!(
258                "Tokenizer file not found: {}. \
259                 The model directory must contain '{TOKENIZER_FILENAME}'",
260                tokenizer_path.display()
261            )));
262        }
263
264        let session = create_session(&model_path)?;
265        let tokenizer = load_tokenizer(&tokenizer_path, max_length)?;
266
267        debug!(dimension, max_length, "ONNX embedding model loaded");
268
269        Ok(Self {
270            session: Mutex::new(session),
271            tokenizer,
272            dimension,
273            max_length,
274        })
275    }
276}
277
278impl EmbeddingService for OnnxEmbedding {
279    fn embed(&self, text: &str) -> Result<Embedding> {
280        if text.is_empty() {
281            return Err(PulseDBError::embedding("Cannot embed empty text"));
282        }
283
284        // 1. Tokenize: text → token IDs + attention mask
285        let encoding = self
286            .tokenizer
287            .encode(text, true)
288            .map_err(|e| PulseDBError::embedding(format!("Tokenization failed: {e}")))?;
289
290        let ids = encoding.get_ids();
291        let mask = encoding.get_attention_mask();
292
293        // 2. Truncate to model's max sequence length
294        let len = ids.len().min(self.max_length);
295
296        // 3. Build input tensors [1, seq_len]
297        let input_ids: Vec<i64> = ids[..len].iter().map(|&x| x as i64).collect();
298        let attention_mask: Vec<i64> = mask[..len].iter().map(|&x| x as i64).collect();
299        let token_type_ids: Vec<i64> = vec![0i64; len];
300
301        let ids_array = Array2::from_shape_vec((1, len), input_ids)
302            .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
303        let mask_array = Array2::from_shape_vec((1, len), attention_mask.clone())
304            .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
305        let type_array = Array2::from_shape_vec((1, len), token_type_ids)
306            .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
307
308        // 4. Create ONNX tensor values from ndarray
309        let ids_tensor = ort::value::Tensor::from_array(ids_array)
310            .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
311        let mask_tensor = ort::value::Tensor::from_array(mask_array)
312            .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
313        let type_tensor = ort::value::Tensor::from_array(type_array)
314            .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
315
316        // 5. Run ONNX inference (lock session for mutable access)
317        let mut session = self
318            .session
319            .lock()
320            .map_err(|e| PulseDBError::embedding(format!("Session lock poisoned: {e}")))?;
321        let outputs = session
322            .run(ort::inputs![
323                "input_ids" => ids_tensor,
324                "attention_mask" => mask_tensor,
325                "token_type_ids" => type_tensor,
326            ])
327            .map_err(|e| PulseDBError::embedding(format!("ONNX inference failed: {e}")))?;
328
329        // 6. Extract token embeddings [1, seq_len, dim]
330        let token_embeddings = outputs[0]
331            .try_extract_tensor::<f32>()
332            .map_err(|e| PulseDBError::embedding(format!("Output extraction failed: {e}")))?;
333
334        // Convert attention mask for pooling
335        let mask_u32: Vec<u32> = attention_mask.iter().map(|&x| x as u32).collect();
336
337        // 7. Mean pool → [dim], then L2 normalize
338        let pooled = mean_pool_raw(token_embeddings.1, &mask_u32, self.dimension, len);
339        Ok(l2_normalize(&pooled))
340    }
341
342    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
343        if texts.is_empty() {
344            return Ok(vec![]);
345        }
346        if texts.len() == 1 {
347            return Ok(vec![self.embed(texts[0])?]);
348        }
349
350        // 1. Tokenize all texts
351        let encodings: Vec<_> = texts
352            .iter()
353            .map(|t| self.tokenizer.encode(*t, true))
354            .collect::<std::result::Result<Vec<_>, _>>()
355            .map_err(|e| PulseDBError::embedding(format!("Batch tokenization failed: {e}")))?;
356
357        // 2. Pad to longest sequence in batch (not max_length — saves compute)
358        let max_len = encodings
359            .iter()
360            .map(|enc| enc.get_ids().len().min(self.max_length))
361            .max()
362            .unwrap_or(0);
363
364        let batch_size = texts.len();
365
366        // 3. Build padded tensors [batch_size, max_len]
367        let mut input_ids = vec![0i64; batch_size * max_len];
368        let mut attention_mask = vec![0i64; batch_size * max_len];
369        let token_type_ids = vec![0i64; batch_size * max_len];
370
371        for (i, enc) in encodings.iter().enumerate() {
372            let ids = enc.get_ids();
373            let mask = enc.get_attention_mask();
374            let len = ids.len().min(self.max_length);
375
376            for j in 0..len {
377                input_ids[i * max_len + j] = ids[j] as i64;
378                attention_mask[i * max_len + j] = mask[j] as i64;
379            }
380        }
381
382        let ids_array = Array2::from_shape_vec((batch_size, max_len), input_ids)
383            .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
384        let mask_array = Array2::from_shape_vec((batch_size, max_len), attention_mask.clone())
385            .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
386        let type_array = Array2::from_shape_vec((batch_size, max_len), token_type_ids)
387            .map_err(|e| PulseDBError::embedding(format!("Tensor shape error: {e}")))?;
388
389        // 4. Create ONNX tensor values
390        let ids_tensor = ort::value::Tensor::from_array(ids_array)
391            .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
392        let mask_tensor = ort::value::Tensor::from_array(mask_array)
393            .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
394        let type_tensor = ort::value::Tensor::from_array(type_array)
395            .map_err(|e| PulseDBError::embedding(format!("Tensor creation failed: {e}")))?;
396
397        // 5. Run batched inference (lock session for mutable access)
398        let mut session = self
399            .session
400            .lock()
401            .map_err(|e| PulseDBError::embedding(format!("Session lock poisoned: {e}")))?;
402        let outputs = session
403            .run(ort::inputs![
404                "input_ids" => ids_tensor,
405                "attention_mask" => mask_tensor,
406                "token_type_ids" => type_tensor,
407            ])
408            .map_err(|e| PulseDBError::embedding(format!("ONNX inference failed: {e}")))?;
409
410        // 6. Extract [batch_size, max_len, dim]
411        let token_embeddings = outputs[0]
412            .try_extract_tensor::<f32>()
413            .map_err(|e| PulseDBError::embedding(format!("Output extraction failed: {e}")))?;
414
415        let (_shape, data) = token_embeddings;
416
417        // 7. Per-text mean pooling + L2 normalization
418        let mut results = Vec::with_capacity(batch_size);
419        for i in 0..batch_size {
420            let text_mask: Vec<u32> = (0..max_len)
421                .map(|j| attention_mask[i * max_len + j] as u32)
422                .collect();
423
424            // Extract this text's token embeddings from the flat data
425            let offset = i * max_len * self.dimension;
426            let text_data = &data[offset..offset + max_len * self.dimension];
427
428            let pooled = mean_pool_raw(text_data, &text_mask, self.dimension, max_len);
429            results.push(l2_normalize(&pooled));
430        }
431
432        Ok(results)
433    }
434
435    fn dimension(&self) -> usize {
436        self.dimension
437    }
438}
439
440// ---------------------------------------------------------------------------
441// Helper functions
442// ---------------------------------------------------------------------------
443
444/// Creates an ONNX Runtime session with optimized settings.
445fn create_session(model_path: &Path) -> Result<Session> {
446    Session::builder()
447        .map_err(|e| PulseDBError::embedding(format!("Failed to create session builder: {e}")))?
448        // Level3: all optimizations (operator fusion, constant folding, etc.)
449        .with_optimization_level(GraphOptimizationLevel::Level3)
450        .map_err(|e| PulseDBError::embedding(format!("Failed to set optimization level: {e}")))?
451        .commit_from_file(model_path)
452        .map_err(|e| {
453            PulseDBError::embedding(format!(
454                "Failed to load ONNX model from {}: {e}",
455                model_path.display()
456            ))
457        })
458}
459
460/// Loads a HuggingFace tokenizer from a tokenizer.json file.
461fn load_tokenizer(tokenizer_path: &Path, max_length: usize) -> Result<Tokenizer> {
462    let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
463        PulseDBError::embedding(format!(
464            "Failed to load tokenizer from {}: {e}",
465            tokenizer_path.display()
466        ))
467    })?;
468
469    // Configure truncation to model's max sequence length
470    tokenizer
471        .with_truncation(Some(tokenizers::TruncationParams {
472            max_length,
473            strategy: tokenizers::TruncationStrategy::LongestFirst,
474            ..Default::default()
475        }))
476        .map_err(|e| PulseDBError::embedding(format!("Failed to set truncation: {e}")))?;
477
478    // Disable padding — we handle padding manually in embed_batch()
479    // for smart padding (pad to longest in batch, not max_length)
480    tokenizer.with_padding(None);
481
482    Ok(tokenizer)
483}
484
485/// Resolves the model directory from an optional user path or default cache.
486fn resolve_model_dir(model_path: Option<&Path>, dimension: usize) -> Result<PathBuf> {
487    match model_path {
488        Some(path) => {
489            if !path.exists() {
490                return Err(PulseDBError::embedding(format!(
491                    "Model directory not found: {}",
492                    path.display()
493                )));
494            }
495            Ok(path.to_path_buf())
496        }
497        None => {
498            // Determine model name from dimension
499            let model_name = match dimension {
500                DEFAULT_DIMENSION => DEFAULT_MODEL_NAME,
501                768 => BGE_MODEL_NAME,
502                _ => {
503                    return Err(PulseDBError::embedding(format!(
504                        "No default model for dimension {dimension}. \
505                         Provide a model_path for custom dimensions, \
506                         or use 384 (all-MiniLM-L6-v2) or 768 (bge-base-en-v1.5)"
507                    )));
508                }
509            };
510
511            let cache_dir = default_cache_dir(model_name);
512
513            if !cache_dir.join(MODEL_FILENAME).exists() {
514                return Err(PulseDBError::embedding(format!(
515                    "Model not found at {}. \
516                     Download with: OnnxEmbedding::download_default_model({dimension})",
517                    cache_dir.display()
518                )));
519            }
520
521            Ok(cache_dir)
522        }
523    }
524}
525
526/// Returns the default cache directory for a model.
527///
528/// Platform-specific:
529/// - Linux: `~/.cache/pulsedb/models/{name}/`
530/// - macOS: `~/Library/Caches/pulsedb/models/{name}/`
531/// - Windows: `{LOCALAPPDATA}/pulsedb/models/{name}/`
532fn default_cache_dir(model_name: &str) -> PathBuf {
533    dirs::cache_dir()
534        .unwrap_or_else(|| PathBuf::from(".cache"))
535        .join("pulsedb")
536        .join("models")
537        .join(model_name)
538}
539
540/// Mean pooling over token embeddings from flat data.
541///
542/// Computes the attention-weighted average of token embeddings to produce
543/// a single sentence embedding. Only tokens with mask=1 contribute.
544///
545/// The data is laid out as `[seq_len * dim]` in row-major order, where
546/// each contiguous block of `dim` floats is one token's embedding.
547///
548/// # Arguments
549///
550/// * `data` - Flat f32 slice of shape `[seq_len, dim]`
551/// * `attention_mask` - Shape `[seq_len]`, 1 for real tokens, 0 for padding
552/// * `dim` - Embedding dimension
553/// * `seq_len` - Number of tokens
554fn mean_pool_raw(data: &[f32], attention_mask: &[u32], dim: usize, seq_len: usize) -> Vec<f32> {
555    let mut pooled = vec![0.0f32; dim];
556    let mut mask_sum = 0.0f32;
557
558    for (t, &mask_val) in attention_mask.iter().enumerate().take(seq_len) {
559        let weight = mask_val as f32;
560        mask_sum += weight;
561        let offset = t * dim;
562        for d in 0..dim {
563            pooled[d] += data[offset + d] * weight;
564        }
565    }
566
567    // Divide by number of real tokens (avoid division by zero)
568    if mask_sum > 0.0 {
569        for val in &mut pooled {
570            *val /= mask_sum;
571        }
572    }
573
574    pooled
575}
576
577/// L2 normalizes a vector to unit length.
578///
579/// After normalization, the vector has magnitude 1.0, which means
580/// cosine similarity can be computed as a simple dot product:
581/// `cos(a, b) = a · b` when `|a| = |b| = 1`.
582fn l2_normalize(v: &[f32]) -> Vec<f32> {
583    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
584    if norm > 0.0 {
585        v.iter().map(|x| x / norm).collect()
586    } else {
587        v.to_vec()
588    }
589}
590
591/// Downloads a file from a URL to a local path.
592///
593/// Uses atomic write (temp file + rename) to prevent partial downloads
594/// from leaving corrupted files that block future retry attempts.
595fn download_file(url: &str, dest: &Path) -> Result<()> {
596    let response = ureq::get(url)
597        .call()
598        .map_err(|e| PulseDBError::embedding(format!("Download failed for {url}: {e}")))?;
599
600    // Write to temp file first — rename on success prevents partial corruption
601    let temp = dest.with_extension("tmp");
602    let mut reader = response.into_body().into_reader();
603    let mut file = std::fs::File::create(&temp).map_err(|e| {
604        PulseDBError::embedding(format!("Failed to create file {}: {e}", temp.display()))
605    })?;
606
607    if let Err(e) = std::io::copy(&mut reader, &mut file) {
608        let _ = std::fs::remove_file(&temp);
609        return Err(PulseDBError::embedding(format!(
610            "Failed to write to {}: {e}",
611            dest.display()
612        )));
613    }
614
615    // Atomic rename — only the complete file appears at the destination
616    std::fs::rename(&temp, dest).map_err(|e| {
617        let _ = std::fs::remove_file(&temp);
618        PulseDBError::embedding(format!(
619            "Failed to finalize download {}: {e}",
620            dest.display()
621        ))
622    })?;
623
624    Ok(())
625}
626
627// ---------------------------------------------------------------------------
628// Tests
629// ---------------------------------------------------------------------------
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634
635    // --- L2 normalization tests ---
636
637    #[test]
638    fn test_l2_normalize_basic() {
639        let v = vec![3.0, 4.0];
640        let normalized = l2_normalize(&v);
641        // norm = sqrt(9 + 16) = 5
642        assert!((normalized[0] - 0.6).abs() < 1e-6);
643        assert!((normalized[1] - 0.8).abs() < 1e-6);
644
645        // Verify unit length
646        let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
647        assert!((norm - 1.0).abs() < 1e-6);
648    }
649
650    #[test]
651    fn test_l2_normalize_zero_vector() {
652        let v = vec![0.0, 0.0, 0.0];
653        let normalized = l2_normalize(&v);
654        // Zero vector stays zero (no division by zero)
655        assert_eq!(normalized, vec![0.0, 0.0, 0.0]);
656    }
657
658    #[test]
659    fn test_l2_normalize_already_unit() {
660        let v = vec![1.0, 0.0, 0.0];
661        let normalized = l2_normalize(&v);
662        assert!((normalized[0] - 1.0).abs() < 1e-6);
663        assert!((normalized[1] - 0.0).abs() < 1e-6);
664    }
665
666    // --- Mean pooling tests ---
667
668    #[test]
669    fn test_mean_pool_uniform_mask() {
670        // All tokens are real (mask = all ones)
671        // 2 tokens, 3 dimensions → average of both
672        let data = vec![
673            1.0, 2.0, 3.0, // token 0
674            5.0, 6.0, 7.0, // token 1
675        ];
676        let mask = vec![1u32, 1];
677
678        let pooled = mean_pool_raw(&data, &mask, 3, 2);
679        // Average: [(1+5)/2, (2+6)/2, (3+7)/2] = [3, 4, 5]
680        assert!((pooled[0] - 3.0).abs() < 1e-6);
681        assert!((pooled[1] - 4.0).abs() < 1e-6);
682        assert!((pooled[2] - 5.0).abs() < 1e-6);
683    }
684
685    #[test]
686    fn test_mean_pool_partial_mask() {
687        // Only first token is real, second is padding
688        let data = vec![
689            1.0, 2.0, 3.0, // token 0 (real)
690            99.0, 99.0, 99.0, // token 1 (padding — should be ignored)
691        ];
692        let mask = vec![1u32, 0]; // Only token 0 counts
693
694        let pooled = mean_pool_raw(&data, &mask, 3, 2);
695        // Only token 0 contributes: [1, 2, 3]
696        assert!((pooled[0] - 1.0).abs() < 1e-6);
697        assert!((pooled[1] - 2.0).abs() < 1e-6);
698        assert!((pooled[2] - 3.0).abs() < 1e-6);
699    }
700
701    #[test]
702    fn test_mean_pool_zero_mask() {
703        // Edge case: all tokens masked (shouldn't happen in practice)
704        let data = vec![99.0, 99.0, 99.0];
705        let mask = vec![0u32];
706
707        let pooled = mean_pool_raw(&data, &mask, 3, 1);
708        // All zeros (no tokens contribute)
709        assert_eq!(pooled, vec![0.0, 0.0, 0.0]);
710    }
711
712    // --- Path resolution tests ---
713
714    #[test]
715    fn test_resolve_model_dir_custom_path_missing() {
716        let result = resolve_model_dir(Some(Path::new("/nonexistent/path")), 384);
717        assert!(result.is_err());
718        let err = result.unwrap_err().to_string();
719        assert!(err.contains("not found"), "Error: {err}");
720    }
721
722    #[test]
723    fn test_resolve_model_dir_unsupported_dimension() {
724        let result = resolve_model_dir(None, 999);
725        assert!(result.is_err());
726        let err = result.unwrap_err().to_string();
727        assert!(err.contains("No default model"), "Error: {err}");
728    }
729
730    #[test]
731    fn test_default_cache_dir_format() {
732        let dir = default_cache_dir("test-model");
733        // Should end with pulsedb/models/test-model
734        let path_str = dir.to_string_lossy();
735        assert!(path_str.contains("pulsedb"), "Path: {path_str}");
736        assert!(path_str.contains("models"), "Path: {path_str}");
737        assert!(path_str.contains("test-model"), "Path: {path_str}");
738    }
739
740    // --- Thread safety ---
741
742    #[test]
743    fn test_onnx_embedding_is_send_sync() {
744        fn assert_send_sync<T: Send + Sync>() {}
745        assert_send_sync::<OnnxEmbedding>();
746    }
747}