normalize_semantic/
embedder.rs1use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
7use std::path::Path;
8
9pub const DEFAULT_MODEL: &str = "nomic-embed-text-v1.5";
12
13pub struct Embedder {
15 model: TextEmbedding,
16 pub model_name: String,
17 pub dimensions: usize,
18}
19
20impl Embedder {
21 pub fn load(model_name: &str, cache_dir: Option<&Path>) -> anyhow::Result<Self> {
26 let embedding_model = resolve_model(model_name)?;
27 let mut opts = InitOptions::new(embedding_model);
28 if let Some(dir) = cache_dir {
29 opts = opts.with_cache_dir(dir.to_path_buf());
30 }
31 let mut model = TextEmbedding::try_new(opts).map_err(|e| {
32 anyhow::anyhow!("Failed to load embedding model '{}': {}", model_name, e)
33 })?;
34
35 let probe = model
37 .embed(vec![""], None)
38 .map_err(|e| anyhow::anyhow!("Failed to probe embedding dimensions: {}", e))?;
39 let dimensions = probe.first().map(|v| v.len()).unwrap_or(768);
40
41 Ok(Self {
42 model,
43 model_name: model_name.to_string(),
44 dimensions,
45 })
46 }
47
48 pub fn embed_batch(&mut self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
50 self.model
51 .embed(texts, None)
52 .map_err(|e| anyhow::anyhow!("Embedding failed: {}", e))
53 }
54
55 pub fn embed_one(&mut self, text: &str) -> anyhow::Result<Vec<f32>> {
57 let batch = self.embed_batch(&[text])?;
58 batch
59 .into_iter()
60 .next()
61 .ok_or_else(|| anyhow::anyhow!("Embedder returned empty result for single text"))
62 }
63}
64
65pub fn encode_vector(v: &[f32]) -> Vec<u8> {
67 let mut bytes = Vec::with_capacity(v.len() * 4);
68 for &x in v {
69 bytes.extend_from_slice(&x.to_le_bytes());
70 }
71 bytes
72}
73
74pub fn decode_vector(bytes: &[u8]) -> Vec<f32> {
76 bytes
77 .chunks_exact(4)
78 .map(|b| f32::from_le_bytes(b.try_into().unwrap_or([0u8; 4])))
79 .collect()
80}
81
82pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
85 debug_assert_eq!(
86 a.len(),
87 b.len(),
88 "cosine_similarity: vector length mismatch"
89 );
90 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
91 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
92 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
93 if mag_a == 0.0 || mag_b == 0.0 {
94 return 0.0;
95 }
96 (dot / (mag_a * mag_b)).clamp(-1.0, 1.0)
97}
98
99pub fn dims_for_model(name: &str) -> Option<usize> {
104 match name {
105 "nomic-embed-text-v1.5" => Some(768),
106 "all-MiniLM-L6-v2" => Some(384),
107 "all-MiniLM-L12-v2" => Some(384),
108 _ => None,
109 }
110}
111
112fn resolve_model(name: &str) -> anyhow::Result<EmbeddingModel> {
114 match name {
115 "nomic-embed-text-v1.5" => Ok(EmbeddingModel::NomicEmbedTextV15),
116 "all-MiniLM-L6-v2" => Ok(EmbeddingModel::AllMiniLML6V2),
117 "all-MiniLM-L12-v2" => Ok(EmbeddingModel::AllMiniLML12V2),
118 other => Err(anyhow::anyhow!(
119 "Unknown embedding model '{}'. Supported: nomic-embed-text-v1.5, all-MiniLM-L6-v2, all-MiniLM-L12-v2",
120 other
121 )),
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_encode_decode_roundtrip() {
131 let original = vec![1.0_f32, -0.5, 0.25, 0.0];
132 let bytes = encode_vector(&original);
133 let decoded = decode_vector(&bytes);
134 for (a, b) in original.iter().zip(decoded.iter()) {
135 assert!((a - b).abs() < 1e-7, "roundtrip mismatch: {} vs {}", a, b);
136 }
137 }
138
139 #[test]
140 fn test_cosine_similarity_identical() {
141 let v = vec![1.0_f32, 2.0, 3.0];
142 let sim = cosine_similarity(&v, &v);
143 assert!(
144 (sim - 1.0).abs() < 1e-6,
145 "identical vectors should have sim=1"
146 );
147 }
148
149 #[test]
150 fn test_cosine_similarity_orthogonal() {
151 let a = vec![1.0_f32, 0.0, 0.0];
152 let b = vec![0.0_f32, 1.0, 0.0];
153 let sim = cosine_similarity(&a, &b);
154 assert!(sim.abs() < 1e-6, "orthogonal vectors should have sim=0");
155 }
156
157 #[test]
158 fn test_cosine_zero_vector() {
159 let a = vec![0.0_f32, 0.0, 0.0];
160 let b = vec![1.0_f32, 2.0, 3.0];
161 let sim = cosine_similarity(&a, &b);
162 assert_eq!(sim, 0.0);
163 }
164}