use async_trait::async_trait;
use crate::error::Result;
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed(text).await?);
}
Ok(results)
}
fn dimensions(&self) -> usize;
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_embedding_provider_is_object_safe() {
fn _assert_object_safe(_: &dyn EmbeddingProvider) {}
fn _assert_boxable(_: Box<dyn EmbeddingProvider>) {}
fn _assert_arcable(_: Arc<dyn EmbeddingProvider>) {}
}
struct MockEmbeddingProvider {
dims: usize,
}
#[async_trait]
impl EmbeddingProvider for MockEmbeddingProvider {
async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![0.1; self.dims])
}
fn dimensions(&self) -> usize {
self.dims
}
}
#[test]
fn test_dimensions_returns_configured_value() {
let provider = MockEmbeddingProvider { dims: 384 };
assert_eq!(provider.dimensions(), 384);
let provider = MockEmbeddingProvider { dims: 1536 };
assert_eq!(provider.dimensions(), 1536);
}
#[tokio::test]
async fn test_embed_returns_correct_length() {
let provider = MockEmbeddingProvider { dims: 384 };
let result = provider.embed("test text").await.unwrap();
assert_eq!(result.len(), 384);
}
#[tokio::test]
async fn test_embed_batch_default_impl() {
let provider = MockEmbeddingProvider { dims: 384 };
let texts = &["hello", "world", "test"];
let results = provider.embed_batch(texts).await.unwrap();
assert_eq!(results.len(), 3);
for result in &results {
assert_eq!(result.len(), 384);
}
}
#[tokio::test]
async fn test_embed_batch_empty_input() {
let provider = MockEmbeddingProvider { dims: 384 };
let results = provider.embed_batch(&[]).await.unwrap();
assert!(results.is_empty());
}
}