ai_sdk_provider/embedding_model/
trait_def.rs

1use super::*;
2use crate::{Result, SharedHeaders, SharedProviderOptions};
3use async_trait::async_trait;
4use std::future::{Future, IntoFuture};
5use std::pin::Pin;
6
7/// A builder for constructing and executing embedding requests.
8///
9/// This type provides a fluent interface for configuring embedding generation requests
10/// before sending them to an embedding model. It allows you to set provider-specific
11/// options and custom headers before execution.
12///
13/// # Generics
14///
15/// - `M`: The embedding model implementation that will process the request
16/// - `VALUE`: The type of values to embed (e.g., `String` for text embeddings)
17///
18/// # Examples
19///
20/// ```ignore
21/// let model = /* ... */;
22/// let embeddings = model.embed(vec!["hello", "world"])
23///     .send()
24///     .await?;
25/// ```
26pub struct EmbedBuilder<'a, M: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync> {
27    model: &'a M,
28    options: EmbedOptions<VALUE>,
29}
30
31impl<'a, M: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync> EmbedBuilder<'a, M, VALUE> {
32    /// Creates a new embedding request builder with the provided values.
33    ///
34    /// # Arguments
35    ///
36    /// * `model` - A reference to the embedding model that will process the request
37    /// * `values` - The input values to be embedded, which can be any type implementing `Into<Vec<VALUE>>`
38    ///
39    /// # Returns
40    ///
41    /// A new `EmbedBuilder` instance configured with the provided values and default options.
42    pub fn new(model: &'a M, values: impl Into<Vec<VALUE>>) -> Self {
43        Self {
44            model,
45            options: EmbedOptions {
46                values: values.into(),
47                provider_options: None,
48                headers: None,
49            },
50        }
51    }
52
53    /// Sets provider-specific options for the embedding request.
54    ///
55    /// Provider-specific options allow you to configure model-specific parameters that are
56    /// not part of the standard embedding request interface. This is useful for passing
57    /// additional configuration to specialized embedding providers.
58    ///
59    /// # Arguments
60    ///
61    /// * `provider_options` - A set of provider-specific options to apply to the request
62    ///
63    /// # Returns
64    ///
65    /// This builder instance for method chaining.
66    pub fn provider_options(mut self, provider_options: SharedProviderOptions) -> Self {
67        self.options.provider_options = Some(provider_options);
68        self
69    }
70
71    /// Sets custom HTTP headers for the embedding request.
72    ///
73    /// Use this method to include custom HTTP headers in the request sent to the embedding
74    /// service. This is commonly used for authentication tokens, custom user agents, or
75    /// other provider-specific headers.
76    ///
77    /// # Arguments
78    ///
79    /// * `headers` - Custom HTTP headers to include in the request
80    ///
81    /// # Returns
82    ///
83    /// This builder instance for method chaining.
84    pub fn headers(mut self, headers: SharedHeaders) -> Self {
85        self.options.headers = Some(headers);
86        self
87    }
88
89    /// Sends the embedding request to the model and returns the generated embeddings.
90    ///
91    /// This method executes the configured embedding request asynchronously, sending the
92    /// input values to the embedding model for processing. The method respects all options
93    /// configured through the builder.
94    ///
95    /// # Returns
96    ///
97    /// A `Result` containing the `EmbedResponse` with the generated embeddings on success,
98    /// or an error if the request fails.
99    pub async fn send(self) -> Result<EmbedResponse> {
100        self.model.do_embed(self.options).await
101    }
102}
103
104impl<'a, M: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync + 'a> IntoFuture
105    for EmbedBuilder<'a, M, VALUE>
106{
107    type Output = Result<EmbedResponse>;
108    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
109
110    fn into_future(self) -> Self::IntoFuture {
111        Box::pin(async move { self.model.do_embed(self.options).await })
112    }
113}
114
115/// The core trait for embedding model implementations following the v3 specification.
116///
117/// This trait defines the interface that all embedding models must implement to generate
118/// vector representations of input values. It supports a generic `VALUE` type to enable
119/// future extensibility beyond text embeddings (e.g., image embeddings, audio embeddings).
120///
121/// # Generic Parameters
122///
123/// - `VALUE`: The type of values that this model can embed (e.g., `String` for text,
124///   potentially `Vec<u8>` for images in future versions). Must be `Send + Sync`.
125///
126/// # Design Notes
127///
128/// Implementations should be stateless or use `Arc`/`Box` for shared state management.
129/// The trait provides metadata about the model's capabilities and limits, allowing
130/// client code to optimize request batching and concurrency strategies.
131#[async_trait]
132pub trait EmbeddingModel<VALUE>: Send + Sync
133where
134    VALUE: Send + Sync,
135{
136    /// Returns the API specification version implemented by this model.
137    ///
138    /// This returns the version of the embedding model specification being used.
139    /// Currently always returns `"v3"`.
140    fn specification_version(&self) -> &str {
141        "v3"
142    }
143
144    /// Returns the provider identifier for this embedding model.
145    ///
146    /// This is typically the name of the service provider (e.g., `"openai"`, `"anthropic"`,
147    /// `"cohere"`). It identifies which external service provides the embedding functionality.
148    ///
149    /// # Examples
150    ///
151    /// - `"openai"` - For OpenAI's embedding models
152    /// - `"cohere"` - For Cohere's embedding models
153    fn provider(&self) -> &str;
154
155    /// Returns the provider-specific model identifier.
156    ///
157    /// This is the model identifier as used by the provider's API. Different providers
158    /// use different naming conventions for their models.
159    ///
160    /// # Examples
161    ///
162    /// - `"text-embedding-3-small"` - OpenAI's small embedding model
163    /// - `"text-embedding-3-large"` - OpenAI's large embedding model
164    /// - `"embed-english-v3.0"` - Cohere's English embedding model
165    fn model_id(&self) -> &str;
166
167    /// Returns the maximum number of embeddings that can be generated in a single API call.
168    ///
169    /// This defines the batch size limit for embedding requests to this model. When embedding
170    /// large numbers of values, clients should respect this limit by splitting requests into
171    /// appropriately sized batches.
172    ///
173    /// # Returns
174    ///
175    /// - `Some(n)` - The maximum number of embeddings per call is `n`
176    /// - `None` - The model has no documented batch size limit
177    async fn max_embeddings_per_call(&self) -> Option<usize>;
178
179    /// Indicates whether this model supports parallel embedding requests.
180    ///
181    /// When `true`, multiple embedding requests can be sent to the service concurrently
182    /// for improved throughput. When `false`, requests should be processed sequentially
183    /// to avoid rate limiting or service errors.
184    ///
185    /// # Returns
186    ///
187    /// `true` if the model can handle multiple concurrent embedding requests, `false` otherwise.
188    async fn supports_parallel_calls(&self) -> bool;
189
190    /// Creates a builder for constructing and executing an embedding request.
191    ///
192    /// This is the primary way to initiate an embedding request. It provides a fluent
193    /// interface for configuring the request before sending it to the model.
194    ///
195    /// # Arguments
196    ///
197    /// * `values` - The input values to be embedded, converted into a vector
198    ///
199    /// # Returns
200    ///
201    /// An `EmbedBuilder` that can be further configured with additional options before sending.
202    ///
203    /// # Example
204    ///
205    /// ```ignore
206    /// let response = model.embed(vec!["hello world"])
207    ///     .headers(custom_headers)
208    ///     .send()
209    ///     .await?;
210    /// ```
211    fn embed(&self, values: impl Into<Vec<VALUE>>) -> EmbedBuilder<'_, Self, VALUE>
212    where
213        Self: Sized,
214    {
215        EmbedBuilder::new(self, values)
216    }
217
218    /// Performs the actual embedding generation for the provided values.
219    ///
220    /// This is the internal method that implements the core embedding functionality.
221    /// Clients should prefer using the `embed()` builder method instead of calling
222    /// this directly. The `do_` prefix indicates this is an internal implementation detail.
223    ///
224    /// # Arguments
225    ///
226    /// * `options` - The embedding request options, including values, headers, and provider options
227    ///
228    /// # Returns
229    ///
230    /// A `Result` containing the `EmbedResponse` with generated embeddings on success,
231    /// or an error if embedding generation fails.
232    ///
233    /// # Implementation Notes
234    ///
235    /// Implementations of this method should:
236    /// - Validate the input values according to model constraints
237    /// - Respect the `provider_options` and `headers` in the request options
238    /// - Return meaningful error messages for failures
239    /// - Handle API-specific response processing and error handling
240    async fn do_embed(&self, options: EmbedOptions<VALUE>) -> Result<EmbedResponse>;
241}
242
243#[async_trait]
244impl<T: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync + 'static> EmbeddingModel<VALUE>
245    for Box<T>
246{
247    fn specification_version(&self) -> &str {
248        (**self).specification_version()
249    }
250
251    fn provider(&self) -> &str {
252        (**self).provider()
253    }
254
255    fn model_id(&self) -> &str {
256        (**self).model_id()
257    }
258
259    async fn max_embeddings_per_call(&self) -> Option<usize> {
260        (**self).max_embeddings_per_call().await
261    }
262
263    async fn supports_parallel_calls(&self) -> bool {
264        (**self).supports_parallel_calls().await
265    }
266
267    async fn do_embed(&self, options: EmbedOptions<VALUE>) -> Result<EmbedResponse> {
268        (**self).do_embed(options).await
269    }
270}
271
272#[async_trait]
273impl<T: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync + 'static> EmbeddingModel<VALUE>
274    for std::sync::Arc<T>
275{
276    fn specification_version(&self) -> &str {
277        (**self).specification_version()
278    }
279
280    fn provider(&self) -> &str {
281        (**self).provider()
282    }
283
284    fn model_id(&self) -> &str {
285        (**self).model_id()
286    }
287
288    async fn max_embeddings_per_call(&self) -> Option<usize> {
289        (**self).max_embeddings_per_call().await
290    }
291
292    async fn supports_parallel_calls(&self) -> bool {
293        (**self).supports_parallel_calls().await
294    }
295
296    async fn do_embed(&self, options: EmbedOptions<VALUE>) -> Result<EmbedResponse> {
297        (**self).do_embed(options).await
298    }
299}
300
301#[async_trait]
302impl<T: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync + 'static> EmbeddingModel<VALUE> for &T {
303    fn specification_version(&self) -> &str {
304        (**self).specification_version()
305    }
306
307    fn provider(&self) -> &str {
308        (**self).provider()
309    }
310
311    fn model_id(&self) -> &str {
312        (**self).model_id()
313    }
314
315    async fn max_embeddings_per_call(&self) -> Option<usize> {
316        (**self).max_embeddings_per_call().await
317    }
318
319    async fn supports_parallel_calls(&self) -> bool {
320        (**self).supports_parallel_calls().await
321    }
322
323    async fn do_embed(&self, options: EmbedOptions<VALUE>) -> Result<EmbedResponse> {
324        (**self).do_embed(options).await
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    struct DummyEmbeddingModel;
333
334    #[async_trait]
335    impl EmbeddingModel<String> for DummyEmbeddingModel {
336        fn provider(&self) -> &str {
337            "test"
338        }
339
340        fn model_id(&self) -> &str {
341            "dummy"
342        }
343
344        async fn max_embeddings_per_call(&self) -> Option<usize> {
345            Some(100)
346        }
347
348        async fn supports_parallel_calls(&self) -> bool {
349            true
350        }
351
352        async fn do_embed(&self, _options: EmbedOptions<String>) -> Result<EmbedResponse> {
353            Ok(EmbedResponse {
354                embeddings: vec![vec![0.1, 0.2, 0.3]],
355                usage: Some(EmbeddingUsage { tokens: 10 }),
356                provider_metadata: None,
357                response: None,
358            })
359        }
360    }
361
362    #[tokio::test]
363    async fn test_embedding_model_trait() {
364        let model = DummyEmbeddingModel;
365        assert_eq!(model.provider(), "test");
366        assert_eq!(model.model_id(), "dummy");
367        assert_eq!(model.specification_version(), "v3");
368        assert_eq!(model.max_embeddings_per_call().await, Some(100));
369        assert!(model.supports_parallel_calls().await);
370
371        // Test builder
372        let res = model.embed(vec!["test".to_string()]).await;
373        assert!(res.is_ok());
374    }
375}