use anyhow::Result;
use async_trait::async_trait;
pub const DEFAULT_EMBEDDING_DIM: usize = 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InputType {
Document,
Query,
}
#[async_trait]
pub trait Embedder: Send + Sync {
fn dim(&self) -> usize;
async fn embed(&self, texts: &[String], input_type: InputType) -> Result<Vec<Vec<f32>>>;
}
#[derive(Debug, Clone)]
pub struct DeterministicEmbedder {
dim: usize,
}
impl DeterministicEmbedder {
#[must_use]
pub fn new() -> Self {
Self {
dim: DEFAULT_EMBEDDING_DIM,
}
}
#[must_use]
pub fn with_dim(dim: usize) -> Self {
Self { dim }
}
fn hash_token(token: &str) -> u64 {
let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
for b in token.bytes() {
hash ^= u64::from(b);
hash = hash.wrapping_mul(0x0000_0100_0000_01b3);
}
hash
}
fn embed_one(&self, text: &str) -> Vec<f32> {
let mut v = vec![0.0_f32; self.dim];
let lower = text.to_lowercase();
let tokens: Vec<&str> = lower
.split(|c: char| !c.is_alphanumeric())
.filter(|t| !t.is_empty())
.collect();
for token in tokens {
let h = Self::hash_token(token);
let idx_a = (h % self.dim as u64) as usize;
let idx_b = ((h >> 32) % self.dim as u64) as usize;
let sign_a = if (h & 1) == 0 { 1.0 } else { -1.0 };
let sign_b = if (h & 2) == 0 { 1.0 } else { -1.0 };
v[idx_a] += sign_a;
v[idx_b] += sign_b;
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
}
impl Default for DeterministicEmbedder {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Embedder for DeterministicEmbedder {
fn dim(&self) -> usize {
self.dim
}
async fn embed(&self, texts: &[String], _input_type: InputType) -> Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|t| self.embed_one(t)).collect())
}
}
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn deterministic_is_stable_and_normalized() {
let e = DeterministicEmbedder::new();
let a = e
.embed(&["hello world".to_string()], InputType::Document)
.await
.unwrap();
let b = e
.embed(&["hello world".to_string()], InputType::Query)
.await
.unwrap();
assert_eq!(a[0].len(), DEFAULT_EMBEDDING_DIM);
assert_eq!(a, b, "same text must yield the same vector");
let norm: f32 = a[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4, "expected unit norm, got {norm}");
}
#[tokio::test]
async fn deterministic_similar_text_is_closer() {
let e = DeterministicEmbedder::new();
let vecs = e
.embed(
&[
"the quick brown fox jumps".to_string(),
"the quick brown fox leaps".to_string(),
"completely unrelated banana finance report".to_string(),
],
InputType::Document,
)
.await
.unwrap();
let close = cosine_similarity(&vecs[0], &vecs[1]);
let far = cosine_similarity(&vecs[0], &vecs[2]);
assert!(
close > far,
"shared-token texts should be more similar ({close} vs {far})"
);
}
#[tokio::test]
async fn custom_dim_respected() {
let e = DeterministicEmbedder::with_dim(1536);
let v = e
.embed(&["x".to_string()], InputType::Document)
.await
.unwrap();
assert_eq!(v[0].len(), 1536);
}
#[tokio::test]
async fn known_input_produces_known_vector_prefix() {
let e = DeterministicEmbedder::with_dim(16);
let v = e
.embed(&["return policy refund".to_string()], InputType::Document)
.await
.unwrap();
assert_eq!(v[0].len(), 16);
let expected: [f32; 16] = [
-0.28867513,
0.0,
0.0,
-0.28867513,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
-0.28867513,
0.0,
-0.86602545,
];
for (i, (got, want)) in v[0].iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-5,
"vector drift at index {i}: got {got}, expected {want}"
);
}
}
#[test]
fn cosine_similarity_basics() {
assert!((cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-6);
assert!(cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-6);
assert_eq!(cosine_similarity(&[1.0, 0.0], &[1.0]), 0.0);
assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
}
}