1use crate::error::EmbeddingError;
4
5#[async_trait::async_trait]
7pub trait EmbeddingEngine: Send + Sync {
8 fn name(&self) -> &str;
10
11 fn dimensions(&self) -> usize;
13
14 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
16
17 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
19}
20
21#[cfg(test)]
22mod tests {
23 use super::*;
24 use std::sync::Arc;
25
26 struct MockEngine {
28 embed_fn: Arc<dyn Fn(&str) -> Result<Vec<f32>, EmbeddingError> + Send + Sync>,
29 dims: usize,
30 }
31
32 impl MockEngine {
33 fn returning_ok(dims: usize, values: Vec<f32>) -> Self {
34 Self {
35 embed_fn: Arc::new(move |_| Ok(values.clone())),
36 dims,
37 }
38 }
39
40 fn returning_err(err_fn: impl Fn() -> EmbeddingError + Send + Sync + 'static) -> Self {
41 Self {
42 embed_fn: Arc::new(move |_| Err(err_fn())),
43 dims: 0,
44 }
45 }
46 }
47
48 #[async_trait::async_trait]
49 impl EmbeddingEngine for MockEngine {
50 fn name(&self) -> &str {
51 "mock"
52 }
53
54 fn dimensions(&self) -> usize {
55 self.dims
56 }
57
58 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
59 (self.embed_fn)(text)
60 }
61
62 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
63 texts.iter().map(|t| (self.embed_fn)(t)).collect()
64 }
65 }
66
67 #[tokio::test]
68 async fn mock_engine_returns_configured_embedding() {
69 let mock = MockEngine::returning_ok(3, vec![0.1, 0.2, 0.3]);
70 assert_eq!(mock.dimensions(), 3);
71 let result = mock.embed("hello").await.unwrap();
72 assert_eq!(result, vec![0.1, 0.2, 0.3]);
73 }
74
75 #[tokio::test]
76 async fn mock_engine_can_simulate_errors() {
77 let mock =
78 MockEngine::returning_err(|| EmbeddingError::Inference("simulated failure".into()));
79 let result = mock.embed("test").await;
80 assert!(result.is_err());
81 assert!(matches!(result.unwrap_err(), EmbeddingError::Inference(_)));
82 }
83}