autoagents_llm/embedding/
mod.rs

1use async_trait::async_trait;
2
3use crate::error::LLMError;
4
5#[async_trait]
6pub trait EmbeddingProvider {
7    async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError>;
8}
9
10#[cfg(test)]
11mod tests {
12    use super::*;
13    use crate::error::LLMError;
14
15    // Mock embedding provider for testing
16    struct MockEmbeddingProvider {
17        should_fail: bool,
18        dimension: usize,
19    }
20
21    impl MockEmbeddingProvider {
22        fn new(dimension: usize) -> Self {
23            Self {
24                should_fail: false,
25                dimension,
26            }
27        }
28
29        fn new_failing() -> Self {
30            Self {
31                should_fail: true,
32                dimension: 0,
33            }
34        }
35    }
36
37    #[async_trait::async_trait]
38    impl EmbeddingProvider for MockEmbeddingProvider {
39        async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
40            if self.should_fail {
41                return Err(LLMError::ProviderError(
42                    "Mock embedding failure".to_string(),
43                ));
44            }
45
46            let mut embeddings = Vec::new();
47            for (i, _text) in input.iter().enumerate() {
48                let mut embedding = Vec::new();
49                for j in 0..self.dimension {
50                    embedding.push((i as f32 + j as f32) / 10.0);
51                }
52                embeddings.push(embedding);
53            }
54            Ok(embeddings)
55        }
56    }
57
58    #[tokio::test]
59    async fn test_embedding_provider_single_text() {
60        let provider = MockEmbeddingProvider::new(3);
61        let input = vec!["Hello world".to_string()];
62
63        let result = provider.embed(input).await;
64        assert!(result.is_ok());
65
66        let embeddings = result.unwrap();
67        assert_eq!(embeddings.len(), 1);
68        assert_eq!(embeddings[0].len(), 3);
69        assert_eq!(embeddings[0][0], 0.0);
70        assert_eq!(embeddings[0][1], 0.1);
71        assert_eq!(embeddings[0][2], 0.2);
72    }
73
74    #[tokio::test]
75    async fn test_embedding_provider_multiple_texts() {
76        let provider = MockEmbeddingProvider::new(2);
77        let input = vec![
78            "First text".to_string(),
79            "Second text".to_string(),
80            "Third text".to_string(),
81        ];
82
83        let result = provider.embed(input).await;
84        assert!(result.is_ok());
85
86        let embeddings = result.unwrap();
87        assert_eq!(embeddings.len(), 3);
88
89        // Check dimensions
90        for embedding in &embeddings {
91            assert_eq!(embedding.len(), 2);
92        }
93
94        // Check that each embedding is different
95        assert_ne!(embeddings[0], embeddings[1]);
96        assert_ne!(embeddings[1], embeddings[2]);
97
98        // Check specific values
99        assert_eq!(embeddings[0][0], 0.0);
100        assert_eq!(embeddings[0][1], 0.1);
101        assert_eq!(embeddings[1][0], 0.1);
102        assert_eq!(embeddings[1][1], 0.2);
103    }
104
105    #[tokio::test]
106    async fn test_embedding_provider_empty_input() {
107        let provider = MockEmbeddingProvider::new(5);
108        let input: Vec<String> = vec![];
109
110        let result = provider.embed(input).await;
111        assert!(result.is_ok());
112
113        let embeddings = result.unwrap();
114        assert!(embeddings.is_empty());
115    }
116
117    #[tokio::test]
118    async fn test_embedding_provider_failure() {
119        let provider = MockEmbeddingProvider::new_failing();
120        let input = vec!["Test text".to_string()];
121
122        let result = provider.embed(input).await;
123        assert!(result.is_err());
124
125        let error = result.unwrap_err();
126        assert!(error.to_string().contains("Mock embedding failure"));
127    }
128
129    #[tokio::test]
130    async fn test_embedding_provider_large_input() {
131        let provider = MockEmbeddingProvider::new(10);
132        let large_text = "x".repeat(10000);
133        let input = vec![large_text];
134
135        let result = provider.embed(input).await;
136        assert!(result.is_ok());
137
138        let embeddings = result.unwrap();
139        assert_eq!(embeddings.len(), 1);
140        assert_eq!(embeddings[0].len(), 10);
141    }
142
143    #[tokio::test]
144    async fn test_embedding_provider_unicode_text() {
145        let provider = MockEmbeddingProvider::new(3);
146        let input = vec![
147            "Hello δΈ–η•Œ".to_string(),
148            "🌍 Earth".to_string(),
149            "ζ΅‹θ―• test".to_string(),
150        ];
151
152        let result = provider.embed(input).await;
153        assert!(result.is_ok());
154
155        let embeddings = result.unwrap();
156        assert_eq!(embeddings.len(), 3);
157
158        for embedding in embeddings {
159            assert_eq!(embedding.len(), 3);
160        }
161    }
162
163    #[tokio::test]
164    async fn test_embedding_provider_special_characters() {
165        let provider = MockEmbeddingProvider::new(2);
166        let input = vec![
167            "Special chars: !@#$%^&*()".to_string(),
168            "Newlines\nand\ttabs".to_string(),
169            "\"Quotes\" and 'apostrophes'".to_string(),
170        ];
171
172        let result = provider.embed(input).await;
173        assert!(result.is_ok());
174
175        let embeddings = result.unwrap();
176        assert_eq!(embeddings.len(), 3);
177    }
178
179    #[tokio::test]
180    async fn test_embedding_provider_very_large_dimension() {
181        let provider = MockEmbeddingProvider::new(1000);
182        let input = vec!["Test".to_string()];
183
184        let result = provider.embed(input).await;
185        assert!(result.is_ok());
186
187        let embeddings = result.unwrap();
188        assert_eq!(embeddings.len(), 1);
189        assert_eq!(embeddings[0].len(), 1000);
190
191        // Check that values are within expected range
192        for (i, value) in embeddings[0].iter().enumerate() {
193            assert_eq!(*value, i as f32 / 10.0);
194        }
195    }
196
197    #[tokio::test]
198    async fn test_embedding_provider_zero_dimension() {
199        let provider = MockEmbeddingProvider::new(0);
200        let input = vec!["Test".to_string()];
201
202        let result = provider.embed(input).await;
203        assert!(result.is_ok());
204
205        let embeddings = result.unwrap();
206        assert_eq!(embeddings.len(), 1);
207        assert_eq!(embeddings[0].len(), 0);
208    }
209
210    #[tokio::test]
211    async fn test_embedding_provider_mixed_content() {
212        let provider = MockEmbeddingProvider::new(4);
213        let input = vec![
214            "".to_string(),                                           // Empty string
215            "Single word".to_string(),                                // Short text
216            "This is a longer sentence with more words.".to_string(), // Long text
217            "123 456 789".to_string(),                                // Numbers
218        ];
219
220        let result = provider.embed(input).await;
221        assert!(result.is_ok());
222
223        let embeddings = result.unwrap();
224        assert_eq!(embeddings.len(), 4);
225
226        for embedding in embeddings {
227            assert_eq!(embedding.len(), 4);
228        }
229    }
230
231    #[tokio::test]
232    async fn test_embedding_provider_consistency() {
233        let provider = MockEmbeddingProvider::new(3);
234        let input = vec!["Consistent test".to_string()];
235
236        // Run embedding multiple times
237        let result1 = provider.embed(input.clone()).await.unwrap();
238        let result2 = provider.embed(input.clone()).await.unwrap();
239        let result3 = provider.embed(input).await.unwrap();
240
241        // Results should be identical
242        assert_eq!(result1, result2);
243        assert_eq!(result2, result3);
244    }
245
246    #[tokio::test]
247    async fn test_embedding_provider_batch_processing() {
248        let provider = MockEmbeddingProvider::new(2);
249        let batch_size = 100;
250        let input: Vec<String> = (0..batch_size)
251            .map(|i| format!("Text number {i}"))
252            .collect();
253
254        let result = provider.embed(input).await;
255        assert!(result.is_ok());
256
257        let embeddings = result.unwrap();
258        assert_eq!(embeddings.len(), batch_size);
259
260        // Check that each embedding is different
261        for i in 0..batch_size - 1 {
262            assert_ne!(embeddings[i], embeddings[i + 1]);
263        }
264    }
265}