1use anyhow::{Context, Result};
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config};
8use hf_hub::{Repo, RepoType, api::sync::Api};
9use std::sync::Arc;
10use tokenizers::Tokenizer;
11
12use crate::config::EmbeddingModel;
13
14#[must_use]
25pub fn embedding_document(
26 title: impl std::fmt::Display,
27 content: impl std::fmt::Display,
28) -> String {
29 format!("{title} {content}")
30}
31
32const MINILM_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
33#[allow(dead_code)]
34const MINILM_DIM: usize = 384;
35const MAX_SEQ_LEN: usize = 256;
36
37const HF_DOWNLOAD_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(180);
47const FALLBACK_MODEL_SUBDIR: &str =
49 ".cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main";
50
51pub(crate) const NOMIC_OLLAMA_MODEL: &str = "nomic-embed-text";
53const NOMIC_MODEL_FAMILY_NEEDLE: &str = "nomic-embed";
58pub(crate) const HF_CONFIG_FILE: &str = "config.json";
61pub(crate) const HF_TOKENIZER_FILE: &str = "tokenizer.json";
63pub(crate) const HF_WEIGHTS_FILE: &str = "model.safetensors";
65#[allow(dead_code)]
66const NOMIC_DIM: usize = 768;
67
68const NOMIC_PREFIX_DOCUMENT: &str = "search_document: ";
74const NOMIC_PREFIX_QUERY: &str = "search_query: ";
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum EmbedRole {
82 Document,
86 Query,
88}
89
90impl EmbedRole {
91 #[must_use]
93 pub fn nomic_prefix(self) -> &'static str {
94 match self {
95 Self::Document => NOMIC_PREFIX_DOCUMENT,
96 Self::Query => NOMIC_PREFIX_QUERY,
97 }
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
129pub enum EmbedStatus {
130 Indexed,
131 Skipped(String),
132 Failed(String),
133}
134
135impl EmbedStatus {
136 #[must_use]
138 pub fn as_str(&self) -> &str {
139 match self {
140 Self::Indexed => "indexed",
141 Self::Skipped(_) => "skipped",
142 Self::Failed(_) => "failed",
143 }
144 }
145
146 #[must_use]
149 pub fn is_degraded(&self) -> bool {
150 !matches!(self, Self::Indexed)
151 }
152
153 #[must_use]
155 pub fn reason(&self) -> &str {
156 match self {
157 Self::Indexed => "",
158 Self::Skipped(r) | Self::Failed(r) => r.as_str(),
159 }
160 }
161}
162
163impl std::fmt::Display for EmbedStatus {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match self {
166 Self::Indexed => write!(f, "indexed"),
167 Self::Skipped(r) => write!(f, "skipped: {r}"),
168 Self::Failed(r) => write!(f, "failed: {r}"),
169 }
170 }
171}
172
173pub const EMBED_MAX_BYTES: usize = 64 * 1024;
180
181#[must_use]
188pub fn oversize_embed_reason(byte_len: usize) -> Option<String> {
189 (byte_len > EMBED_MAX_BYTES)
190 .then(|| format!("content {byte_len} bytes exceeds embed cap {EMBED_MAX_BYTES} bytes"))
191}
192
193pub trait Embed: Send + Sync {
213 fn embed(&self, text: &str) -> Result<Vec<f32>>;
222
223 fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
234 self.embed(text)
235 }
236
237 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
245 texts.iter().map(|t| self.embed(t)).collect()
246 }
247
248 fn is_degraded(&self) -> bool {
254 false
255 }
256}
257
258#[derive(Clone)]
263pub enum Embedder {
264 Local {
275 model: Arc<BertModel>,
276 tokenizer: Arc<Tokenizer>,
277 device: Device,
278 },
279 Ollama {
289 client: Arc<crate::llm::OllamaClient>,
290 model_name: String,
291 dim: usize,
292 degraded: Arc<std::sync::atomic::AtomicBool>,
293 },
294}
295
296#[derive(Debug, Clone, Copy, PartialEq)]
308pub enum CosineComparison {
309 Comparable(f32),
311 DimensionMismatch {
315 query_dim: usize,
317 stored_dim: usize,
319 },
320}
321
322impl Embedder {
323 #[allow(dead_code)]
326 pub fn new() -> Result<Self> {
327 Self::new_local()
328 }
329
330 pub fn new_local() -> Result<Self> {
332 let device = Device::Cpu;
333
334 let (config_path, tokenizer_path, weights_path) = if Self::remote_fetch_disabled() {
335 Self::load_from_fallback()?
345 } else {
346 match Self::download_within(HF_DOWNLOAD_TIMEOUT, Self::download_via_hf_hub) {
347 Ok(paths) => paths,
348 Err(e) => {
349 eprintln!("ai-memory: hf-hub download failed ({e}), trying fallback dir");
350 Self::load_from_fallback()?
351 }
352 }
353 };
354
355 let config_data =
356 std::fs::read_to_string(&config_path).context("failed to read config.json")?;
357 let config: Config =
358 serde_json::from_str(&config_data).context("failed to parse config.json")?;
359
360 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
361 .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
362
363 let truncation = tokenizers::TruncationParams {
364 max_length: MAX_SEQ_LEN,
365 ..Default::default()
366 };
367 tokenizer
368 .with_truncation(Some(truncation))
369 .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
370 tokenizer.with_padding(None);
371
372 let vb = unsafe {
373 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
374 .context("failed to load model weights")?
375 };
376 let model = BertModel::load(vb, &config).context("failed to build BertModel")?;
377
378 Ok(Self::Local {
379 model: Arc::new(model),
380 tokenizer: Arc::new(tokenizer),
381 device,
382 })
383 }
384
385 pub fn new_ollama(client: Arc<crate::llm::OllamaClient>) -> Self {
389 Self::new_remote(client, NOMIC_OLLAMA_MODEL.to_string(), NOMIC_DIM)
390 }
391
392 #[must_use]
399 pub fn new_remote(
400 client: Arc<crate::llm::OllamaClient>,
401 model_name: String,
402 dim: usize,
403 ) -> Self {
404 Self::Ollama {
405 client,
406 model_name,
407 dim,
408 degraded: Arc::new(std::sync::atomic::AtomicBool::new(false)),
409 }
410 }
411
412 pub fn from_resolved(
436 resolved: &crate::config::ResolvedEmbeddings,
437 tier_model: Option<crate::config::EmbeddingModel>,
438 ) -> Result<Option<Self>> {
439 let Some(tier_model) = tier_model else {
440 return Ok(None);
442 };
443 if crate::config::is_api_embed_backend(&resolved.backend) {
444 let Some(dim) = resolved.embedding_dim else {
445 anyhow::bail!(
446 "embedding model {:?} (backend {:?}) has no known vector dim — \
447 pick a model from the known-dims table (override with the \
448 {} env var) or set the `[embeddings].dim` escape hatch in \
449 config.toml (#1598)",
450 resolved.model,
451 resolved.backend,
452 crate::config::ENV_EMBED_MODEL,
453 );
454 };
455 let api_key = resolved.api_key().unwrap_or_default();
459 let client = crate::llm::OllamaClient::new_openai_compatible(
460 &resolved.url,
461 &resolved.model,
462 api_key,
463 )
464 .context("failed to build OpenAI-compatible embed client (#1598)")?
465 .with_embed_dimensions(resolved.requested_dim);
469 return Ok(Some(Self::new_remote(
470 Arc::new(client),
471 resolved.model.clone(),
472 dim as usize,
473 )));
474 }
475 match tier_model {
476 crate::config::EmbeddingModel::MiniLmL6V2 => {
477 Self::for_model(tier_model, None).map(Some)
478 }
479 crate::config::EmbeddingModel::NomicEmbedV15 => {
480 let client =
481 crate::llm::OllamaClient::new_with_url(&resolved.url, NOMIC_OLLAMA_MODEL)
482 .context("failed to build Ollama embed client")?;
483 Self::for_model(tier_model, Some(Arc::new(client))).map(Some)
484 }
485 }
486 }
487
488 pub fn for_model(
493 model: EmbeddingModel,
494 ollama_client: Option<Arc<crate::llm::OllamaClient>>,
495 ) -> Result<Self> {
496 match model {
497 EmbeddingModel::MiniLmL6V2 => Self::new_local(),
498 EmbeddingModel::NomicEmbedV15 => {
499 let client = ollama_client.ok_or_else(|| {
500 anyhow::anyhow!("nomic-embed-text-v1.5 requires Ollama (smart tier or above)")
501 })?;
502 if let Err(e) = client.ensure_embed_model(NOMIC_OLLAMA_MODEL) {
504 eprintln!("ai-memory: warning: failed to pull nomic model: {e}");
505 }
506 Ok(Self::new_ollama(client))
507 }
508 }
509 }
510
511 #[allow(dead_code)]
513 pub fn dim(&self) -> usize {
514 match self {
515 Self::Local { .. } => MINILM_DIM,
516 Self::Ollama { dim, .. } => *dim,
517 }
518 }
519
520 #[must_use]
525 pub fn model_description(&self) -> String {
526 match self {
527 Self::Local { .. } => "all-MiniLM-L6-v2 (384-dim, local)".to_string(),
528 Self::Ollama {
529 model_name, dim, ..
530 } => format!("{model_name} ({dim}-dim, remote)"),
531 }
532 }
533
534 #[must_use]
541 pub fn is_degraded(&self) -> bool {
542 match self {
543 Self::Local { .. } => false,
544 Self::Ollama { degraded, .. } => degraded.load(std::sync::atomic::Ordering::Relaxed),
545 }
546 }
547
548 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
553 self.embed_with_role(text, EmbedRole::Document)
554 }
555
556 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
563 self.embed_with_role(text, EmbedRole::Query)
564 }
565
566 pub fn embed_with_role(&self, text: &str, role: EmbedRole) -> Result<Vec<f32>> {
572 match self {
573 Self::Local {
574 model,
575 tokenizer,
576 device,
577 } => {
578 Self::embed_local(model, tokenizer, device, text)
584 }
585 Self::Ollama {
586 client,
587 model_name,
588 degraded,
589 ..
590 } => {
591 let result = if Self::model_requires_nomic_prefix(model_name) {
592 let prefixed = format!("{}{}", role.nomic_prefix(), text);
593 client.embed_text(&prefixed, model_name)
594 } else {
595 client.embed_text(text, model_name)
596 };
597 degraded.store(result.is_err(), std::sync::atomic::Ordering::Relaxed);
602 result
603 }
604 }
605 }
606
607 fn model_requires_nomic_prefix(model_name: &str) -> bool {
616 model_name
617 .to_ascii_lowercase()
618 .contains(NOMIC_MODEL_FAMILY_NEEDLE)
619 }
620
621 pub fn embed_with_status(&self, text: &str) -> (Option<Vec<f32>>, EmbedStatus) {
636 if text.is_empty() {
637 return (None, EmbedStatus::Skipped("empty content".to_string()));
638 }
639 if let Some(reason) = oversize_embed_reason(text.len()) {
640 return (None, EmbedStatus::Skipped(reason));
641 }
642 match self.embed(text) {
643 Ok(v) if v.is_empty() => (
644 None,
645 EmbedStatus::Failed("embedder returned empty vector".to_string()),
646 ),
647 Ok(v) => (Some(v), EmbedStatus::Indexed),
648 Err(e) => {
649 let reason = format!("{e:#}");
650 tracing::warn!(target: "embeddings.degrade", reason = %reason, "embed_with_status: embedder failed");
651 (None, EmbedStatus::Failed(reason))
652 }
653 }
654 }
655
656 fn embed_local(
657 model: &BertModel,
658 tokenizer: &Tokenizer,
659 device: &Device,
660 text: &str,
661 ) -> Result<Vec<f32>> {
662 let encoding = tokenizer
663 .encode(text, true)
664 .map_err(|e| anyhow::anyhow!("tokenisation failed: {e}"))?;
665
666 let input_ids = encoding.get_ids();
667 let attention_mask = encoding.get_attention_mask();
668 let token_type_ids = encoding.get_type_ids();
669 let seq_len = input_ids.len();
670
671 let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
672 let attention_mask_tensor = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
673 let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
674
675 let hidden = model
676 .forward(&input_ids, &token_type_ids, Some(&attention_mask_tensor))
677 .context("model forward pass failed")?;
678
679 let mask = attention_mask_tensor
680 .unsqueeze(2)?
681 .to_dtype(candle_core::DType::F32)?
682 .broadcast_as(hidden.shape())?;
683 let masked = hidden.mul(&mask)?;
684 let summed = masked.sum(1)?;
685 let count = mask.sum(1)?.clamp(1e-9, f64::MAX)?;
686 let pooled = summed.div(&count)?;
687
688 let norm = pooled
689 .sqr()?
690 .sum_keepdim(1)?
691 .sqrt()?
692 .clamp(1e-12, f64::MAX)?;
693 let normalised = pooled.broadcast_div(&norm)?;
694
695 let embedding: Vec<f32> = normalised.squeeze(0)?.to_vec1()?;
696 Ok(embedding)
697 }
698
699 #[allow(dead_code)]
717 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
718 if texts.is_empty() {
719 return Ok(Vec::new());
720 }
721 match self {
722 Self::Local {
723 model,
724 tokenizer,
725 device,
726 } => Self::embed_local_batch(model, tokenizer, device, texts),
727 Self::Ollama {
737 client,
738 model_name,
739 degraded,
740 ..
741 } => {
742 let result = if Self::model_requires_nomic_prefix(model_name) {
743 let prefixed: Vec<String> = texts
744 .iter()
745 .map(|t| format!("{}{}", EmbedRole::Document.nomic_prefix(), t))
746 .collect();
747 let refs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
748 client.embed_texts(&refs, model_name)
749 } else {
750 client.embed_texts(texts, model_name)
751 };
752 degraded.store(result.is_err(), std::sync::atomic::Ordering::Relaxed);
755 result
756 }
757 }
758 }
759
760 fn embed_local_batch(
763 model: &BertModel,
764 tokenizer: &Tokenizer,
765 device: &Device,
766 texts: &[&str],
767 ) -> Result<Vec<Vec<f32>>> {
768 let inputs: Vec<&str> = texts.to_vec();
772 let encodings = tokenizer
773 .encode_batch(inputs, true)
774 .map_err(|e| anyhow::anyhow!("tokenisation batch failed: {e}"))?;
775
776 let max_len = encodings
778 .iter()
779 .map(tokenizers::Encoding::len)
780 .max()
781 .unwrap_or(0);
782 if max_len == 0 {
783 return Ok(texts.iter().map(|_| Vec::new()).collect());
786 }
787
788 let batch_size = encodings.len();
789
790 let mut input_ids_flat = Vec::with_capacity(batch_size * max_len);
795 let mut attention_mask_flat = Vec::with_capacity(batch_size * max_len);
796 let mut token_type_ids_flat = Vec::with_capacity(batch_size * max_len);
797 for enc in &encodings {
798 let ids = enc.get_ids();
799 let mask = enc.get_attention_mask();
800 let tt = enc.get_type_ids();
801 let len = ids.len();
802 input_ids_flat.extend_from_slice(ids);
803 attention_mask_flat.extend_from_slice(mask);
804 token_type_ids_flat.extend_from_slice(tt);
805 for _ in len..max_len {
807 input_ids_flat.push(0);
808 attention_mask_flat.push(0);
809 token_type_ids_flat.push(0);
810 }
811 }
812
813 let input_ids =
814 Tensor::new(input_ids_flat.as_slice(), device)?.reshape((batch_size, max_len))?;
815 let attention_mask_tensor =
816 Tensor::new(attention_mask_flat.as_slice(), device)?.reshape((batch_size, max_len))?;
817 let token_type_ids =
818 Tensor::new(token_type_ids_flat.as_slice(), device)?.reshape((batch_size, max_len))?;
819
820 let hidden = model
821 .forward(&input_ids, &token_type_ids, Some(&attention_mask_tensor))
822 .context("model forward pass (batched) failed")?;
823
824 let mask = attention_mask_tensor
826 .unsqueeze(2)?
827 .to_dtype(candle_core::DType::F32)?
828 .broadcast_as(hidden.shape())?;
829 let masked = hidden.mul(&mask)?;
830 let summed = masked.sum(1)?;
831 let count = mask.sum(1)?.clamp(1e-9, f64::MAX)?;
832 let pooled = summed.div(&count)?;
833
834 let norm = pooled
835 .sqr()?
836 .sum_keepdim(1)?
837 .sqrt()?
838 .clamp(1e-12, f64::MAX)?;
839 let normalised = pooled.broadcast_div(&norm)?;
840
841 let mut out: Vec<Vec<f32>> = Vec::with_capacity(batch_size);
843 for i in 0..batch_size {
844 let row: Vec<f32> = normalised.get(i)?.to_vec1()?;
845 out.push(row);
846 }
847 Ok(out)
848 }
849
850 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
852 if a.len() != b.len() {
854 return 0.0;
855 }
856
857 let mut dot: f32 = 0.0;
865 let mut sq_a: f32 = 0.0;
866 let mut sq_b: f32 = 0.0;
867 for (&x, &y) in a.iter().zip(b.iter()) {
868 dot += x * y;
869 sq_a += x * x;
870 sq_b += y * y;
871 }
872 let denom = sq_a.sqrt() * sq_b.sqrt();
873 if denom < 1e-12 {
874 return 0.0;
875 }
876 let score = dot / denom;
877 if score.is_finite() { score } else { 0.0 }
885 }
886
887 #[must_use]
896 pub fn cosine_similarity_checked(query: &[f32], stored: &[f32]) -> CosineComparison {
897 if query.len() != stored.len() {
898 return CosineComparison::DimensionMismatch {
899 query_dim: query.len(),
900 stored_dim: stored.len(),
901 };
902 }
903 CosineComparison::Comparable(Self::cosine_similarity(query, stored))
904 }
905
906 #[must_use]
915 pub fn fuse(primary: &[f32], secondary: &[f32], primary_weight: f32) -> Vec<f32> {
916 if primary.len() != secondary.len() {
917 return primary.to_vec();
918 }
919 let w = primary_weight.clamp(0.0, 1.0);
920 let one_minus_w = 1.0 - w;
921 primary
922 .iter()
923 .zip(secondary.iter())
924 .map(|(p, s)| w * p + one_minus_w * s)
925 .collect()
926 }
927
928 fn download_within<F>(
941 budget: std::time::Duration,
942 f: F,
943 ) -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
944 where
945 F: FnOnce() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
946 + Send
947 + 'static,
948 {
949 let (tx, rx) = std::sync::mpsc::channel();
950 std::thread::spawn(move || {
951 let _ = tx.send(f());
954 });
955 match rx.recv_timeout(budget) {
956 Ok(result) => result,
957 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => anyhow::bail!(
958 "hf-hub model download exceeded {}s budget",
959 budget.as_secs()
960 ),
961 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
962 anyhow::bail!("hf-hub model download thread terminated without a result")
963 }
964 }
965 }
966
967 fn download_via_hf_hub() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
968 {
969 let api = Api::new().context("failed to initialise HuggingFace Hub API")?;
970 let repo = api.repo(Repo::new(MINILM_MODEL_ID.to_string(), RepoType::Model));
971 let config_path = repo
972 .get(HF_CONFIG_FILE)
973 .context("failed to download config.json")?;
974 let tokenizer_path = repo
975 .get(HF_TOKENIZER_FILE)
976 .context("failed to download tokenizer.json")?;
977 let weights_path = repo
978 .get(HF_WEIGHTS_FILE)
979 .context("failed to download model.safetensors")?;
980 Ok((config_path, tokenizer_path, weights_path))
981 }
982
983 fn remote_fetch_disabled() -> bool {
989 let truthy = |name: &str| {
990 std::env::var(name)
991 .map(|v| matches!(v.trim(), "1" | "true" | "TRUE" | "yes" | "on"))
992 .unwrap_or(false)
993 };
994 truthy("AI_MEMORY_EMBED_OFFLINE") || truthy("HF_HUB_OFFLINE")
995 }
996
997 fn load_from_fallback() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
998 {
999 let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
1000 let dir = std::path::PathBuf::from(home).join(FALLBACK_MODEL_SUBDIR);
1001 let dir = dir.as_path();
1002 let config = dir.join(HF_CONFIG_FILE);
1003 let tokenizer = dir.join(HF_TOKENIZER_FILE);
1004 let weights = dir.join(HF_WEIGHTS_FILE);
1005 if config.exists() && tokenizer.exists() && weights.exists() {
1006 Ok((config, tokenizer, weights))
1007 } else {
1008 anyhow::bail!(
1009 "model files not found in fallback dir: {}. Download them manually from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
1010 dir.display()
1011 )
1012 }
1013 }
1014}
1015
1016impl Embed for Embedder {
1023 fn embed(&self, text: &str) -> Result<Vec<f32>> {
1024 Self::embed(self, text)
1025 }
1026
1027 fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
1028 Self::embed_query(self, text)
1029 }
1030
1031 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1032 Self::embed_batch(self, texts)
1033 }
1034
1035 fn is_degraded(&self) -> bool {
1036 Self::is_degraded(self)
1037 }
1038}
1039
1040#[allow(dead_code)]
1042pub const EMBEDDING_DIM: usize = MINILM_DIM;
1043
1044pub const EMBEDDING_HEADER_LE_F32: u8 = 0x01;
1061
1062pub const EMBEDDING_HEADER_BE_F32: u8 = 0x02;
1065
1066#[derive(Debug)]
1079pub enum EmbeddingFormatError {
1080 UnknownHeader(u8),
1081 BigEndianUnsupported,
1082 MalformedLength(usize),
1083}
1084
1085impl std::fmt::Display for EmbeddingFormatError {
1086 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1087 match self {
1088 Self::UnknownHeader(b) => write!(f, "unknown embedding header byte: 0x{b:02x}"),
1089 Self::BigEndianUnsupported => write!(
1090 f,
1091 "big-endian f32 embeddings (header 0x02) are not supported until v0.7"
1092 ),
1093 Self::MalformedLength(n) => {
1094 write!(f, "embedding payload length {n} is not a multiple of 4")
1095 }
1096 }
1097 }
1098}
1099
1100impl std::error::Error for EmbeddingFormatError {}
1101
1102#[must_use]
1110pub fn encode_embedding_blob(embedding: &[f32]) -> Vec<u8> {
1111 let mut out = Vec::with_capacity(1 + embedding.len() * 4);
1112 out.push(EMBEDDING_HEADER_LE_F32);
1113 for f in embedding {
1114 out.extend_from_slice(&f.to_le_bytes());
1115 }
1116 out
1117}
1118
1119pub fn decode_embedding_blob(bytes: &[u8]) -> Result<Vec<f32>, EmbeddingFormatError> {
1135 if bytes.is_empty() {
1136 return Ok(Vec::new());
1137 }
1138
1139 if bytes.len() % 4 == 1 {
1141 let header = bytes[0];
1142 return match header {
1143 EMBEDDING_HEADER_LE_F32 => {
1144 let payload = &bytes[1..];
1145 Ok(payload
1146 .chunks_exact(4)
1147 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1148 .collect())
1149 }
1150 EMBEDDING_HEADER_BE_F32 => Err(EmbeddingFormatError::BigEndianUnsupported),
1151 other => Err(EmbeddingFormatError::UnknownHeader(other)),
1152 };
1153 }
1154
1155 if bytes.len() % 4 == 0 {
1157 return Ok(bytes
1158 .chunks_exact(4)
1159 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1160 .collect());
1161 }
1162
1163 Err(EmbeddingFormatError::MalformedLength(bytes.len()))
1164}
1165
1166#[must_use]
1170pub fn decoded_dim(bytes: &[u8]) -> usize {
1171 if bytes.is_empty() {
1172 return 0;
1173 }
1174 if bytes.len() % 4 == 1 {
1175 return (bytes.len() - 1) / 4;
1176 }
1177 bytes.len() / 4
1178}
1179
1180#[cfg(test)]
1181mod tests {
1182 use super::*;
1183
1184 #[test]
1185 fn cosine_similarity_identical() {
1186 let v = vec![1.0, 0.0, 0.0];
1187 let sim = Embedder::cosine_similarity(&v, &v);
1188 assert!((sim - 1.0).abs() < 1e-6);
1189 }
1190
1191 #[test]
1192 fn embed_role_maps_to_nomic_prefix() {
1193 assert_eq!(EmbedRole::Document.nomic_prefix(), NOMIC_PREFIX_DOCUMENT);
1196 assert_eq!(EmbedRole::Query.nomic_prefix(), NOMIC_PREFIX_QUERY);
1197 assert_ne!(
1198 EmbedRole::Document.nomic_prefix(),
1199 EmbedRole::Query.nomic_prefix()
1200 );
1201 assert!(NOMIC_PREFIX_DOCUMENT.ends_with(' '));
1203 assert!(NOMIC_PREFIX_QUERY.ends_with(' '));
1204 }
1205
1206 #[test]
1207 fn nomic_prefix_gating_is_model_scoped() {
1208 assert!(Embedder::model_requires_nomic_prefix(NOMIC_OLLAMA_MODEL));
1213 assert!(Embedder::model_requires_nomic_prefix(&format!(
1214 "{NOMIC_OLLAMA_MODEL}:v1.5"
1215 )));
1216 let other_embed_models = ["mxbai-embed-large", "all-minilm"];
1219 for model in other_embed_models {
1220 assert!(!Embedder::model_requires_nomic_prefix(model));
1221 }
1222 }
1223
1224 fn offline_openai_compatible_client() -> Arc<crate::llm::OllamaClient> {
1229 Arc::new(
1230 crate::llm::OllamaClient::new_openai_compatible(
1231 "http://127.0.0.1:1",
1232 "test-embed-model",
1233 "",
1234 )
1235 .expect("client builds without network"),
1236 )
1237 }
1238
1239 #[test]
1240 fn new_remote_carries_dynamic_dim_and_truthful_description_1598() {
1241 let embedder = Embedder::new_remote(
1242 offline_openai_compatible_client(),
1243 "google/gemini-embedding-2".to_string(),
1244 3072,
1245 );
1246 assert_eq!(embedder.dim(), 3072);
1247 assert_eq!(
1248 embedder.model_description(),
1249 "google/gemini-embedding-2 (3072-dim, remote)"
1250 );
1251 assert!(!embedder.is_degraded());
1252 }
1253
1254 #[test]
1255 fn new_ollama_preserves_nomic_defaults_1598() {
1256 let embedder = Embedder::new_ollama(offline_openai_compatible_client());
1257 assert_eq!(embedder.dim(), NOMIC_DIM);
1258 let desc = embedder.model_description();
1259 assert!(desc.contains(NOMIC_OLLAMA_MODEL), "desc: {desc}");
1260 assert!(desc.contains("768"), "desc: {desc}");
1261 assert!(!embedder.is_degraded());
1262 }
1263
1264 #[test]
1265 fn remote_embed_failure_latches_degraded_flag_1598() {
1266 let embedder = Embedder::new_remote(
1269 offline_openai_compatible_client(),
1270 "test-embed-model".to_string(),
1271 8,
1272 );
1273 assert!(!embedder.is_degraded());
1274 let err = embedder.embed("hello");
1275 assert!(err.is_err(), "embed against a closed port must error");
1276 assert!(embedder.is_degraded());
1277 }
1278
1279 #[test]
1280 fn local_embedder_is_never_degraded_via_trait_default_1598() {
1281 let mock = crate::embeddings::test_support::MockEmbedder::new_ollama();
1284 let as_trait: &dyn Embed = &mock;
1285 assert!(!as_trait.is_degraded());
1286 }
1287
1288 #[test]
1289 fn from_resolved_keyword_tier_yields_none_1598() {
1290 let resolved = crate::config::ResolvedEmbeddings::from_parts(
1291 "openrouter".to_string(),
1292 "https://openrouter.ai/api/v1".to_string(),
1293 "google/gemini-embedding-2".to_string(),
1294 Some(3072),
1295 None,
1296 );
1297 let built = Embedder::from_resolved(&resolved, None).expect("keyword tier is Ok(None)");
1298 assert!(built.is_none());
1299 }
1300
1301 #[test]
1302 fn from_resolved_api_backend_unknown_dim_bails_with_escape_hatch_1598() {
1303 let resolved = crate::config::ResolvedEmbeddings::from_parts(
1304 "openrouter".to_string(),
1305 "https://openrouter.ai/api/v1".to_string(),
1306 "some/unknown-embed-model".to_string(),
1307 None,
1308 None,
1309 );
1310 let result = Embedder::from_resolved(
1311 &resolved,
1312 Some(crate::config::EmbeddingModel::NomicEmbedV15),
1313 );
1314 let Err(err) = result else {
1315 panic!("unknown dim on an API backend must fail closed");
1316 };
1317 let msg = format!("{err:#}");
1318 assert!(msg.contains("dim"), "error must name the dim gap: {msg}");
1319 assert!(
1320 msg.contains("[embeddings].dim"),
1321 "error must name the config escape hatch: {msg}"
1322 );
1323 assert!(
1324 msg.contains(crate::config::ENV_EMBED_MODEL),
1325 "error must name the model env var: {msg}"
1326 );
1327 }
1328
1329 #[test]
1330 fn from_resolved_api_backend_builds_remote_embedder_1598() {
1331 let resolved = crate::config::ResolvedEmbeddings::from_parts(
1335 "openrouter".to_string(),
1336 "https://openrouter.ai/api/v1".to_string(),
1337 "google/gemini-embedding-2".to_string(),
1338 Some(3072),
1339 None,
1340 );
1341 let built = Embedder::from_resolved(
1342 &resolved,
1343 Some(crate::config::EmbeddingModel::NomicEmbedV15),
1344 )
1345 .expect("API-backend construction succeeds")
1346 .expect("tier gates embeddings on");
1347 assert!(matches!(built, Embedder::Ollama { .. }));
1348 assert_eq!(built.dim(), 3072);
1349 assert_eq!(
1350 built.model_description(),
1351 "google/gemini-embedding-2 (3072-dim, remote)"
1352 );
1353 }
1354
1355 #[test]
1356 fn nomic_prefix_gating_covers_hf_id_and_case_forms_1598() {
1357 assert!(Embedder::model_requires_nomic_prefix(
1360 "nomic-ai/nomic-embed-text-v1.5"
1361 ));
1362 assert!(Embedder::model_requires_nomic_prefix(
1363 "nomic-embed-text-v1.5"
1364 ));
1365 assert!(Embedder::model_requires_nomic_prefix(
1366 "Nomic-AI/Nomic-Embed-Text-v1.5"
1367 ));
1368 assert!(!Embedder::model_requires_nomic_prefix(
1370 "google/gemini-embedding-2"
1371 ));
1372 assert!(!Embedder::model_requires_nomic_prefix(
1373 "ibm-granite/granite-embedding-125m-english"
1374 ));
1375 }
1376
1377 #[test]
1378 fn cosine_similarity_orthogonal() {
1379 let a = vec![1.0, 0.0, 0.0];
1380 let b = vec![0.0, 1.0, 0.0];
1381 let sim = Embedder::cosine_similarity(&a, &b);
1382 assert!(sim.abs() < 1e-6);
1383 }
1384
1385 #[test]
1386 fn cosine_similarity_opposite() {
1387 let a = vec![1.0, 0.0];
1388 let b = vec![-1.0, 0.0];
1389 let sim = Embedder::cosine_similarity(&a, &b);
1390 assert!((sim + 1.0).abs() < 1e-6);
1391 }
1392
1393 #[test]
1394 fn cosine_similarity_zero_vector() {
1395 let a = vec![0.0, 0.0, 0.0];
1396 let b = vec![1.0, 2.0, 3.0];
1397 let sim = Embedder::cosine_similarity(&a, &b);
1398 assert_eq!(sim, 0.0);
1399 }
1400
1401 #[test]
1402 fn cosine_similarity_dimension_mismatch() {
1403 let a = vec![1.0, 0.0, 0.0];
1404 let b = vec![1.0, 0.0]; let sim = Embedder::cosine_similarity(&a, &b);
1406 assert_eq!(sim, 0.0);
1407 }
1408
1409 #[test]
1412 fn cosine_similarity_checked_comparable_matches_plain_cosine() {
1413 let a = vec![1.0, 2.0, 3.0];
1414 let b = vec![2.0, 1.0, 0.5];
1415 let plain = Embedder::cosine_similarity(&a, &b);
1416 match Embedder::cosine_similarity_checked(&a, &b) {
1417 CosineComparison::Comparable(c) => assert!((c - plain).abs() < 1e-6),
1418 CosineComparison::DimensionMismatch { .. } => {
1419 panic!("equal-length vectors must compare as Comparable")
1420 }
1421 }
1422 }
1423
1424 #[test]
1425 fn cosine_similarity_checked_flags_dimension_mismatch() {
1426 let query = vec![0.0_f32; 5];
1430 let stored = vec![0.0_f32; 3];
1431 match Embedder::cosine_similarity_checked(&query, &stored) {
1432 CosineComparison::DimensionMismatch {
1433 query_dim,
1434 stored_dim,
1435 } => {
1436 assert_eq!(query_dim, 5);
1437 assert_eq!(stored_dim, 3);
1438 }
1439 CosineComparison::Comparable(_) => {
1440 panic!("differing-length vectors must report DimensionMismatch")
1441 }
1442 }
1443 }
1444
1445 #[test]
1448 fn encode_embedding_blob_prefixes_le_header() {
1449 let v = vec![1.0_f32, 2.0_f32];
1450 let blob = encode_embedding_blob(&v);
1451 assert_eq!(blob.len(), 1 + 8);
1452 assert_eq!(blob[0], EMBEDDING_HEADER_LE_F32);
1453 }
1454
1455 #[test]
1456 fn decode_embedding_blob_round_trip_v17() {
1457 let v = vec![1.5_f32, -0.25, 0.0];
1458 let blob = encode_embedding_blob(&v);
1459 let back = decode_embedding_blob(&blob).expect("round-trips");
1460 assert_eq!(back, v);
1461 }
1462
1463 #[test]
1464 fn decode_embedding_blob_legacy_unheaded_le_f32() {
1465 let v = vec![1.0_f32, 2.0, 3.0];
1467 let raw: Vec<u8> = v.iter().flat_map(|f| f.to_le_bytes()).collect();
1468 let back = decode_embedding_blob(&raw).expect("legacy decodes");
1469 assert_eq!(back, v);
1470 }
1471
1472 #[test]
1473 fn decode_embedding_blob_rejects_be_header() {
1474 let mut blob = vec![EMBEDDING_HEADER_BE_F32];
1475 blob.extend_from_slice(&1.0_f32.to_be_bytes());
1476 let err = decode_embedding_blob(&blob).expect_err("BE rejected");
1477 assert!(matches!(err, EmbeddingFormatError::BigEndianUnsupported));
1478 }
1479
1480 #[test]
1481 fn decode_embedding_blob_rejects_unknown_header() {
1482 let mut blob = vec![0xff_u8];
1483 blob.extend_from_slice(&1.0_f32.to_le_bytes());
1484 let err = decode_embedding_blob(&blob).expect_err("unknown header rejected");
1485 assert!(matches!(err, EmbeddingFormatError::UnknownHeader(0xff)));
1486 }
1487
1488 #[test]
1489 fn decode_embedding_blob_rejects_malformed_length() {
1490 let blob = vec![0u8; 6];
1492 let err = decode_embedding_blob(&blob).expect_err("malformed length rejected");
1493 assert!(matches!(err, EmbeddingFormatError::MalformedLength(6)));
1494 }
1495
1496 #[test]
1497 fn decoded_dim_handles_all_three_paths() {
1498 assert_eq!(decoded_dim(&[]), 0);
1500 let raw: Vec<u8> = vec![0u8; 16];
1502 assert_eq!(decoded_dim(&raw), 4);
1503 let mut headed = vec![EMBEDDING_HEADER_LE_F32];
1505 headed.extend_from_slice(&[0u8; 12]);
1506 assert_eq!(decoded_dim(&headed), 3);
1507 }
1508
1509 #[test]
1512 fn fuse_weighted_sum() {
1513 let p = vec![1.0, 0.0, 0.0];
1514 let s = vec![0.0, 1.0, 0.0];
1515 let f = Embedder::fuse(&p, &s, 0.7);
1516 assert!((f[0] - 0.7).abs() < 1e-6);
1517 assert!((f[1] - 0.3).abs() < 1e-6);
1518 assert!((f[2] - 0.0).abs() < 1e-6);
1519 }
1520
1521 #[test]
1522 fn fuse_primary_weight_clamped() {
1523 let p = vec![1.0, 1.0];
1524 let s = vec![0.0, 0.0];
1525 let f = Embedder::fuse(&p, &s, 2.0);
1526 assert!((f[0] - 1.0).abs() < 1e-6);
1528 assert!((f[1] - 1.0).abs() < 1e-6);
1529
1530 let f = Embedder::fuse(&p, &s, -0.5);
1531 assert!((f[0] - 0.0).abs() < 1e-6);
1533 assert!((f[1] - 0.0).abs() < 1e-6);
1534 }
1535
1536 #[test]
1537 fn fuse_dimension_mismatch_returns_primary() {
1538 let p = vec![1.0, 2.0, 3.0];
1539 let s = vec![4.0, 5.0]; let f = Embedder::fuse(&p, &s, 0.7);
1541 assert_eq!(f, p);
1542 }
1543
1544 #[test]
1545 fn fuse_cosine_pulls_toward_context() {
1546 let q = vec![1.0_f32, 0.0];
1549 let ctx = vec![0.0_f32, 1.0];
1550 let fused = Embedder::fuse(&q, &ctx, 0.7);
1551 let sim_q = Embedder::cosine_similarity(&fused, &q);
1553 let sim_ctx = Embedder::cosine_similarity(&fused, &ctx);
1554 assert!(sim_q > sim_ctx);
1555 assert!(sim_q > 0.9); assert!(sim_ctx > 0.3); }
1558
1559 #[test]
1564 fn test_fuse_with_weight_one_returns_primary() {
1565 let primary = vec![0.6_f32, -0.8, 0.0]; let secondary = vec![0.0_f32, 0.0, 1.0];
1570 let fused = Embedder::fuse(&primary, &secondary, 1.0);
1571 assert_eq!(fused.len(), primary.len());
1572 for (i, (f, p)) in fused.iter().zip(primary.iter()).enumerate() {
1573 assert!(
1574 (f - p).abs() < 1e-6,
1575 "fuse weight=1 idx {i}: fused {} != primary {}",
1576 f,
1577 p
1578 );
1579 }
1580
1581 let sim = Embedder::cosine_similarity(&fused, &primary);
1584 assert!(
1585 (sim - 1.0).abs() < 1e-6,
1586 "cos(fuse(p,s,1.0), p) must be 1.0"
1587 );
1588 }
1589
1590 #[test]
1595 fn embed_status_as_str_each_variant() {
1596 assert_eq!(EmbedStatus::Indexed.as_str(), "indexed");
1597 assert_eq!(
1598 EmbedStatus::Skipped("too big".to_string()).as_str(),
1599 "skipped"
1600 );
1601 assert_eq!(
1602 EmbedStatus::Failed("ollama down".to_string()).as_str(),
1603 "failed"
1604 );
1605 }
1606
1607 #[test]
1610 fn oversize_embed_reason_boundary_1595() {
1611 assert_eq!(oversize_embed_reason(0), None);
1612 assert_eq!(
1613 oversize_embed_reason(EMBED_MAX_BYTES),
1614 None,
1615 "cap itself is allowed"
1616 );
1617 let reason = oversize_embed_reason(EMBED_MAX_BYTES + 1).expect("over-cap must skip");
1618 assert!(
1619 reason.contains(&(EMBED_MAX_BYTES + 1).to_string())
1620 && reason.contains(&EMBED_MAX_BYTES.to_string()),
1621 "reason must name size + cap, got: {reason}"
1622 );
1623 }
1624
1625 #[test]
1626 fn embed_status_is_degraded_only_for_non_indexed() {
1627 assert!(!EmbedStatus::Indexed.is_degraded());
1628 assert!(EmbedStatus::Skipped("x".to_string()).is_degraded());
1629 assert!(EmbedStatus::Failed("x".to_string()).is_degraded());
1630 }
1631
1632 #[test]
1633 fn embed_status_reason_helper() {
1634 assert_eq!(EmbedStatus::Indexed.reason(), "");
1635 assert_eq!(EmbedStatus::Skipped("r1".to_string()).reason(), "r1");
1636 assert_eq!(EmbedStatus::Failed("r2".to_string()).reason(), "r2");
1637 }
1638
1639 #[test]
1640 fn embed_status_display_includes_reason() {
1641 assert_eq!(format!("{}", EmbedStatus::Indexed), "indexed");
1642 assert_eq!(
1643 format!("{}", EmbedStatus::Skipped("oversize".to_string())),
1644 "skipped: oversize"
1645 );
1646 assert_eq!(
1647 format!("{}", EmbedStatus::Failed("timeout".to_string())),
1648 "failed: timeout"
1649 );
1650 }
1651
1652 #[test]
1653 fn embedding_format_error_display_each_variant() {
1654 let unk = EmbeddingFormatError::UnknownHeader(0xab);
1655 assert!(unk.to_string().contains("0xab"));
1656 let be = EmbeddingFormatError::BigEndianUnsupported;
1657 assert!(be.to_string().contains("big-endian"));
1658 let ml = EmbeddingFormatError::MalformedLength(7);
1659 assert!(ml.to_string().contains("7"));
1660 }
1661
1662 #[test]
1663 fn embedding_format_error_is_std_error() {
1664 let e: Box<dyn std::error::Error> = Box::new(EmbeddingFormatError::BigEndianUnsupported);
1667 assert!(e.source().is_none());
1670 }
1671
1672 #[test]
1673 fn decode_embedding_blob_empty_returns_empty_vec() {
1674 let v = decode_embedding_blob(&[]).expect("empty decodes to empty");
1675 assert!(v.is_empty());
1676 }
1677
1678 #[test]
1679 fn test_fuse_is_l2_normalized() {
1680 let primary = vec![3.0_f32, 0.0, 0.0]; let secondary = vec![0.0_f32, 4.0, 0.0]; let fused = Embedder::fuse(&primary, &secondary, 0.5);
1689 let norm = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
1691 assert!(
1693 (norm - 2.5).abs() < 1e-5,
1694 "fuse currently returns un-normalized vec; norm should be 2.5, got {norm}"
1695 );
1696
1697 let normalized: Vec<f32> = fused.iter().map(|x| x / norm).collect();
1700 let renorm = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
1701 assert!(
1702 (renorm - 1.0).abs() < 1e-5,
1703 "renormalized fused must have unit norm, got {renorm}"
1704 );
1705 let sim = Embedder::cosine_similarity(&fused, &normalized);
1707 assert!(
1708 (sim - 1.0).abs() < 1e-5,
1709 "cos(raw_fuse, normalize(raw_fuse)) must be 1.0, got {sim}"
1710 );
1711 }
1712}
1713
1714#[cfg(test)]
1715#[allow(
1716 clippy::unused_self,
1717 clippy::unnecessary_wraps,
1718 clippy::needless_pass_by_value,
1719 clippy::wildcard_imports
1720)]
1721pub mod test_support {
1722 use super::*;
1723
1724 pub enum MockEmbedder {
1727 Local,
1729 Ollama,
1731 }
1732
1733 impl MockEmbedder {
1734 pub fn new_local() -> Result<Self> {
1736 Ok(Self::Local)
1737 }
1738
1739 pub fn new_ollama() -> Self {
1741 Self::Ollama
1742 }
1743
1744 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
1746 let dim = match self {
1747 Self::Local => MINILM_DIM,
1748 Self::Ollama => NOMIC_DIM,
1749 };
1750 let hash = text.bytes().fold(0u32, |acc, b| {
1751 acc.wrapping_mul(31).wrapping_add(u32::from(b))
1752 });
1753 let base = ((hash % 1000) as f32) / 1000.0;
1754 let embedding: Vec<f32> = (0..dim)
1755 .map(|i| base + ((i as f32) * 0.0001).sin().abs())
1756 .collect();
1757 Ok(embedding)
1758 }
1759
1760 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1762 texts.iter().map(|t| self.embed(t)).collect()
1763 }
1764
1765 pub fn dim(&self) -> usize {
1767 match self {
1768 Self::Local => MINILM_DIM,
1769 Self::Ollama => NOMIC_DIM,
1770 }
1771 }
1772
1773 pub fn model_description(&self) -> &str {
1775 match self {
1776 Self::Local => "mock-all-MiniLM-L6-v2 (384-dim, local)",
1777 Self::Ollama => "mock-nomic-embed-text-v1.5 (768-dim, Ollama)",
1778 }
1779 }
1780 }
1781
1782 impl Embed for MockEmbedder {
1787 fn embed(&self, text: &str) -> Result<Vec<f32>> {
1788 Self::embed(self, text)
1789 }
1790
1791 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1792 Self::embed_batch(self, texts)
1793 }
1794 }
1795
1796 pub struct FailingEmbedder;
1805
1806 impl Embed for FailingEmbedder {
1807 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
1808 Err(anyhow::anyhow!("test: synthetic embed failure"))
1809 }
1810
1811 fn embed_batch(&self, _texts: &[&str]) -> Result<Vec<Vec<f32>>> {
1812 Err(anyhow::anyhow!("test: synthetic embed_batch failure"))
1813 }
1814 }
1815}
1816
1817#[cfg(test)]
1818mod mock_tests {
1819 use super::test_support::*;
1820 use super::*;
1821
1822 #[test]
1823 fn mock_local_new() {
1824 let embedder = MockEmbedder::new_local();
1825 assert!(embedder.is_ok());
1826 }
1827
1828 #[test]
1829 fn mock_ollama_new() {
1830 let embedder = MockEmbedder::new_ollama();
1831 match embedder {
1832 MockEmbedder::Ollama => {}
1833 _ => panic!("expected Ollama variant"),
1834 }
1835 }
1836
1837 #[test]
1838 fn mock_local_dim() {
1839 let embedder = MockEmbedder::new_local().unwrap();
1840 assert_eq!(embedder.dim(), MINILM_DIM);
1841 }
1842
1843 #[test]
1844 fn mock_ollama_dim() {
1845 let embedder = MockEmbedder::new_ollama();
1846 assert_eq!(embedder.dim(), NOMIC_DIM);
1847 }
1848
1849 #[test]
1850 fn mock_embed_local_deterministic() {
1851 let embedder = MockEmbedder::new_local().unwrap();
1852 let e1 = embedder.embed("test").unwrap();
1853 let e2 = embedder.embed("test").unwrap();
1854 assert_eq!(e1, e2);
1855 }
1856
1857 #[test]
1858 fn mock_embed_local_dimension() {
1859 let embedder = MockEmbedder::new_local().unwrap();
1860 let embedding = embedder.embed("hello world").unwrap();
1861 assert_eq!(embedding.len(), MINILM_DIM);
1862 }
1863
1864 #[test]
1865 fn mock_embed_ollama_dimension() {
1866 let embedder = MockEmbedder::new_ollama();
1867 let embedding = embedder.embed("hello world").unwrap();
1868 assert_eq!(embedding.len(), NOMIC_DIM);
1869 }
1870
1871 #[test]
1872 fn mock_embed_batch_local() {
1873 let embedder = MockEmbedder::new_local().unwrap();
1874 let texts = vec!["text1", "text2", "text3"];
1875 let embeddings = embedder.embed_batch(&texts).unwrap();
1876 assert_eq!(embeddings.len(), 3);
1877 for emb in embeddings {
1878 assert_eq!(emb.len(), MINILM_DIM);
1879 }
1880 }
1881
1882 #[test]
1883 fn mock_embed_batch_ollama() {
1884 let embedder = MockEmbedder::new_ollama();
1885 let texts = vec!["text1", "text2"];
1886 let embeddings = embedder.embed_batch(&texts).unwrap();
1887 assert_eq!(embeddings.len(), 2);
1888 for emb in embeddings {
1889 assert_eq!(emb.len(), NOMIC_DIM);
1890 }
1891 }
1892
1893 #[test]
1894 fn mock_local_model_description() {
1895 let embedder = MockEmbedder::new_local().unwrap();
1896 let desc = embedder.model_description();
1897 assert!(desc.contains("MiniLM"));
1898 assert!(desc.contains("384"));
1899 }
1900
1901 #[test]
1902 fn mock_ollama_model_description() {
1903 let embedder = MockEmbedder::new_ollama();
1904 let desc = embedder.model_description();
1905 assert!(desc.contains("nomic"));
1906 assert!(desc.contains("768"));
1907 }
1908
1909 #[test]
1910 fn mock_embed_different_texts_different_vectors() {
1911 let embedder = MockEmbedder::new_local().unwrap();
1912 let e1 = embedder.embed("text one").unwrap();
1913 let e2 = embedder.embed("text two").unwrap();
1914 assert_ne!(e1[0], e2[0]);
1916 }
1917}
1918
1919#[test]
1920fn cache_evicts_least_recently_used() {
1921 let v1 = vec![1.0, 2.0, 3.0];
1926 let v2 = vec![4.0, 5.0, 6.0];
1927 let sim = Embedder::cosine_similarity(&v1, &v2);
1928 let expected = 32.0 / (14.0_f32.sqrt() * 77.0_f32.sqrt());
1931 assert!((sim - expected).abs() < 1e-5);
1932}
1933
1934#[cfg(test)]
1939mod w12h_extra_tests {
1940 use super::*;
1941
1942 #[test]
1943 fn for_model_nomic_without_ollama_client_errors() {
1944 let res = Embedder::for_model(EmbeddingModel::NomicEmbedV15, None);
1946 match res {
1947 Err(e) => {
1948 let err = e.to_string();
1949 assert!(
1950 err.contains("Ollama") || err.contains("nomic"),
1951 "expected ollama error msg, got: {err}"
1952 );
1953 }
1954 Ok(_) => panic!("expected NomicEmbedV15 without client to error"),
1955 }
1956 }
1957
1958 #[test]
1959 fn cosine_similarity_both_zero_returns_zero() {
1960 let a = vec![0.0_f32; 3];
1961 let b = vec![0.0_f32; 3];
1962 let sim = Embedder::cosine_similarity(&a, &b);
1963 assert_eq!(sim, 0.0);
1965 }
1966
1967 #[test]
1968 fn cosine_similarity_negative_values() {
1969 let a = vec![1.0_f32, 2.0, 3.0];
1970 let b = vec![-1.0_f32, -2.0, -3.0];
1971 let sim = Embedder::cosine_similarity(&a, &b);
1972 assert!((sim + 1.0).abs() < 1e-6);
1973 }
1974
1975 #[test]
1976 fn cosine_similarity_empty_vectors() {
1977 let a: Vec<f32> = vec![];
1978 let b: Vec<f32> = vec![];
1979 let sim = Embedder::cosine_similarity(&a, &b);
1980 assert_eq!(sim, 0.0);
1982 }
1983
1984 #[test]
1985 fn fuse_zero_weight_returns_pure_secondary() {
1986 let p = vec![1.0_f32, 0.0];
1987 let s = vec![0.0_f32, 1.0];
1988 let f = Embedder::fuse(&p, &s, 0.0);
1989 assert!((f[0] - 0.0).abs() < 1e-6);
1990 assert!((f[1] - 1.0).abs() < 1e-6);
1991 }
1992
1993 #[test]
1994 fn fuse_empty_vectors_returns_empty() {
1995 let p: Vec<f32> = vec![];
1996 let s: Vec<f32> = vec![];
1997 let f = Embedder::fuse(&p, &s, 0.5);
1998 assert!(f.is_empty());
1999 }
2000
2001 #[test]
2002 fn embedding_dim_constant_pinned() {
2003 assert_eq!(EMBEDDING_DIM, MINILM_DIM);
2004 assert_eq!(MINILM_DIM, 384);
2005 assert_eq!(NOMIC_DIM, 768);
2006 }
2007
2008 #[test]
2009 fn fuse_dimension_mismatch_secondary_longer() {
2010 let p = vec![1.0_f32, 2.0];
2013 let s = vec![3.0_f32, 4.0, 5.0]; let f = Embedder::fuse(&p, &s, 0.5);
2015 assert_eq!(f, p);
2016 }
2017
2018 #[test]
2019 fn cosine_similarity_dimension_mismatch_inverse() {
2020 let a = vec![1.0_f32, 0.0];
2022 let b = vec![1.0_f32, 0.0, 0.0];
2023 let sim = Embedder::cosine_similarity(&a, &b);
2024 assert_eq!(sim, 0.0);
2025 }
2026
2027 #[test]
2028 fn pr9i_for_model_minilm_dispatches_to_new_local() {
2029 let res = Embedder::for_model(EmbeddingModel::MiniLmL6V2, None);
2035 match res {
2036 Ok(e) => {
2037 assert_eq!(e.dim(), 384);
2039 let desc = e.model_description();
2040 assert!(desc.contains("MiniLM"));
2041 }
2042 Err(e) => {
2043 let msg = e.to_string();
2045 assert!(
2046 msg.contains("model")
2047 || msg.contains("config")
2048 || msg.contains("tokenizer")
2049 || msg.contains("fallback")
2050 || msg.contains("HuggingFace"),
2051 "unexpected new_local error: {msg}"
2052 );
2053 }
2054 }
2055 }
2056
2057 #[test]
2058 fn pr9i_embedder_new_alias_is_new_local() {
2059 let res = Embedder::new();
2062 match res {
2063 Ok(e) => {
2064 assert_eq!(e.dim(), 384);
2065 }
2066 Err(e) => {
2067 let msg = e.to_string();
2068 assert!(!msg.is_empty());
2069 }
2070 }
2071 }
2072}
2073
2074#[test]
2075fn embedder_returns_unreachable_when_model_path_missing() {
2076 let result = Embedder::load_from_fallback();
2079 match result {
2082 Ok(_) => {
2083 }
2085 Err(e) => {
2086 let err_msg = e.to_string();
2088 assert!(
2089 err_msg.contains("not found") || err_msg.contains("fallback"),
2090 "error should mention missing model files: {err_msg}"
2091 );
2092 }
2093 }
2094}
2095
2096#[test]
2097fn load_from_fallback_succeeds_when_files_present() {
2098 use std::sync::Mutex;
2103 static LOCK: Mutex<()> = Mutex::new(());
2106 let _guard = LOCK
2107 .lock()
2108 .unwrap_or_else(std::sync::PoisonError::into_inner);
2109
2110 let tmp = std::env::temp_dir().join(format!("ai-memory-w12h-fallback-{}", std::process::id()));
2111 let model_dir = tmp.join(
2112 ".cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main",
2113 );
2114 std::fs::create_dir_all(&model_dir).expect("mk model dir");
2115 for name in ["config.json", "tokenizer.json", "model.safetensors"] {
2116 std::fs::write(model_dir.join(name), b"{}").expect("write placeholder");
2117 }
2118 let prev = std::env::var("HOME").ok();
2119 unsafe {
2121 std::env::set_var("HOME", &tmp);
2122 }
2123 let result = Embedder::load_from_fallback();
2124 unsafe {
2126 match prev {
2127 Some(p) => std::env::set_var("HOME", p),
2128 None => std::env::remove_var("HOME"),
2129 }
2130 }
2131 let _ = std::fs::remove_dir_all(&tmp);
2132 let (cfg, tok, w) = result.expect("placeholder files satisfy load_from_fallback");
2133 assert!(cfg.ends_with("config.json"));
2134 assert!(tok.ends_with("tokenizer.json"));
2135 assert!(w.ends_with("model.safetensors"));
2136}
2137
2138#[test]
2139fn offline_env_skips_network_and_errors_fast_on_empty_cache() {
2140 use std::sync::Mutex;
2145 static LOCK: Mutex<()> = Mutex::new(());
2148 let _guard = LOCK
2149 .lock()
2150 .unwrap_or_else(std::sync::PoisonError::into_inner);
2151
2152 let tmp = std::env::temp_dir().join(format!(
2153 "ai-memory-1501-offline-{}-{}",
2154 std::process::id(),
2155 uuid::Uuid::new_v4()
2156 ));
2157 std::fs::create_dir_all(&tmp).expect("mk empty home");
2158 let prev_home = std::env::var("HOME").ok();
2159 let prev_off = std::env::var("AI_MEMORY_EMBED_OFFLINE").ok();
2160 unsafe {
2162 std::env::set_var("HOME", &tmp);
2163 std::env::set_var("AI_MEMORY_EMBED_OFFLINE", "1");
2164 }
2165 assert!(
2166 Embedder::remote_fetch_disabled(),
2167 "offline knob must be honored"
2168 );
2169 let result = Embedder::new_local();
2170 unsafe {
2172 match prev_home {
2173 Some(p) => std::env::set_var("HOME", p),
2174 None => std::env::remove_var("HOME"),
2175 }
2176 match prev_off {
2177 Some(v) => std::env::set_var("AI_MEMORY_EMBED_OFFLINE", v),
2178 None => std::env::remove_var("AI_MEMORY_EMBED_OFFLINE"),
2179 }
2180 }
2181 let _ = std::fs::remove_dir_all(&tmp);
2182 let msg = match result {
2183 Ok(_) => panic!("empty cache + offline must error (degrades to keyword)"),
2184 Err(e) => e.to_string(),
2185 };
2186 assert!(
2187 msg.contains("not found") || msg.contains("fallback"),
2188 "offline empty-cache error should point at the fallback dir: {msg}"
2189 );
2190}
2191
2192#[cfg(test)]
2201#[allow(clippy::too_many_lines)]
2202mod c5_ollama_variant_tests {
2203 use super::*;
2204 use crate::llm::OllamaClient;
2205 use serde_json::json;
2206 use std::sync::Arc;
2207 use wiremock::matchers::{method, path};
2208 use wiremock::{Mock, MockServer, ResponseTemplate};
2209
2210 async fn ollama_with_embed_response(embedding_dim: usize) -> (Arc<OllamaClient>, MockServer) {
2215 let server = MockServer::start().await;
2216 Mock::given(method("GET"))
2219 .and(path("/api/tags"))
2220 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
2221 .mount(&server)
2222 .await;
2223 Mock::given(method("POST"))
2225 .and(path("/api/pull"))
2226 .respond_with(ResponseTemplate::new(200).set_body_string(""))
2227 .mount(&server)
2228 .await;
2229 let vec_of_floats: Vec<f32> = (0..embedding_dim).map(|i| (i as f32) * 0.001).collect();
2231 Mock::given(method("POST"))
2232 .and(path("/api/embed"))
2233 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
2234 "embeddings": [vec_of_floats],
2235 })))
2236 .mount(&server)
2237 .await;
2238
2239 let uri = server.uri();
2240 let client = tokio::task::spawn_blocking(move || {
2241 OllamaClient::new_with_url(&uri, "test-model").expect("ollama client builds")
2242 })
2243 .await
2244 .expect("spawn blocking completes");
2245 (Arc::new(client), server)
2246 }
2247
2248 #[tokio::test(flavor = "multi_thread")]
2249 async fn embedder_new_ollama_constructs_with_expected_model_name() {
2250 let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
2252 let embedder = Embedder::new_ollama(client);
2253 assert!(matches!(embedder, Embedder::Ollama { .. }));
2254 }
2255
2256 #[tokio::test(flavor = "multi_thread")]
2257 async fn embedder_for_model_nomic_with_client_succeeds() {
2258 let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
2262 let embedder = tokio::task::spawn_blocking(move || {
2263 Embedder::for_model(EmbeddingModel::NomicEmbedV15, Some(client))
2264 .expect("for_model NomicEmbedV15 with ollama client")
2265 })
2266 .await
2267 .unwrap();
2268 assert!(matches!(embedder, Embedder::Ollama { .. }));
2269 assert_eq!(embedder.dim(), NOMIC_DIM); let desc = embedder.model_description();
2271 assert!(desc.contains("nomic")); assert!(desc.contains("768"));
2273 }
2274
2275 #[tokio::test(flavor = "multi_thread")]
2276 async fn embedder_ollama_embed_returns_vector_from_wiremock() {
2277 let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
2281 let embedder = Embedder::new_ollama(client);
2282 let v = tokio::task::spawn_blocking(move || embedder.embed("hello"))
2283 .await
2284 .unwrap()
2285 .expect("embed_text via wiremock");
2286 assert_eq!(v.len(), NOMIC_DIM);
2287 }
2288
2289 #[tokio::test(flavor = "multi_thread")]
2290 async fn embed_with_status_skipped_on_empty_content() {
2291 let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
2293 let embedder = Embedder::new_ollama(client);
2294 let (vec_opt, status) = embedder.embed_with_status("");
2295 assert!(vec_opt.is_none());
2296 assert!(matches!(status, EmbedStatus::Skipped(_)));
2297 assert_eq!(status.as_str(), "skipped");
2298 assert!(status.reason().contains("empty"));
2299 }
2300
2301 #[tokio::test(flavor = "multi_thread")]
2302 async fn embed_with_status_skipped_on_oversized_content() {
2303 let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
2305 let embedder = Embedder::new_ollama(client);
2306 let big = "a".repeat(EMBED_MAX_BYTES + 1);
2307 let (vec_opt, status) = embedder.embed_with_status(&big);
2308 assert!(vec_opt.is_none());
2309 match status {
2310 EmbedStatus::Skipped(r) => {
2311 assert!(r.contains("exceeds embed cap"), "got: {r}");
2312 }
2313 other => panic!("expected Skipped, got: {other:?}"),
2314 }
2315 }
2316
2317 #[tokio::test(flavor = "multi_thread")]
2318 async fn embed_with_status_indexed_on_happy_path() {
2319 let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
2321 let embedder = Embedder::new_ollama(client);
2322 let (vec_opt, status) =
2323 tokio::task::spawn_blocking(move || embedder.embed_with_status("hello world"))
2324 .await
2325 .unwrap();
2326 assert!(vec_opt.is_some());
2327 assert_eq!(status, EmbedStatus::Indexed);
2328 assert!(!status.is_degraded());
2329 assert_eq!(vec_opt.unwrap().len(), NOMIC_DIM);
2330 }
2331
2332 #[tokio::test(flavor = "multi_thread")]
2333 async fn embed_with_status_failed_when_embedder_errors() {
2334 let server = MockServer::start().await;
2337 Mock::given(method("GET"))
2338 .and(path("/api/tags"))
2339 .respond_with(ResponseTemplate::new(200).set_body_json(json!({"models": []})))
2340 .mount(&server)
2341 .await;
2342 Mock::given(method("POST"))
2343 .and(path("/api/embed"))
2344 .respond_with(ResponseTemplate::new(500).set_body_string("server error"))
2345 .mount(&server)
2346 .await;
2347 let uri = server.uri();
2348 let embedder = tokio::task::spawn_blocking(move || {
2349 let client = OllamaClient::new_with_url(&uri, "test-model").unwrap();
2350 Embedder::new_ollama(Arc::new(client))
2351 })
2352 .await
2353 .unwrap();
2354
2355 let (vec_opt, status) =
2356 tokio::task::spawn_blocking(move || embedder.embed_with_status("hello"))
2357 .await
2358 .unwrap();
2359 assert!(vec_opt.is_none());
2360 match status {
2361 EmbedStatus::Failed(reason) => {
2362 assert!(!reason.is_empty());
2363 }
2364 other => panic!("expected Failed(_), got {other:?}"),
2365 }
2366 }
2367
2368 #[test]
2369 fn perf_5_embed_batch_empty_input_returns_empty_vec() {
2370 use super::test_support::MockEmbedder;
2378 let mock = MockEmbedder::new_local().expect("mock local");
2379 let result = mock.embed_batch(&[]).expect("empty batch ok");
2380 assert!(
2381 result.is_empty(),
2382 "PERF-5: empty input must yield empty output (got {} rows)",
2383 result.len(),
2384 );
2385 }
2386
2387 #[tokio::test(flavor = "multi_thread")]
2388 async fn embed_batch_via_inherent_impl_returns_one_vec_per_input() {
2389 let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
2391 let embedder = Embedder::new_ollama(client);
2392 let vecs =
2393 tokio::task::spawn_blocking(move || embedder.embed_batch(&["one", "two", "three"]))
2394 .await
2395 .unwrap()
2396 .expect("batch embed succeeds");
2397 assert_eq!(vecs.len(), 3);
2398 for v in &vecs {
2399 assert_eq!(v.len(), NOMIC_DIM);
2400 }
2401 }
2402
2403 #[tokio::test(flavor = "multi_thread")]
2404 async fn embed_trait_for_embedder_delegates_to_inherent_impl() {
2405 let (client, _server) = ollama_with_embed_response(NOMIC_DIM).await;
2407 let embedder = Embedder::new_ollama(client);
2408 let embedder_box: Box<dyn Embed> = Box::new(embedder);
2409 let single = tokio::task::spawn_blocking({
2410 let e = embedder_box;
2411 move || {
2412 let single = e.embed("alpha").expect("single embed");
2413 let batch = e.embed_batch(&["beta", "gamma"]).expect("batch embed");
2414 (single, batch)
2415 }
2416 })
2417 .await
2418 .unwrap();
2419 let (single, batch) = single;
2420 assert_eq!(single.len(), NOMIC_DIM);
2421 assert_eq!(batch.len(), 2);
2422 for v in &batch {
2423 assert_eq!(v.len(), NOMIC_DIM);
2424 }
2425 }
2426
2427 #[test]
2428 fn embed_trait_default_batch_default_impl_runs_for_external_impls() {
2429 struct ConstEmbedder;
2433 impl Embed for ConstEmbedder {
2434 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
2435 Ok(vec![1.0_f32, 2.0_f32, 3.0_f32])
2436 }
2437 }
2439 let e = ConstEmbedder;
2440 let batch = e.embed_batch(&["a", "b"]).expect("default batch path");
2441 assert_eq!(batch.len(), 2);
2442 assert_eq!(batch[0], vec![1.0_f32, 2.0_f32, 3.0_f32]);
2443 assert_eq!(batch[1], vec![1.0_f32, 2.0_f32, 3.0_f32]);
2444 }
2445
2446 #[test]
2449 fn download_within_times_out_on_stalled_closure() {
2450 let start = std::time::Instant::now();
2451 let res = Embedder::download_within(std::time::Duration::from_millis(50), || {
2452 std::thread::sleep(std::time::Duration::from_secs(30));
2455 Ok((
2456 std::path::PathBuf::new(),
2457 std::path::PathBuf::new(),
2458 std::path::PathBuf::new(),
2459 ))
2460 });
2461 let elapsed = start.elapsed();
2462 assert!(res.is_err(), "stalled download must error, not hang");
2463 assert!(
2464 res.unwrap_err().to_string().contains("budget"),
2465 "error should explain the timeout budget"
2466 );
2467 assert!(
2468 elapsed < std::time::Duration::from_secs(5),
2469 "watchdog must return promptly after the budget, not wait for the closure: {elapsed:?}"
2470 );
2471 }
2472
2473 #[test]
2476 fn download_within_passes_through_fast_result() {
2477 let res = Embedder::download_within(std::time::Duration::from_secs(5), || {
2478 Ok((
2479 std::path::PathBuf::from("config.json"),
2480 std::path::PathBuf::from("tokenizer.json"),
2481 std::path::PathBuf::from("model.safetensors"),
2482 ))
2483 })
2484 .expect("fast closure must pass through");
2485 assert_eq!(res.0, std::path::PathBuf::from("config.json"));
2486 assert_eq!(res.2, std::path::PathBuf::from("model.safetensors"));
2487 }
2488}