ai_sdk_core/embed/
batch.rs

1use super::EmbeddingConfig;
2use crate::error::EmbedError;
3use crate::retry::RetryPolicy;
4use ai_sdk_provider::{EmbedOptions, EmbeddingModel, EmbeddingUsage};
5use futures::stream::{self, StreamExt};
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9/// Builder for embedding multiple values
10pub struct EmbedManyBuilder<M, V, Val> {
11    model: M,
12    values: V,
13    config: EmbeddingConfig,
14    _marker: PhantomData<Val>,
15}
16
17// 1. Initial State
18impl<Val> EmbedManyBuilder<(), (), Val>
19where
20    Val: Send + Sync + Clone + 'static,
21{
22    /// Create a new embed_many builder
23    pub fn new() -> Self {
24        Self {
25            model: (),
26            values: (),
27            config: EmbeddingConfig::default(),
28            _marker: PhantomData,
29        }
30    }
31}
32
33impl<Val> Default for EmbedManyBuilder<(), (), Val>
34where
35    Val: Send + Sync + Clone + 'static,
36{
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42// 2. Configuration Setters (Available in ANY state)
43impl<M, V, Val> EmbedManyBuilder<M, V, Val> {
44    /// Set custom retry policy
45    pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
46        self.config.retry_policy = retry_policy;
47        self
48    }
49
50    /// Set maximum number of parallel API calls (default: unlimited)
51    pub fn max_parallel_calls(mut self, max: usize) -> Self {
52        self.config.max_parallel_calls = Some(max);
53        self
54    }
55}
56
57// 3. State Transition: Set Model
58impl<V, Val> EmbedManyBuilder<(), V, Val>
59where
60    Val: Send + Sync + Clone + 'static,
61{
62    /// Set the embedding model
63    pub fn model<Mod: EmbeddingModel<Val> + 'static>(
64        self,
65        model: Mod,
66    ) -> EmbedManyBuilder<Arc<dyn EmbeddingModel<Val>>, V, Val> {
67        EmbedManyBuilder {
68            model: Arc::new(model),
69            values: self.values,
70            config: self.config,
71            _marker: PhantomData,
72        }
73    }
74}
75
76// 4. State Transition: Set Values
77impl<M, Val> EmbedManyBuilder<M, (), Val> {
78    /// Set the values to embed
79    pub fn values(self, values: Vec<Val>) -> EmbedManyBuilder<M, Vec<Val>, Val> {
80        EmbedManyBuilder {
81            model: self.model,
82            values,
83            config: self.config,
84            _marker: PhantomData,
85        }
86    }
87}
88
89// 5. Execution Logic (Requires Model and Values)
90impl<Val> EmbedManyBuilder<Arc<dyn EmbeddingModel<Val>>, Vec<Val>, Val>
91where
92    Val: Send + Sync + Clone + 'static,
93{
94    /// Execute the embedding
95    pub async fn execute(self) -> std::result::Result<EmbedManyResult<Val>, EmbedError> {
96        let model = self.model;
97        let values = self.values;
98        let config = self.config;
99
100        if values.is_empty() {
101            return Ok(EmbedManyResult {
102                values: Vec::new(),
103                embeddings: Vec::new(),
104                total_usage: EmbeddingUsage { tokens: 0 },
105            });
106        }
107
108        // Get model capabilities
109        let max_embeddings_per_call = model.max_embeddings_per_call().await;
110        let supports_parallel = model.supports_parallel_calls().await;
111
112        // Determine batching strategy
113        let result = if let Some(max_per_call) = max_embeddings_per_call {
114            Self::embed_with_batching(model, values, config, max_per_call, supports_parallel)
115                .await?
116        } else {
117            Self::embed_single_call(model, values, config).await?
118        };
119
120        Ok(result)
121    }
122
123    /// Embed all values in a single call
124    async fn embed_single_call(
125        model: Arc<dyn EmbeddingModel<Val>>,
126        values: Vec<Val>,
127        config: EmbeddingConfig,
128    ) -> std::result::Result<EmbedManyResult<Val>, EmbedError> {
129        let response = config
130            .retry_policy
131            .retry(|| {
132                let options = EmbedOptions {
133                    values: values.clone(),
134                    provider_options: None,
135                    headers: None,
136                };
137                let model = model.clone();
138                async move { model.do_embed(options).await }
139            })
140            .await
141            .map_err(EmbedError::ProviderError)?;
142
143        Ok(EmbedManyResult {
144            values,
145            embeddings: response.embeddings,
146            total_usage: response.usage.unwrap_or(EmbeddingUsage { tokens: 0 }),
147        })
148    }
149
150    /// Embed with batching and optional parallel execution
151    async fn embed_with_batching(
152        model: Arc<dyn EmbeddingModel<Val>>,
153        values: Vec<Val>,
154        config: EmbeddingConfig,
155        max_per_call: usize,
156        supports_parallel: bool,
157    ) -> std::result::Result<EmbedManyResult<Val>, EmbedError> {
158        // Split into batches
159        let batches: Vec<Vec<Val>> = values
160            .chunks(max_per_call)
161            .map(|chunk| chunk.to_vec())
162            .collect();
163
164        let mut all_embeddings = Vec::new();
165        let mut total_usage = EmbeddingUsage { tokens: 0 };
166        let max_parallel_calls = config.max_parallel_calls.unwrap_or(usize::MAX);
167
168        if supports_parallel && max_parallel_calls > 1 {
169            // Parallel execution
170            let max_concurrent = max_parallel_calls.min(batches.len());
171
172            let results = stream::iter(batches)
173                .map(|batch| {
174                    let model = model.clone();
175                    let retry_policy = config.retry_policy.clone();
176                    async move {
177                        retry_policy
178                            .retry(|| {
179                                let options = EmbedOptions {
180                                    values: batch.clone(),
181                                    provider_options: None,
182                                    headers: None,
183                                };
184                                async { model.do_embed(options).await }
185                            })
186                            .await
187                    }
188                })
189                .buffer_unordered(max_concurrent)
190                .collect::<Vec<_>>()
191                .await;
192
193            // Aggregate results
194            for result in results {
195                let response = result.map_err(EmbedError::ProviderError)?;
196                all_embeddings.extend(response.embeddings);
197                if let Some(usage) = response.usage {
198                    total_usage.tokens += usage.tokens;
199                }
200            }
201        } else {
202            // Sequential execution
203            for batch in batches {
204                let response = config
205                    .retry_policy
206                    .retry(|| {
207                        let options = EmbedOptions {
208                            values: batch.clone(),
209                            provider_options: None,
210                            headers: None,
211                        };
212                        let model = model.clone();
213                        async move { model.do_embed(options).await }
214                    })
215                    .await
216                    .map_err(EmbedError::ProviderError)?;
217
218                all_embeddings.extend(response.embeddings);
219                if let Some(usage) = response.usage {
220                    total_usage.tokens += usage.tokens;
221                }
222            }
223        }
224
225        Ok(EmbedManyResult {
226            values,
227            embeddings: all_embeddings,
228            total_usage,
229        })
230    }
231}
232
233/// Result of embedding multiple values
234#[derive(Debug, Clone)]
235pub struct EmbedManyResult<Val> {
236    /// The original values that were embedded
237    pub values: Vec<Val>,
238    /// The embedding vectors (in the same order as values)
239    pub embeddings: Vec<Vec<f32>>,
240    /// Total token usage across all API calls
241    pub total_usage: EmbeddingUsage,
242}
243
244impl<Val> EmbedManyResult<Val> {
245    /// Get all embeddings
246    pub fn embeddings(&self) -> &[Vec<f32>] {
247        &self.embeddings
248    }
249
250    /// Get embedding at index
251    pub fn embedding(&self, index: usize) -> Option<&[f32]> {
252        self.embeddings.get(index).map(|e| e.as_slice())
253    }
254
255    /// Get original values
256    pub fn values(&self) -> &[Val] {
257        &self.values
258    }
259
260    /// Get total token usage
261    pub fn usage(&self) -> &EmbeddingUsage {
262        &self.total_usage
263    }
264
265    /// Iterate over (value, embedding) pairs
266    pub fn iter(&self) -> impl Iterator<Item = (&Val, &[f32])> {
267        self.values
268            .iter()
269            .zip(self.embeddings.iter().map(|e| e.as_slice()))
270    }
271}
272
273/// Entry point function for embedding multiple values
274pub fn embed_many<Val>() -> EmbedManyBuilder<(), (), Val>
275where
276    Val: Send + Sync + Clone + 'static,
277{
278    EmbedManyBuilder::new()
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use ai_sdk_provider::{EmbedResponse, Result};
285    use async_trait::async_trait;
286
287    struct DummyModel {
288        max_per_call: Option<usize>,
289        supports_parallel: bool,
290    }
291
292    #[async_trait]
293    impl EmbeddingModel<String> for DummyModel {
294        fn provider(&self) -> &str {
295            "test"
296        }
297        fn model_id(&self) -> &str {
298            "dummy"
299        }
300        async fn max_embeddings_per_call(&self) -> Option<usize> {
301            self.max_per_call
302        }
303        async fn supports_parallel_calls(&self) -> bool {
304            self.supports_parallel
305        }
306        async fn do_embed(&self, options: EmbedOptions<String>) -> Result<EmbedResponse> {
307            let embeddings = options
308                .values
309                .iter()
310                .enumerate()
311                .map(|(i, _)| vec![i as f32, (i + 1) as f32])
312                .collect();
313
314            Ok(EmbedResponse {
315                embeddings,
316                usage: Some(EmbeddingUsage {
317                    tokens: options.values.len() as u32,
318                }),
319                provider_metadata: None,
320                response: None,
321            })
322        }
323    }
324
325    #[test]
326    fn test_embed_many_builder_defaults() {
327        let builder = embed_many::<String>();
328        let _ = builder;
329        // config is internal now
330    }
331
332    #[tokio::test]
333    async fn test_embed_many_empty_values() {
334        let model = DummyModel {
335            max_per_call: Some(10),
336            supports_parallel: true,
337        };
338
339        let result = embed_many()
340            .model(model)
341            .values(Vec::<String>::new())
342            .execute()
343            .await
344            .unwrap();
345
346        assert_eq!(result.values().len(), 0);
347        assert_eq!(result.embeddings().len(), 0);
348        assert_eq!(result.usage().tokens, 0);
349    }
350
351    #[tokio::test]
352    async fn test_embed_many_single_call() {
353        let model = DummyModel {
354            max_per_call: None, // No limit - single call
355            supports_parallel: true,
356        };
357
358        let values = vec![
359            "text1".to_string(),
360            "text2".to_string(),
361            "text3".to_string(),
362        ];
363
364        let result = embed_many()
365            .model(model)
366            .values(values.clone())
367            .execute()
368            .await
369            .unwrap();
370
371        assert_eq!(result.values().len(), 3);
372        assert_eq!(result.embeddings().len(), 3);
373        assert_eq!(result.embedding(0).unwrap(), &[0.0, 1.0]);
374        assert_eq!(result.embedding(1).unwrap(), &[1.0, 2.0]);
375        assert_eq!(result.embedding(2).unwrap(), &[2.0, 3.0]);
376        assert_eq!(result.usage().tokens, 3);
377    }
378
379    #[tokio::test]
380    async fn test_embed_many_batched() {
381        let model = DummyModel {
382            max_per_call: Some(2), // 2 embeddings per call
383            supports_parallel: false,
384        };
385
386        let values = vec![
387            "text1".to_string(),
388            "text2".to_string(),
389            "text3".to_string(),
390            "text4".to_string(),
391            "text5".to_string(),
392        ];
393
394        let result = embed_many()
395            .model(model)
396            .values(values.clone())
397            .execute()
398            .await
399            .unwrap();
400
401        assert_eq!(result.values().len(), 5);
402        assert_eq!(result.embeddings().len(), 5);
403        // Total tokens: 2 + 2 + 1 = 5 (batches of 2, 2, 1)
404        assert_eq!(result.usage().tokens, 5);
405    }
406
407    #[tokio::test]
408    async fn test_embed_many_parallel() {
409        let model = DummyModel {
410            max_per_call: Some(2),
411            supports_parallel: true,
412        };
413
414        let values: Vec<String> = (0..10).map(|i| format!("text{}", i)).collect();
415
416        let result = embed_many()
417            .model(model)
418            .values(values.clone())
419            .max_parallel_calls(3)
420            .execute()
421            .await
422            .unwrap();
423
424        assert_eq!(result.values().len(), 10);
425        assert_eq!(result.embeddings().len(), 10);
426        assert_eq!(result.usage().tokens, 10);
427    }
428
429    #[tokio::test]
430    async fn test_embed_many_iter() {
431        let model = DummyModel {
432            max_per_call: None,
433            supports_parallel: true,
434        };
435
436        let values = vec!["a".to_string(), "b".to_string(), "c".to_string()];
437
438        let result = embed_many()
439            .model(model)
440            .values(values.clone())
441            .execute()
442            .await
443            .unwrap();
444
445        let pairs: Vec<(&String, &[f32])> = result.iter().collect();
446        assert_eq!(pairs.len(), 3);
447        assert_eq!(pairs[0].0, "a");
448        assert_eq!(pairs[0].1, &[0.0, 1.0]);
449        assert_eq!(pairs[1].0, "b");
450        assert_eq!(pairs[1].1, &[1.0, 2.0]);
451        assert_eq!(pairs[2].0, "c");
452        assert_eq!(pairs[2].1, &[2.0, 3.0]);
453    }
454}