use super::provider::{EmbeddingProvider, EmbeddingResult};
use anyhow::Result;
use async_trait::async_trait;
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct MockEmbedder {
dimensions: usize,
model_name: String,
}
#[allow(dead_code)]
impl MockEmbedder {
pub fn new(dimensions: usize) -> Self {
Self {
dimensions,
model_name: "mock-embedder".to_string(),
}
}
pub fn with_model_name(dimensions: usize, model_name: &str) -> Self {
Self {
dimensions,
model_name: model_name.to_string(),
}
}
fn text_to_embedding(&self, text: &str) -> Vec<f32> {
let hash = blake3::hash(text.as_bytes());
let hash_bytes = hash.as_bytes();
let mut embedding = Vec::with_capacity(self.dimensions);
for i in 0..self.dimensions {
let byte_idx = i % 32;
let variation = (i / 32) as u8;
let value = hash_bytes[byte_idx].wrapping_add(variation.wrapping_mul(17));
let normalized = (value as f32 / 127.5) - 1.0;
embedding.push(normalized);
}
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for value in &mut embedding {
*value /= magnitude;
}
}
embedding
}
}
#[async_trait]
impl EmbeddingProvider for MockEmbedder {
fn dimensions(&self) -> usize {
self.dimensions
}
fn model_name(&self) -> &str {
&self.model_name
}
async fn embed(&self, text: &str) -> Result<EmbeddingResult> {
Ok(EmbeddingResult {
embedding: self.text_to_embedding(text),
token_count: Some(text.split_whitespace().count()),
})
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingResult>> {
Ok(texts
.iter()
.map(|text| EmbeddingResult {
embedding: self.text_to_embedding(text),
token_count: Some(text.split_whitespace().count()),
})
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_embedder_dimensions() {
let embedder = MockEmbedder::new(384);
assert_eq!(embedder.dimensions(), 384);
let embedder = MockEmbedder::new(1536);
assert_eq!(embedder.dimensions(), 1536);
}
#[test]
fn test_embedding_determinism() {
let embedder = MockEmbedder::new(384);
let embedding1 = embedder.text_to_embedding("hello world");
let embedding2 = embedder.text_to_embedding("hello world");
assert_eq!(embedding1, embedding2);
}
#[test]
fn test_different_texts_different_embeddings() {
let embedder = MockEmbedder::new(384);
let embedding1 = embedder.text_to_embedding("hello world");
let embedding2 = embedder.text_to_embedding("goodbye world");
assert_ne!(embedding1, embedding2);
}
#[test]
fn test_embedding_normalized() {
let embedder = MockEmbedder::new(384);
let embedding = embedder.text_to_embedding("test text");
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.001);
}
#[test]
fn test_embedding_correct_dimensions() {
let embedder = MockEmbedder::new(1536);
let embedding = embedder.text_to_embedding("test");
assert_eq!(embedding.len(), 1536);
}
#[tokio::test]
async fn test_embed_async() {
let embedder = MockEmbedder::new(384);
let result = embedder.embed("hello world").await.unwrap();
assert_eq!(result.embedding.len(), 384);
assert!(result.token_count.is_some());
}
#[tokio::test]
async fn test_embed_batch() {
let embedder = MockEmbedder::new(384);
let texts = vec![
"first text".to_string(),
"second text".to_string(),
"third text".to_string(),
];
let results = embedder.embed_batch(&texts).await.unwrap();
assert_eq!(results.len(), 3);
for result in results {
assert_eq!(result.embedding.len(), 384);
}
}
#[test]
fn test_similar_texts_have_different_embeddings() {
let embedder = MockEmbedder::new(384);
let e1 = embedder.text_to_embedding("hello");
let e2 = embedder.text_to_embedding("Hello");
let e3 = embedder.text_to_embedding("hello ");
assert_ne!(e1, e2);
assert_ne!(e1, e3);
assert_ne!(e2, e3);
}
}