Skip to main content

autoagents_llm/embedding/
mod.rs

1use crate::error::LLMError;
2use async_trait::async_trait;
3
4pub mod model_provider;
5pub use model_provider::EmbeddingBuilder;
6
7#[async_trait]
8pub trait EmbeddingProvider: Sync + Send {
9    async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError>;
10}
11
12#[cfg(test)]
13mod tests {
14    use super::*;
15    use crate::error::LLMError;
16
17    // Mock embedding provider for testing
18    struct MockEmbeddingProvider {
19        should_fail: bool,
20        dimension: usize,
21    }
22
23    impl MockEmbeddingProvider {
24        fn new(dimension: usize) -> Self {
25            Self {
26                should_fail: false,
27                dimension,
28            }
29        }
30
31        fn new_failing() -> Self {
32            Self {
33                should_fail: true,
34                dimension: 0,
35            }
36        }
37    }
38
39    #[async_trait::async_trait]
40    impl EmbeddingProvider for MockEmbeddingProvider {
41        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
42            if self.should_fail {
43                return Err(LLMError::ProviderError(
44                    "Mock embedding failure".to_string(),
45                ));
46            }
47
48            let mut embeddings = Vec::new();
49            for (i, _text) in input.iter().enumerate() {
50                let mut embedding = Vec::new();
51                for j in 0..self.dimension {
52                    embedding.push((i as f32 + j as f32) / 10.0);
53                }
54                embeddings.push(embedding);
55            }
56            Ok(embeddings)
57        }
58    }
59
60    #[tokio::test]
61    async fn test_embedding_provider_single_text() {
62        let provider = MockEmbeddingProvider::new(3);
63        let input = vec!["Hello world".to_string()];
64
65        let result = provider.embed(input).await;
66        assert!(result.is_ok());
67
68        let embeddings = result.unwrap();
69        assert_eq!(embeddings.len(), 1);
70        assert_eq!(embeddings[0].len(), 3);
71        assert_eq!(embeddings[0][0], 0.0);
72        assert_eq!(embeddings[0][1], 0.1);
73        assert_eq!(embeddings[0][2], 0.2);
74    }
75
76    #[tokio::test]
77    async fn test_embedding_provider_multiple_texts() {
78        let provider = MockEmbeddingProvider::new(2);
79        let input = vec![
80            "First text".to_string(),
81            "Second text".to_string(),
82            "Third text".to_string(),
83        ];
84
85        let result = provider.embed(input).await;
86        assert!(result.is_ok());
87
88        let embeddings = result.unwrap();
89        assert_eq!(embeddings.len(), 3);
90
91        // Check dimensions
92        for embedding in &embeddings {
93            assert_eq!(embedding.len(), 2);
94        }
95
96        // Check that each embedding is different
97        assert_ne!(embeddings[0], embeddings[1]);
98        assert_ne!(embeddings[1], embeddings[2]);
99
100        // Check specific values
101        assert_eq!(embeddings[0][0], 0.0);
102        assert_eq!(embeddings[0][1], 0.1);
103        assert_eq!(embeddings[1][0], 0.1);
104        assert_eq!(embeddings[1][1], 0.2);
105    }
106
107    #[tokio::test]
108    async fn test_embedding_provider_empty_input() {
109        let provider = MockEmbeddingProvider::new(5);
110        let input: Vec<String> = vec![];
111
112        let result = provider.embed(input).await;
113        assert!(result.is_ok());
114
115        let embeddings = result.unwrap();
116        assert!(embeddings.is_empty());
117    }
118
119    #[tokio::test]
120    async fn test_embedding_provider_failure() {
121        let provider = MockEmbeddingProvider::new_failing();
122        let input = vec!["Test text".to_string()];
123
124        let result = provider.embed(input).await;
125        assert!(result.is_err());
126
127        let error = result.unwrap_err();
128        assert!(error.to_string().contains("Mock embedding failure"));
129    }
130
131    #[tokio::test]
132    async fn test_embedding_provider_large_input() {
133        let provider = MockEmbeddingProvider::new(10);
134        let large_text = "x".repeat(10000);
135        let input = vec![large_text];
136
137        let result = provider.embed(input).await;
138        assert!(result.is_ok());
139
140        let embeddings = result.unwrap();
141        assert_eq!(embeddings.len(), 1);
142        assert_eq!(embeddings[0].len(), 10);
143    }
144
145    #[tokio::test]
146    async fn test_embedding_provider_unicode_text() {
147        let provider = MockEmbeddingProvider::new(3);
148        let input = vec![
149            "Hello δΈ–η•Œ".to_string(),
150            "🌍 Earth".to_string(),
151            "ζ΅‹θ―• test".to_string(),
152        ];
153
154        let result = provider.embed(input).await;
155        assert!(result.is_ok());
156
157        let embeddings = result.unwrap();
158        assert_eq!(embeddings.len(), 3);
159
160        for embedding in embeddings {
161            assert_eq!(embedding.len(), 3);
162        }
163    }
164
165    #[tokio::test]
166    async fn test_embedding_provider_special_characters() {
167        let provider = MockEmbeddingProvider::new(2);
168        let input = vec![
169            "Special chars: !@#$%^&*()".to_string(),
170            "Newlines\nand\ttabs".to_string(),
171            "\"Quotes\" and 'apostrophes'".to_string(),
172        ];
173
174        let result = provider.embed(input).await;
175        assert!(result.is_ok());
176
177        let embeddings = result.unwrap();
178        assert_eq!(embeddings.len(), 3);
179    }
180
181    #[tokio::test]
182    async fn test_embedding_provider_very_large_dimension() {
183        let provider = MockEmbeddingProvider::new(1000);
184        let input = vec!["Test".to_string()];
185
186        let result = provider.embed(input).await;
187        assert!(result.is_ok());
188
189        let embeddings = result.unwrap();
190        assert_eq!(embeddings.len(), 1);
191        assert_eq!(embeddings[0].len(), 1000);
192
193        // Check that values are within expected range
194        for (i, value) in embeddings[0].iter().enumerate() {
195            assert_eq!(*value, i as f32 / 10.0);
196        }
197    }
198
199    #[tokio::test]
200    async fn test_embedding_provider_zero_dimension() {
201        let provider = MockEmbeddingProvider::new(0);
202        let input = vec!["Test".to_string()];
203
204        let result = provider.embed(input).await;
205        assert!(result.is_ok());
206
207        let embeddings = result.unwrap();
208        assert_eq!(embeddings.len(), 1);
209        assert_eq!(embeddings[0].len(), 0);
210    }
211
212    #[tokio::test]
213    async fn test_embedding_provider_mixed_content() {
214        let provider = MockEmbeddingProvider::new(4);
215        let input = vec![
216            "".to_string(),                                           // Empty string
217            "Single word".to_string(),                                // Short text
218            "This is a longer sentence with more words.".to_string(), // Long text
219            "123 456 789".to_string(),                                // Numbers
220        ];
221
222        let result = provider.embed(input).await;
223        assert!(result.is_ok());
224
225        let embeddings = result.unwrap();
226        assert_eq!(embeddings.len(), 4);
227
228        for embedding in embeddings {
229            assert_eq!(embedding.len(), 4);
230        }
231    }
232
233    #[tokio::test]
234    async fn test_embedding_provider_consistency() {
235        let provider = MockEmbeddingProvider::new(3);
236        let input = vec!["Consistent test".to_string()];
237
238        // Run embedding multiple times
239        let result1 = provider.embed(input.clone()).await.unwrap();
240        let result2 = provider.embed(input.clone()).await.unwrap();
241        let result3 = provider.embed(input).await.unwrap();
242
243        // Results should be identical
244        assert_eq!(result1, result2);
245        assert_eq!(result2, result3);
246    }
247
248    #[tokio::test]
249    async fn test_embedding_provider_batch_processing() {
250        let provider = MockEmbeddingProvider::new(2);
251        let batch_size = 100;
252        let input: Vec<String> = (0..batch_size)
253            .map(|i| format!("Text number {i}"))
254            .collect();
255
256        let result = provider.embed(input).await;
257        assert!(result.is_ok());
258
259        let embeddings = result.unwrap();
260        assert_eq!(embeddings.len(), batch_size);
261
262        // Check that each embedding is different
263        for i in 0..batch_size - 1 {
264            assert_ne!(embeddings[i], embeddings[i + 1]);
265        }
266    }
267}