Skip to main content

mag/memory_core/
embedder.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use sha2::{Digest, Sha256};
5
6use crate::app_paths;
7
8/// Trait for generating text embeddings.
9pub trait Embedder: Send + Sync {
10    /// Returns the embedding dimension.
11    fn dimension(&self) -> usize;
12    /// Generates an embedding vector for the given text.
13    fn embed(&self, text: &str) -> Result<Vec<f32>>;
14    /// Generates embedding vectors for multiple texts in a single call.
15    /// The default implementation calls `embed()` in a loop; backends that
16    /// support true batched inference (e.g. ONNX) override this for better
17    /// throughput.
18    #[allow(dead_code)]
19    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
20        texts.iter().map(|t| self.embed(t)).collect()
21    }
22}
23
24#[derive(Debug, Default, Clone)]
25#[allow(dead_code)]
26pub struct PlaceholderEmbedder;
27
28impl Embedder for PlaceholderEmbedder {
29    fn dimension(&self) -> usize {
30        32
31    }
32
33    fn embed(&self, text: &str) -> Result<Vec<f32>> {
34        Ok(embedding_for_text(text))
35    }
36}
37
38#[allow(dead_code)]
39pub(crate) fn embedding_for_text(input: &str) -> Vec<f32> {
40    let mut hasher = Sha256::new();
41    hasher.update(input.as_bytes());
42    let digest = hasher.finalize();
43    let mut vec: Vec<f32> = digest.iter().map(|b| *b as f32 / 255.0).collect();
44    normalize_embedding(&mut vec);
45    vec
46}
47
48pub(crate) fn normalize_embedding(vec: &mut [f32]) {
49    let norm = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
50    if norm > 0.0 {
51        for value in vec {
52            *value /= norm;
53        }
54    }
55}
56
57#[cfg(feature = "real-embeddings")]
58const MODEL_NAME: &str = "bge-small-en-v1.5-int8";
59#[cfg(feature = "real-embeddings")]
60const MODEL_URL: &str =
61    "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/onnx/model_int8.onnx";
62#[cfg(feature = "real-embeddings")]
63const TOKENIZER_URL: &str =
64    "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/tokenizer.json";
65
66#[cfg(feature = "real-embeddings")]
67const EMBEDDING_CACHE_CAPACITY: std::num::NonZeroUsize = std::num::NonZeroUsize::new(2048).unwrap();
68
69#[cfg(feature = "real-embeddings")]
70const IDLE_TIMEOUT_SECS: u64 = 600; // 10 minutes
71
72#[cfg(feature = "real-embeddings")]
73#[derive(Debug)]
74pub struct OnnxEmbedder {
75    model_dir: PathBuf,
76    model_url: String,
77    model_data_url: Option<String>,
78    tokenizer_url: String,
79    dimension: usize,
80    output_tensor_name: String,
81    use_token_type_ids: bool,
82    runtime: std::sync::Mutex<Option<OnnxRuntime>>,
83    last_used: std::sync::atomic::AtomicU64,
84    cache: std::sync::Mutex<lru::LruCache<[u8; 32], Vec<f32>>>,
85}
86
87#[cfg(feature = "real-embeddings")]
88#[derive(Debug)]
89struct OnnxRuntime {
90    session: std::sync::Mutex<ort::session::Session>,
91    tokenizer: tokenizers::Tokenizer,
92}
93
94#[cfg(feature = "real-embeddings")]
95#[derive(Debug, Clone)]
96struct ModelFiles {
97    directory: PathBuf,
98    model_path: PathBuf,
99    model_data_path: Option<PathBuf>,
100    tokenizer_path: PathBuf,
101}
102
103#[cfg(feature = "real-embeddings")]
104impl OnnxEmbedder {
105    pub fn new() -> Result<Self> {
106        Self::with_model(
107            MODEL_NAME,
108            MODEL_URL,
109            TOKENIZER_URL,
110            384,
111            "last_hidden_state",
112        )
113    }
114
115    pub fn with_model(
116        name: &str,
117        model_url: &str,
118        tokenizer_url: &str,
119        dimension: usize,
120        output_tensor_name: &str,
121    ) -> Result<Self> {
122        Self::with_model_and_data(
123            name,
124            model_url,
125            None,
126            tokenizer_url,
127            dimension,
128            output_tensor_name,
129            true,
130        )
131    }
132
133    pub fn with_model_and_data(
134        name: &str,
135        model_url: &str,
136        model_data_url: Option<&str>,
137        tokenizer_url: &str,
138        dimension: usize,
139        output_tensor_name: &str,
140        use_token_type_ids: bool,
141    ) -> Result<Self> {
142        let model_dir = app_paths::resolve_app_paths()?.model_root.join(name);
143        Ok(Self {
144            model_dir,
145            model_url: model_url.to_string(),
146            model_data_url: model_data_url.map(str::to_string),
147            tokenizer_url: tokenizer_url.to_string(),
148            dimension,
149            output_tensor_name: output_tensor_name.to_string(),
150            use_token_type_ids,
151            runtime: std::sync::Mutex::new(None),
152            last_used: std::sync::atomic::AtomicU64::new(0),
153            cache: std::sync::Mutex::new(lru::LruCache::new(EMBEDDING_CACHE_CAPACITY)),
154        })
155    }
156
157    fn epoch_secs() -> u64 {
158        std::time::SystemTime::now()
159            .duration_since(std::time::UNIX_EPOCH)
160            .map(|d| d.as_secs())
161            .unwrap_or(0)
162    }
163
164    fn touch_last_used(&self) {
165        self.last_used
166            .store(Self::epoch_secs(), std::sync::atomic::Ordering::Relaxed);
167    }
168
169    /// Eagerly load the ONNX session so the first `embed()` call doesn't pay
170    /// the cold-start penalty. Must be called from an async context (uses
171    /// `spawn_blocking` internally since ONNX init creates a mini runtime).
172    pub async fn warmup(self: &std::sync::Arc<Self>) -> Result<()> {
173        {
174            let guard = self
175                .runtime
176                .lock()
177                .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
178            if guard.is_some() {
179                return Ok(());
180            }
181        }
182        let this = std::sync::Arc::clone(self);
183        tokio::task::spawn_blocking(move || {
184            let mut guard = this
185                .runtime
186                .lock()
187                .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
188            if guard.is_none() {
189                let rt = this.init_runtime()?;
190                *guard = Some(rt);
191                this.touch_last_used();
192            }
193            Ok::<_, anyhow::Error>(())
194        })
195        .await
196        .context("spawn_blocking join error")?
197    }
198
199    /// Drops the ONNX session if it has been idle for longer than the timeout,
200    /// freeing ~240 MB RSS. The LRU embedding cache is preserved. The session
201    /// is re-initialised transparently on the next `embed()` call.
202    pub fn try_unload_if_idle(&self) -> bool {
203        let last = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
204        if last == 0 {
205            return false; // never loaded
206        }
207        if Self::epoch_secs().saturating_sub(last) < IDLE_TIMEOUT_SECS {
208            return false;
209        }
210        if let Ok(mut guard) = self.runtime.lock()
211            && guard.is_some()
212        {
213            *guard = None;
214            tracing::info!("unloaded idle ONNX session after {IDLE_TIMEOUT_SECS}s");
215            return true;
216        }
217        false
218    }
219
220    /// Periodic maintenance entry-point. Call from a tokio interval timer.
221    pub async fn maintenance_tick(self: &std::sync::Arc<Self>) {
222        let this = std::sync::Arc::clone(self);
223        let _ = tokio::task::spawn_blocking(move || {
224            this.try_unload_if_idle();
225        })
226        .await;
227    }
228
229    fn init_runtime(&self) -> Result<OnnxRuntime> {
230        let files = ensure_model_files_blocking(
231            self.model_dir.clone(),
232            &self.model_url,
233            self.model_data_url.as_deref(),
234            &self.tokenizer_url,
235        )?;
236        // Force CPU-only execution (skip CoreML/Metal which leak memory on
237        // long-running macOS processes) and disable the CPU memory arena to
238        // reduce RSS by ~50 MB.
239        let cpu_ep = ort::ep::CPU::default().with_arena_allocator(false).build();
240        let session = ort::session::Session::builder()?
241            .with_execution_providers([cpu_ep])?
242            .with_intra_threads(num_cpus::get())?
243            .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
244            .commit_from_file(&files.model_path)
245            .with_context(|| {
246                format!(
247                    "failed to create ONNX session from {}",
248                    files.model_path.display()
249                )
250            })?;
251        let mut tokenizer = tokenizers::Tokenizer::from_file(&files.tokenizer_path)
252            .map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
253        // bge-small-en-v1.5 supports max 512 tokens. Truncate longer inputs to avoid
254        // ONNX positional-encoding broadcast errors.
255        tokenizer
256            .with_truncation(Some(tokenizers::TruncationParams {
257                max_length: 512,
258                ..Default::default()
259            }))
260            .map_err(|e| anyhow!("failed to configure tokenizer truncation: {e}"))?;
261        Ok(OnnxRuntime {
262            session: std::sync::Mutex::new(session),
263            tokenizer,
264        })
265    }
266}
267
268#[cfg(feature = "real-embeddings")]
269impl Embedder for OnnxEmbedder {
270    fn dimension(&self) -> usize {
271        self.dimension
272    }
273
274    fn embed(&self, text: &str) -> Result<Vec<f32>> {
275        // Cache lookup by SHA256 hash of input text
276        let mut hasher = Sha256::new();
277        hasher.update(text.as_bytes());
278        let key: [u8; 32] = hasher.finalize().into();
279
280        match self.cache.lock() {
281            Ok(mut cache) => {
282                if let Some(cached) = cache.get(&key) {
283                    return Ok(cached.clone());
284                }
285            }
286            Err(_) => tracing::warn!("embedding cache mutex poisoned, bypassing cache"),
287        }
288
289        // Cache miss — compute embedding.
290        // Acquire the runtime, initialising on demand if it was unloaded.
291        // Scoped so all ONNX borrows are released before caching.
292        let pooled = {
293            let mut rt_guard = self
294                .runtime
295                .lock()
296                .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
297            if rt_guard.is_none() {
298                *rt_guard = Some(self.init_runtime()?);
299                self.touch_last_used();
300            }
301            let runtime = rt_guard
302                .as_ref()
303                .ok_or_else(|| anyhow!("runtime missing after init"))?;
304
305            let encoding = runtime
306                .tokenizer
307                .encode(text, true)
308                .map_err(|e| anyhow!("tokenization failed: {e}"))?;
309            let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
310            let attention_mask: Vec<i64> = encoding
311                .get_attention_mask()
312                .iter()
313                .map(|&m| m as i64)
314                .collect();
315            if input_ids.is_empty() || input_ids.len() != attention_mask.len() {
316                return Err(anyhow!("invalid tokenization output for embedding"));
317            }
318
319            let seq_len = input_ids.len();
320            let input_ids_value = ort::value::Value::from_array(([1_usize, seq_len], input_ids))
321                .context("failed to create ONNX input_ids value")?;
322            let attention_mask_value =
323                ort::value::Value::from_array(([1_usize, seq_len], attention_mask))
324                    .context("failed to create ONNX attention_mask value")?;
325
326            let mut session = runtime
327                .session
328                .lock()
329                .map_err(|_| anyhow!("onnx session mutex poisoned"))?;
330            let outputs = if self.use_token_type_ids {
331                let token_type_ids_value =
332                    ort::value::Value::from_array(([1_usize, seq_len], vec![0_i64; seq_len]))
333                        .context("failed to create ONNX token_type_ids value")?;
334                session
335                    .run(ort::inputs![
336                        input_ids_value,
337                        attention_mask_value,
338                        token_type_ids_value
339                    ])
340                    .context("ONNX inference failed")?
341            } else {
342                session
343                    .run(ort::inputs![input_ids_value, attention_mask_value])
344                    .context("ONNX inference failed")?
345            };
346            let first_output = outputs
347                .get(self.output_tensor_name.as_str())
348                .ok_or_else(|| {
349                    anyhow!("missing ONNX output tensor '{}'", self.output_tensor_name)
350                })?;
351            let (shape, output) = first_output
352                .try_extract_tensor::<f32>()
353                .context("failed to extract ONNX output tensor")?;
354
355            // Support both 2D pre-pooled tensors (shape [1, hidden]) and
356            // 3D token-level tensors (shape [1, seq_len, hidden]) that need mean-pooling.
357            let mut pooled = if shape.len() == 2 {
358                // Pre-pooled output (e.g. onnx-community/voyage-4-nano-ONNX `pooler_output`).
359                // Shape: [1, hidden_size] — already mean-pooled and L2-normalised by the model.
360                if shape[0] != 1 {
361                    return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
362                }
363                let hidden_size =
364                    usize::try_from(shape[1]).context("invalid output hidden size")?;
365                if hidden_size < self.dimension {
366                    return Err(anyhow!(
367                        "ONNX output dim {hidden_size} is smaller than requested dim {}",
368                        self.dimension
369                    ));
370                }
371                // Matryoshka truncation: slice to requested dimension then re-normalise.
372                output[..self.dimension].to_vec()
373            } else if shape.len() == 3 {
374                // Token-level output — apply mean pooling with attention mask.
375                if shape[0] != 1 {
376                    return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
377                }
378                let output_seq_len =
379                    usize::try_from(shape[1]).context("invalid output sequence length")?;
380                let hidden_size =
381                    usize::try_from(shape[2]).context("invalid output hidden size")?;
382                if hidden_size < self.dimension {
383                    return Err(anyhow!(
384                        "ONNX output dim {hidden_size} smaller than requested dim {}",
385                        self.dimension
386                    ));
387                }
388                if output_seq_len == 0 {
389                    return Err(anyhow!("ONNX output sequence length is zero"));
390                }
391                let effective_len = output_seq_len.min(seq_len);
392                let mut pooled = vec![0.0f32; self.dimension];
393                let mut mask_sum = 0.0f32;
394                for token_idx in 0..effective_len {
395                    #[allow(clippy::cast_precision_loss)]
396                    let mask_value = encoding.get_attention_mask()[token_idx] as f32;
397                    if mask_value <= 0.0 {
398                        continue;
399                    }
400                    mask_sum += mask_value;
401                    for (d, pooled_value) in pooled.iter_mut().enumerate() {
402                        let flat_index = token_idx * hidden_size + d;
403                        *pooled_value += output[flat_index] * mask_value;
404                    }
405                }
406                if mask_sum <= 0.0 {
407                    return Err(anyhow!("attention mask sum is zero during mean pooling"));
408                }
409                for value in &mut pooled {
410                    *value /= mask_sum;
411                }
412                pooled
413            } else {
414                return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
415            };
416            normalize_embedding(&mut pooled);
417            pooled
418        };
419        self.touch_last_used();
420
421        // Cache the result before returning
422        let result = pooled.clone();
423        match self.cache.lock() {
424            Ok(mut cache) => {
425                cache.put(key, pooled);
426            }
427            Err(_) => tracing::warn!("embedding cache mutex poisoned, bypassing cache"),
428        }
429
430        Ok(result)
431    }
432
433    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
434        if texts.is_empty() {
435            return Ok(Vec::new());
436        }
437        // Single-element batch: delegate to the optimised single-text path.
438        if texts.len() == 1 {
439            return Ok(vec![self.embed(texts[0])?]);
440        }
441
442        // --- Cache probe: split into hits and misses ---
443        let mut keys: Vec<[u8; 32]> = Vec::with_capacity(texts.len());
444        for text in texts {
445            let mut hasher = Sha256::new();
446            hasher.update(text.as_bytes());
447            keys.push(hasher.finalize().into());
448        }
449
450        // `results[i]` = Some(embedding) if cached, None if needs compute.
451        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
452        // Indices (into `texts`) that still need ONNX inference.
453        let mut miss_indices: Vec<usize> = Vec::new();
454
455        match self.cache.lock() {
456            Ok(mut cache) => {
457                for (i, key) in keys.iter().enumerate() {
458                    if let Some(cached) = cache.get(key) {
459                        results[i] = Some(cached.clone());
460                    } else {
461                        miss_indices.push(i);
462                    }
463                }
464            }
465            Err(_) => {
466                tracing::warn!("embedding cache mutex poisoned, bypassing cache");
467                miss_indices.extend(0..texts.len());
468            }
469        }
470
471        // All hits — return immediately.
472        if miss_indices.is_empty() {
473            return results
474                .into_iter()
475                .map(|opt| opt.ok_or_else(|| anyhow!("unexpected None in cache-hit path")))
476                .collect();
477        }
478
479        // --- Batched ONNX inference for cache misses ---
480        let computed = {
481            let mut rt_guard = self
482                .runtime
483                .lock()
484                .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
485            if rt_guard.is_none() {
486                *rt_guard = Some(self.init_runtime()?);
487                self.touch_last_used();
488            }
489            let runtime = rt_guard
490                .as_ref()
491                .ok_or_else(|| anyhow!("runtime missing after init"))?;
492
493            // Tokenize all miss texts.
494            let miss_texts: Vec<&str> = miss_indices.iter().map(|&i| texts[i]).collect();
495            let encodings: Vec<tokenizers::Encoding> = miss_texts
496                .iter()
497                .map(|t| {
498                    runtime
499                        .tokenizer
500                        .encode(*t, true)
501                        .map_err(|e| anyhow!("tokenization failed: {e}"))
502                })
503                .collect::<Result<Vec<_>>>()?;
504
505            // Find the maximum sequence length across the batch for padding.
506            let max_len = encodings
507                .iter()
508                .map(|enc| enc.get_ids().len())
509                .max()
510                .ok_or_else(|| anyhow!("empty encodings in batch"))?;
511            if max_len == 0 {
512                return Err(anyhow!("all tokenizations produced zero-length sequences"));
513            }
514
515            let batch_size = encodings.len();
516
517            // Build padded flat tensors: [batch_size * max_len].
518            let mut flat_input_ids = vec![0_i64; batch_size * max_len];
519            let mut flat_attention_mask = vec![0_i64; batch_size * max_len];
520
521            for (b, enc) in encodings.iter().enumerate() {
522                let ids = enc.get_ids();
523                let mask = enc.get_attention_mask();
524                let seq_len = ids.len();
525                if seq_len != mask.len() {
526                    return Err(anyhow!(
527                        "tokenization ids/mask length mismatch for batch item {b}"
528                    ));
529                }
530                let offset = b * max_len;
531                for j in 0..seq_len {
532                    flat_input_ids[offset + j] = ids[j] as i64;
533                    flat_attention_mask[offset + j] = mask[j] as i64;
534                }
535                // Remaining positions stay 0 (padding).
536            }
537
538            let input_ids_value =
539                ort::value::Value::from_array(([batch_size, max_len], flat_input_ids))
540                    .context("failed to create batched ONNX input_ids value")?;
541            let attention_mask_value =
542                ort::value::Value::from_array(([batch_size, max_len], flat_attention_mask))
543                    .context("failed to create batched ONNX attention_mask value")?;
544
545            let mut session = runtime
546                .session
547                .lock()
548                .map_err(|_| anyhow!("onnx session mutex poisoned"))?;
549            let outputs = if self.use_token_type_ids {
550                let token_type_ids_value = ort::value::Value::from_array((
551                    [batch_size, max_len],
552                    vec![0_i64; batch_size * max_len],
553                ))
554                .context("failed to create batched ONNX token_type_ids value")?;
555                session
556                    .run(ort::inputs![
557                        input_ids_value,
558                        attention_mask_value,
559                        token_type_ids_value
560                    ])
561                    .context("batched ONNX inference failed")?
562            } else {
563                session
564                    .run(ort::inputs![input_ids_value, attention_mask_value])
565                    .context("batched ONNX inference failed")?
566            };
567            let first_output = outputs
568                .get(self.output_tensor_name.as_str())
569                .ok_or_else(|| {
570                    anyhow!("missing ONNX output tensor '{}'", self.output_tensor_name)
571                })?;
572            let (shape, output) = first_output
573                .try_extract_tensor::<f32>()
574                .context("failed to extract batched ONNX output tensor")?;
575
576            // Support both 2D pre-pooled tensors (shape [batch, hidden]) and
577            // 3D token-level tensors (shape [batch, seq_len, hidden]).
578            let mut batch_embeddings: Vec<Vec<f32>> = Vec::with_capacity(batch_size);
579            if shape.len() == 2 {
580                // Pre-pooled output — shape [batch, hidden_size].
581                let out_batch =
582                    usize::try_from(shape[0]).context("invalid output batch dimension")?;
583                let hidden_size =
584                    usize::try_from(shape[1]).context("invalid output hidden size")?;
585                if out_batch != batch_size {
586                    return Err(anyhow!(
587                        "output batch size mismatch: got {out_batch}, expected {batch_size}"
588                    ));
589                }
590                if hidden_size < self.dimension {
591                    return Err(anyhow!(
592                        "ONNX output dim {hidden_size} is smaller than requested dim {}",
593                        self.dimension
594                    ));
595                }
596                for b in 0..batch_size {
597                    let row_start = b * hidden_size;
598                    // Matryoshka truncation + re-normalise.
599                    let mut pooled = output[row_start..row_start + self.dimension].to_vec();
600                    normalize_embedding(&mut pooled);
601                    batch_embeddings.push(pooled);
602                }
603            } else if shape.len() == 3 {
604                let out_batch =
605                    usize::try_from(shape[0]).context("invalid output batch dimension")?;
606                let out_seq_len =
607                    usize::try_from(shape[1]).context("invalid output sequence length")?;
608                let hidden_size =
609                    usize::try_from(shape[2]).context("invalid output hidden size")?;
610                if out_batch != batch_size {
611                    return Err(anyhow!(
612                        "output batch size mismatch: got {out_batch}, expected {batch_size}"
613                    ));
614                }
615                if hidden_size < self.dimension {
616                    return Err(anyhow!(
617                        "unexpected embedding dimension: got {hidden_size}, expected >= {}",
618                        self.dimension
619                    ));
620                }
621                if out_seq_len == 0 {
622                    return Err(anyhow!("ONNX output sequence length is zero"));
623                }
624                // Mean-pool each item in the batch using its own attention mask.
625                for (b, enc) in encodings.iter().enumerate() {
626                    let seq_len = enc.get_ids().len();
627                    let effective_len = out_seq_len.min(seq_len);
628                    let mut pooled = vec![0.0f32; self.dimension];
629                    let mut mask_sum = 0.0f32;
630
631                    for token_idx in 0..effective_len {
632                        // Attention mask values are 0 or 1, so cast to f32 is lossless.
633                        #[allow(clippy::cast_precision_loss)]
634                        let mask_value = enc.get_attention_mask()[token_idx] as f32;
635                        if mask_value <= 0.0 {
636                            continue;
637                        }
638                        mask_sum += mask_value;
639                        let row_offset = b * out_seq_len * hidden_size + token_idx * hidden_size;
640                        for (d, pooled_value) in pooled.iter_mut().enumerate() {
641                            *pooled_value += output[row_offset + d] * mask_value;
642                        }
643                    }
644
645                    if mask_sum <= 0.0 {
646                        return Err(anyhow!(
647                            "attention mask sum is zero during mean pooling for batch item {b}"
648                        ));
649                    }
650                    for value in &mut pooled {
651                        *value /= mask_sum;
652                    }
653                    normalize_embedding(&mut pooled);
654                    batch_embeddings.push(pooled);
655                }
656            } else {
657                return Err(anyhow!("unexpected batched ONNX output shape: {shape:?}"));
658            }
659            batch_embeddings
660        };
661        self.touch_last_used();
662
663        // --- Populate cache and assemble final result vector ---
664        match self.cache.lock() {
665            Ok(mut cache) => {
666                for (embedding, &orig_idx) in computed.into_iter().zip(miss_indices.iter()) {
667                    cache.put(keys[orig_idx], embedding.clone());
668                    results[orig_idx] = Some(embedding);
669                }
670            }
671            Err(_) => {
672                tracing::warn!("embedding cache mutex poisoned, bypassing cache");
673                for (embedding, &orig_idx) in computed.into_iter().zip(miss_indices.iter()) {
674                    results[orig_idx] = Some(embedding);
675                }
676            }
677        }
678
679        results
680            .into_iter()
681            .map(|opt| opt.ok_or_else(|| anyhow!("unexpected None in batch result")))
682            .collect()
683    }
684}
685
686#[cfg(feature = "real-embeddings")]
687pub async fn download_bge_small_model() -> Result<PathBuf> {
688    let model_dir = default_model_dir()?;
689    let files = ensure_model_files_async(model_dir, MODEL_URL, None, TOKENIZER_URL).await?;
690    Ok(files.directory)
691}
692
693#[cfg(feature = "real-embeddings")]
694pub fn model_dir() -> Result<PathBuf> {
695    Ok(app_paths::resolve_app_paths()?.model_root.join(MODEL_NAME))
696}
697
698#[cfg(feature = "real-embeddings")]
699fn default_model_dir() -> Result<PathBuf> {
700    model_dir()
701}
702
703#[cfg(feature = "real-embeddings")]
704fn ensure_model_files_blocking(
705    model_dir: PathBuf,
706    model_url: &str,
707    model_data_url: Option<&str>,
708    tokenizer_url: &str,
709) -> Result<ModelFiles> {
710    if model_files_exist(&model_dir, model_data_url) {
711        return Ok(model_files_for_dir(model_dir, model_data_url));
712    }
713
714    // Create a dedicated single-threaded runtime for model download.
715    // We avoid block_in_place because embed() runs inside spawn_blocking
716    // threads where block_in_place panics. A lightweight current-thread
717    // runtime is safe and sufficient for the download I/O.
718    let runtime = tokio::runtime::Builder::new_current_thread()
719        .enable_all()
720        .build()
721        .context("failed to create temporary tokio runtime for model download")?;
722    let model_data_url_owned = model_data_url.map(str::to_string);
723    runtime.block_on(ensure_model_files_async(
724        model_dir,
725        model_url,
726        model_data_url_owned.as_deref(),
727        tokenizer_url,
728    ))
729}
730
731#[cfg(feature = "real-embeddings")]
732async fn ensure_model_files_async(
733    model_dir: PathBuf,
734    model_url: &str,
735    model_data_url: Option<&str>,
736    tokenizer_url: &str,
737) -> Result<ModelFiles> {
738    let files = model_files_for_dir(model_dir, model_data_url);
739    if model_files_exist(&files.directory, model_data_url) {
740        return Ok(files);
741    }
742
743    tokio::fs::create_dir_all(&files.directory)
744        .await
745        .with_context(|| {
746            format!(
747                "failed to create model directory {}",
748                files.directory.display()
749            )
750        })?;
751
752    if !tokio::fs::try_exists(&files.model_path)
753        .await
754        .context("failed to check model.onnx path")?
755    {
756        download_file(model_url, &files.model_path).await?;
757    }
758    if let (Some(data_url), Some(data_path)) = (model_data_url, &files.model_data_path)
759        && !tokio::fs::try_exists(data_path)
760            .await
761            .context("failed to check model data path")?
762    {
763        download_file(data_url, data_path).await?;
764    }
765    if !tokio::fs::try_exists(&files.tokenizer_path)
766        .await
767        .context("failed to check tokenizer.json path")?
768    {
769        download_file(tokenizer_url, &files.tokenizer_path).await?;
770    }
771
772    Ok(files)
773}
774
775#[cfg(feature = "real-embeddings")]
776fn model_files_exist(model_dir: &Path, model_data_url: Option<&str>) -> bool {
777    let files = model_files_for_dir(model_dir.to_path_buf(), model_data_url);
778    let base = files.model_path.exists() && files.tokenizer_path.exists();
779    if model_data_url.is_some() {
780        base && files.model_data_path.as_ref().is_some_and(|p| p.exists())
781    } else {
782        base
783    }
784}
785
786#[cfg(feature = "real-embeddings")]
787fn model_files_for_dir(model_dir: PathBuf, model_data_url: Option<&str>) -> ModelFiles {
788    let model_data_path = model_data_url.and_then(|url| {
789        let filename = url.split('/').next_back()?;
790        if filename.is_empty() {
791            None
792        } else {
793            Some(model_dir.join(filename))
794        }
795    });
796    ModelFiles {
797        model_path: model_dir.join("model.onnx"),
798        model_data_path,
799        tokenizer_path: model_dir.join("tokenizer.json"),
800        directory: model_dir,
801    }
802}
803
804#[cfg(feature = "real-embeddings")]
805pub(crate) async fn download_file(url: &str, path: &Path) -> Result<()> {
806    let client = reqwest::Client::builder()
807        .timeout(std::time::Duration::from_secs(300))
808        .connect_timeout(std::time::Duration::from_secs(30))
809        .build()
810        .context("failed to build HTTP client")?;
811    let response = client
812        .get(url)
813        .send()
814        .await
815        .with_context(|| format!("failed to download {url}"))?
816        .error_for_status()
817        .with_context(|| format!("download request failed for {url}"))?;
818    let bytes = response
819        .bytes()
820        .await
821        .with_context(|| format!("failed to read body from {url}"))?;
822
823    // Write to a temporary .part file, then atomically rename to avoid
824    // leaving corrupt files on interrupted downloads.
825    let mut part_name = path.file_name().unwrap_or_default().to_os_string();
826    part_name.push(".part");
827    let part_path = path.with_file_name(part_name);
828    tokio::fs::write(&part_path, &bytes)
829        .await
830        .with_context(|| format!("failed to write temporary file {}", part_path.display()))?;
831    tokio::fs::rename(&part_path, path).await.with_context(|| {
832        format!(
833            "failed to rename {} to {}",
834            part_path.display(),
835            path.display()
836        )
837    })
838}
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843    use crate::memory_core::storage::sqlite::dot_product;
844
845    #[test]
846    fn test_placeholder_embedder_dimension() {
847        let embedder = PlaceholderEmbedder;
848        assert_eq!(embedder.dimension(), 32);
849    }
850
851    #[test]
852    fn test_placeholder_embedder_deterministic() {
853        let embedder = PlaceholderEmbedder;
854        let first = embedder.embed("hello world").unwrap();
855        let second = embedder.embed("hello world").unwrap();
856        assert_eq!(first, second);
857    }
858
859    #[test]
860    fn test_placeholder_embedder_different_inputs() {
861        let embedder = PlaceholderEmbedder;
862        let first = embedder.embed("hello world").unwrap();
863        let second = embedder.embed("different text").unwrap();
864        assert_ne!(first, second);
865    }
866
867    #[test]
868    fn test_placeholder_embedder_normalized() {
869        let embedder = PlaceholderEmbedder;
870        let embedding = embedder.embed("normalized").unwrap();
871        let norm = embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
872        assert!((norm - 1.0).abs() < 1e-6);
873    }
874
875    #[test]
876    fn test_placeholder_embedder_empty_input() {
877        let embedder = PlaceholderEmbedder;
878        let embedding = embedder.embed("").unwrap();
879        assert_eq!(embedding.len(), 32);
880    }
881
882    #[test]
883    fn test_dot_product_identical() {
884        let a = vec![0.5_f32, 0.5, 0.5, 0.5];
885        let score = dot_product(&a, &a);
886        assert!((score - 1.0).abs() < 1e-6);
887    }
888
889    #[test]
890    fn test_dot_product_orthogonal() {
891        let a = vec![1.0_f32, 0.0, 0.0];
892        let b = vec![0.0_f32, 1.0, 0.0];
893        let score = dot_product(&a, &b);
894        assert!(score.abs() < 1e-6);
895    }
896
897    #[test]
898    fn test_dot_product_different_lengths() {
899        let a = vec![1.0_f32, 0.0, 0.0];
900        let b = vec![1.0_f32, 0.0];
901        let score = dot_product(&a, &b);
902        assert_eq!(score, 0.0);
903    }
904
905    // --- embed_batch tests (PlaceholderEmbedder / default impl) ---
906
907    #[test]
908    fn test_placeholder_embed_batch_empty() {
909        let embedder = PlaceholderEmbedder;
910        let results = embedder.embed_batch(&[]).unwrap();
911        assert!(results.is_empty());
912    }
913
914    #[test]
915    fn test_placeholder_embed_batch_single() {
916        let embedder = PlaceholderEmbedder;
917        let single = embedder.embed("hello").unwrap();
918        let batch = embedder.embed_batch(&["hello"]).unwrap();
919        assert_eq!(batch.len(), 1);
920        assert_eq!(batch[0], single);
921    }
922
923    #[test]
924    fn test_placeholder_embed_batch_multiple() {
925        let embedder = PlaceholderEmbedder;
926        let texts = ["alpha", "beta", "gamma"];
927        let batch = embedder.embed_batch(&texts).unwrap();
928        assert_eq!(batch.len(), 3);
929        // Each result should match the individual embed call.
930        for (i, text) in texts.iter().enumerate() {
931            let individual = embedder.embed(text).unwrap();
932            assert_eq!(batch[i], individual);
933        }
934    }
935
936    #[test]
937    fn test_placeholder_embed_batch_normalized() {
938        let embedder = PlaceholderEmbedder;
939        let batch = embedder.embed_batch(&["one", "two", "three"]).unwrap();
940        for emb in &batch {
941            let norm = emb.iter().map(|v| v * v).sum::<f32>().sqrt();
942            assert!((norm - 1.0).abs() < 1e-6);
943        }
944    }
945
946    #[test]
947    fn test_placeholder_embed_batch_deterministic() {
948        let embedder = PlaceholderEmbedder;
949        let first = embedder.embed_batch(&["a", "b"]).unwrap();
950        let second = embedder.embed_batch(&["a", "b"]).unwrap();
951        assert_eq!(first, second);
952    }
953
954    #[cfg(feature = "real-embeddings")]
955    #[test]
956    fn model_dir_returns_expected_path() {
957        crate::test_helpers::with_temp_home(|home| {
958            let expected = home
959                .join(".mag")
960                .join("models")
961                .join("bge-small-en-v1.5-int8");
962            let actual = crate::memory_core::embedder::model_dir()
963                .expect("model_dir() should succeed with a valid HOME");
964            assert_eq!(actual, expected);
965        });
966    }
967}