1use anyhow::{Context, Result};
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config};
8use hf_hub::{Repo, RepoType, api::sync::Api};
9use std::sync::{Arc, Mutex};
10use tokenizers::Tokenizer;
11
12use crate::config::EmbeddingModel;
13
14const MINILM_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";
15#[allow(dead_code)]
16const MINILM_DIM: usize = 384;
17const MAX_SEQ_LEN: usize = 256;
18const FALLBACK_MODEL_SUBDIR: &str =
20 ".cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main";
21
22const NOMIC_OLLAMA_MODEL: &str = "nomic-embed-text";
24#[allow(dead_code)]
25const NOMIC_DIM: usize = 768;
26
27#[derive(Clone)]
32pub enum Embedder {
33 Local {
35 model: Arc<Mutex<BertModel>>,
36 tokenizer: Arc<Tokenizer>,
37 device: Device,
38 },
39 Ollama {
41 client: Arc<crate::llm::OllamaClient>,
42 model_name: String,
43 },
44}
45
46impl Embedder {
47 #[allow(dead_code)]
50 pub fn new() -> Result<Self> {
51 Self::new_local()
52 }
53
54 pub fn new_local() -> Result<Self> {
56 let device = Device::Cpu;
57
58 let (config_path, tokenizer_path, weights_path) = match Self::download_via_hf_hub() {
59 Ok(paths) => paths,
60 Err(e) => {
61 eprintln!("ai-memory: hf-hub download failed ({e}), trying fallback dir");
62 Self::load_from_fallback()?
63 }
64 };
65
66 let config_data =
67 std::fs::read_to_string(&config_path).context("failed to read config.json")?;
68 let config: Config =
69 serde_json::from_str(&config_data).context("failed to parse config.json")?;
70
71 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
72 .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
73
74 let truncation = tokenizers::TruncationParams {
75 max_length: MAX_SEQ_LEN,
76 ..Default::default()
77 };
78 tokenizer
79 .with_truncation(Some(truncation))
80 .map_err(|e| anyhow::anyhow!("failed to set truncation: {e}"))?;
81 tokenizer.with_padding(None);
82
83 let vb = unsafe {
84 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
85 .context("failed to load model weights")?
86 };
87 let model = BertModel::load(vb, &config).context("failed to build BertModel")?;
88
89 Ok(Self::Local {
90 model: Arc::new(Mutex::new(model)),
91 tokenizer: Arc::new(tokenizer),
92 device,
93 })
94 }
95
96 pub fn new_ollama(client: Arc<crate::llm::OllamaClient>) -> Self {
100 Self::Ollama {
101 client,
102 model_name: NOMIC_OLLAMA_MODEL.to_string(),
103 }
104 }
105
106 pub fn for_model(
111 model: EmbeddingModel,
112 ollama_client: Option<Arc<crate::llm::OllamaClient>>,
113 ) -> Result<Self> {
114 match model {
115 EmbeddingModel::MiniLmL6V2 => Self::new_local(),
116 EmbeddingModel::NomicEmbedV15 => {
117 let client = ollama_client.ok_or_else(|| {
118 anyhow::anyhow!("nomic-embed-text-v1.5 requires Ollama (smart tier or above)")
119 })?;
120 if let Err(e) = client.ensure_embed_model(NOMIC_OLLAMA_MODEL) {
122 eprintln!("ai-memory: warning: failed to pull nomic model: {e}");
123 }
124 Ok(Self::new_ollama(client))
125 }
126 }
127 }
128
129 #[allow(dead_code)]
131 pub fn dim(&self) -> usize {
132 match self {
133 Self::Local { .. } => MINILM_DIM,
134 Self::Ollama { .. } => NOMIC_DIM,
135 }
136 }
137
138 pub fn model_description(&self) -> &str {
140 match self {
141 Self::Local { .. } => "all-MiniLM-L6-v2 (384-dim, local)",
142 Self::Ollama { .. } => "nomic-embed-text-v1.5 (768-dim, Ollama)",
143 }
144 }
145
146 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
148 match self {
149 Self::Local {
150 model,
151 tokenizer,
152 device,
153 } => {
154 let model_guard = model
155 .lock()
156 .map_err(|e| anyhow::anyhow!("model lock poisoned: {e}"))?;
157 Self::embed_local(&model_guard, tokenizer, device, text)
158 }
159 Self::Ollama { client, model_name } => client.embed_text(text, model_name),
160 }
161 }
162
163 fn embed_local(
164 model: &BertModel,
165 tokenizer: &Tokenizer,
166 device: &Device,
167 text: &str,
168 ) -> Result<Vec<f32>> {
169 let encoding = tokenizer
170 .encode(text, true)
171 .map_err(|e| anyhow::anyhow!("tokenisation failed: {e}"))?;
172
173 let input_ids = encoding.get_ids();
174 let attention_mask = encoding.get_attention_mask();
175 let token_type_ids = encoding.get_type_ids();
176 let seq_len = input_ids.len();
177
178 let input_ids = Tensor::new(input_ids, device)?.reshape((1, seq_len))?;
179 let attention_mask_tensor = Tensor::new(attention_mask, device)?.reshape((1, seq_len))?;
180 let token_type_ids = Tensor::new(token_type_ids, device)?.reshape((1, seq_len))?;
181
182 let hidden = model
183 .forward(&input_ids, &token_type_ids, Some(&attention_mask_tensor))
184 .context("model forward pass failed")?;
185
186 let mask = attention_mask_tensor
187 .unsqueeze(2)?
188 .to_dtype(candle_core::DType::F32)?
189 .broadcast_as(hidden.shape())?;
190 let masked = hidden.mul(&mask)?;
191 let summed = masked.sum(1)?;
192 let count = mask.sum(1)?.clamp(1e-9, f64::MAX)?;
193 let pooled = summed.div(&count)?;
194
195 let norm = pooled
196 .sqr()?
197 .sum_keepdim(1)?
198 .sqrt()?
199 .clamp(1e-12, f64::MAX)?;
200 let normalised = pooled.broadcast_div(&norm)?;
201
202 let embedding: Vec<f32> = normalised.squeeze(0)?.to_vec1()?;
203 Ok(embedding)
204 }
205
206 #[allow(dead_code)]
208 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
209 texts.iter().map(|t| self.embed(t)).collect()
210 }
211
212 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
214 if a.len() != b.len() {
216 return 0.0;
217 }
218
219 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
220 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
221 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
222 let denom = norm_a * norm_b;
223 if denom < 1e-12 { 0.0 } else { dot / denom }
224 }
225
226 #[must_use]
235 pub fn fuse(primary: &[f32], secondary: &[f32], primary_weight: f32) -> Vec<f32> {
236 if primary.len() != secondary.len() {
237 return primary.to_vec();
238 }
239 let w = primary_weight.clamp(0.0, 1.0);
240 let one_minus_w = 1.0 - w;
241 primary
242 .iter()
243 .zip(secondary.iter())
244 .map(|(p, s)| w * p + one_minus_w * s)
245 .collect()
246 }
247
248 fn download_via_hf_hub() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
249 {
250 let api = Api::new().context("failed to initialise HuggingFace Hub API")?;
251 let repo = api.repo(Repo::new(MINILM_MODEL_ID.to_string(), RepoType::Model));
252 let config_path = repo
253 .get("config.json")
254 .context("failed to download config.json")?;
255 let tokenizer_path = repo
256 .get("tokenizer.json")
257 .context("failed to download tokenizer.json")?;
258 let weights_path = repo
259 .get("model.safetensors")
260 .context("failed to download model.safetensors")?;
261 Ok((config_path, tokenizer_path, weights_path))
262 }
263
264 fn load_from_fallback() -> Result<(std::path::PathBuf, std::path::PathBuf, std::path::PathBuf)>
265 {
266 let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string());
267 let dir = std::path::PathBuf::from(home).join(FALLBACK_MODEL_SUBDIR);
268 let dir = dir.as_path();
269 let config = dir.join("config.json");
270 let tokenizer = dir.join("tokenizer.json");
271 let weights = dir.join("model.safetensors");
272 if config.exists() && tokenizer.exists() && weights.exists() {
273 Ok((config, tokenizer, weights))
274 } else {
275 anyhow::bail!(
276 "model files not found in fallback dir: {}. Download them manually from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2",
277 dir.display()
278 )
279 }
280 }
281}
282
283#[allow(dead_code)]
285pub const EMBEDDING_DIM: usize = MINILM_DIM;
286
287pub const EMBEDDING_HEADER_LE_F32: u8 = 0x01;
304
305pub const EMBEDDING_HEADER_BE_F32: u8 = 0x02;
308
309#[derive(Debug)]
322pub enum EmbeddingFormatError {
323 UnknownHeader(u8),
324 BigEndianUnsupported,
325 MalformedLength(usize),
326}
327
328impl std::fmt::Display for EmbeddingFormatError {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 match self {
331 Self::UnknownHeader(b) => write!(f, "unknown embedding header byte: 0x{b:02x}"),
332 Self::BigEndianUnsupported => write!(
333 f,
334 "big-endian f32 embeddings (header 0x02) are not supported until v0.7"
335 ),
336 Self::MalformedLength(n) => {
337 write!(f, "embedding payload length {n} is not a multiple of 4")
338 }
339 }
340 }
341}
342
343impl std::error::Error for EmbeddingFormatError {}
344
345#[must_use]
353pub fn encode_embedding_blob(embedding: &[f32]) -> Vec<u8> {
354 let mut out = Vec::with_capacity(1 + embedding.len() * 4);
355 out.push(EMBEDDING_HEADER_LE_F32);
356 for f in embedding {
357 out.extend_from_slice(&f.to_le_bytes());
358 }
359 out
360}
361
362pub fn decode_embedding_blob(bytes: &[u8]) -> Result<Vec<f32>, EmbeddingFormatError> {
378 if bytes.is_empty() {
379 return Ok(Vec::new());
380 }
381
382 if bytes.len() % 4 == 1 {
384 let header = bytes[0];
385 return match header {
386 EMBEDDING_HEADER_LE_F32 => {
387 let payload = &bytes[1..];
388 Ok(payload
389 .chunks_exact(4)
390 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
391 .collect())
392 }
393 EMBEDDING_HEADER_BE_F32 => Err(EmbeddingFormatError::BigEndianUnsupported),
394 other => Err(EmbeddingFormatError::UnknownHeader(other)),
395 };
396 }
397
398 if bytes.len() % 4 == 0 {
400 return Ok(bytes
401 .chunks_exact(4)
402 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
403 .collect());
404 }
405
406 Err(EmbeddingFormatError::MalformedLength(bytes.len()))
407}
408
409#[must_use]
413pub fn decoded_dim(bytes: &[u8]) -> usize {
414 if bytes.is_empty() {
415 return 0;
416 }
417 if bytes.len() % 4 == 1 {
418 return (bytes.len() - 1) / 4;
419 }
420 bytes.len() / 4
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn cosine_similarity_identical() {
429 let v = vec![1.0, 0.0, 0.0];
430 let sim = Embedder::cosine_similarity(&v, &v);
431 assert!((sim - 1.0).abs() < 1e-6);
432 }
433
434 #[test]
435 fn cosine_similarity_orthogonal() {
436 let a = vec![1.0, 0.0, 0.0];
437 let b = vec![0.0, 1.0, 0.0];
438 let sim = Embedder::cosine_similarity(&a, &b);
439 assert!(sim.abs() < 1e-6);
440 }
441
442 #[test]
443 fn cosine_similarity_opposite() {
444 let a = vec![1.0, 0.0];
445 let b = vec![-1.0, 0.0];
446 let sim = Embedder::cosine_similarity(&a, &b);
447 assert!((sim + 1.0).abs() < 1e-6);
448 }
449
450 #[test]
451 fn cosine_similarity_zero_vector() {
452 let a = vec![0.0, 0.0, 0.0];
453 let b = vec![1.0, 2.0, 3.0];
454 let sim = Embedder::cosine_similarity(&a, &b);
455 assert_eq!(sim, 0.0);
456 }
457
458 #[test]
459 fn cosine_similarity_dimension_mismatch() {
460 let a = vec![1.0, 0.0, 0.0];
461 let b = vec![1.0, 0.0]; let sim = Embedder::cosine_similarity(&a, &b);
463 assert_eq!(sim, 0.0);
464 }
465
466 #[test]
469 fn encode_embedding_blob_prefixes_le_header() {
470 let v = vec![1.0_f32, 2.0_f32];
471 let blob = encode_embedding_blob(&v);
472 assert_eq!(blob.len(), 1 + 8);
473 assert_eq!(blob[0], EMBEDDING_HEADER_LE_F32);
474 }
475
476 #[test]
477 fn decode_embedding_blob_round_trip_v17() {
478 let v = vec![1.5_f32, -0.25, 0.0];
479 let blob = encode_embedding_blob(&v);
480 let back = decode_embedding_blob(&blob).expect("round-trips");
481 assert_eq!(back, v);
482 }
483
484 #[test]
485 fn decode_embedding_blob_legacy_unheaded_le_f32() {
486 let v = vec![1.0_f32, 2.0, 3.0];
488 let raw: Vec<u8> = v.iter().flat_map(|f| f.to_le_bytes()).collect();
489 let back = decode_embedding_blob(&raw).expect("legacy decodes");
490 assert_eq!(back, v);
491 }
492
493 #[test]
494 fn decode_embedding_blob_rejects_be_header() {
495 let mut blob = vec![EMBEDDING_HEADER_BE_F32];
496 blob.extend_from_slice(&1.0_f32.to_be_bytes());
497 let err = decode_embedding_blob(&blob).expect_err("BE rejected");
498 assert!(matches!(err, EmbeddingFormatError::BigEndianUnsupported));
499 }
500
501 #[test]
502 fn decode_embedding_blob_rejects_unknown_header() {
503 let mut blob = vec![0xff_u8];
504 blob.extend_from_slice(&1.0_f32.to_le_bytes());
505 let err = decode_embedding_blob(&blob).expect_err("unknown header rejected");
506 assert!(matches!(err, EmbeddingFormatError::UnknownHeader(0xff)));
507 }
508
509 #[test]
510 fn decode_embedding_blob_rejects_malformed_length() {
511 let blob = vec![0u8; 6];
513 let err = decode_embedding_blob(&blob).expect_err("malformed length rejected");
514 assert!(matches!(err, EmbeddingFormatError::MalformedLength(6)));
515 }
516
517 #[test]
518 fn decoded_dim_handles_all_three_paths() {
519 assert_eq!(decoded_dim(&[]), 0);
521 let raw: Vec<u8> = vec![0u8; 16];
523 assert_eq!(decoded_dim(&raw), 4);
524 let mut headed = vec![EMBEDDING_HEADER_LE_F32];
526 headed.extend_from_slice(&[0u8; 12]);
527 assert_eq!(decoded_dim(&headed), 3);
528 }
529
530 #[test]
533 fn fuse_weighted_sum() {
534 let p = vec![1.0, 0.0, 0.0];
535 let s = vec![0.0, 1.0, 0.0];
536 let f = Embedder::fuse(&p, &s, 0.7);
537 assert!((f[0] - 0.7).abs() < 1e-6);
538 assert!((f[1] - 0.3).abs() < 1e-6);
539 assert!((f[2] - 0.0).abs() < 1e-6);
540 }
541
542 #[test]
543 fn fuse_primary_weight_clamped() {
544 let p = vec![1.0, 1.0];
545 let s = vec![0.0, 0.0];
546 let f = Embedder::fuse(&p, &s, 2.0);
547 assert!((f[0] - 1.0).abs() < 1e-6);
549 assert!((f[1] - 1.0).abs() < 1e-6);
550
551 let f = Embedder::fuse(&p, &s, -0.5);
552 assert!((f[0] - 0.0).abs() < 1e-6);
554 assert!((f[1] - 0.0).abs() < 1e-6);
555 }
556
557 #[test]
558 fn fuse_dimension_mismatch_returns_primary() {
559 let p = vec![1.0, 2.0, 3.0];
560 let s = vec![4.0, 5.0]; let f = Embedder::fuse(&p, &s, 0.7);
562 assert_eq!(f, p);
563 }
564
565 #[test]
566 fn fuse_cosine_pulls_toward_context() {
567 let q = vec![1.0_f32, 0.0];
570 let ctx = vec![0.0_f32, 1.0];
571 let fused = Embedder::fuse(&q, &ctx, 0.7);
572 let sim_q = Embedder::cosine_similarity(&fused, &q);
574 let sim_ctx = Embedder::cosine_similarity(&fused, &ctx);
575 assert!(sim_q > sim_ctx);
576 assert!(sim_q > 0.9); assert!(sim_ctx > 0.3); }
579
580 #[test]
585 fn test_fuse_with_weight_one_returns_primary() {
586 let primary = vec![0.6_f32, -0.8, 0.0]; let secondary = vec![0.0_f32, 0.0, 1.0];
591 let fused = Embedder::fuse(&primary, &secondary, 1.0);
592 assert_eq!(fused.len(), primary.len());
593 for (i, (f, p)) in fused.iter().zip(primary.iter()).enumerate() {
594 assert!(
595 (f - p).abs() < 1e-6,
596 "fuse weight=1 idx {i}: fused {} != primary {}",
597 f,
598 p
599 );
600 }
601
602 let sim = Embedder::cosine_similarity(&fused, &primary);
605 assert!(
606 (sim - 1.0).abs() < 1e-6,
607 "cos(fuse(p,s,1.0), p) must be 1.0"
608 );
609 }
610
611 #[test]
612 fn test_fuse_is_l2_normalized() {
613 let primary = vec![3.0_f32, 0.0, 0.0]; let secondary = vec![0.0_f32, 4.0, 0.0]; let fused = Embedder::fuse(&primary, &secondary, 0.5);
622 let norm = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
624 assert!(
626 (norm - 2.5).abs() < 1e-5,
627 "fuse currently returns un-normalized vec; norm should be 2.5, got {norm}"
628 );
629
630 let normalized: Vec<f32> = fused.iter().map(|x| x / norm).collect();
633 let renorm = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
634 assert!(
635 (renorm - 1.0).abs() < 1e-5,
636 "renormalized fused must have unit norm, got {renorm}"
637 );
638 let sim = Embedder::cosine_similarity(&fused, &normalized);
640 assert!(
641 (sim - 1.0).abs() < 1e-5,
642 "cos(raw_fuse, normalize(raw_fuse)) must be 1.0, got {sim}"
643 );
644 }
645}
646
647#[cfg(test)]
648#[allow(
649 clippy::unused_self,
650 clippy::unnecessary_wraps,
651 clippy::needless_pass_by_value,
652 clippy::wildcard_imports
653)]
654pub mod test_support {
655 use super::*;
656
657 pub enum MockEmbedder {
660 Local,
662 Ollama,
664 }
665
666 impl MockEmbedder {
667 pub fn new_local() -> Result<Self> {
669 Ok(Self::Local)
670 }
671
672 pub fn new_ollama() -> Self {
674 Self::Ollama
675 }
676
677 pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
679 let dim = match self {
680 Self::Local => MINILM_DIM,
681 Self::Ollama => NOMIC_DIM,
682 };
683 let hash = text.bytes().fold(0u32, |acc, b| {
684 acc.wrapping_mul(31).wrapping_add(u32::from(b))
685 });
686 let base = ((hash % 1000) as f32) / 1000.0;
687 let embedding: Vec<f32> = (0..dim)
688 .map(|i| base + ((i as f32) * 0.0001).sin().abs())
689 .collect();
690 Ok(embedding)
691 }
692
693 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
695 texts.iter().map(|t| self.embed(t)).collect()
696 }
697
698 pub fn dim(&self) -> usize {
700 match self {
701 Self::Local => MINILM_DIM,
702 Self::Ollama => NOMIC_DIM,
703 }
704 }
705
706 pub fn model_description(&self) -> &str {
708 match self {
709 Self::Local => "mock-all-MiniLM-L6-v2 (384-dim, local)",
710 Self::Ollama => "mock-nomic-embed-text-v1.5 (768-dim, Ollama)",
711 }
712 }
713 }
714}
715
716#[cfg(test)]
717mod mock_tests {
718 use super::test_support::*;
719 use super::*;
720
721 #[test]
722 fn mock_local_new() {
723 let embedder = MockEmbedder::new_local();
724 assert!(embedder.is_ok());
725 }
726
727 #[test]
728 fn mock_ollama_new() {
729 let embedder = MockEmbedder::new_ollama();
730 match embedder {
731 MockEmbedder::Ollama => {}
732 _ => panic!("expected Ollama variant"),
733 }
734 }
735
736 #[test]
737 fn mock_local_dim() {
738 let embedder = MockEmbedder::new_local().unwrap();
739 assert_eq!(embedder.dim(), MINILM_DIM);
740 }
741
742 #[test]
743 fn mock_ollama_dim() {
744 let embedder = MockEmbedder::new_ollama();
745 assert_eq!(embedder.dim(), NOMIC_DIM);
746 }
747
748 #[test]
749 fn mock_embed_local_deterministic() {
750 let embedder = MockEmbedder::new_local().unwrap();
751 let e1 = embedder.embed("test").unwrap();
752 let e2 = embedder.embed("test").unwrap();
753 assert_eq!(e1, e2);
754 }
755
756 #[test]
757 fn mock_embed_local_dimension() {
758 let embedder = MockEmbedder::new_local().unwrap();
759 let embedding = embedder.embed("hello world").unwrap();
760 assert_eq!(embedding.len(), MINILM_DIM);
761 }
762
763 #[test]
764 fn mock_embed_ollama_dimension() {
765 let embedder = MockEmbedder::new_ollama();
766 let embedding = embedder.embed("hello world").unwrap();
767 assert_eq!(embedding.len(), NOMIC_DIM);
768 }
769
770 #[test]
771 fn mock_embed_batch_local() {
772 let embedder = MockEmbedder::new_local().unwrap();
773 let texts = vec!["text1", "text2", "text3"];
774 let embeddings = embedder.embed_batch(&texts).unwrap();
775 assert_eq!(embeddings.len(), 3);
776 for emb in embeddings {
777 assert_eq!(emb.len(), MINILM_DIM);
778 }
779 }
780
781 #[test]
782 fn mock_embed_batch_ollama() {
783 let embedder = MockEmbedder::new_ollama();
784 let texts = vec!["text1", "text2"];
785 let embeddings = embedder.embed_batch(&texts).unwrap();
786 assert_eq!(embeddings.len(), 2);
787 for emb in embeddings {
788 assert_eq!(emb.len(), NOMIC_DIM);
789 }
790 }
791
792 #[test]
793 fn mock_local_model_description() {
794 let embedder = MockEmbedder::new_local().unwrap();
795 let desc = embedder.model_description();
796 assert!(desc.contains("MiniLM"));
797 assert!(desc.contains("384"));
798 }
799
800 #[test]
801 fn mock_ollama_model_description() {
802 let embedder = MockEmbedder::new_ollama();
803 let desc = embedder.model_description();
804 assert!(desc.contains("nomic"));
805 assert!(desc.contains("768"));
806 }
807
808 #[test]
809 fn mock_embed_different_texts_different_vectors() {
810 let embedder = MockEmbedder::new_local().unwrap();
811 let e1 = embedder.embed("text one").unwrap();
812 let e2 = embedder.embed("text two").unwrap();
813 assert_ne!(e1[0], e2[0]);
815 }
816}
817
818#[test]
819fn cache_evicts_least_recently_used() {
820 let v1 = vec![1.0, 2.0, 3.0];
825 let v2 = vec![4.0, 5.0, 6.0];
826 let sim = Embedder::cosine_similarity(&v1, &v2);
827 let expected = 32.0 / (14.0_f32.sqrt() * 77.0_f32.sqrt());
830 assert!((sim - expected).abs() < 1e-5);
831}
832
833#[cfg(test)]
838mod w12h_extra_tests {
839 use super::*;
840
841 #[test]
842 fn for_model_nomic_without_ollama_client_errors() {
843 let res = Embedder::for_model(EmbeddingModel::NomicEmbedV15, None);
845 match res {
846 Err(e) => {
847 let err = e.to_string();
848 assert!(
849 err.contains("Ollama") || err.contains("nomic"),
850 "expected ollama error msg, got: {err}"
851 );
852 }
853 Ok(_) => panic!("expected NomicEmbedV15 without client to error"),
854 }
855 }
856
857 #[test]
858 fn cosine_similarity_both_zero_returns_zero() {
859 let a = vec![0.0_f32; 3];
860 let b = vec![0.0_f32; 3];
861 let sim = Embedder::cosine_similarity(&a, &b);
862 assert_eq!(sim, 0.0);
864 }
865
866 #[test]
867 fn cosine_similarity_negative_values() {
868 let a = vec![1.0_f32, 2.0, 3.0];
869 let b = vec![-1.0_f32, -2.0, -3.0];
870 let sim = Embedder::cosine_similarity(&a, &b);
871 assert!((sim + 1.0).abs() < 1e-6);
872 }
873
874 #[test]
875 fn cosine_similarity_empty_vectors() {
876 let a: Vec<f32> = vec![];
877 let b: Vec<f32> = vec![];
878 let sim = Embedder::cosine_similarity(&a, &b);
879 assert_eq!(sim, 0.0);
881 }
882
883 #[test]
884 fn fuse_zero_weight_returns_pure_secondary() {
885 let p = vec![1.0_f32, 0.0];
886 let s = vec![0.0_f32, 1.0];
887 let f = Embedder::fuse(&p, &s, 0.0);
888 assert!((f[0] - 0.0).abs() < 1e-6);
889 assert!((f[1] - 1.0).abs() < 1e-6);
890 }
891
892 #[test]
893 fn fuse_empty_vectors_returns_empty() {
894 let p: Vec<f32> = vec![];
895 let s: Vec<f32> = vec![];
896 let f = Embedder::fuse(&p, &s, 0.5);
897 assert!(f.is_empty());
898 }
899
900 #[test]
901 fn embedding_dim_constant_pinned() {
902 assert_eq!(EMBEDDING_DIM, MINILM_DIM);
903 assert_eq!(MINILM_DIM, 384);
904 assert_eq!(NOMIC_DIM, 768);
905 }
906
907 #[test]
908 fn fuse_dimension_mismatch_secondary_longer() {
909 let p = vec![1.0_f32, 2.0];
912 let s = vec![3.0_f32, 4.0, 5.0]; let f = Embedder::fuse(&p, &s, 0.5);
914 assert_eq!(f, p);
915 }
916
917 #[test]
918 fn cosine_similarity_dimension_mismatch_inverse() {
919 let a = vec![1.0_f32, 0.0];
921 let b = vec![1.0_f32, 0.0, 0.0];
922 let sim = Embedder::cosine_similarity(&a, &b);
923 assert_eq!(sim, 0.0);
924 }
925
926 #[test]
927 fn pr9i_for_model_minilm_dispatches_to_new_local() {
928 let res = Embedder::for_model(EmbeddingModel::MiniLmL6V2, None);
934 match res {
935 Ok(e) => {
936 assert_eq!(e.dim(), 384);
938 let desc = e.model_description();
939 assert!(desc.contains("MiniLM"));
940 }
941 Err(e) => {
942 let msg = e.to_string();
944 assert!(
945 msg.contains("model")
946 || msg.contains("config")
947 || msg.contains("tokenizer")
948 || msg.contains("fallback")
949 || msg.contains("HuggingFace"),
950 "unexpected new_local error: {msg}"
951 );
952 }
953 }
954 }
955
956 #[test]
957 fn pr9i_embedder_new_alias_is_new_local() {
958 let res = Embedder::new();
961 match res {
962 Ok(e) => {
963 assert_eq!(e.dim(), 384);
964 }
965 Err(e) => {
966 let msg = e.to_string();
967 assert!(!msg.is_empty());
968 }
969 }
970 }
971}
972
973#[test]
974fn embedder_returns_unreachable_when_model_path_missing() {
975 let result = Embedder::load_from_fallback();
978 match result {
981 Ok(_) => {
982 }
984 Err(e) => {
985 let err_msg = e.to_string();
987 assert!(
988 err_msg.contains("not found") || err_msg.contains("fallback"),
989 "error should mention missing model files: {err_msg}"
990 );
991 }
992 }
993}
994
995#[test]
996fn load_from_fallback_succeeds_when_files_present() {
997 use std::sync::Mutex;
1002 static LOCK: Mutex<()> = Mutex::new(());
1005 let _guard = LOCK
1006 .lock()
1007 .unwrap_or_else(std::sync::PoisonError::into_inner);
1008
1009 let tmp = std::env::temp_dir().join(format!("ai-memory-w12h-fallback-{}", std::process::id()));
1010 let model_dir = tmp.join(
1011 ".cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/main",
1012 );
1013 std::fs::create_dir_all(&model_dir).expect("mk model dir");
1014 for name in ["config.json", "tokenizer.json", "model.safetensors"] {
1015 std::fs::write(model_dir.join(name), b"{}").expect("write placeholder");
1016 }
1017 let prev = std::env::var("HOME").ok();
1018 unsafe {
1020 std::env::set_var("HOME", &tmp);
1021 }
1022 let result = Embedder::load_from_fallback();
1023 unsafe {
1025 match prev {
1026 Some(p) => std::env::set_var("HOME", p),
1027 None => std::env::remove_var("HOME"),
1028 }
1029 }
1030 let _ = std::fs::remove_dir_all(&tmp);
1031 let (cfg, tok, w) = result.expect("placeholder files satisfy load_from_fallback");
1032 assert!(cfg.ends_with("config.json"));
1033 assert!(tok.ends_with("tokenizer.json"));
1034 assert!(w.ends_with("model.safetensors"));
1035}