ai_sdk_provider/embedding_model/trait_def.rs
1use super::*;
2use crate::{Result, SharedHeaders, SharedProviderOptions};
3use async_trait::async_trait;
4use std::future::{Future, IntoFuture};
5use std::pin::Pin;
6
7/// A builder for constructing and executing embedding requests.
8///
9/// This type provides a fluent interface for configuring embedding generation requests
10/// before sending them to an embedding model. It allows you to set provider-specific
11/// options and custom headers before execution.
12///
13/// # Generics
14///
15/// - `M`: The embedding model implementation that will process the request
16/// - `VALUE`: The type of values to embed (e.g., `String` for text embeddings)
17///
18/// # Examples
19///
20/// ```ignore
21/// let model = /* ... */;
22/// let embeddings = model.embed(vec!["hello", "world"])
23/// .send()
24/// .await?;
25/// ```
26pub struct EmbedBuilder<'a, M: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync> {
27 model: &'a M,
28 options: EmbedOptions<VALUE>,
29}
30
31impl<'a, M: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync> EmbedBuilder<'a, M, VALUE> {
32 /// Creates a new embedding request builder with the provided values.
33 ///
34 /// # Arguments
35 ///
36 /// * `model` - A reference to the embedding model that will process the request
37 /// * `values` - The input values to be embedded, which can be any type implementing `Into<Vec<VALUE>>`
38 ///
39 /// # Returns
40 ///
41 /// A new `EmbedBuilder` instance configured with the provided values and default options.
42 pub fn new(model: &'a M, values: impl Into<Vec<VALUE>>) -> Self {
43 Self {
44 model,
45 options: EmbedOptions {
46 values: values.into(),
47 provider_options: None,
48 headers: None,
49 },
50 }
51 }
52
53 /// Sets provider-specific options for the embedding request.
54 ///
55 /// Provider-specific options allow you to configure model-specific parameters that are
56 /// not part of the standard embedding request interface. This is useful for passing
57 /// additional configuration to specialized embedding providers.
58 ///
59 /// # Arguments
60 ///
61 /// * `provider_options` - A set of provider-specific options to apply to the request
62 ///
63 /// # Returns
64 ///
65 /// This builder instance for method chaining.
66 pub fn provider_options(mut self, provider_options: SharedProviderOptions) -> Self {
67 self.options.provider_options = Some(provider_options);
68 self
69 }
70
71 /// Sets custom HTTP headers for the embedding request.
72 ///
73 /// Use this method to include custom HTTP headers in the request sent to the embedding
74 /// service. This is commonly used for authentication tokens, custom user agents, or
75 /// other provider-specific headers.
76 ///
77 /// # Arguments
78 ///
79 /// * `headers` - Custom HTTP headers to include in the request
80 ///
81 /// # Returns
82 ///
83 /// This builder instance for method chaining.
84 pub fn headers(mut self, headers: SharedHeaders) -> Self {
85 self.options.headers = Some(headers);
86 self
87 }
88
89 /// Sends the embedding request to the model and returns the generated embeddings.
90 ///
91 /// This method executes the configured embedding request asynchronously, sending the
92 /// input values to the embedding model for processing. The method respects all options
93 /// configured through the builder.
94 ///
95 /// # Returns
96 ///
97 /// A `Result` containing the `EmbedResponse` with the generated embeddings on success,
98 /// or an error if the request fails.
99 pub async fn send(self) -> Result<EmbedResponse> {
100 self.model.do_embed(self.options).await
101 }
102}
103
104impl<'a, M: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync + 'a> IntoFuture
105 for EmbedBuilder<'a, M, VALUE>
106{
107 type Output = Result<EmbedResponse>;
108 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
109
110 fn into_future(self) -> Self::IntoFuture {
111 Box::pin(async move { self.model.do_embed(self.options).await })
112 }
113}
114
115/// The core trait for embedding model implementations following the v3 specification.
116///
117/// This trait defines the interface that all embedding models must implement to generate
118/// vector representations of input values. It supports a generic `VALUE` type to enable
119/// future extensibility beyond text embeddings (e.g., image embeddings, audio embeddings).
120///
121/// # Generic Parameters
122///
123/// - `VALUE`: The type of values that this model can embed (e.g., `String` for text,
124/// potentially `Vec<u8>` for images in future versions). Must be `Send + Sync`.
125///
126/// # Design Notes
127///
128/// Implementations should be stateless or use `Arc`/`Box` for shared state management.
129/// The trait provides metadata about the model's capabilities and limits, allowing
130/// client code to optimize request batching and concurrency strategies.
131#[async_trait]
132pub trait EmbeddingModel<VALUE>: Send + Sync
133where
134 VALUE: Send + Sync,
135{
136 /// Returns the API specification version implemented by this model.
137 ///
138 /// This returns the version of the embedding model specification being used.
139 /// Currently always returns `"v3"`.
140 fn specification_version(&self) -> &str {
141 "v3"
142 }
143
144 /// Returns the provider identifier for this embedding model.
145 ///
146 /// This is typically the name of the service provider (e.g., `"openai"`, `"anthropic"`,
147 /// `"cohere"`). It identifies which external service provides the embedding functionality.
148 ///
149 /// # Examples
150 ///
151 /// - `"openai"` - For OpenAI's embedding models
152 /// - `"cohere"` - For Cohere's embedding models
153 fn provider(&self) -> &str;
154
155 /// Returns the provider-specific model identifier.
156 ///
157 /// This is the model identifier as used by the provider's API. Different providers
158 /// use different naming conventions for their models.
159 ///
160 /// # Examples
161 ///
162 /// - `"text-embedding-3-small"` - OpenAI's small embedding model
163 /// - `"text-embedding-3-large"` - OpenAI's large embedding model
164 /// - `"embed-english-v3.0"` - Cohere's English embedding model
165 fn model_id(&self) -> &str;
166
167 /// Returns the maximum number of embeddings that can be generated in a single API call.
168 ///
169 /// This defines the batch size limit for embedding requests to this model. When embedding
170 /// large numbers of values, clients should respect this limit by splitting requests into
171 /// appropriately sized batches.
172 ///
173 /// # Returns
174 ///
175 /// - `Some(n)` - The maximum number of embeddings per call is `n`
176 /// - `None` - The model has no documented batch size limit
177 async fn max_embeddings_per_call(&self) -> Option<usize>;
178
179 /// Indicates whether this model supports parallel embedding requests.
180 ///
181 /// When `true`, multiple embedding requests can be sent to the service concurrently
182 /// for improved throughput. When `false`, requests should be processed sequentially
183 /// to avoid rate limiting or service errors.
184 ///
185 /// # Returns
186 ///
187 /// `true` if the model can handle multiple concurrent embedding requests, `false` otherwise.
188 async fn supports_parallel_calls(&self) -> bool;
189
190 /// Creates a builder for constructing and executing an embedding request.
191 ///
192 /// This is the primary way to initiate an embedding request. It provides a fluent
193 /// interface for configuring the request before sending it to the model.
194 ///
195 /// # Arguments
196 ///
197 /// * `values` - The input values to be embedded, converted into a vector
198 ///
199 /// # Returns
200 ///
201 /// An `EmbedBuilder` that can be further configured with additional options before sending.
202 ///
203 /// # Example
204 ///
205 /// ```ignore
206 /// let response = model.embed(vec!["hello world"])
207 /// .headers(custom_headers)
208 /// .send()
209 /// .await?;
210 /// ```
211 fn embed(&self, values: impl Into<Vec<VALUE>>) -> EmbedBuilder<'_, Self, VALUE>
212 where
213 Self: Sized,
214 {
215 EmbedBuilder::new(self, values)
216 }
217
218 /// Performs the actual embedding generation for the provided values.
219 ///
220 /// This is the internal method that implements the core embedding functionality.
221 /// Clients should prefer using the `embed()` builder method instead of calling
222 /// this directly. The `do_` prefix indicates this is an internal implementation detail.
223 ///
224 /// # Arguments
225 ///
226 /// * `options` - The embedding request options, including values, headers, and provider options
227 ///
228 /// # Returns
229 ///
230 /// A `Result` containing the `EmbedResponse` with generated embeddings on success,
231 /// or an error if embedding generation fails.
232 ///
233 /// # Implementation Notes
234 ///
235 /// Implementations of this method should:
236 /// - Validate the input values according to model constraints
237 /// - Respect the `provider_options` and `headers` in the request options
238 /// - Return meaningful error messages for failures
239 /// - Handle API-specific response processing and error handling
240 async fn do_embed(&self, options: EmbedOptions<VALUE>) -> Result<EmbedResponse>;
241}
242
243#[async_trait]
244impl<T: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync + 'static> EmbeddingModel<VALUE>
245 for Box<T>
246{
247 fn specification_version(&self) -> &str {
248 (**self).specification_version()
249 }
250
251 fn provider(&self) -> &str {
252 (**self).provider()
253 }
254
255 fn model_id(&self) -> &str {
256 (**self).model_id()
257 }
258
259 async fn max_embeddings_per_call(&self) -> Option<usize> {
260 (**self).max_embeddings_per_call().await
261 }
262
263 async fn supports_parallel_calls(&self) -> bool {
264 (**self).supports_parallel_calls().await
265 }
266
267 async fn do_embed(&self, options: EmbedOptions<VALUE>) -> Result<EmbedResponse> {
268 (**self).do_embed(options).await
269 }
270}
271
272#[async_trait]
273impl<T: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync + 'static> EmbeddingModel<VALUE>
274 for std::sync::Arc<T>
275{
276 fn specification_version(&self) -> &str {
277 (**self).specification_version()
278 }
279
280 fn provider(&self) -> &str {
281 (**self).provider()
282 }
283
284 fn model_id(&self) -> &str {
285 (**self).model_id()
286 }
287
288 async fn max_embeddings_per_call(&self) -> Option<usize> {
289 (**self).max_embeddings_per_call().await
290 }
291
292 async fn supports_parallel_calls(&self) -> bool {
293 (**self).supports_parallel_calls().await
294 }
295
296 async fn do_embed(&self, options: EmbedOptions<VALUE>) -> Result<EmbedResponse> {
297 (**self).do_embed(options).await
298 }
299}
300
301#[async_trait]
302impl<T: EmbeddingModel<VALUE> + ?Sized, VALUE: Send + Sync + 'static> EmbeddingModel<VALUE> for &T {
303 fn specification_version(&self) -> &str {
304 (**self).specification_version()
305 }
306
307 fn provider(&self) -> &str {
308 (**self).provider()
309 }
310
311 fn model_id(&self) -> &str {
312 (**self).model_id()
313 }
314
315 async fn max_embeddings_per_call(&self) -> Option<usize> {
316 (**self).max_embeddings_per_call().await
317 }
318
319 async fn supports_parallel_calls(&self) -> bool {
320 (**self).supports_parallel_calls().await
321 }
322
323 async fn do_embed(&self, options: EmbedOptions<VALUE>) -> Result<EmbedResponse> {
324 (**self).do_embed(options).await
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 struct DummyEmbeddingModel;
333
334 #[async_trait]
335 impl EmbeddingModel<String> for DummyEmbeddingModel {
336 fn provider(&self) -> &str {
337 "test"
338 }
339
340 fn model_id(&self) -> &str {
341 "dummy"
342 }
343
344 async fn max_embeddings_per_call(&self) -> Option<usize> {
345 Some(100)
346 }
347
348 async fn supports_parallel_calls(&self) -> bool {
349 true
350 }
351
352 async fn do_embed(&self, _options: EmbedOptions<String>) -> Result<EmbedResponse> {
353 Ok(EmbedResponse {
354 embeddings: vec![vec![0.1, 0.2, 0.3]],
355 usage: Some(EmbeddingUsage { tokens: 10 }),
356 provider_metadata: None,
357 response: None,
358 })
359 }
360 }
361
362 #[tokio::test]
363 async fn test_embedding_model_trait() {
364 let model = DummyEmbeddingModel;
365 assert_eq!(model.provider(), "test");
366 assert_eq!(model.model_id(), "dummy");
367 assert_eq!(model.specification_version(), "v3");
368 assert_eq!(model.max_embeddings_per_call().await, Some(100));
369 assert!(model.supports_parallel_calls().await);
370
371 // Test builder
372 let res = model.embed(vec!["test".to_string()]).await;
373 assert!(res.is_ok());
374 }
375}