fierros_core/
embedding.rs1use crate::{FierrosError, FierrosResult};
2use async_trait::async_trait;
3
4#[async_trait]
5pub trait Embedder: Send + Sync {
6 async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>>;
7}
8
9#[derive(Debug, Clone)]
10pub struct MockEmbedder {
11 dimension: usize,
12 forced_error: Option<FierrosError>,
13}
14
15impl MockEmbedder {
16 pub fn new(dimension: usize) -> Self {
17 Self {
18 dimension,
19 forced_error: None,
20 }
21 }
22
23 pub fn failing(dimension: usize, error: FierrosError) -> Self {
24 Self {
25 dimension,
26 forced_error: Some(error),
27 }
28 }
29
30 pub fn with_error(mut self, error: FierrosError) -> Self {
31 self.forced_error = Some(error);
32 self
33 }
34
35 fn embed_one(&self, input: &str) -> Vec<f32> {
36 let mut out = vec![0.0; self.dimension];
37 if self.dimension == 0 {
38 return out;
39 }
40
41 for (idx, byte) in input.bytes().enumerate() {
42 out[idx % self.dimension] += byte as f32 / 255.0;
43 }
44
45 out
46 }
47}
48
49#[async_trait]
50impl Embedder for MockEmbedder {
51 async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
52 if let Some(error) = &self.forced_error {
53 return Err(error.clone());
54 }
55
56 if self.dimension == 0 {
57 return Err(FierrosError::InvalidInput(
58 "embedding dimension must be greater than zero".into(),
59 ));
60 }
61
62 Ok(inputs.iter().map(|input| self.embed_one(input)).collect())
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::{Embedder, MockEmbedder};
69 use crate::FierrosError;
70
71 #[tokio::test]
72 async fn mock_embedder_returns_expected_dimensions() {
73 let embedder = MockEmbedder::new(4);
74 let embeddings = embedder
75 .embed(&["hello".to_string(), "world".to_string()])
76 .await
77 .unwrap();
78 assert_eq!(embeddings.len(), 2);
79 assert_eq!(embeddings[0].len(), 4);
80 }
81
82 #[tokio::test]
83 async fn mock_embedder_rejects_zero_dimension() {
84 let embedder = MockEmbedder::new(0);
85 assert!(embedder.embed(&["x".to_string()]).await.is_err());
86 }
87
88 #[tokio::test]
89 async fn mock_embedder_can_return_configured_error() {
90 let embedder = MockEmbedder::failing(
91 4,
92 FierrosError::Provider("embedding endpoint timeout".into()),
93 );
94 let error = embedder.embed(&["x".to_string()]).await.unwrap_err();
95 assert_eq!(
96 error,
97 FierrosError::Provider("embedding endpoint timeout".into())
98 );
99 }
100}