use crate::error::Result;
pub trait Embedder: Send + Sync {
fn dimensions(&self) -> usize;
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|t| self.embed(t)).collect()
}
}
#[derive(Debug, Clone)]
pub struct HashEmbedder {
dims: usize,
}
impl HashEmbedder {
pub fn new(dims: usize) -> Self {
assert!(dims > 0, "HashEmbedder dimension must be > 0");
Self { dims }
}
#[inline]
fn token_hash(token: &str) -> u64 {
let mut h: u64 = 0xcbf29ce484222325;
for b in token.bytes() {
h ^= b as u64;
h = h.wrapping_mul(0x100000001b3);
}
h
}
}
impl Embedder for HashEmbedder {
fn dimensions(&self) -> usize {
self.dims
}
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut v = vec![0.0f32; self.dims];
for raw in text.split_whitespace() {
let token = raw.to_ascii_lowercase();
let h = Self::token_hash(&token);
let idx = (h % self.dims as u64) as usize;
let sign = if (h >> 32) & 1 == 0 { 1.0 } else { -1.0 };
v[idx] += sign;
}
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
Ok(v)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deterministic_and_normalized() {
let e = HashEmbedder::new(64);
let a = e.embed("the quick brown fox").unwrap();
let b = e.embed("the quick brown fox").unwrap();
assert_eq!(a, b); assert_eq!(a.len(), 64);
let norm: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5); }
#[test]
fn overlap_scores_higher_than_disjoint() {
let e = HashEmbedder::new(256);
let cos = |x: &[f32], y: &[f32]| -> f32 { x.iter().zip(y).map(|(a, b)| a * b).sum() };
let base = e.embed("machine learning vector database").unwrap();
let near = e.embed("machine learning vector search").unwrap();
let far = e.embed("unrelated cooking recipe content").unwrap();
assert!(cos(&base, &near) > cos(&base, &far));
}
#[test]
fn case_insensitive_tokens() {
let e = HashEmbedder::new(64);
assert_eq!(e.embed("Hello World").unwrap(), e.embed("hello world").unwrap());
}
#[test]
fn empty_text_is_zero_vector() {
let e = HashEmbedder::new(32);
assert_eq!(e.embed(" ").unwrap(), vec![0.0f32; 32]);
}
#[test]
fn batch_matches_single() {
let e = HashEmbedder::new(48);
let batch = e.embed_batch(&["alpha beta", "gamma"]).unwrap();
assert_eq!(batch[0], e.embed("alpha beta").unwrap());
assert_eq!(batch[1], e.embed("gamma").unwrap());
}
}