ai_sdk_core/embed/
single.rs

1use super::EmbeddingConfig;
2use crate::error::EmbedError;
3use crate::retry::RetryPolicy;
4use crate::Result;
5use ai_sdk_provider::{EmbedOptions, EmbeddingModel, EmbeddingUsage};
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9/// Builder for single value embedding
10pub struct EmbedBuilder<M, V, Val> {
11    model: M,
12    value: V,
13    config: EmbeddingConfig,
14    _marker: PhantomData<Val>,
15}
16
17// 1. Initial State
18impl<Val> EmbedBuilder<(), (), Val>
19where
20    Val: Send + Sync + Clone + 'static,
21{
22    /// Create a new embed builder
23    pub fn new() -> Self {
24        Self {
25            model: (),
26            value: (),
27            config: EmbeddingConfig::default(),
28            _marker: PhantomData,
29        }
30    }
31}
32
33impl<Val> Default for EmbedBuilder<(), (), Val>
34where
35    Val: Send + Sync + Clone + 'static,
36{
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42// 2. Configuration Setters (Available in ANY state)
43impl<M, V, Val> EmbedBuilder<M, V, Val> {
44    /// Set custom retry policy
45    pub fn retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
46        self.config.retry_policy = retry_policy;
47        self
48    }
49}
50
51// 3. State Transition: Set Model
52impl<V, Val> EmbedBuilder<(), V, Val>
53where
54    Val: Send + Sync + Clone + 'static,
55{
56    /// Set the embedding model
57    pub fn model<Mod: EmbeddingModel<Val> + 'static>(
58        self,
59        model: Mod,
60    ) -> EmbedBuilder<Arc<dyn EmbeddingModel<Val>>, V, Val> {
61        EmbedBuilder {
62            model: Arc::new(model),
63            value: self.value,
64            config: self.config,
65            _marker: PhantomData,
66        }
67    }
68}
69
70// 4. State Transition: Set Value
71impl<M, Val> EmbedBuilder<M, (), Val> {
72    /// Set the value to embed
73    pub fn value(self, value: Val) -> EmbedBuilder<M, Val, Val> {
74        EmbedBuilder {
75            model: self.model,
76            value,
77            config: self.config,
78            _marker: PhantomData,
79        }
80    }
81}
82
83// 5. Execution Logic (Requires Model and Value)
84impl<Val> EmbedBuilder<Arc<dyn EmbeddingModel<Val>>, Val, Val>
85where
86    Val: Send + Sync + Clone + 'static,
87{
88    /// Execute the embedding
89    pub async fn execute(self) -> Result<EmbedResult<Val>> {
90        let model = self.model;
91        let value = self.value;
92        let config = self.config;
93
94        // Call model with retry
95        let response = config
96            .retry_policy
97            .retry(|| {
98                let options = EmbedOptions {
99                    values: vec![value.clone()],
100                    provider_options: None,
101                    headers: None,
102                };
103                let model = model.clone();
104                async move { model.do_embed(options).await }
105            })
106            .await
107            .map_err(EmbedError::ProviderError)?;
108
109        // Extract first embedding
110        let embedding = response
111            .embeddings
112            .into_iter()
113            .next()
114            .ok_or(EmbedError::EmptyResponse)?;
115
116        Ok(EmbedResult {
117            value,
118            embedding,
119            usage: response.usage,
120        })
121    }
122}
123
124/// Result of embedding a single value
125#[derive(Debug, Clone)]
126pub struct EmbedResult<Val> {
127    /// The original value that was embedded
128    pub value: Val,
129    /// The embedding vector
130    pub embedding: Vec<f32>,
131    /// Token usage information
132    pub usage: Option<EmbeddingUsage>,
133}
134
135impl<Val> EmbedResult<Val> {
136    /// Get the embedding vector
137    pub fn embedding(&self) -> &[f32] {
138        &self.embedding
139    }
140
141    /// Get the original value
142    pub fn value(&self) -> &Val {
143        &self.value
144    }
145
146    /// Get token usage
147    pub fn usage(&self) -> Option<&EmbeddingUsage> {
148        self.usage.as_ref()
149    }
150}
151
152/// Entry point function for embedding a single value
153pub fn embed<Val>() -> EmbedBuilder<(), (), Val>
154where
155    Val: Send + Sync + Clone + 'static,
156{
157    EmbedBuilder::new()
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_embed_builder_defaults() {
166        let builder = embed::<String>();
167        let _ = builder; // prevent unused warning
168    }
169
170    // Note: Missing model/value tests are removed because they are compile-time errors now.
171
172    #[tokio::test]
173    async fn test_embed_success() {
174        use ai_sdk_provider::{EmbedResponse, Result};
175        use async_trait::async_trait;
176
177        struct DummyModel;
178        #[async_trait]
179        impl EmbeddingModel<String> for DummyModel {
180            fn provider(&self) -> &str {
181                "test"
182            }
183            fn model_id(&self) -> &str {
184                "dummy"
185            }
186            async fn max_embeddings_per_call(&self) -> Option<usize> {
187                Some(100)
188            }
189            async fn supports_parallel_calls(&self) -> bool {
190                true
191            }
192            async fn do_embed(&self, options: EmbedOptions<String>) -> Result<EmbedResponse> {
193                assert_eq!(options.values.len(), 1);
194                assert_eq!(options.values[0], "test value");
195                Ok(EmbedResponse {
196                    embeddings: vec![vec![0.1, 0.2, 0.3]],
197                    usage: Some(EmbeddingUsage { tokens: 10 }),
198                    provider_metadata: None,
199                    response: None,
200                    // ..Default::default() // If needed
201                })
202            }
203        }
204
205        let result = embed()
206            .model(DummyModel)
207            .value("test value".to_string())
208            .execute()
209            .await
210            .unwrap();
211
212        assert_eq!(result.value(), "test value");
213        assert_eq!(result.embedding().len(), 3);
214        assert_eq!(result.embedding(), &[0.1, 0.2, 0.3]);
215        assert_eq!(result.usage().unwrap().tokens, 10);
216    }
217}