use crate::error::EmbeddingError;
#[async_trait::async_trait]
pub trait EmbeddingEngine: Send + Sync {
fn name(&self) -> &str;
fn dimensions(&self) -> usize;
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
struct MockEngine {
embed_fn: Arc<dyn Fn(&str) -> Result<Vec<f32>, EmbeddingError> + Send + Sync>,
dims: usize,
}
impl MockEngine {
fn returning_ok(dims: usize, values: Vec<f32>) -> Self {
Self {
embed_fn: Arc::new(move |_| Ok(values.clone())),
dims,
}
}
fn returning_err(err_fn: impl Fn() -> EmbeddingError + Send + Sync + 'static) -> Self {
Self {
embed_fn: Arc::new(move |_| Err(err_fn())),
dims: 0,
}
}
}
#[async_trait::async_trait]
impl EmbeddingEngine for MockEngine {
fn name(&self) -> &str {
"mock"
}
fn dimensions(&self) -> usize {
self.dims
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
(self.embed_fn)(text)
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
texts.iter().map(|t| (self.embed_fn)(t)).collect()
}
}
#[tokio::test]
async fn mock_engine_returns_configured_embedding() {
let mock = MockEngine::returning_ok(3, vec![0.1, 0.2, 0.3]);
assert_eq!(mock.dimensions(), 3);
let result = mock.embed("hello").await.unwrap();
assert_eq!(result, vec![0.1, 0.2, 0.3]);
}
#[tokio::test]
async fn mock_engine_can_simulate_errors() {
let mock =
MockEngine::returning_err(|| EmbeddingError::Inference("simulated failure".into()));
let result = mock.embed("test").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), EmbeddingError::Inference(_)));
}
}