1use super::EmbeddingConfig;
2use crate::error::EmbedError;
3use crate::retry::RetryPolicy;
4use ai_sdk_provider::{EmbedOptions, EmbeddingModel, EmbeddingUsage};
5use futures::stream::{self, StreamExt};
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9pub struct EmbedManyBuilder<M, V, Val> {
11 model: M,
12 values: V,
13 config: EmbeddingConfig,
14 _marker: PhantomData<Val>,
15}
16
17impl<Val> EmbedManyBuilder<(), (), Val>
19where
20 Val: Send + Sync + Clone + 'static,
21{
22 pub fn new() -> Self {
24 Self {
25 model: (),
26 values: (),
27 config: EmbeddingConfig::default(),
28 _marker: PhantomData,
29 }
30 }
31}
32
33impl<Val> Default for EmbedManyBuilder<(), (), Val>
34where
35 Val: Send + Sync + Clone + 'static,
36{
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl<M, V, Val> EmbedManyBuilder<M, V, Val> {
44 pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
46 self.config.retry_policy = retry_policy;
47 self
48 }
49
50 pub fn max_parallel_calls(mut self, max: usize) -> Self {
52 self.config.max_parallel_calls = Some(max);
53 self
54 }
55}
56
57impl<V, Val> EmbedManyBuilder<(), V, Val>
59where
60 Val: Send + Sync + Clone + 'static,
61{
62 pub fn model<Mod: EmbeddingModel<Val> + 'static>(
64 self,
65 model: Mod,
66 ) -> EmbedManyBuilder<Arc<dyn EmbeddingModel<Val>>, V, Val> {
67 EmbedManyBuilder {
68 model: Arc::new(model),
69 values: self.values,
70 config: self.config,
71 _marker: PhantomData,
72 }
73 }
74}
75
76impl<M, Val> EmbedManyBuilder<M, (), Val> {
78 pub fn values(self, values: Vec<Val>) -> EmbedManyBuilder<M, Vec<Val>, Val> {
80 EmbedManyBuilder {
81 model: self.model,
82 values,
83 config: self.config,
84 _marker: PhantomData,
85 }
86 }
87}
88
89impl<Val> EmbedManyBuilder<Arc<dyn EmbeddingModel<Val>>, Vec<Val>, Val>
91where
92 Val: Send + Sync + Clone + 'static,
93{
94 pub async fn execute(self) -> std::result::Result<EmbedManyResult<Val>, EmbedError> {
96 let model = self.model;
97 let values = self.values;
98 let config = self.config;
99
100 if values.is_empty() {
101 return Ok(EmbedManyResult {
102 values: Vec::new(),
103 embeddings: Vec::new(),
104 total_usage: EmbeddingUsage { tokens: 0 },
105 });
106 }
107
108 let max_embeddings_per_call = model.max_embeddings_per_call().await;
110 let supports_parallel = model.supports_parallel_calls().await;
111
112 let result = if let Some(max_per_call) = max_embeddings_per_call {
114 Self::embed_with_batching(model, values, config, max_per_call, supports_parallel)
115 .await?
116 } else {
117 Self::embed_single_call(model, values, config).await?
118 };
119
120 Ok(result)
121 }
122
123 async fn embed_single_call(
125 model: Arc<dyn EmbeddingModel<Val>>,
126 values: Vec<Val>,
127 config: EmbeddingConfig,
128 ) -> std::result::Result<EmbedManyResult<Val>, EmbedError> {
129 let response = config
130 .retry_policy
131 .retry(|| {
132 let options = EmbedOptions {
133 values: values.clone(),
134 provider_options: None,
135 headers: None,
136 };
137 let model = model.clone();
138 async move { model.do_embed(options).await }
139 })
140 .await
141 .map_err(EmbedError::ProviderError)?;
142
143 Ok(EmbedManyResult {
144 values,
145 embeddings: response.embeddings,
146 total_usage: response.usage.unwrap_or(EmbeddingUsage { tokens: 0 }),
147 })
148 }
149
150 async fn embed_with_batching(
152 model: Arc<dyn EmbeddingModel<Val>>,
153 values: Vec<Val>,
154 config: EmbeddingConfig,
155 max_per_call: usize,
156 supports_parallel: bool,
157 ) -> std::result::Result<EmbedManyResult<Val>, EmbedError> {
158 let batches: Vec<Vec<Val>> = values
160 .chunks(max_per_call)
161 .map(|chunk| chunk.to_vec())
162 .collect();
163
164 let mut all_embeddings = Vec::new();
165 let mut total_usage = EmbeddingUsage { tokens: 0 };
166 let max_parallel_calls = config.max_parallel_calls.unwrap_or(usize::MAX);
167
168 if supports_parallel && max_parallel_calls > 1 {
169 let max_concurrent = max_parallel_calls.min(batches.len());
171
172 let results = stream::iter(batches)
173 .map(|batch| {
174 let model = model.clone();
175 let retry_policy = config.retry_policy.clone();
176 async move {
177 retry_policy
178 .retry(|| {
179 let options = EmbedOptions {
180 values: batch.clone(),
181 provider_options: None,
182 headers: None,
183 };
184 async { model.do_embed(options).await }
185 })
186 .await
187 }
188 })
189 .buffer_unordered(max_concurrent)
190 .collect::<Vec<_>>()
191 .await;
192
193 for result in results {
195 let response = result.map_err(EmbedError::ProviderError)?;
196 all_embeddings.extend(response.embeddings);
197 if let Some(usage) = response.usage {
198 total_usage.tokens += usage.tokens;
199 }
200 }
201 } else {
202 for batch in batches {
204 let response = config
205 .retry_policy
206 .retry(|| {
207 let options = EmbedOptions {
208 values: batch.clone(),
209 provider_options: None,
210 headers: None,
211 };
212 let model = model.clone();
213 async move { model.do_embed(options).await }
214 })
215 .await
216 .map_err(EmbedError::ProviderError)?;
217
218 all_embeddings.extend(response.embeddings);
219 if let Some(usage) = response.usage {
220 total_usage.tokens += usage.tokens;
221 }
222 }
223 }
224
225 Ok(EmbedManyResult {
226 values,
227 embeddings: all_embeddings,
228 total_usage,
229 })
230 }
231}
232
233#[derive(Debug, Clone)]
235pub struct EmbedManyResult<Val> {
236 pub values: Vec<Val>,
238 pub embeddings: Vec<Vec<f32>>,
240 pub total_usage: EmbeddingUsage,
242}
243
244impl<Val> EmbedManyResult<Val> {
245 pub fn embeddings(&self) -> &[Vec<f32>] {
247 &self.embeddings
248 }
249
250 pub fn embedding(&self, index: usize) -> Option<&[f32]> {
252 self.embeddings.get(index).map(|e| e.as_slice())
253 }
254
255 pub fn values(&self) -> &[Val] {
257 &self.values
258 }
259
260 pub fn usage(&self) -> &EmbeddingUsage {
262 &self.total_usage
263 }
264
265 pub fn iter(&self) -> impl Iterator<Item = (&Val, &[f32])> {
267 self.values
268 .iter()
269 .zip(self.embeddings.iter().map(|e| e.as_slice()))
270 }
271}
272
273pub fn embed_many<Val>() -> EmbedManyBuilder<(), (), Val>
275where
276 Val: Send + Sync + Clone + 'static,
277{
278 EmbedManyBuilder::new()
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use ai_sdk_provider::{EmbedResponse, Result};
285 use async_trait::async_trait;
286
287 struct DummyModel {
288 max_per_call: Option<usize>,
289 supports_parallel: bool,
290 }
291
292 #[async_trait]
293 impl EmbeddingModel<String> for DummyModel {
294 fn provider(&self) -> &str {
295 "test"
296 }
297 fn model_id(&self) -> &str {
298 "dummy"
299 }
300 async fn max_embeddings_per_call(&self) -> Option<usize> {
301 self.max_per_call
302 }
303 async fn supports_parallel_calls(&self) -> bool {
304 self.supports_parallel
305 }
306 async fn do_embed(&self, options: EmbedOptions<String>) -> Result<EmbedResponse> {
307 let embeddings = options
308 .values
309 .iter()
310 .enumerate()
311 .map(|(i, _)| vec![i as f32, (i + 1) as f32])
312 .collect();
313
314 Ok(EmbedResponse {
315 embeddings,
316 usage: Some(EmbeddingUsage {
317 tokens: options.values.len() as u32,
318 }),
319 provider_metadata: None,
320 response: None,
321 })
322 }
323 }
324
325 #[test]
326 fn test_embed_many_builder_defaults() {
327 let builder = embed_many::<String>();
328 let _ = builder;
329 }
331
332 #[tokio::test]
333 async fn test_embed_many_empty_values() {
334 let model = DummyModel {
335 max_per_call: Some(10),
336 supports_parallel: true,
337 };
338
339 let result = embed_many()
340 .model(model)
341 .values(Vec::<String>::new())
342 .execute()
343 .await
344 .unwrap();
345
346 assert_eq!(result.values().len(), 0);
347 assert_eq!(result.embeddings().len(), 0);
348 assert_eq!(result.usage().tokens, 0);
349 }
350
351 #[tokio::test]
352 async fn test_embed_many_single_call() {
353 let model = DummyModel {
354 max_per_call: None, supports_parallel: true,
356 };
357
358 let values = vec![
359 "text1".to_string(),
360 "text2".to_string(),
361 "text3".to_string(),
362 ];
363
364 let result = embed_many()
365 .model(model)
366 .values(values.clone())
367 .execute()
368 .await
369 .unwrap();
370
371 assert_eq!(result.values().len(), 3);
372 assert_eq!(result.embeddings().len(), 3);
373 assert_eq!(result.embedding(0).unwrap(), &[0.0, 1.0]);
374 assert_eq!(result.embedding(1).unwrap(), &[1.0, 2.0]);
375 assert_eq!(result.embedding(2).unwrap(), &[2.0, 3.0]);
376 assert_eq!(result.usage().tokens, 3);
377 }
378
379 #[tokio::test]
380 async fn test_embed_many_batched() {
381 let model = DummyModel {
382 max_per_call: Some(2), supports_parallel: false,
384 };
385
386 let values = vec![
387 "text1".to_string(),
388 "text2".to_string(),
389 "text3".to_string(),
390 "text4".to_string(),
391 "text5".to_string(),
392 ];
393
394 let result = embed_many()
395 .model(model)
396 .values(values.clone())
397 .execute()
398 .await
399 .unwrap();
400
401 assert_eq!(result.values().len(), 5);
402 assert_eq!(result.embeddings().len(), 5);
403 assert_eq!(result.usage().tokens, 5);
405 }
406
407 #[tokio::test]
408 async fn test_embed_many_parallel() {
409 let model = DummyModel {
410 max_per_call: Some(2),
411 supports_parallel: true,
412 };
413
414 let values: Vec<String> = (0..10).map(|i| format!("text{}", i)).collect();
415
416 let result = embed_many()
417 .model(model)
418 .values(values.clone())
419 .max_parallel_calls(3)
420 .execute()
421 .await
422 .unwrap();
423
424 assert_eq!(result.values().len(), 10);
425 assert_eq!(result.embeddings().len(), 10);
426 assert_eq!(result.usage().tokens, 10);
427 }
428
429 #[tokio::test]
430 async fn test_embed_many_iter() {
431 let model = DummyModel {
432 max_per_call: None,
433 supports_parallel: true,
434 };
435
436 let values = vec!["a".to_string(), "b".to_string(), "c".to_string()];
437
438 let result = embed_many()
439 .model(model)
440 .values(values.clone())
441 .execute()
442 .await
443 .unwrap();
444
445 let pairs: Vec<(&String, &[f32])> = result.iter().collect();
446 assert_eq!(pairs.len(), 3);
447 assert_eq!(pairs[0].0, "a");
448 assert_eq!(pairs[0].1, &[0.0, 1.0]);
449 assert_eq!(pairs[1].0, "b");
450 assert_eq!(pairs[1].1, &[1.0, 2.0]);
451 assert_eq!(pairs[2].0, "c");
452 assert_eq!(pairs[2].1, &[2.0, 3.0]);
453 }
454}