use crate::error::Result;
use async_trait::async_trait;
#[async_trait]
pub trait EmbeddingGenerator: Send + Sync {
async fn generate(&self, text: &str) -> Result<Vec<f32>>;
fn dimension(&self) -> usize;
}
pub struct MockEmbeddingGenerator {
dimension: usize,
}
impl MockEmbeddingGenerator {
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
}
#[async_trait]
impl EmbeddingGenerator for MockEmbeddingGenerator {
async fn generate(&self, text: &str) -> Result<Vec<f32>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
let hash = hasher.finish();
let mut embedding = Vec::with_capacity(self.dimension);
for i in 0..self.dimension {
let value = ((hash.wrapping_mul(i as u64 + 1)) as f32) / (u64::MAX as f32);
embedding.push(value);
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in embedding.iter_mut() {
*val /= norm;
}
}
Ok(embedding)
}
fn dimension(&self) -> usize {
self.dimension
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_embedding_deterministic() {
let generator = MockEmbeddingGenerator::new(384);
let emb1 = generator.generate("test text").await.unwrap();
let emb2 = generator.generate("test text").await.unwrap();
assert_eq!(emb1.len(), 384);
assert_eq!(emb1, emb2); }
#[tokio::test]
async fn test_mock_embedding_normalized() {
let generator = MockEmbeddingGenerator::new(384);
let embedding = generator.generate("test").await.unwrap();
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.001); }
}