1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use sha2::{Digest, Sha256};
5
6use crate::app_paths;
7
8pub trait Embedder: Send + Sync {
10 fn dimension(&self) -> usize;
12 fn embed(&self, text: &str) -> Result<Vec<f32>>;
14 #[allow(dead_code)]
19 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
20 texts.iter().map(|t| self.embed(t)).collect()
21 }
22}
23
24#[derive(Debug, Default, Clone)]
25#[allow(dead_code)]
26pub struct PlaceholderEmbedder;
27
28impl Embedder for PlaceholderEmbedder {
29 fn dimension(&self) -> usize {
30 32
31 }
32
33 fn embed(&self, text: &str) -> Result<Vec<f32>> {
34 Ok(embedding_for_text(text))
35 }
36}
37
38#[allow(dead_code)]
39pub(crate) fn embedding_for_text(input: &str) -> Vec<f32> {
40 let mut hasher = Sha256::new();
41 hasher.update(input.as_bytes());
42 let digest = hasher.finalize();
43 let mut vec: Vec<f32> = digest.iter().map(|b| *b as f32 / 255.0).collect();
44 normalize_embedding(&mut vec);
45 vec
46}
47
48pub(crate) fn normalize_embedding(vec: &mut [f32]) {
49 let norm = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
50 if norm > 0.0 {
51 for value in vec {
52 *value /= norm;
53 }
54 }
55}
56
57#[cfg(feature = "real-embeddings")]
58const MODEL_NAME: &str = "bge-small-en-v1.5-int8";
59#[cfg(feature = "real-embeddings")]
60const MODEL_URL: &str =
61 "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/onnx/model_int8.onnx";
62#[cfg(feature = "real-embeddings")]
63const TOKENIZER_URL: &str =
64 "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/tokenizer.json";
65
66#[cfg(feature = "real-embeddings")]
67const EMBEDDING_CACHE_CAPACITY: std::num::NonZeroUsize = std::num::NonZeroUsize::new(2048).unwrap();
68
69#[cfg(feature = "real-embeddings")]
70const IDLE_TIMEOUT_SECS: u64 = 600; #[cfg(feature = "real-embeddings")]
73#[derive(Debug)]
74pub struct OnnxEmbedder {
75 model_dir: PathBuf,
76 model_url: String,
77 model_data_url: Option<String>,
78 tokenizer_url: String,
79 dimension: usize,
80 output_tensor_name: String,
81 use_token_type_ids: bool,
82 runtime: std::sync::Mutex<Option<OnnxRuntime>>,
83 last_used: std::sync::atomic::AtomicU64,
84 cache: std::sync::Mutex<lru::LruCache<[u8; 32], Vec<f32>>>,
85}
86
87#[cfg(feature = "real-embeddings")]
88#[derive(Debug)]
89struct OnnxRuntime {
90 session: std::sync::Mutex<ort::session::Session>,
91 tokenizer: tokenizers::Tokenizer,
92}
93
94#[cfg(feature = "real-embeddings")]
95#[derive(Debug, Clone)]
96struct ModelFiles {
97 directory: PathBuf,
98 model_path: PathBuf,
99 model_data_path: Option<PathBuf>,
100 tokenizer_path: PathBuf,
101}
102
103#[cfg(feature = "real-embeddings")]
104impl OnnxEmbedder {
105 pub fn new() -> Result<Self> {
106 Self::with_model(
107 MODEL_NAME,
108 MODEL_URL,
109 TOKENIZER_URL,
110 384,
111 "last_hidden_state",
112 )
113 }
114
115 pub fn with_model(
116 name: &str,
117 model_url: &str,
118 tokenizer_url: &str,
119 dimension: usize,
120 output_tensor_name: &str,
121 ) -> Result<Self> {
122 Self::with_model_and_data(
123 name,
124 model_url,
125 None,
126 tokenizer_url,
127 dimension,
128 output_tensor_name,
129 true,
130 )
131 }
132
133 pub fn with_model_and_data(
134 name: &str,
135 model_url: &str,
136 model_data_url: Option<&str>,
137 tokenizer_url: &str,
138 dimension: usize,
139 output_tensor_name: &str,
140 use_token_type_ids: bool,
141 ) -> Result<Self> {
142 let model_dir = app_paths::resolve_app_paths()?.model_root.join(name);
143 Ok(Self {
144 model_dir,
145 model_url: model_url.to_string(),
146 model_data_url: model_data_url.map(str::to_string),
147 tokenizer_url: tokenizer_url.to_string(),
148 dimension,
149 output_tensor_name: output_tensor_name.to_string(),
150 use_token_type_ids,
151 runtime: std::sync::Mutex::new(None),
152 last_used: std::sync::atomic::AtomicU64::new(0),
153 cache: std::sync::Mutex::new(lru::LruCache::new(EMBEDDING_CACHE_CAPACITY)),
154 })
155 }
156
157 fn epoch_secs() -> u64 {
158 std::time::SystemTime::now()
159 .duration_since(std::time::UNIX_EPOCH)
160 .map(|d| d.as_secs())
161 .unwrap_or(0)
162 }
163
164 fn touch_last_used(&self) {
165 self.last_used
166 .store(Self::epoch_secs(), std::sync::atomic::Ordering::Relaxed);
167 }
168
169 pub async fn warmup(self: &std::sync::Arc<Self>) -> Result<()> {
173 {
174 let guard = self
175 .runtime
176 .lock()
177 .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
178 if guard.is_some() {
179 return Ok(());
180 }
181 }
182 let this = std::sync::Arc::clone(self);
183 tokio::task::spawn_blocking(move || {
184 let mut guard = this
185 .runtime
186 .lock()
187 .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
188 if guard.is_none() {
189 let rt = this.init_runtime()?;
190 *guard = Some(rt);
191 this.touch_last_used();
192 }
193 Ok::<_, anyhow::Error>(())
194 })
195 .await
196 .context("spawn_blocking join error")?
197 }
198
199 pub fn try_unload_if_idle(&self) -> bool {
203 let last = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
204 if last == 0 {
205 return false; }
207 if Self::epoch_secs().saturating_sub(last) < IDLE_TIMEOUT_SECS {
208 return false;
209 }
210 if let Ok(mut guard) = self.runtime.lock()
211 && guard.is_some()
212 {
213 *guard = None;
214 tracing::info!("unloaded idle ONNX session after {IDLE_TIMEOUT_SECS}s");
215 return true;
216 }
217 false
218 }
219
220 pub async fn maintenance_tick(self: &std::sync::Arc<Self>) {
222 let this = std::sync::Arc::clone(self);
223 let _ = tokio::task::spawn_blocking(move || {
224 this.try_unload_if_idle();
225 })
226 .await;
227 }
228
229 fn init_runtime(&self) -> Result<OnnxRuntime> {
230 let files = ensure_model_files_blocking(
231 self.model_dir.clone(),
232 &self.model_url,
233 self.model_data_url.as_deref(),
234 &self.tokenizer_url,
235 )?;
236 let cpu_ep = ort::ep::CPU::default().with_arena_allocator(false).build();
240 let session = ort::session::Session::builder()?
241 .with_execution_providers([cpu_ep])?
242 .with_intra_threads(num_cpus::get())?
243 .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
244 .commit_from_file(&files.model_path)
245 .with_context(|| {
246 format!(
247 "failed to create ONNX session from {}",
248 files.model_path.display()
249 )
250 })?;
251 let mut tokenizer = tokenizers::Tokenizer::from_file(&files.tokenizer_path)
252 .map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
253 tokenizer
256 .with_truncation(Some(tokenizers::TruncationParams {
257 max_length: 512,
258 ..Default::default()
259 }))
260 .map_err(|e| anyhow!("failed to configure tokenizer truncation: {e}"))?;
261 Ok(OnnxRuntime {
262 session: std::sync::Mutex::new(session),
263 tokenizer,
264 })
265 }
266}
267
268#[cfg(feature = "real-embeddings")]
269impl Embedder for OnnxEmbedder {
270 fn dimension(&self) -> usize {
271 self.dimension
272 }
273
274 fn embed(&self, text: &str) -> Result<Vec<f32>> {
275 let mut hasher = Sha256::new();
277 hasher.update(text.as_bytes());
278 let key: [u8; 32] = hasher.finalize().into();
279
280 match self.cache.lock() {
281 Ok(mut cache) => {
282 if let Some(cached) = cache.get(&key) {
283 return Ok(cached.clone());
284 }
285 }
286 Err(_) => tracing::warn!("embedding cache mutex poisoned, bypassing cache"),
287 }
288
289 let pooled = {
293 let mut rt_guard = self
294 .runtime
295 .lock()
296 .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
297 if rt_guard.is_none() {
298 *rt_guard = Some(self.init_runtime()?);
299 self.touch_last_used();
300 }
301 let runtime = rt_guard
302 .as_ref()
303 .ok_or_else(|| anyhow!("runtime missing after init"))?;
304
305 let encoding = runtime
306 .tokenizer
307 .encode(text, true)
308 .map_err(|e| anyhow!("tokenization failed: {e}"))?;
309 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
310 let attention_mask: Vec<i64> = encoding
311 .get_attention_mask()
312 .iter()
313 .map(|&m| m as i64)
314 .collect();
315 if input_ids.is_empty() || input_ids.len() != attention_mask.len() {
316 return Err(anyhow!("invalid tokenization output for embedding"));
317 }
318
319 let seq_len = input_ids.len();
320 let input_ids_value = ort::value::Value::from_array(([1_usize, seq_len], input_ids))
321 .context("failed to create ONNX input_ids value")?;
322 let attention_mask_value =
323 ort::value::Value::from_array(([1_usize, seq_len], attention_mask))
324 .context("failed to create ONNX attention_mask value")?;
325
326 let mut session = runtime
327 .session
328 .lock()
329 .map_err(|_| anyhow!("onnx session mutex poisoned"))?;
330 let outputs = if self.use_token_type_ids {
331 let token_type_ids_value =
332 ort::value::Value::from_array(([1_usize, seq_len], vec![0_i64; seq_len]))
333 .context("failed to create ONNX token_type_ids value")?;
334 session
335 .run(ort::inputs![
336 input_ids_value,
337 attention_mask_value,
338 token_type_ids_value
339 ])
340 .context("ONNX inference failed")?
341 } else {
342 session
343 .run(ort::inputs![input_ids_value, attention_mask_value])
344 .context("ONNX inference failed")?
345 };
346 let first_output = outputs
347 .get(self.output_tensor_name.as_str())
348 .ok_or_else(|| {
349 anyhow!("missing ONNX output tensor '{}'", self.output_tensor_name)
350 })?;
351 let (shape, output) = first_output
352 .try_extract_tensor::<f32>()
353 .context("failed to extract ONNX output tensor")?;
354
355 let mut pooled = if shape.len() == 2 {
358 if shape[0] != 1 {
361 return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
362 }
363 let hidden_size =
364 usize::try_from(shape[1]).context("invalid output hidden size")?;
365 if hidden_size < self.dimension {
366 return Err(anyhow!(
367 "ONNX output dim {hidden_size} is smaller than requested dim {}",
368 self.dimension
369 ));
370 }
371 output[..self.dimension].to_vec()
373 } else if shape.len() == 3 {
374 if shape[0] != 1 {
376 return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
377 }
378 let output_seq_len =
379 usize::try_from(shape[1]).context("invalid output sequence length")?;
380 let hidden_size =
381 usize::try_from(shape[2]).context("invalid output hidden size")?;
382 if hidden_size < self.dimension {
383 return Err(anyhow!(
384 "ONNX output dim {hidden_size} smaller than requested dim {}",
385 self.dimension
386 ));
387 }
388 if output_seq_len == 0 {
389 return Err(anyhow!("ONNX output sequence length is zero"));
390 }
391 let effective_len = output_seq_len.min(seq_len);
392 let mut pooled = vec![0.0f32; self.dimension];
393 let mut mask_sum = 0.0f32;
394 for token_idx in 0..effective_len {
395 #[allow(clippy::cast_precision_loss)]
396 let mask_value = encoding.get_attention_mask()[token_idx] as f32;
397 if mask_value <= 0.0 {
398 continue;
399 }
400 mask_sum += mask_value;
401 for (d, pooled_value) in pooled.iter_mut().enumerate() {
402 let flat_index = token_idx * hidden_size + d;
403 *pooled_value += output[flat_index] * mask_value;
404 }
405 }
406 if mask_sum <= 0.0 {
407 return Err(anyhow!("attention mask sum is zero during mean pooling"));
408 }
409 for value in &mut pooled {
410 *value /= mask_sum;
411 }
412 pooled
413 } else {
414 return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
415 };
416 normalize_embedding(&mut pooled);
417 pooled
418 };
419 self.touch_last_used();
420
421 let result = pooled.clone();
423 match self.cache.lock() {
424 Ok(mut cache) => {
425 cache.put(key, pooled);
426 }
427 Err(_) => tracing::warn!("embedding cache mutex poisoned, bypassing cache"),
428 }
429
430 Ok(result)
431 }
432
433 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
434 if texts.is_empty() {
435 return Ok(Vec::new());
436 }
437 if texts.len() == 1 {
439 return Ok(vec![self.embed(texts[0])?]);
440 }
441
442 let mut keys: Vec<[u8; 32]> = Vec::with_capacity(texts.len());
444 for text in texts {
445 let mut hasher = Sha256::new();
446 hasher.update(text.as_bytes());
447 keys.push(hasher.finalize().into());
448 }
449
450 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
452 let mut miss_indices: Vec<usize> = Vec::new();
454
455 match self.cache.lock() {
456 Ok(mut cache) => {
457 for (i, key) in keys.iter().enumerate() {
458 if let Some(cached) = cache.get(key) {
459 results[i] = Some(cached.clone());
460 } else {
461 miss_indices.push(i);
462 }
463 }
464 }
465 Err(_) => {
466 tracing::warn!("embedding cache mutex poisoned, bypassing cache");
467 miss_indices.extend(0..texts.len());
468 }
469 }
470
471 if miss_indices.is_empty() {
473 return results
474 .into_iter()
475 .map(|opt| opt.ok_or_else(|| anyhow!("unexpected None in cache-hit path")))
476 .collect();
477 }
478
479 let computed = {
481 let mut rt_guard = self
482 .runtime
483 .lock()
484 .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
485 if rt_guard.is_none() {
486 *rt_guard = Some(self.init_runtime()?);
487 self.touch_last_used();
488 }
489 let runtime = rt_guard
490 .as_ref()
491 .ok_or_else(|| anyhow!("runtime missing after init"))?;
492
493 let miss_texts: Vec<&str> = miss_indices.iter().map(|&i| texts[i]).collect();
495 let encodings: Vec<tokenizers::Encoding> = miss_texts
496 .iter()
497 .map(|t| {
498 runtime
499 .tokenizer
500 .encode(*t, true)
501 .map_err(|e| anyhow!("tokenization failed: {e}"))
502 })
503 .collect::<Result<Vec<_>>>()?;
504
505 let max_len = encodings
507 .iter()
508 .map(|enc| enc.get_ids().len())
509 .max()
510 .ok_or_else(|| anyhow!("empty encodings in batch"))?;
511 if max_len == 0 {
512 return Err(anyhow!("all tokenizations produced zero-length sequences"));
513 }
514
515 let batch_size = encodings.len();
516
517 let mut flat_input_ids = vec![0_i64; batch_size * max_len];
519 let mut flat_attention_mask = vec![0_i64; batch_size * max_len];
520
521 for (b, enc) in encodings.iter().enumerate() {
522 let ids = enc.get_ids();
523 let mask = enc.get_attention_mask();
524 let seq_len = ids.len();
525 if seq_len != mask.len() {
526 return Err(anyhow!(
527 "tokenization ids/mask length mismatch for batch item {b}"
528 ));
529 }
530 let offset = b * max_len;
531 for j in 0..seq_len {
532 flat_input_ids[offset + j] = ids[j] as i64;
533 flat_attention_mask[offset + j] = mask[j] as i64;
534 }
535 }
537
538 let input_ids_value =
539 ort::value::Value::from_array(([batch_size, max_len], flat_input_ids))
540 .context("failed to create batched ONNX input_ids value")?;
541 let attention_mask_value =
542 ort::value::Value::from_array(([batch_size, max_len], flat_attention_mask))
543 .context("failed to create batched ONNX attention_mask value")?;
544
545 let mut session = runtime
546 .session
547 .lock()
548 .map_err(|_| anyhow!("onnx session mutex poisoned"))?;
549 let outputs = if self.use_token_type_ids {
550 let token_type_ids_value = ort::value::Value::from_array((
551 [batch_size, max_len],
552 vec![0_i64; batch_size * max_len],
553 ))
554 .context("failed to create batched ONNX token_type_ids value")?;
555 session
556 .run(ort::inputs![
557 input_ids_value,
558 attention_mask_value,
559 token_type_ids_value
560 ])
561 .context("batched ONNX inference failed")?
562 } else {
563 session
564 .run(ort::inputs![input_ids_value, attention_mask_value])
565 .context("batched ONNX inference failed")?
566 };
567 let first_output = outputs
568 .get(self.output_tensor_name.as_str())
569 .ok_or_else(|| {
570 anyhow!("missing ONNX output tensor '{}'", self.output_tensor_name)
571 })?;
572 let (shape, output) = first_output
573 .try_extract_tensor::<f32>()
574 .context("failed to extract batched ONNX output tensor")?;
575
576 let mut batch_embeddings: Vec<Vec<f32>> = Vec::with_capacity(batch_size);
579 if shape.len() == 2 {
580 let out_batch =
582 usize::try_from(shape[0]).context("invalid output batch dimension")?;
583 let hidden_size =
584 usize::try_from(shape[1]).context("invalid output hidden size")?;
585 if out_batch != batch_size {
586 return Err(anyhow!(
587 "output batch size mismatch: got {out_batch}, expected {batch_size}"
588 ));
589 }
590 if hidden_size < self.dimension {
591 return Err(anyhow!(
592 "ONNX output dim {hidden_size} is smaller than requested dim {}",
593 self.dimension
594 ));
595 }
596 for b in 0..batch_size {
597 let row_start = b * hidden_size;
598 let mut pooled = output[row_start..row_start + self.dimension].to_vec();
600 normalize_embedding(&mut pooled);
601 batch_embeddings.push(pooled);
602 }
603 } else if shape.len() == 3 {
604 let out_batch =
605 usize::try_from(shape[0]).context("invalid output batch dimension")?;
606 let out_seq_len =
607 usize::try_from(shape[1]).context("invalid output sequence length")?;
608 let hidden_size =
609 usize::try_from(shape[2]).context("invalid output hidden size")?;
610 if out_batch != batch_size {
611 return Err(anyhow!(
612 "output batch size mismatch: got {out_batch}, expected {batch_size}"
613 ));
614 }
615 if hidden_size < self.dimension {
616 return Err(anyhow!(
617 "unexpected embedding dimension: got {hidden_size}, expected >= {}",
618 self.dimension
619 ));
620 }
621 if out_seq_len == 0 {
622 return Err(anyhow!("ONNX output sequence length is zero"));
623 }
624 for (b, enc) in encodings.iter().enumerate() {
626 let seq_len = enc.get_ids().len();
627 let effective_len = out_seq_len.min(seq_len);
628 let mut pooled = vec![0.0f32; self.dimension];
629 let mut mask_sum = 0.0f32;
630
631 for token_idx in 0..effective_len {
632 #[allow(clippy::cast_precision_loss)]
634 let mask_value = enc.get_attention_mask()[token_idx] as f32;
635 if mask_value <= 0.0 {
636 continue;
637 }
638 mask_sum += mask_value;
639 let row_offset = b * out_seq_len * hidden_size + token_idx * hidden_size;
640 for (d, pooled_value) in pooled.iter_mut().enumerate() {
641 *pooled_value += output[row_offset + d] * mask_value;
642 }
643 }
644
645 if mask_sum <= 0.0 {
646 return Err(anyhow!(
647 "attention mask sum is zero during mean pooling for batch item {b}"
648 ));
649 }
650 for value in &mut pooled {
651 *value /= mask_sum;
652 }
653 normalize_embedding(&mut pooled);
654 batch_embeddings.push(pooled);
655 }
656 } else {
657 return Err(anyhow!("unexpected batched ONNX output shape: {shape:?}"));
658 }
659 batch_embeddings
660 };
661 self.touch_last_used();
662
663 match self.cache.lock() {
665 Ok(mut cache) => {
666 for (embedding, &orig_idx) in computed.into_iter().zip(miss_indices.iter()) {
667 cache.put(keys[orig_idx], embedding.clone());
668 results[orig_idx] = Some(embedding);
669 }
670 }
671 Err(_) => {
672 tracing::warn!("embedding cache mutex poisoned, bypassing cache");
673 for (embedding, &orig_idx) in computed.into_iter().zip(miss_indices.iter()) {
674 results[orig_idx] = Some(embedding);
675 }
676 }
677 }
678
679 results
680 .into_iter()
681 .map(|opt| opt.ok_or_else(|| anyhow!("unexpected None in batch result")))
682 .collect()
683 }
684}
685
686#[cfg(feature = "real-embeddings")]
687pub async fn download_bge_small_model() -> Result<PathBuf> {
688 let model_dir = default_model_dir()?;
689 let files = ensure_model_files_async(model_dir, MODEL_URL, None, TOKENIZER_URL).await?;
690 Ok(files.directory)
691}
692
693#[cfg(feature = "real-embeddings")]
694pub fn model_dir() -> Result<PathBuf> {
695 Ok(app_paths::resolve_app_paths()?.model_root.join(MODEL_NAME))
696}
697
698#[cfg(feature = "real-embeddings")]
699fn default_model_dir() -> Result<PathBuf> {
700 model_dir()
701}
702
703#[cfg(feature = "real-embeddings")]
704fn ensure_model_files_blocking(
705 model_dir: PathBuf,
706 model_url: &str,
707 model_data_url: Option<&str>,
708 tokenizer_url: &str,
709) -> Result<ModelFiles> {
710 if model_files_exist(&model_dir, model_data_url) {
711 return Ok(model_files_for_dir(model_dir, model_data_url));
712 }
713
714 let runtime = tokio::runtime::Builder::new_current_thread()
719 .enable_all()
720 .build()
721 .context("failed to create temporary tokio runtime for model download")?;
722 let model_data_url_owned = model_data_url.map(str::to_string);
723 runtime.block_on(ensure_model_files_async(
724 model_dir,
725 model_url,
726 model_data_url_owned.as_deref(),
727 tokenizer_url,
728 ))
729}
730
731#[cfg(feature = "real-embeddings")]
732async fn ensure_model_files_async(
733 model_dir: PathBuf,
734 model_url: &str,
735 model_data_url: Option<&str>,
736 tokenizer_url: &str,
737) -> Result<ModelFiles> {
738 let files = model_files_for_dir(model_dir, model_data_url);
739 if model_files_exist(&files.directory, model_data_url) {
740 return Ok(files);
741 }
742
743 tokio::fs::create_dir_all(&files.directory)
744 .await
745 .with_context(|| {
746 format!(
747 "failed to create model directory {}",
748 files.directory.display()
749 )
750 })?;
751
752 if !tokio::fs::try_exists(&files.model_path)
753 .await
754 .context("failed to check model.onnx path")?
755 {
756 download_file(model_url, &files.model_path).await?;
757 }
758 if let (Some(data_url), Some(data_path)) = (model_data_url, &files.model_data_path)
759 && !tokio::fs::try_exists(data_path)
760 .await
761 .context("failed to check model data path")?
762 {
763 download_file(data_url, data_path).await?;
764 }
765 if !tokio::fs::try_exists(&files.tokenizer_path)
766 .await
767 .context("failed to check tokenizer.json path")?
768 {
769 download_file(tokenizer_url, &files.tokenizer_path).await?;
770 }
771
772 Ok(files)
773}
774
775#[cfg(feature = "real-embeddings")]
776fn model_files_exist(model_dir: &Path, model_data_url: Option<&str>) -> bool {
777 let files = model_files_for_dir(model_dir.to_path_buf(), model_data_url);
778 let base = files.model_path.exists() && files.tokenizer_path.exists();
779 if model_data_url.is_some() {
780 base && files.model_data_path.as_ref().is_some_and(|p| p.exists())
781 } else {
782 base
783 }
784}
785
786#[cfg(feature = "real-embeddings")]
787fn model_files_for_dir(model_dir: PathBuf, model_data_url: Option<&str>) -> ModelFiles {
788 let model_data_path = model_data_url.and_then(|url| {
789 let filename = url.split('/').next_back()?;
790 if filename.is_empty() {
791 None
792 } else {
793 Some(model_dir.join(filename))
794 }
795 });
796 ModelFiles {
797 model_path: model_dir.join("model.onnx"),
798 model_data_path,
799 tokenizer_path: model_dir.join("tokenizer.json"),
800 directory: model_dir,
801 }
802}
803
804#[cfg(feature = "real-embeddings")]
805pub(crate) async fn download_file(url: &str, path: &Path) -> Result<()> {
806 let client = reqwest::Client::builder()
807 .timeout(std::time::Duration::from_secs(300))
808 .connect_timeout(std::time::Duration::from_secs(30))
809 .build()
810 .context("failed to build HTTP client")?;
811 let response = client
812 .get(url)
813 .send()
814 .await
815 .with_context(|| format!("failed to download {url}"))?
816 .error_for_status()
817 .with_context(|| format!("download request failed for {url}"))?;
818 let bytes = response
819 .bytes()
820 .await
821 .with_context(|| format!("failed to read body from {url}"))?;
822
823 let mut part_name = path.file_name().unwrap_or_default().to_os_string();
826 part_name.push(".part");
827 let part_path = path.with_file_name(part_name);
828 tokio::fs::write(&part_path, &bytes)
829 .await
830 .with_context(|| format!("failed to write temporary file {}", part_path.display()))?;
831 tokio::fs::rename(&part_path, path).await.with_context(|| {
832 format!(
833 "failed to rename {} to {}",
834 part_path.display(),
835 path.display()
836 )
837 })
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843 use crate::memory_core::storage::sqlite::dot_product;
844
845 #[test]
846 fn test_placeholder_embedder_dimension() {
847 let embedder = PlaceholderEmbedder;
848 assert_eq!(embedder.dimension(), 32);
849 }
850
851 #[test]
852 fn test_placeholder_embedder_deterministic() {
853 let embedder = PlaceholderEmbedder;
854 let first = embedder.embed("hello world").unwrap();
855 let second = embedder.embed("hello world").unwrap();
856 assert_eq!(first, second);
857 }
858
859 #[test]
860 fn test_placeholder_embedder_different_inputs() {
861 let embedder = PlaceholderEmbedder;
862 let first = embedder.embed("hello world").unwrap();
863 let second = embedder.embed("different text").unwrap();
864 assert_ne!(first, second);
865 }
866
867 #[test]
868 fn test_placeholder_embedder_normalized() {
869 let embedder = PlaceholderEmbedder;
870 let embedding = embedder.embed("normalized").unwrap();
871 let norm = embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
872 assert!((norm - 1.0).abs() < 1e-6);
873 }
874
875 #[test]
876 fn test_placeholder_embedder_empty_input() {
877 let embedder = PlaceholderEmbedder;
878 let embedding = embedder.embed("").unwrap();
879 assert_eq!(embedding.len(), 32);
880 }
881
882 #[test]
883 fn test_dot_product_identical() {
884 let a = vec![0.5_f32, 0.5, 0.5, 0.5];
885 let score = dot_product(&a, &a);
886 assert!((score - 1.0).abs() < 1e-6);
887 }
888
889 #[test]
890 fn test_dot_product_orthogonal() {
891 let a = vec![1.0_f32, 0.0, 0.0];
892 let b = vec![0.0_f32, 1.0, 0.0];
893 let score = dot_product(&a, &b);
894 assert!(score.abs() < 1e-6);
895 }
896
897 #[test]
898 fn test_dot_product_different_lengths() {
899 let a = vec![1.0_f32, 0.0, 0.0];
900 let b = vec![1.0_f32, 0.0];
901 let score = dot_product(&a, &b);
902 assert_eq!(score, 0.0);
903 }
904
905 #[test]
908 fn test_placeholder_embed_batch_empty() {
909 let embedder = PlaceholderEmbedder;
910 let results = embedder.embed_batch(&[]).unwrap();
911 assert!(results.is_empty());
912 }
913
914 #[test]
915 fn test_placeholder_embed_batch_single() {
916 let embedder = PlaceholderEmbedder;
917 let single = embedder.embed("hello").unwrap();
918 let batch = embedder.embed_batch(&["hello"]).unwrap();
919 assert_eq!(batch.len(), 1);
920 assert_eq!(batch[0], single);
921 }
922
923 #[test]
924 fn test_placeholder_embed_batch_multiple() {
925 let embedder = PlaceholderEmbedder;
926 let texts = ["alpha", "beta", "gamma"];
927 let batch = embedder.embed_batch(&texts).unwrap();
928 assert_eq!(batch.len(), 3);
929 for (i, text) in texts.iter().enumerate() {
931 let individual = embedder.embed(text).unwrap();
932 assert_eq!(batch[i], individual);
933 }
934 }
935
936 #[test]
937 fn test_placeholder_embed_batch_normalized() {
938 let embedder = PlaceholderEmbedder;
939 let batch = embedder.embed_batch(&["one", "two", "three"]).unwrap();
940 for emb in &batch {
941 let norm = emb.iter().map(|v| v * v).sum::<f32>().sqrt();
942 assert!((norm - 1.0).abs() < 1e-6);
943 }
944 }
945
946 #[test]
947 fn test_placeholder_embed_batch_deterministic() {
948 let embedder = PlaceholderEmbedder;
949 let first = embedder.embed_batch(&["a", "b"]).unwrap();
950 let second = embedder.embed_batch(&["a", "b"]).unwrap();
951 assert_eq!(first, second);
952 }
953
954 #[cfg(feature = "real-embeddings")]
955 #[test]
956 fn model_dir_returns_expected_path() {
957 crate::test_helpers::with_temp_home(|home| {
958 let expected = home
959 .join(".mag")
960 .join("models")
961 .join("bge-small-en-v1.5-int8");
962 let actual = crate::memory_core::embedder::model_dir()
963 .expect("model_dir() should succeed with a valid HOME");
964 assert_eq!(actual, expected);
965 });
966 }
967}