coding_agent_search/search/
hash_embedder.rs1use super::embedder::{Embedder, EmbedderError, EmbedderResult};
36use frankensearch::{
37 HashAlgorithm as FsHashAlgorithm, HashEmbedder as FsHashEmbedder, ModelCategory, ModelTier,
38};
39
40pub const DEFAULT_DIMENSION: usize = 384;
42
43const MIN_TOKEN_LEN: usize = 2;
45
46#[derive(Debug, Clone)]
52pub struct HashEmbedder {
53 dimension: usize,
54 id: String,
55 delegate: FsHashEmbedder,
56}
57
58impl HashEmbedder {
59 pub fn new(dimension: usize) -> Self {
70 assert!(dimension > 0, "dimension must be positive");
71 Self {
72 dimension,
73 id: format!("fnv1a-{dimension}"),
74 delegate: FsHashEmbedder::new(dimension, FsHashAlgorithm::FnvModular),
75 }
76 }
77
78 pub fn default_dimension() -> Self {
80 Self::new(DEFAULT_DIMENSION)
81 }
82
83 fn tokenize(text: &str) -> Vec<String> {
89 text.to_lowercase()
90 .split(|c: char| !c.is_alphanumeric())
91 .filter(|s| s.chars().count() >= MIN_TOKEN_LEN)
92 .map(String::from)
93 .collect()
94 }
95
96 fn uniform_fallback(&self) -> Vec<f32> {
97 let mut embedding = vec![1.0f32; self.dimension];
98 let norm = (self.dimension as f32).sqrt();
99 for value in &mut embedding {
100 *value /= norm;
101 }
102 embedding
103 }
104}
105
106impl Default for HashEmbedder {
107 fn default() -> Self {
108 Self::default_dimension()
109 }
110}
111
112impl Embedder for HashEmbedder {
113 fn embed_sync(&self, text: &str) -> EmbedderResult<Vec<f32>> {
114 if text.is_empty() {
115 return Err(EmbedderError::InvalidConfig {
116 field: "input_text".to_string(),
117 value: "(empty)".to_string(),
118 reason: "empty text".to_string(),
119 });
120 }
121
122 let tokens = Self::tokenize(text);
123
124 if tokens.is_empty() {
126 return Ok(self.uniform_fallback());
127 }
128
129 let canonical = tokens.join(" ");
132 let embedding = self.delegate.embed_sync(&canonical);
133 if embedding.len() != self.dimension {
134 return Err(EmbedderError::EmbeddingFailed {
135 model: self.id.clone(),
136 source: Box::new(std::io::Error::other(format!(
137 "delegate dimension mismatch: expected {}, got {}",
138 self.dimension,
139 embedding.len()
140 ))),
141 });
142 }
143 Ok(embedding)
144 }
145
146 fn embed_batch_sync(&self, texts: &[&str]) -> EmbedderResult<Vec<Vec<f32>>> {
147 texts.iter().map(|t| self.embed_sync(t)).collect()
148 }
149
150 fn dimension(&self) -> usize {
151 self.dimension
152 }
153
154 fn id(&self) -> &str {
155 &self.id
156 }
157
158 fn is_semantic(&self) -> bool {
159 false
160 }
161
162 fn category(&self) -> ModelCategory {
163 ModelCategory::HashEmbedder
164 }
165
166 fn tier(&self) -> ModelTier {
167 ModelTier::Fast
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_hash_embedder_basic() {
177 let embedder = HashEmbedder::new(256);
178 let embedding = embedder.embed_sync("hello world").unwrap();
179
180 assert_eq!(embedding.len(), 256);
181 assert_eq!(embedder.id(), "fnv1a-256");
182 assert!(!embedder.is_semantic());
183 }
184
185 #[test]
186 fn test_hash_embedder_default() {
187 let embedder = HashEmbedder::default();
188
189 assert_eq!(embedder.dimension(), DEFAULT_DIMENSION);
190 assert_eq!(embedder.id(), format!("fnv1a-{DEFAULT_DIMENSION}"));
191 }
192
193 #[test]
194 fn test_hash_embedder_deterministic() {
195 let embedder = HashEmbedder::new(256);
196
197 let text = "deterministic embedding test with some words";
198 let embedding1 = embedder.embed_sync(text).unwrap();
199 let embedding2 = embedder.embed_sync(text).unwrap();
200
201 assert_eq!(embedding1, embedding2);
203 }
204
205 #[test]
206 fn test_hash_embedder_l2_normalized() {
207 let embedder = HashEmbedder::new(256);
208 let embedding = embedder.embed_sync("normalize this vector").unwrap();
209
210 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
212
213 assert!(
215 (norm - 1.0).abs() < 1e-5,
216 "L2 norm should be ~1.0, got {norm}"
217 );
218 }
219
220 #[test]
221 fn test_hash_embedder_different_texts_different_embeddings() {
222 let embedder = HashEmbedder::new(256);
223
224 let embedding1 = embedder.embed_sync("hello world").unwrap();
225 let embedding2 = embedder.embed_sync("goodbye world").unwrap();
226
227 assert_ne!(embedding1, embedding2);
229 }
230
231 #[test]
232 fn test_hash_embedder_empty_input_error() {
233 let embedder = HashEmbedder::new(256);
234 let result = embedder.embed_sync("");
235
236 assert!(result.is_err());
237 }
238
239 #[test]
240 fn test_hash_embedder_punctuation_only() {
241 let embedder = HashEmbedder::new(256);
242
243 let embedding = embedder.embed_sync("!@#$%^&*()").unwrap();
245
246 assert_eq!(embedding.len(), 256);
247 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
249 assert!(
250 (norm - 1.0).abs() < 1e-5,
251 "L2 norm should be ~1.0, got {norm}"
252 );
253 }
254
255 #[test]
256 fn test_hash_embedder_batch() {
257 let embedder = HashEmbedder::new(256);
258 let texts = &["hello world", "goodbye world", "test batch"];
259
260 let embeddings = embedder.embed_batch_sync(texts).unwrap();
261
262 assert_eq!(embeddings.len(), 3);
263 for embedding in &embeddings {
264 assert_eq!(embedding.len(), 256);
265
266 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
268 assert!(
269 (norm - 1.0).abs() < 1e-5,
270 "L2 norm should be ~1.0, got {norm}"
271 );
272 }
273 }
274
275 #[test]
276 fn test_hash_embedder_batch_empty_error() {
277 let embedder = HashEmbedder::new(256);
278 let texts = &["hello", "", "world"];
279
280 let result = embedder.embed_batch_sync(texts);
281 assert!(result.is_err());
282 }
283
284 #[test]
285 fn test_tokenize() {
286 let tokens = HashEmbedder::tokenize("Hello, World! This is a TEST-123.");
287
288 for expected in ["hello", "world", "this", "test", "123", "is"] {
290 assert!(
291 tokens.iter().any(|candidate| candidate == expected),
292 "expected token {expected:?} in {tokens:?}"
293 );
294 }
295
296 assert!(
298 !tokens.iter().any(|candidate| candidate == "a"),
299 "single-character token should be filtered: {tokens:?}"
300 );
301 }
302
303 #[test]
304 fn test_tokenize_includes_len_2() {
305 let tokens = HashEmbedder::tokenize("is it ok");
306
307 assert!(tokens.contains(&"is".to_string()));
309 assert!(tokens.contains(&"it".to_string()));
310 assert!(tokens.contains(&"ok".to_string()));
311 }
312
313 #[test]
314 fn test_case_insensitivity() {
315 let embedder = HashEmbedder::new(256);
316
317 let embedding1 = embedder.embed_sync("Hello World").unwrap();
318 let embedding2 = embedder.embed_sync("hello world").unwrap();
319 let embedding3 = embedder.embed_sync("HELLO WORLD").unwrap();
320
321 assert_eq!(embedding1, embedding2);
323 assert_eq!(embedding2, embedding3);
324 }
325
326 #[test]
327 fn test_whitespace_insensitivity() {
328 let embedder = HashEmbedder::new(256);
329
330 let embedding1 = embedder.embed_sync("hello world").unwrap();
331 let embedding2 = embedder.embed_sync("hello world").unwrap();
332 let embedding3 = embedder.embed_sync("hello\n\tworld").unwrap();
333
334 assert_eq!(embedding1, embedding2);
336 assert_eq!(embedding2, embedding3);
337 }
338
339 #[test]
340 #[should_panic(expected = "dimension must be positive")]
341 fn test_zero_dimension_panics() {
342 let _ = HashEmbedder::new(0);
343 }
344
345 #[test]
346 fn test_large_dimension() {
347 let embedder = HashEmbedder::new(4096);
348 let embedding = embedder.embed_sync("test large dimension").unwrap();
349
350 assert_eq!(embedding.len(), 4096);
351
352 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
354 assert!(
355 (norm - 1.0).abs() < 1e-5,
356 "L2 norm should be ~1.0, got {norm}"
357 );
358 }
359
360 #[test]
361 fn test_unicode_text() {
362 let embedder = HashEmbedder::new(256);
363
364 let embedding = embedder.embed_sync("café résumé naïve").unwrap();
366 assert_eq!(embedding.len(), 256);
367
368 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
370 assert!(
371 (norm - 1.0).abs() < 1e-5,
372 "L2 norm should be ~1.0, got {norm}"
373 );
374 }
375
376 #[test]
377 fn test_embedding_similarity() {
378 let embedder = HashEmbedder::new(256);
379
380 let emb_dog = embedder.embed_sync("the quick brown dog").unwrap();
382 let emb_fox = embedder.embed_sync("the quick brown fox").unwrap();
383 let emb_unrelated = embedder.embed_sync("quantum physics equations").unwrap();
384
385 let sim_dog_fox: f32 = emb_dog.iter().zip(&emb_fox).map(|(a, b)| a * b).sum();
387 let sim_dog_unrelated: f32 = emb_dog.iter().zip(&emb_unrelated).map(|(a, b)| a * b).sum();
388
389 assert!(
391 sim_dog_fox > sim_dog_unrelated,
392 "similar texts should have higher cosine similarity: dog_fox={sim_dog_fox}, dog_unrelated={sim_dog_unrelated}"
393 );
394 }
395
396 #[test]
397 fn test_sync_embedder_adapter_bridge() {
398 use frankensearch::SyncEmbedderAdapter;
399
400 let embedder = HashEmbedder::new(256);
401 let adapted = SyncEmbedderAdapter(embedder);
402
403 assert_eq!(frankensearch::Embedder::dimension(&adapted), 256);
405 assert_eq!(frankensearch::Embedder::id(&adapted), "fnv1a-256");
406 assert!(!frankensearch::Embedder::is_semantic(&adapted));
407 }
408}