autoagents_llm/embedding/
mod.rs1use crate::error::LLMError;
2use async_trait::async_trait;
3
4pub mod model_provider;
5pub use model_provider::EmbeddingBuilder;
6
7#[async_trait]
8pub trait EmbeddingProvider: Sync + Send {
9 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError>;
10}
11
12#[cfg(test)]
13mod tests {
14 use super::*;
15 use crate::error::LLMError;
16
17 struct MockEmbeddingProvider {
19 should_fail: bool,
20 dimension: usize,
21 }
22
23 impl MockEmbeddingProvider {
24 fn new(dimension: usize) -> Self {
25 Self {
26 should_fail: false,
27 dimension,
28 }
29 }
30
31 fn new_failing() -> Self {
32 Self {
33 should_fail: true,
34 dimension: 0,
35 }
36 }
37 }
38
39 #[async_trait::async_trait]
40 impl EmbeddingProvider for MockEmbeddingProvider {
41 async fn embed(&self, input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
42 if self.should_fail {
43 return Err(LLMError::ProviderError(
44 "Mock embedding failure".to_string(),
45 ));
46 }
47
48 let mut embeddings = Vec::new();
49 for (i, _text) in input.iter().enumerate() {
50 let mut embedding = Vec::new();
51 for j in 0..self.dimension {
52 embedding.push((i as f32 + j as f32) / 10.0);
53 }
54 embeddings.push(embedding);
55 }
56 Ok(embeddings)
57 }
58 }
59
60 #[tokio::test]
61 async fn test_embedding_provider_single_text() {
62 let provider = MockEmbeddingProvider::new(3);
63 let input = vec!["Hello world".to_string()];
64
65 let result = provider.embed(input).await;
66 assert!(result.is_ok());
67
68 let embeddings = result.unwrap();
69 assert_eq!(embeddings.len(), 1);
70 assert_eq!(embeddings[0].len(), 3);
71 assert_eq!(embeddings[0][0], 0.0);
72 assert_eq!(embeddings[0][1], 0.1);
73 assert_eq!(embeddings[0][2], 0.2);
74 }
75
76 #[tokio::test]
77 async fn test_embedding_provider_multiple_texts() {
78 let provider = MockEmbeddingProvider::new(2);
79 let input = vec![
80 "First text".to_string(),
81 "Second text".to_string(),
82 "Third text".to_string(),
83 ];
84
85 let result = provider.embed(input).await;
86 assert!(result.is_ok());
87
88 let embeddings = result.unwrap();
89 assert_eq!(embeddings.len(), 3);
90
91 for embedding in &embeddings {
93 assert_eq!(embedding.len(), 2);
94 }
95
96 assert_ne!(embeddings[0], embeddings[1]);
98 assert_ne!(embeddings[1], embeddings[2]);
99
100 assert_eq!(embeddings[0][0], 0.0);
102 assert_eq!(embeddings[0][1], 0.1);
103 assert_eq!(embeddings[1][0], 0.1);
104 assert_eq!(embeddings[1][1], 0.2);
105 }
106
107 #[tokio::test]
108 async fn test_embedding_provider_empty_input() {
109 let provider = MockEmbeddingProvider::new(5);
110 let input: Vec<String> = vec![];
111
112 let result = provider.embed(input).await;
113 assert!(result.is_ok());
114
115 let embeddings = result.unwrap();
116 assert!(embeddings.is_empty());
117 }
118
119 #[tokio::test]
120 async fn test_embedding_provider_failure() {
121 let provider = MockEmbeddingProvider::new_failing();
122 let input = vec!["Test text".to_string()];
123
124 let result = provider.embed(input).await;
125 assert!(result.is_err());
126
127 let error = result.unwrap_err();
128 assert!(error.to_string().contains("Mock embedding failure"));
129 }
130
131 #[tokio::test]
132 async fn test_embedding_provider_large_input() {
133 let provider = MockEmbeddingProvider::new(10);
134 let large_text = "x".repeat(10000);
135 let input = vec![large_text];
136
137 let result = provider.embed(input).await;
138 assert!(result.is_ok());
139
140 let embeddings = result.unwrap();
141 assert_eq!(embeddings.len(), 1);
142 assert_eq!(embeddings[0].len(), 10);
143 }
144
145 #[tokio::test]
146 async fn test_embedding_provider_unicode_text() {
147 let provider = MockEmbeddingProvider::new(3);
148 let input = vec![
149 "Hello δΈη".to_string(),
150 "π Earth".to_string(),
151 "ζ΅θ― test".to_string(),
152 ];
153
154 let result = provider.embed(input).await;
155 assert!(result.is_ok());
156
157 let embeddings = result.unwrap();
158 assert_eq!(embeddings.len(), 3);
159
160 for embedding in embeddings {
161 assert_eq!(embedding.len(), 3);
162 }
163 }
164
165 #[tokio::test]
166 async fn test_embedding_provider_special_characters() {
167 let provider = MockEmbeddingProvider::new(2);
168 let input = vec![
169 "Special chars: !@#$%^&*()".to_string(),
170 "Newlines\nand\ttabs".to_string(),
171 "\"Quotes\" and 'apostrophes'".to_string(),
172 ];
173
174 let result = provider.embed(input).await;
175 assert!(result.is_ok());
176
177 let embeddings = result.unwrap();
178 assert_eq!(embeddings.len(), 3);
179 }
180
181 #[tokio::test]
182 async fn test_embedding_provider_very_large_dimension() {
183 let provider = MockEmbeddingProvider::new(1000);
184 let input = vec!["Test".to_string()];
185
186 let result = provider.embed(input).await;
187 assert!(result.is_ok());
188
189 let embeddings = result.unwrap();
190 assert_eq!(embeddings.len(), 1);
191 assert_eq!(embeddings[0].len(), 1000);
192
193 for (i, value) in embeddings[0].iter().enumerate() {
195 assert_eq!(*value, i as f32 / 10.0);
196 }
197 }
198
199 #[tokio::test]
200 async fn test_embedding_provider_zero_dimension() {
201 let provider = MockEmbeddingProvider::new(0);
202 let input = vec!["Test".to_string()];
203
204 let result = provider.embed(input).await;
205 assert!(result.is_ok());
206
207 let embeddings = result.unwrap();
208 assert_eq!(embeddings.len(), 1);
209 assert_eq!(embeddings[0].len(), 0);
210 }
211
212 #[tokio::test]
213 async fn test_embedding_provider_mixed_content() {
214 let provider = MockEmbeddingProvider::new(4);
215 let input = vec![
216 "".to_string(), "Single word".to_string(), "This is a longer sentence with more words.".to_string(), "123 456 789".to_string(), ];
221
222 let result = provider.embed(input).await;
223 assert!(result.is_ok());
224
225 let embeddings = result.unwrap();
226 assert_eq!(embeddings.len(), 4);
227
228 for embedding in embeddings {
229 assert_eq!(embedding.len(), 4);
230 }
231 }
232
233 #[tokio::test]
234 async fn test_embedding_provider_consistency() {
235 let provider = MockEmbeddingProvider::new(3);
236 let input = vec!["Consistent test".to_string()];
237
238 let result1 = provider.embed(input.clone()).await.unwrap();
240 let result2 = provider.embed(input.clone()).await.unwrap();
241 let result3 = provider.embed(input).await.unwrap();
242
243 assert_eq!(result1, result2);
245 assert_eq!(result2, result3);
246 }
247
248 #[tokio::test]
249 async fn test_embedding_provider_batch_processing() {
250 let provider = MockEmbeddingProvider::new(2);
251 let batch_size = 100;
252 let input: Vec<String> = (0..batch_size)
253 .map(|i| format!("Text number {i}"))
254 .collect();
255
256 let result = provider.embed(input).await;
257 assert!(result.is_ok());
258
259 let embeddings = result.unwrap();
260 assert_eq!(embeddings.len(), batch_size);
261
262 for i in 0..batch_size - 1 {
264 assert_ne!(embeddings[i], embeddings[i + 1]);
265 }
266 }
267}