autoagents_llm/embedding/
mod.rs1use async_trait::async_trait;
2
3use crate::error::LLMError;
4
5#[async_trait]
6pub trait EmbeddingProvider {
7 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError>;
8}
9
10#[cfg(test)]
11mod tests {
12 use super::*;
13 use crate::error::LLMError;
14
15 struct MockEmbeddingProvider {
17 should_fail: bool,
18 dimension: usize,
19 }
20
21 impl MockEmbeddingProvider {
22 fn new(dimension: usize) -> Self {
23 Self {
24 should_fail: false,
25 dimension,
26 }
27 }
28
29 fn new_failing() -> Self {
30 Self {
31 should_fail: true,
32 dimension: 0,
33 }
34 }
35 }
36
37 #[async_trait::async_trait]
38 impl EmbeddingProvider for MockEmbeddingProvider {
39 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
40 if self.should_fail {
41 return Err(LLMError::ProviderError(
42 "Mock embedding failure".to_string(),
43 ));
44 }
45
46 let mut embeddings = Vec::new();
47 for (i, _text) in input.iter().enumerate() {
48 let mut embedding = Vec::new();
49 for j in 0..self.dimension {
50 embedding.push((i as f32 + j as f32) / 10.0);
51 }
52 embeddings.push(embedding);
53 }
54 Ok(embeddings)
55 }
56 }
57
58 #[tokio::test]
59 async fn test_embedding_provider_single_text() {
60 let provider = MockEmbeddingProvider::new(3);
61 let input = vec!["Hello world".to_string()];
62
63 let result = provider.embed(input).await;
64 assert!(result.is_ok());
65
66 let embeddings = result.unwrap();
67 assert_eq!(embeddings.len(), 1);
68 assert_eq!(embeddings[0].len(), 3);
69 assert_eq!(embeddings[0][0], 0.0);
70 assert_eq!(embeddings[0][1], 0.1);
71 assert_eq!(embeddings[0][2], 0.2);
72 }
73
74 #[tokio::test]
75 async fn test_embedding_provider_multiple_texts() {
76 let provider = MockEmbeddingProvider::new(2);
77 let input = vec![
78 "First text".to_string(),
79 "Second text".to_string(),
80 "Third text".to_string(),
81 ];
82
83 let result = provider.embed(input).await;
84 assert!(result.is_ok());
85
86 let embeddings = result.unwrap();
87 assert_eq!(embeddings.len(), 3);
88
89 for embedding in &embeddings {
91 assert_eq!(embedding.len(), 2);
92 }
93
94 assert_ne!(embeddings[0], embeddings[1]);
96 assert_ne!(embeddings[1], embeddings[2]);
97
98 assert_eq!(embeddings[0][0], 0.0);
100 assert_eq!(embeddings[0][1], 0.1);
101 assert_eq!(embeddings[1][0], 0.1);
102 assert_eq!(embeddings[1][1], 0.2);
103 }
104
105 #[tokio::test]
106 async fn test_embedding_provider_empty_input() {
107 let provider = MockEmbeddingProvider::new(5);
108 let input: Vec<String> = vec![];
109
110 let result = provider.embed(input).await;
111 assert!(result.is_ok());
112
113 let embeddings = result.unwrap();
114 assert!(embeddings.is_empty());
115 }
116
117 #[tokio::test]
118 async fn test_embedding_provider_failure() {
119 let provider = MockEmbeddingProvider::new_failing();
120 let input = vec!["Test text".to_string()];
121
122 let result = provider.embed(input).await;
123 assert!(result.is_err());
124
125 let error = result.unwrap_err();
126 assert!(error.to_string().contains("Mock embedding failure"));
127 }
128
129 #[tokio::test]
130 async fn test_embedding_provider_large_input() {
131 let provider = MockEmbeddingProvider::new(10);
132 let large_text = "x".repeat(10000);
133 let input = vec![large_text];
134
135 let result = provider.embed(input).await;
136 assert!(result.is_ok());
137
138 let embeddings = result.unwrap();
139 assert_eq!(embeddings.len(), 1);
140 assert_eq!(embeddings[0].len(), 10);
141 }
142
143 #[tokio::test]
144 async fn test_embedding_provider_unicode_text() {
145 let provider = MockEmbeddingProvider::new(3);
146 let input = vec![
147 "Hello δΈη".to_string(),
148 "π Earth".to_string(),
149 "ζ΅θ― test".to_string(),
150 ];
151
152 let result = provider.embed(input).await;
153 assert!(result.is_ok());
154
155 let embeddings = result.unwrap();
156 assert_eq!(embeddings.len(), 3);
157
158 for embedding in embeddings {
159 assert_eq!(embedding.len(), 3);
160 }
161 }
162
163 #[tokio::test]
164 async fn test_embedding_provider_special_characters() {
165 let provider = MockEmbeddingProvider::new(2);
166 let input = vec![
167 "Special chars: !@#$%^&*()".to_string(),
168 "Newlines\nand\ttabs".to_string(),
169 "\"Quotes\" and 'apostrophes'".to_string(),
170 ];
171
172 let result = provider.embed(input).await;
173 assert!(result.is_ok());
174
175 let embeddings = result.unwrap();
176 assert_eq!(embeddings.len(), 3);
177 }
178
179 #[tokio::test]
180 async fn test_embedding_provider_very_large_dimension() {
181 let provider = MockEmbeddingProvider::new(1000);
182 let input = vec!["Test".to_string()];
183
184 let result = provider.embed(input).await;
185 assert!(result.is_ok());
186
187 let embeddings = result.unwrap();
188 assert_eq!(embeddings.len(), 1);
189 assert_eq!(embeddings[0].len(), 1000);
190
191 for (i, value) in embeddings[0].iter().enumerate() {
193 assert_eq!(*value, i as f32 / 10.0);
194 }
195 }
196
197 #[tokio::test]
198 async fn test_embedding_provider_zero_dimension() {
199 let provider = MockEmbeddingProvider::new(0);
200 let input = vec!["Test".to_string()];
201
202 let result = provider.embed(input).await;
203 assert!(result.is_ok());
204
205 let embeddings = result.unwrap();
206 assert_eq!(embeddings.len(), 1);
207 assert_eq!(embeddings[0].len(), 0);
208 }
209
210 #[tokio::test]
211 async fn test_embedding_provider_mixed_content() {
212 let provider = MockEmbeddingProvider::new(4);
213 let input = vec![
214 "".to_string(), "Single word".to_string(), "This is a longer sentence with more words.".to_string(), "123 456 789".to_string(), ];
219
220 let result = provider.embed(input).await;
221 assert!(result.is_ok());
222
223 let embeddings = result.unwrap();
224 assert_eq!(embeddings.len(), 4);
225
226 for embedding in embeddings {
227 assert_eq!(embedding.len(), 4);
228 }
229 }
230
231 #[tokio::test]
232 async fn test_embedding_provider_consistency() {
233 let provider = MockEmbeddingProvider::new(3);
234 let input = vec!["Consistent test".to_string()];
235
236 let result1 = provider.embed(input.clone()).await.unwrap();
238 let result2 = provider.embed(input.clone()).await.unwrap();
239 let result3 = provider.embed(input).await.unwrap();
240
241 assert_eq!(result1, result2);
243 assert_eq!(result2, result3);
244 }
245
246 #[tokio::test]
247 async fn test_embedding_provider_batch_processing() {
248 let provider = MockEmbeddingProvider::new(2);
249 let batch_size = 100;
250 let input: Vec<String> = (0..batch_size)
251 .map(|i| format!("Text number {i}"))
252 .collect();
253
254 let result = provider.embed(input).await;
255 assert!(result.is_ok());
256
257 let embeddings = result.unwrap();
258 assert_eq!(embeddings.len(), batch_size);
259
260 for i in 0..batch_size - 1 {
262 assert_ne!(embeddings[i], embeddings[i + 1]);
263 }
264 }
265}