Skip to main content

fierros_core/
embedding.rs

1use crate::{FierrosError, FierrosResult};
2use async_trait::async_trait;
3
4#[async_trait]
5pub trait Embedder: Send + Sync {
6    async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>>;
7}
8
9#[derive(Debug, Clone)]
10pub struct MockEmbedder {
11    dimension: usize,
12    forced_error: Option<FierrosError>,
13}
14
15impl MockEmbedder {
16    pub fn new(dimension: usize) -> Self {
17        Self {
18            dimension,
19            forced_error: None,
20        }
21    }
22
23    pub fn failing(dimension: usize, error: FierrosError) -> Self {
24        Self {
25            dimension,
26            forced_error: Some(error),
27        }
28    }
29
30    pub fn with_error(mut self, error: FierrosError) -> Self {
31        self.forced_error = Some(error);
32        self
33    }
34
35    fn embed_one(&self, input: &str) -> Vec<f32> {
36        let mut out = vec![0.0; self.dimension];
37        if self.dimension == 0 {
38            return out;
39        }
40
41        for (idx, byte) in input.bytes().enumerate() {
42            out[idx % self.dimension] += byte as f32 / 255.0;
43        }
44
45        out
46    }
47}
48
49#[async_trait]
50impl Embedder for MockEmbedder {
51    async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
52        if let Some(error) = &self.forced_error {
53            return Err(error.clone());
54        }
55
56        if self.dimension == 0 {
57            return Err(FierrosError::InvalidInput(
58                "embedding dimension must be greater than zero".into(),
59            ));
60        }
61
62        Ok(inputs.iter().map(|input| self.embed_one(input)).collect())
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::{Embedder, MockEmbedder};
69    use crate::FierrosError;
70
71    #[tokio::test]
72    async fn mock_embedder_returns_expected_dimensions() {
73        let embedder = MockEmbedder::new(4);
74        let embeddings = embedder
75            .embed(&["hello".to_string(), "world".to_string()])
76            .await
77            .unwrap();
78        assert_eq!(embeddings.len(), 2);
79        assert_eq!(embeddings[0].len(), 4);
80    }
81
82    #[tokio::test]
83    async fn mock_embedder_rejects_zero_dimension() {
84        let embedder = MockEmbedder::new(0);
85        assert!(embedder.embed(&["x".to_string()]).await.is_err());
86    }
87
88    #[tokio::test]
89    async fn mock_embedder_can_return_configured_error() {
90        let embedder = MockEmbedder::failing(
91            4,
92            FierrosError::Provider("embedding endpoint timeout".into()),
93        );
94        let error = embedder.embed(&["x".to_string()]).await.unwrap_err();
95        assert_eq!(
96            error,
97            FierrosError::Provider("embedding endpoint timeout".into())
98        );
99    }
100}