use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use super::traits::EmbeddingModel;
pub struct EmbeddingEngine {
dimension: usize,
}
impl EmbeddingEngine {
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
pub fn embed(&self, text: &str) -> Vec<f32> {
let mut vector = vec![0.0f32; self.dimension];
for (i, val) in vector.iter_mut().enumerate() {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
text.hash(&mut hasher);
let h = hasher.finish();
*val = ((h as f64 / u64::MAX as f64) * 2.0 - 1.0) as f32;
}
normalize(&mut vector);
vector
}
pub fn batch_embed(&self, texts: &[&str]) -> Vec<Vec<f32>> {
texts.iter().map(|t| self.embed(t)).collect()
}
pub fn dimension(&self) -> usize {
self.dimension
}
}
impl EmbeddingModel for EmbeddingEngine {
fn embed(&self, text: &str) -> Vec<f32> {
EmbeddingEngine::embed(self, text)
}
fn batch_embed(&self, texts: &[&str]) -> Vec<Vec<f32>> {
EmbeddingEngine::batch_embed(self, texts)
}
fn dimension(&self) -> usize {
self.dimension
}
}
pub fn normalize(vector: &mut [f32]) {
let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > f32::EPSILON {
for val in vector.iter_mut() {
*val /= magnitude;
}
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have equal dimensions");
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_determinism() {
let engine = EmbeddingEngine::new(384);
let v1 = engine.embed("hello world");
let v2 = engine.embed("hello world");
assert_eq!(v1, v2);
}
#[test]
fn test_embedding_dimension() {
let engine = EmbeddingEngine::new(128);
let v = engine.embed("test");
assert_eq!(v.len(), 128);
}
#[test]
fn test_embedding_normalized() {
let engine = EmbeddingEngine::new(384);
let v = engine.embed("test normalization");
let magnitude: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 1e-5, "Expected unit vector, got magnitude {}", magnitude);
}
#[test]
fn test_cosine_similarity_identical() {
let engine = EmbeddingEngine::new(384);
let v = engine.embed("same text");
let sim = cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-5);
}
#[test]
fn test_cosine_similarity_different() {
let engine = EmbeddingEngine::new(384);
let v1 = engine.embed("hello world");
let v2 = engine.embed("completely different text about cats");
let sim = cosine_similarity(&v1, &v2);
assert!(sim < 1.0);
}
#[test]
fn test_batch_embed() {
let engine = EmbeddingEngine::new(64);
let texts = vec!["one", "two", "three"];
let embeddings = engine.batch_embed(&texts);
assert_eq!(embeddings.len(), 3);
for emb in &embeddings {
assert_eq!(emb.len(), 64);
}
}
}