1use std::collections::HashMap;
28use std::path::{Path, PathBuf};
29
30use super::tokenize;
31
32pub trait Embedder: Send + Sync {
40 fn embed(&self, text: &str) -> Vec<f32>;
43
44 fn dim(&self) -> usize;
46
47 fn name(&self) -> &str;
50
51 fn embed_batch(&self, texts: &[String]) -> Vec<Vec<f32>> {
54 texts.iter().map(|t| self.embed(t)).collect()
55 }
56}
57
58fn fnv1a(bytes: &[u8], seed: u64) -> u64 {
64 const FNV_PRIME: u64 = 0x0000_0100_0000_01B3;
65 let mut hash = seed ^ 0xcbf2_9ce4_8422_2325;
66 for &b in bytes {
67 hash ^= b as u64;
68 hash = hash.wrapping_mul(FNV_PRIME);
69 }
70 hash
71}
72
73pub struct LexicalEmbedder {
81 dim: usize,
82 name: String,
83}
84
85impl LexicalEmbedder {
86 pub fn new(dim: usize) -> Self {
90 Self {
91 dim: dim.max(16),
92 name: "lexical-hash".to_string(),
93 }
94 }
95
96 fn add_feature(&self, vec: &mut [f32], feature: &str, weight: f32) {
97 let h = fnv1a(feature.as_bytes(), 0);
98 let bucket = (h % self.dim as u64) as usize;
99 let sign = if fnv1a(feature.as_bytes(), 0x9e37_79b9_7f4a_7c15) & 1 == 0 {
101 1.0
102 } else {
103 -1.0
104 };
105 vec[bucket] += sign * weight;
106 }
107}
108
109impl Default for LexicalEmbedder {
110 fn default() -> Self {
111 Self::new(256)
112 }
113}
114
115impl Embedder for LexicalEmbedder {
116 fn embed(&self, text: &str) -> Vec<f32> {
117 let mut vec = vec![0.0f32; self.dim];
118 for token in tokenize::word_tokens(text) {
119 self.add_feature(&mut vec, &token, 1.0);
120 }
121 for gram in tokenize::char_ngrams(text, 3) {
122 self.add_feature(&mut vec, &gram, 0.35);
124 }
125 l2_normalize(&mut vec);
126 vec
127 }
128
129 fn dim(&self) -> usize {
130 self.dim
131 }
132
133 fn name(&self) -> &str {
134 &self.name
135 }
136}
137
138pub struct StaticEmbedder {
157 dim: usize,
158 vectors: HashMap<String, Vec<f32>>,
159 name: String,
160 fallback: LexicalEmbedder,
164}
165
166impl StaticEmbedder {
167 pub fn from_asset_dir(asset_dir: &Path) -> Result<Self, String> {
174 let path = asset_dir.join("static-embeddings.json");
175 let raw = std::fs::read_to_string(&path)
176 .map_err(|e| format!("static embedding asset {} unreadable: {e}", path.display()))?;
177 Self::from_json(&raw)
178 }
179
180 pub fn from_json(raw: &str) -> Result<Self, String> {
183 let doc: AssetDoc = parse_asset(raw)?;
187 if doc.vectors.is_empty() {
188 return Err("static embedding asset has no vectors".to_string());
189 }
190 for (tok, v) in &doc.vectors {
191 if v.len() != doc.dim {
192 return Err(format!(
193 "static embedding vector for `{tok}` has length {} but dim is {}",
194 v.len(),
195 doc.dim
196 ));
197 }
198 }
199 Ok(Self {
200 dim: doc.dim,
201 vectors: doc.vectors,
202 name: "static-model2vec".to_string(),
203 fallback: LexicalEmbedder::new(doc.dim),
204 })
205 }
206}
207
208impl Embedder for StaticEmbedder {
209 fn embed(&self, text: &str) -> Vec<f32> {
210 let mut acc = vec![0.0f32; self.dim];
211 let mut hits = 0usize;
212 for token in tokenize::word_tokens(text) {
213 if let Some(v) = self.vectors.get(&token) {
214 for (a, x) in acc.iter_mut().zip(v.iter()) {
215 *a += x;
216 }
217 hits += 1;
218 }
219 }
220 if hits == 0 {
221 return self.fallback.embed(text);
224 }
225 let inv = 1.0 / hits as f32;
226 for a in acc.iter_mut() {
227 *a *= inv;
228 }
229 l2_normalize(&mut acc);
230 acc
231 }
232
233 fn dim(&self) -> usize {
234 self.dim
235 }
236
237 fn name(&self) -> &str {
238 &self.name
239 }
240}
241
242pub(crate) fn l2_normalize(vec: &mut [f32]) {
245 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
246 if norm > 0.0 {
247 let inv = 1.0 / norm;
248 for x in vec.iter_mut() {
249 *x *= inv;
250 }
251 }
252}
253
254struct AssetDoc {
257 dim: usize,
258 vectors: HashMap<String, Vec<f32>>,
259}
260
261fn parse_asset(raw: &str) -> Result<AssetDoc, String> {
265 let dim = extract_int(raw, "\"dim\"")
268 .ok_or_else(|| "static embedding asset missing integer `dim`".to_string())?;
269 if dim == 0 {
270 return Err("static embedding `dim` must be > 0".to_string());
271 }
272 let vectors = extract_vectors(raw)?;
273 Ok(AssetDoc {
274 dim: dim as usize,
275 vectors,
276 })
277}
278
279fn extract_int(raw: &str, key: &str) -> Option<i64> {
280 let idx = raw.find(key)?;
281 let after = &raw[idx + key.len()..];
282 let colon = after.find(':')?;
283 let rest = after[colon + 1..].trim_start();
284 let end = rest
285 .find(|c: char| !c.is_ascii_digit() && c != '-')
286 .unwrap_or(rest.len());
287 rest[..end].parse::<i64>().ok()
288}
289
290fn extract_vectors(raw: &str) -> Result<HashMap<String, Vec<f32>>, String> {
291 let key = "\"vectors\"";
292 let idx = raw
293 .find(key)
294 .ok_or_else(|| "static embedding asset missing `vectors`".to_string())?;
295 let after = &raw[idx + key.len()..];
296 let open = after
297 .find('{')
298 .ok_or_else(|| "`vectors` is not an object".to_string())?;
299 let body = &after[open + 1..];
300 let mut map = HashMap::new();
301 let bytes = body.as_bytes();
302 let mut i = 0usize;
303 while i < bytes.len() {
304 match bytes[i] {
306 b'}' => break,
307 b'"' => {
308 let (k, next) = parse_string(body, i)?;
309 i = next;
310 while i < bytes.len() && bytes[i] != b':' {
312 i += 1;
313 }
314 i += 1;
315 while i < bytes.len() && bytes[i] != b'[' {
317 i += 1;
318 }
319 let (vec, next) = parse_float_array(body, i)?;
320 i = next;
321 map.insert(k, vec);
322 }
323 _ => i += 1,
324 }
325 }
326 Ok(map)
327}
328
329fn parse_string(s: &str, start: usize) -> Result<(String, usize), String> {
330 let bytes = s.as_bytes();
331 debug_assert_eq!(bytes[start], b'"');
332 let mut i = start + 1;
333 let mut out = String::new();
334 while i < bytes.len() {
335 match bytes[i] {
336 b'"' => return Ok((out, i + 1)),
337 b'\\' if i + 1 < bytes.len() => {
338 out.push(bytes[i + 1] as char);
339 i += 2;
340 }
341 c => {
342 out.push(c as char);
343 i += 1;
344 }
345 }
346 }
347 Err("unterminated string in static embedding asset".to_string())
348}
349
350fn parse_float_array(s: &str, start: usize) -> Result<(Vec<f32>, usize), String> {
351 let bytes = s.as_bytes();
352 if start >= bytes.len() || bytes[start] != b'[' {
353 return Err("expected float array in static embedding asset".to_string());
354 }
355 let mut i = start + 1;
356 let mut out = Vec::new();
357 let mut num = String::new();
358 let flush = |num: &mut String, out: &mut Vec<f32>| -> Result<(), String> {
359 let t = num.trim();
360 if !t.is_empty() {
361 out.push(
362 t.parse::<f32>()
363 .map_err(|_| format!("bad float `{t}` in static embedding asset"))?,
364 );
365 }
366 num.clear();
367 Ok(())
368 };
369 while i < bytes.len() {
370 match bytes[i] {
371 b']' => {
372 flush(&mut num, &mut out)?;
373 return Ok((out, i + 1));
374 }
375 b',' => {
376 flush(&mut num, &mut out)?;
377 i += 1;
378 }
379 c if c.is_ascii_whitespace() => i += 1,
380 c => {
381 num.push(c as char);
382 i += 1;
383 }
384 }
385 }
386 Err("unterminated float array in static embedding asset".to_string())
387}
388
389pub fn resolve_asset_dir(
401 override_dir: Option<&Path>,
402 data_dir: Option<&Path>,
403 model: &str,
404) -> Option<PathBuf> {
405 if let Some(dir) = override_dir {
406 if dir.join("static-embeddings.json").is_file() {
407 return Some(dir.to_path_buf());
408 }
409 }
410 if let Some(base) = data_dir {
411 let candidate = base.join("embeddings").join(model);
412 if candidate.join("static-embeddings.json").is_file() {
413 return Some(candidate);
414 }
415 }
416 None
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn lexical_identical_text_is_self_similar() {
425 let e = LexicalEmbedder::default();
426 let v = e.embed("rate limiter middleware");
427 assert_eq!(v.len(), 256);
428 let sim = super::super::similarity::cosine(&v, &v);
429 assert!((sim - 1.0).abs() < 1e-5, "self-sim was {sim}");
430 }
431
432 #[test]
433 fn lexical_related_beats_unrelated() {
434 let e = LexicalEmbedder::default();
435 let query = e.embed("rate limiter for the API");
436 let related = e.embed("RateLimiter API throttle");
437 let unrelated = e.embed("parse markdown table renderer");
438 let s_rel = super::super::similarity::cosine(&query, &related);
439 let s_unrel = super::super::similarity::cosine(&query, &unrelated);
440 assert!(
441 s_rel > s_unrel,
442 "related {s_rel} should beat unrelated {s_unrel}"
443 );
444 }
445
446 #[test]
447 fn lexical_empty_is_zero_vector() {
448 let e = LexicalEmbedder::default();
449 let v = e.embed("");
450 assert!(v.iter().all(|&x| x == 0.0));
451 }
452
453 #[test]
454 fn lexical_is_l2_normalized() {
455 let e = LexicalEmbedder::default();
456 let v = e.embed("hello world embedding test");
457 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
458 assert!((norm - 1.0).abs() < 1e-5, "norm was {norm}");
459 }
460
461 #[test]
462 fn lexical_is_deterministic_cross_run() {
463 let e = LexicalEmbedder::default();
465 assert_eq!(e.embed("getUserById"), e.embed("getUserById"));
466 }
467
468 #[test]
469 fn static_embedder_pools_known_tokens() {
470 let json = r#"{ "dim": 2, "vectors": {
471 "rate": [1.0, 0.0],
472 "limit": [0.0, 1.0],
473 "throttle": [0.7071, 0.7071]
474 } }"#;
475 let e = StaticEmbedder::from_json(json).expect("parse");
476 assert_eq!(e.dim(), 2);
477 let v = e.embed("rate limit");
479 let expected = std::f32::consts::FRAC_1_SQRT_2;
480 assert!((v[0] - expected).abs() < 1e-3, "{v:?}");
481 assert!((v[1] - expected).abs() < 1e-3, "{v:?}");
482 let sim = super::super::similarity::cosine(&v, &e.embed("throttle"));
484 assert!(sim > 0.99, "throttle sim {sim}");
485 }
486
487 #[test]
488 fn static_embedder_falls_back_for_unknown_tokens() {
489 let json = r#"{ "dim": 2, "vectors": { "rate": [1.0, 0.0] } }"#;
490 let e = StaticEmbedder::from_json(json).expect("parse");
491 let v = e.embed("zzz totally unknown words");
493 assert!(v.iter().any(|&x| x != 0.0));
494 }
495
496 #[test]
497 fn static_embedder_rejects_malformed_asset() {
498 assert!(StaticEmbedder::from_json("not json").is_err());
499 assert!(StaticEmbedder::from_json(r#"{ "dim": 2, "vectors": {} }"#).is_err());
500 assert!(
502 StaticEmbedder::from_json(r#"{ "dim": 3, "vectors": { "x": [1.0, 2.0] } }"#).is_err()
503 );
504 }
505
506 #[test]
507 fn resolve_asset_dir_respects_override_and_absence() {
508 let tmp = std::env::temp_dir().join("embed-resolve-test-absent-xyz");
509 let _ = std::fs::remove_dir_all(&tmp);
510 assert_eq!(resolve_asset_dir(Some(&tmp), None, "potion"), None);
511 assert_eq!(resolve_asset_dir(None, Some(&tmp), "potion"), None);
512 }
513
514 #[test]
515 fn parse_handles_negative_and_scientific_floats() {
516 let json = r#"{ "dim": 3, "vectors": { "x": [-1.5, 0.0, 2.0] } }"#;
517 let e = StaticEmbedder::from_json(json).expect("parse");
518 assert_eq!(e.dim(), 3);
519 }
520}