1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use sha2::{Digest, Sha256};
5
6use crate::app_paths;
7
8pub trait Embedder: Send + Sync {
10 fn dimension(&self) -> usize;
12 fn embed(&self, text: &str) -> Result<Vec<f32>>;
14 #[allow(dead_code)]
19 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
20 texts.iter().map(|t| self.embed(t)).collect()
21 }
22}
23
24#[derive(Debug, Default, Clone)]
25#[allow(dead_code)]
26pub struct PlaceholderEmbedder;
27
28impl Embedder for PlaceholderEmbedder {
29 fn dimension(&self) -> usize {
30 32
31 }
32
33 fn embed(&self, text: &str) -> Result<Vec<f32>> {
34 Ok(embedding_for_text(text))
35 }
36}
37
38#[allow(dead_code)]
39pub(crate) fn embedding_for_text(input: &str) -> Vec<f32> {
40 let mut hasher = Sha256::new();
41 hasher.update(input.as_bytes());
42 let digest = hasher.finalize();
43 let mut vec: Vec<f32> = digest.iter().map(|b| *b as f32 / 255.0).collect();
44 normalize_embedding(&mut vec);
45 vec
46}
47
48pub(crate) fn normalize_embedding(vec: &mut [f32]) {
49 let norm = vec.iter().map(|v| v * v).sum::<f32>().sqrt();
50 if norm > 0.0 {
51 for value in vec {
52 *value /= norm;
53 }
54 }
55}
56
57#[cfg(feature = "real-embeddings")]
58const MODEL_NAME: &str = "bge-small-en-v1.5-int8";
59#[cfg(feature = "real-embeddings")]
60const MODEL_URL: &str =
61 "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/onnx/model_int8.onnx";
62#[cfg(feature = "real-embeddings")]
63const TOKENIZER_URL: &str =
64 "https://huggingface.co/Xenova/bge-small-en-v1.5/resolve/main/tokenizer.json";
65
66#[cfg(feature = "real-embeddings")]
67const EMBEDDING_CACHE_CAPACITY: std::num::NonZeroUsize = std::num::NonZeroUsize::new(2048).unwrap();
68
69#[cfg(feature = "real-embeddings")]
70const IDLE_TIMEOUT_SECS: u64 = 600; #[cfg(feature = "real-embeddings")]
73#[derive(Debug)]
74pub struct OnnxEmbedder {
75 model_dir: PathBuf,
76 model_url: String,
77 model_data_url: Option<String>,
78 tokenizer_url: String,
79 dimension: usize,
80 output_tensor_name: String,
81 use_token_type_ids: bool,
82 runtime: std::sync::Mutex<Option<OnnxRuntime>>,
83 last_used: std::sync::atomic::AtomicU64,
84 cache: std::sync::Mutex<lru::LruCache<[u8; 32], Vec<f32>>>,
85}
86
87#[cfg(feature = "real-embeddings")]
88#[derive(Debug)]
89struct OnnxRuntime {
90 session: std::sync::Mutex<ort::session::Session>,
91 tokenizer: tokenizers::Tokenizer,
92}
93
94#[cfg(feature = "real-embeddings")]
95#[derive(Debug, Clone)]
96struct ModelFiles {
97 directory: PathBuf,
98 model_path: PathBuf,
99 model_data_path: Option<PathBuf>,
100 tokenizer_path: PathBuf,
101}
102
103#[cfg(feature = "real-embeddings")]
104impl OnnxEmbedder {
105 pub fn new() -> Result<Self> {
106 Self::with_model(
107 MODEL_NAME,
108 MODEL_URL,
109 TOKENIZER_URL,
110 384,
111 "last_hidden_state",
112 )
113 }
114
115 pub fn with_model(
116 name: &str,
117 model_url: &str,
118 tokenizer_url: &str,
119 dimension: usize,
120 output_tensor_name: &str,
121 ) -> Result<Self> {
122 Self::with_model_and_data(
123 name,
124 model_url,
125 None,
126 tokenizer_url,
127 dimension,
128 output_tensor_name,
129 true,
130 )
131 }
132
133 pub fn with_model_and_data(
134 name: &str,
135 model_url: &str,
136 model_data_url: Option<&str>,
137 tokenizer_url: &str,
138 dimension: usize,
139 output_tensor_name: &str,
140 use_token_type_ids: bool,
141 ) -> Result<Self> {
142 let model_dir = app_paths::resolve_app_paths()?.model_root.join(name);
143 Ok(Self {
144 model_dir,
145 model_url: model_url.to_string(),
146 model_data_url: model_data_url.map(str::to_string),
147 tokenizer_url: tokenizer_url.to_string(),
148 dimension,
149 output_tensor_name: output_tensor_name.to_string(),
150 use_token_type_ids,
151 runtime: std::sync::Mutex::new(None),
152 last_used: std::sync::atomic::AtomicU64::new(0),
153 cache: std::sync::Mutex::new(lru::LruCache::new(EMBEDDING_CACHE_CAPACITY)),
154 })
155 }
156
157 fn epoch_secs() -> u64 {
158 std::time::SystemTime::now()
159 .duration_since(std::time::UNIX_EPOCH)
160 .map(|d| d.as_secs())
161 .unwrap_or(0)
162 }
163
164 fn touch_last_used(&self) {
165 self.last_used
166 .store(Self::epoch_secs(), std::sync::atomic::Ordering::Relaxed);
167 }
168
169 pub async fn warmup(self: &std::sync::Arc<Self>) -> Result<()> {
173 {
174 let guard = self
175 .runtime
176 .lock()
177 .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
178 if guard.is_some() {
179 return Ok(());
180 }
181 }
182 let this = std::sync::Arc::clone(self);
183 tokio::task::spawn_blocking(move || {
184 let mut guard = this
185 .runtime
186 .lock()
187 .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
188 if guard.is_none() {
189 let rt = this.init_runtime()?;
190 *guard = Some(rt);
191 this.touch_last_used();
192 }
193 Ok::<_, anyhow::Error>(())
194 })
195 .await
196 .context("spawn_blocking join error")?
197 }
198
199 pub fn try_unload_if_idle(&self) -> bool {
203 let last = self.last_used.load(std::sync::atomic::Ordering::Relaxed);
204 if last == 0 {
205 return false; }
207 if Self::epoch_secs().saturating_sub(last) < IDLE_TIMEOUT_SECS {
208 return false;
209 }
210 if let Ok(mut guard) = self.runtime.lock()
211 && guard.is_some()
212 {
213 *guard = None;
214 tracing::info!("unloaded idle ONNX session after {IDLE_TIMEOUT_SECS}s");
215 return true;
216 }
217 false
218 }
219
220 pub async fn maintenance_tick(self: &std::sync::Arc<Self>) {
222 let this = std::sync::Arc::clone(self);
223 let _ = tokio::task::spawn_blocking(move || {
224 this.try_unload_if_idle();
225 })
226 .await;
227 }
228
229 fn init_runtime(&self) -> Result<OnnxRuntime> {
230 let files = ensure_model_files_blocking(
231 self.model_dir.clone(),
232 &self.model_url,
233 self.model_data_url.as_deref(),
234 &self.tokenizer_url,
235 )?;
236 let cpu_ep = ort::ep::CPU::default().with_arena_allocator(false).build();
240 let session = ort::session::Session::builder()?
241 .with_execution_providers([cpu_ep])?
242 .with_intra_threads(num_cpus::get())?
243 .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
244 .commit_from_file(&files.model_path)
245 .with_context(|| {
246 format!(
247 "failed to create ONNX session from {}",
248 files.model_path.display()
249 )
250 })?;
251 let mut tokenizer = tokenizers::Tokenizer::from_file(&files.tokenizer_path)
252 .map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;
253 tokenizer
256 .with_truncation(Some(tokenizers::TruncationParams {
257 max_length: 512,
258 ..Default::default()
259 }))
260 .map_err(|e| anyhow!("failed to configure tokenizer truncation: {e}"))?;
261 Ok(OnnxRuntime {
262 session: std::sync::Mutex::new(session),
263 tokenizer,
264 })
265 }
266}
267
268#[cfg(feature = "real-embeddings")]
269impl Embedder for OnnxEmbedder {
270 fn dimension(&self) -> usize {
271 self.dimension
272 }
273
274 fn embed(&self, text: &str) -> Result<Vec<f32>> {
275 let mut hasher = Sha256::new();
277 hasher.update(text.as_bytes());
278 let key: [u8; 32] = hasher.finalize().into();
279
280 match self.cache.lock() {
281 Ok(mut cache) => {
282 if let Some(cached) = cache.get(&key) {
283 return Ok(cached.clone());
284 }
285 }
286 Err(_) => tracing::warn!("embedding cache mutex poisoned, bypassing cache"),
287 }
288
289 let pooled = {
293 let mut rt_guard = self
294 .runtime
295 .lock()
296 .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
297 if rt_guard.is_none() {
298 *rt_guard = Some(self.init_runtime()?);
299 self.touch_last_used();
300 }
301 let runtime = rt_guard
302 .as_ref()
303 .ok_or_else(|| anyhow!("runtime missing after init"))?;
304
305 let encoding = runtime
306 .tokenizer
307 .encode(text, true)
308 .map_err(|e| anyhow!("tokenization failed: {e}"))?;
309 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
310 let attention_mask: Vec<i64> = encoding
311 .get_attention_mask()
312 .iter()
313 .map(|&m| m as i64)
314 .collect();
315 if input_ids.is_empty() || input_ids.len() != attention_mask.len() {
316 return Err(anyhow!("invalid tokenization output for embedding"));
317 }
318
319 let seq_len = input_ids.len();
320 let input_ids_value = ort::value::Value::from_array(([1_usize, seq_len], input_ids))
321 .context("failed to create ONNX input_ids value")?;
322 let attention_mask_value =
323 ort::value::Value::from_array(([1_usize, seq_len], attention_mask))
324 .context("failed to create ONNX attention_mask value")?;
325
326 let mut session = runtime
327 .session
328 .lock()
329 .map_err(|_| anyhow!("onnx session mutex poisoned"))?;
330 let outputs = if self.use_token_type_ids {
331 let token_type_ids_value =
332 ort::value::Value::from_array(([1_usize, seq_len], vec![0_i64; seq_len]))
333 .context("failed to create ONNX token_type_ids value")?;
334 session
335 .run(ort::inputs![
336 input_ids_value,
337 attention_mask_value,
338 token_type_ids_value
339 ])
340 .context("ONNX inference failed")?
341 } else {
342 session
343 .run(ort::inputs![input_ids_value, attention_mask_value])
344 .context("ONNX inference failed")?
345 };
346 let first_output = outputs
347 .get(self.output_tensor_name.as_str())
348 .ok_or_else(|| {
349 anyhow!("missing ONNX output tensor '{}'", self.output_tensor_name)
350 })?;
351 let (shape, output) = first_output
352 .try_extract_tensor::<f32>()
353 .context("failed to extract ONNX output tensor")?;
354
355 let mut pooled = if shape.len() == 2 {
358 if shape[0] != 1 {
361 return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
362 }
363 let hidden_size =
364 usize::try_from(shape[1]).context("invalid output hidden size")?;
365 if hidden_size < self.dimension {
366 return Err(anyhow!(
367 "ONNX output dim {hidden_size} is smaller than requested dim {}",
368 self.dimension
369 ));
370 }
371 output[..self.dimension].to_vec()
373 } else if shape.len() == 3 {
374 if shape[0] != 1 {
376 return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
377 }
378 let output_seq_len =
379 usize::try_from(shape[1]).context("invalid output sequence length")?;
380 let hidden_size =
381 usize::try_from(shape[2]).context("invalid output hidden size")?;
382 if hidden_size < self.dimension {
383 return Err(anyhow!(
384 "ONNX output dim {hidden_size} smaller than requested dim {}",
385 self.dimension
386 ));
387 }
388 if output_seq_len == 0 {
389 return Err(anyhow!("ONNX output sequence length is zero"));
390 }
391 let effective_len = output_seq_len.min(seq_len);
392 let mut pooled = vec![0.0f32; self.dimension];
393 let mut mask_sum = 0.0f32;
394 for token_idx in 0..effective_len {
395 #[allow(clippy::cast_precision_loss)]
396 let mask_value = encoding.get_attention_mask()[token_idx] as f32;
397 if mask_value <= 0.0 {
398 continue;
399 }
400 mask_sum += mask_value;
401 for (d, pooled_value) in pooled.iter_mut().enumerate() {
402 let flat_index = token_idx * hidden_size + d;
403 *pooled_value += output[flat_index] * mask_value;
404 }
405 }
406 if mask_sum <= 0.0 {
407 return Err(anyhow!("attention mask sum is zero during mean pooling"));
408 }
409 for value in &mut pooled {
410 *value /= mask_sum;
411 }
412 pooled
413 } else {
414 return Err(anyhow!("unexpected ONNX output shape: {shape:?}"));
415 };
416 normalize_embedding(&mut pooled);
417 pooled
418 };
419 self.touch_last_used();
420
421 let result = pooled.clone();
423 match self.cache.lock() {
424 Ok(mut cache) => {
425 cache.put(key, pooled);
426 }
427 Err(_) => tracing::warn!("embedding cache mutex poisoned, bypassing cache"),
428 }
429
430 Ok(result)
431 }
432
433 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
434 if texts.is_empty() {
435 return Ok(Vec::new());
436 }
437 if texts.len() == 1 {
439 return Ok(vec![self.embed(texts[0])?]);
440 }
441
442 let mut keys: Vec<[u8; 32]> = Vec::with_capacity(texts.len());
444 for text in texts {
445 let mut hasher = Sha256::new();
446 hasher.update(text.as_bytes());
447 keys.push(hasher.finalize().into());
448 }
449
450 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
452 let mut miss_indices: Vec<usize> = Vec::new();
454
455 match self.cache.lock() {
456 Ok(mut cache) => {
457 for (i, key) in keys.iter().enumerate() {
458 if let Some(cached) = cache.get(key) {
459 results[i] = Some(cached.clone());
460 } else {
461 miss_indices.push(i);
462 }
463 }
464 }
465 Err(_) => {
466 tracing::warn!("embedding cache mutex poisoned, bypassing cache");
467 miss_indices.extend(0..texts.len());
468 }
469 }
470
471 if miss_indices.is_empty() {
473 return results
474 .into_iter()
475 .map(|opt| opt.ok_or_else(|| anyhow!("unexpected None in cache-hit path")))
476 .collect();
477 }
478
479 let computed = {
481 let mut rt_guard = self
482 .runtime
483 .lock()
484 .map_err(|_| anyhow!("onnx runtime mutex poisoned"))?;
485 if rt_guard.is_none() {
486 *rt_guard = Some(self.init_runtime()?);
487 self.touch_last_used();
488 }
489 let runtime = rt_guard
490 .as_ref()
491 .ok_or_else(|| anyhow!("runtime missing after init"))?;
492
493 let miss_texts: Vec<&str> = miss_indices.iter().map(|&i| texts[i]).collect();
495 let encodings: Vec<tokenizers::Encoding> = miss_texts
496 .iter()
497 .map(|t| {
498 runtime
499 .tokenizer
500 .encode(*t, true)
501 .map_err(|e| anyhow!("tokenization failed: {e}"))
502 })
503 .collect::<Result<Vec<_>>>()?;
504
505 let max_len = encodings
507 .iter()
508 .map(|enc| enc.get_ids().len())
509 .max()
510 .ok_or_else(|| anyhow!("empty encodings in batch"))?;
511 if max_len == 0 {
512 return Err(anyhow!("all tokenizations produced zero-length sequences"));
513 }
514
515 let batch_size = encodings.len();
516
517 let mut flat_input_ids = vec![0_i64; batch_size * max_len];
519 let mut flat_attention_mask = vec![0_i64; batch_size * max_len];
520
521 for (b, enc) in encodings.iter().enumerate() {
522 let ids = enc.get_ids();
523 let mask = enc.get_attention_mask();
524 let seq_len = ids.len();
525 if seq_len != mask.len() {
526 return Err(anyhow!(
527 "tokenization ids/mask length mismatch for batch item {b}"
528 ));
529 }
530 let offset = b * max_len;
531 for j in 0..seq_len {
532 flat_input_ids[offset + j] = ids[j] as i64;
533 flat_attention_mask[offset + j] = mask[j] as i64;
534 }
535 }
537
538 let input_ids_value =
539 ort::value::Value::from_array(([batch_size, max_len], flat_input_ids))
540 .context("failed to create batched ONNX input_ids value")?;
541 let attention_mask_value =
542 ort::value::Value::from_array(([batch_size, max_len], flat_attention_mask))
543 .context("failed to create batched ONNX attention_mask value")?;
544
545 let mut session = runtime
546 .session
547 .lock()
548 .map_err(|_| anyhow!("onnx session mutex poisoned"))?;
549 let outputs = if self.use_token_type_ids {
550 let token_type_ids_value = ort::value::Value::from_array((
551 [batch_size, max_len],
552 vec![0_i64; batch_size * max_len],
553 ))
554 .context("failed to create batched ONNX token_type_ids value")?;
555 session
556 .run(ort::inputs![
557 input_ids_value,
558 attention_mask_value,
559 token_type_ids_value
560 ])
561 .context("batched ONNX inference failed")?
562 } else {
563 session
564 .run(ort::inputs![input_ids_value, attention_mask_value])
565 .context("batched ONNX inference failed")?
566 };
567 let first_output = outputs
568 .get(self.output_tensor_name.as_str())
569 .ok_or_else(|| {
570 anyhow!("missing ONNX output tensor '{}'", self.output_tensor_name)
571 })?;
572 let (shape, output) = first_output
573 .try_extract_tensor::<f32>()
574 .context("failed to extract batched ONNX output tensor")?;
575
576 let mut batch_embeddings: Vec<Vec<f32>> = Vec::with_capacity(batch_size);
579 if shape.len() == 2 {
580 let out_batch =
582 usize::try_from(shape[0]).context("invalid output batch dimension")?;
583 let hidden_size =
584 usize::try_from(shape[1]).context("invalid output hidden size")?;
585 if out_batch != batch_size {
586 return Err(anyhow!(
587 "output batch size mismatch: got {out_batch}, expected {batch_size}"
588 ));
589 }
590 if hidden_size < self.dimension {
591 return Err(anyhow!(
592 "ONNX output dim {hidden_size} is smaller than requested dim {}",
593 self.dimension
594 ));
595 }
596 for b in 0..batch_size {
597 let row_start = b * hidden_size;
598 let mut pooled = output[row_start..row_start + self.dimension].to_vec();
600 normalize_embedding(&mut pooled);
601 batch_embeddings.push(pooled);
602 }
603 } else if shape.len() == 3 {
604 let out_batch =
605 usize::try_from(shape[0]).context("invalid output batch dimension")?;
606 let out_seq_len =
607 usize::try_from(shape[1]).context("invalid output sequence length")?;
608 let hidden_size =
609 usize::try_from(shape[2]).context("invalid output hidden size")?;
610 if out_batch != batch_size {
611 return Err(anyhow!(
612 "output batch size mismatch: got {out_batch}, expected {batch_size}"
613 ));
614 }
615 if hidden_size < self.dimension {
616 return Err(anyhow!(
617 "unexpected embedding dimension: got {hidden_size}, expected >= {}",
618 self.dimension
619 ));
620 }
621 if out_seq_len == 0 {
622 return Err(anyhow!("ONNX output sequence length is zero"));
623 }
624 for (b, enc) in encodings.iter().enumerate() {
626 let seq_len = enc.get_ids().len();
627 let effective_len = out_seq_len.min(seq_len);
628 let mut pooled = vec![0.0f32; self.dimension];
629 let mut mask_sum = 0.0f32;
630
631 for token_idx in 0..effective_len {
632 #[allow(clippy::cast_precision_loss)]
634 let mask_value = enc.get_attention_mask()[token_idx] as f32;
635 if mask_value <= 0.0 {
636 continue;
637 }
638 mask_sum += mask_value;
639 let row_offset = b * out_seq_len * hidden_size + token_idx * hidden_size;
640 for (d, pooled_value) in pooled.iter_mut().enumerate() {
641 *pooled_value += output[row_offset + d] * mask_value;
642 }
643 }
644
645 if mask_sum <= 0.0 {
646 return Err(anyhow!(
647 "attention mask sum is zero during mean pooling for batch item {b}"
648 ));
649 }
650 for value in &mut pooled {
651 *value /= mask_sum;
652 }
653 normalize_embedding(&mut pooled);
654 batch_embeddings.push(pooled);
655 }
656 } else {
657 return Err(anyhow!("unexpected batched ONNX output shape: {shape:?}"));
658 }
659 batch_embeddings
660 };
661 self.touch_last_used();
662
663 match self.cache.lock() {
665 Ok(mut cache) => {
666 for (embedding, &orig_idx) in computed.into_iter().zip(miss_indices.iter()) {
667 cache.put(keys[orig_idx], embedding.clone());
668 results[orig_idx] = Some(embedding);
669 }
670 }
671 Err(_) => {
672 tracing::warn!("embedding cache mutex poisoned, bypassing cache");
673 for (embedding, &orig_idx) in computed.into_iter().zip(miss_indices.iter()) {
674 results[orig_idx] = Some(embedding);
675 }
676 }
677 }
678
679 results
680 .into_iter()
681 .map(|opt| opt.ok_or_else(|| anyhow!("unexpected None in batch result")))
682 .collect()
683 }
684}
685
686#[cfg(feature = "real-embeddings")]
687pub async fn download_bge_small_model() -> Result<PathBuf> {
688 let model_dir = default_model_dir()?;
689 let files = ensure_model_files_async(model_dir, MODEL_URL, None, TOKENIZER_URL).await?;
690 Ok(files.directory)
691}
692
693#[cfg(feature = "real-embeddings")]
694fn default_model_dir() -> Result<PathBuf> {
695 Ok(app_paths::resolve_app_paths()?.model_root.join(MODEL_NAME))
696}
697
698#[cfg(feature = "real-embeddings")]
699fn ensure_model_files_blocking(
700 model_dir: PathBuf,
701 model_url: &str,
702 model_data_url: Option<&str>,
703 tokenizer_url: &str,
704) -> Result<ModelFiles> {
705 if model_files_exist(&model_dir, model_data_url) {
706 return Ok(model_files_for_dir(model_dir, model_data_url));
707 }
708
709 let runtime = tokio::runtime::Builder::new_current_thread()
714 .enable_all()
715 .build()
716 .context("failed to create temporary tokio runtime for model download")?;
717 let model_data_url_owned = model_data_url.map(str::to_string);
718 runtime.block_on(ensure_model_files_async(
719 model_dir,
720 model_url,
721 model_data_url_owned.as_deref(),
722 tokenizer_url,
723 ))
724}
725
726#[cfg(feature = "real-embeddings")]
727async fn ensure_model_files_async(
728 model_dir: PathBuf,
729 model_url: &str,
730 model_data_url: Option<&str>,
731 tokenizer_url: &str,
732) -> Result<ModelFiles> {
733 let files = model_files_for_dir(model_dir, model_data_url);
734 if model_files_exist(&files.directory, model_data_url) {
735 return Ok(files);
736 }
737
738 tokio::fs::create_dir_all(&files.directory)
739 .await
740 .with_context(|| {
741 format!(
742 "failed to create model directory {}",
743 files.directory.display()
744 )
745 })?;
746
747 if !tokio::fs::try_exists(&files.model_path)
748 .await
749 .context("failed to check model.onnx path")?
750 {
751 download_file(model_url, &files.model_path).await?;
752 }
753 if let (Some(data_url), Some(data_path)) = (model_data_url, &files.model_data_path)
754 && !tokio::fs::try_exists(data_path)
755 .await
756 .context("failed to check model data path")?
757 {
758 download_file(data_url, data_path).await?;
759 }
760 if !tokio::fs::try_exists(&files.tokenizer_path)
761 .await
762 .context("failed to check tokenizer.json path")?
763 {
764 download_file(tokenizer_url, &files.tokenizer_path).await?;
765 }
766
767 Ok(files)
768}
769
770#[cfg(feature = "real-embeddings")]
771fn model_files_exist(model_dir: &Path, model_data_url: Option<&str>) -> bool {
772 let files = model_files_for_dir(model_dir.to_path_buf(), model_data_url);
773 let base = files.model_path.exists() && files.tokenizer_path.exists();
774 if model_data_url.is_some() {
775 base && files.model_data_path.as_ref().is_some_and(|p| p.exists())
776 } else {
777 base
778 }
779}
780
781#[cfg(feature = "real-embeddings")]
782fn model_files_for_dir(model_dir: PathBuf, model_data_url: Option<&str>) -> ModelFiles {
783 let model_data_path = model_data_url.and_then(|url| {
784 let filename = url.split('/').next_back()?;
785 if filename.is_empty() {
786 None
787 } else {
788 Some(model_dir.join(filename))
789 }
790 });
791 ModelFiles {
792 model_path: model_dir.join("model.onnx"),
793 model_data_path,
794 tokenizer_path: model_dir.join("tokenizer.json"),
795 directory: model_dir,
796 }
797}
798
799#[cfg(feature = "real-embeddings")]
800pub(crate) async fn download_file(url: &str, path: &Path) -> Result<()> {
801 let client = reqwest::Client::builder()
802 .timeout(std::time::Duration::from_secs(300))
803 .connect_timeout(std::time::Duration::from_secs(30))
804 .build()
805 .context("failed to build HTTP client")?;
806 let response = client
807 .get(url)
808 .send()
809 .await
810 .with_context(|| format!("failed to download {url}"))?
811 .error_for_status()
812 .with_context(|| format!("download request failed for {url}"))?;
813 let bytes = response
814 .bytes()
815 .await
816 .with_context(|| format!("failed to read body from {url}"))?;
817
818 let mut part_name = path.file_name().unwrap_or_default().to_os_string();
821 part_name.push(".part");
822 let part_path = path.with_file_name(part_name);
823 tokio::fs::write(&part_path, &bytes)
824 .await
825 .with_context(|| format!("failed to write temporary file {}", part_path.display()))?;
826 tokio::fs::rename(&part_path, path).await.with_context(|| {
827 format!(
828 "failed to rename {} to {}",
829 part_path.display(),
830 path.display()
831 )
832 })
833}
834
835#[cfg(test)]
836mod tests {
837 use super::*;
838 use crate::memory_core::storage::sqlite::dot_product;
839
840 #[test]
841 fn test_placeholder_embedder_dimension() {
842 let embedder = PlaceholderEmbedder;
843 assert_eq!(embedder.dimension(), 32);
844 }
845
846 #[test]
847 fn test_placeholder_embedder_deterministic() {
848 let embedder = PlaceholderEmbedder;
849 let first = embedder.embed("hello world").unwrap();
850 let second = embedder.embed("hello world").unwrap();
851 assert_eq!(first, second);
852 }
853
854 #[test]
855 fn test_placeholder_embedder_different_inputs() {
856 let embedder = PlaceholderEmbedder;
857 let first = embedder.embed("hello world").unwrap();
858 let second = embedder.embed("different text").unwrap();
859 assert_ne!(first, second);
860 }
861
862 #[test]
863 fn test_placeholder_embedder_normalized() {
864 let embedder = PlaceholderEmbedder;
865 let embedding = embedder.embed("normalized").unwrap();
866 let norm = embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
867 assert!((norm - 1.0).abs() < 1e-6);
868 }
869
870 #[test]
871 fn test_placeholder_embedder_empty_input() {
872 let embedder = PlaceholderEmbedder;
873 let embedding = embedder.embed("").unwrap();
874 assert_eq!(embedding.len(), 32);
875 }
876
877 #[test]
878 fn test_dot_product_identical() {
879 let a = vec![0.5_f32, 0.5, 0.5, 0.5];
880 let score = dot_product(&a, &a);
881 assert!((score - 1.0).abs() < 1e-6);
882 }
883
884 #[test]
885 fn test_dot_product_orthogonal() {
886 let a = vec![1.0_f32, 0.0, 0.0];
887 let b = vec![0.0_f32, 1.0, 0.0];
888 let score = dot_product(&a, &b);
889 assert!(score.abs() < 1e-6);
890 }
891
892 #[test]
893 fn test_dot_product_different_lengths() {
894 let a = vec![1.0_f32, 0.0, 0.0];
895 let b = vec![1.0_f32, 0.0];
896 let score = dot_product(&a, &b);
897 assert_eq!(score, 0.0);
898 }
899
900 #[test]
903 fn test_placeholder_embed_batch_empty() {
904 let embedder = PlaceholderEmbedder;
905 let results = embedder.embed_batch(&[]).unwrap();
906 assert!(results.is_empty());
907 }
908
909 #[test]
910 fn test_placeholder_embed_batch_single() {
911 let embedder = PlaceholderEmbedder;
912 let single = embedder.embed("hello").unwrap();
913 let batch = embedder.embed_batch(&["hello"]).unwrap();
914 assert_eq!(batch.len(), 1);
915 assert_eq!(batch[0], single);
916 }
917
918 #[test]
919 fn test_placeholder_embed_batch_multiple() {
920 let embedder = PlaceholderEmbedder;
921 let texts = ["alpha", "beta", "gamma"];
922 let batch = embedder.embed_batch(&texts).unwrap();
923 assert_eq!(batch.len(), 3);
924 for (i, text) in texts.iter().enumerate() {
926 let individual = embedder.embed(text).unwrap();
927 assert_eq!(batch[i], individual);
928 }
929 }
930
931 #[test]
932 fn test_placeholder_embed_batch_normalized() {
933 let embedder = PlaceholderEmbedder;
934 let batch = embedder.embed_batch(&["one", "two", "three"]).unwrap();
935 for emb in &batch {
936 let norm = emb.iter().map(|v| v * v).sum::<f32>().sqrt();
937 assert!((norm - 1.0).abs() < 1e-6);
938 }
939 }
940
941 #[test]
942 fn test_placeholder_embed_batch_deterministic() {
943 let embedder = PlaceholderEmbedder;
944 let first = embedder.embed_batch(&["a", "b"]).unwrap();
945 let second = embedder.embed_batch(&["a", "b"]).unwrap();
946 assert_eq!(first, second);
947 }
948}