Skip to main content

erio_embedding/
engine.rs

1//! Embedding engine trait definition.
2
3use crate::error::EmbeddingError;
4
5/// An engine that computes vector embeddings from text.
6#[async_trait::async_trait]
7pub trait EmbeddingEngine: Send + Sync {
8    /// Returns the engine name (e.g., "gemma", "remote").
9    fn name(&self) -> &str;
10
11    /// Returns the output vector dimensions.
12    fn dimensions(&self) -> usize;
13
14    /// Computes an embedding vector for a single text input.
15    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
16
17    /// Computes embedding vectors for a batch of text inputs.
18    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
19}
20
21#[cfg(test)]
22mod tests {
23    use super::*;
24    use std::sync::Arc;
25
26    /// A configurable mock embedding engine for tests.
27    struct MockEngine {
28        embed_fn: Arc<dyn Fn(&str) -> Result<Vec<f32>, EmbeddingError> + Send + Sync>,
29        dims: usize,
30    }
31
32    impl MockEngine {
33        fn returning_ok(dims: usize, values: Vec<f32>) -> Self {
34            Self {
35                embed_fn: Arc::new(move |_| Ok(values.clone())),
36                dims,
37            }
38        }
39
40        fn returning_err(err_fn: impl Fn() -> EmbeddingError + Send + Sync + 'static) -> Self {
41            Self {
42                embed_fn: Arc::new(move |_| Err(err_fn())),
43                dims: 0,
44            }
45        }
46    }
47
48    #[async_trait::async_trait]
49    impl EmbeddingEngine for MockEngine {
50        fn name(&self) -> &str {
51            "mock"
52        }
53
54        fn dimensions(&self) -> usize {
55            self.dims
56        }
57
58        async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
59            (self.embed_fn)(text)
60        }
61
62        async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
63            texts.iter().map(|t| (self.embed_fn)(t)).collect()
64        }
65    }
66
67    #[tokio::test]
68    async fn mock_engine_returns_configured_embedding() {
69        let mock = MockEngine::returning_ok(3, vec![0.1, 0.2, 0.3]);
70        assert_eq!(mock.dimensions(), 3);
71        let result = mock.embed("hello").await.unwrap();
72        assert_eq!(result, vec![0.1, 0.2, 0.3]);
73    }
74
75    #[tokio::test]
76    async fn mock_engine_can_simulate_errors() {
77        let mock =
78            MockEngine::returning_err(|| EmbeddingError::Inference("simulated failure".into()));
79        let result = mock.embed("test").await;
80        assert!(result.is_err());
81        assert!(matches!(result.unwrap_err(), EmbeddingError::Inference(_)));
82    }
83}