Skip to main content

llmsdk_provider/
embedding_model.rs

1//! Embedding model trait and supporting types.
2//!
3//! Mirrors `@ai-sdk/provider/src/embedding-model/v4/*`. Generic over the
4//! embedding input type because some providers accept binary inputs
5//! (image embeddings) in addition to text.
6// Rust guideline compliant 2026-02-21
7
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10
11use crate::error::Result;
12use crate::shared::{Headers, ProviderMetadata, ProviderOptions, RequestInfo, ResponseInfo};
13
14/// Contract every text-embedding model implements.
15///
16/// Mirrors `EmbeddingModelV4`. We pin the input type to `String` for now —
17/// audio / image embeddings will introduce a parallel trait when needed.
18#[async_trait]
19pub trait EmbeddingModel: Send + Sync + std::fmt::Debug {
20    /// Provider id, e.g. `"openai"`.
21    fn provider(&self) -> &str;
22
23    /// Provider-specific model id, e.g. `"text-embedding-3-small"`.
24    fn model_id(&self) -> &str;
25
26    /// Specification version (currently `"v4"`).
27    ///
28    /// Mirrors `EmbeddingModelV4.specificationVersion` (ai-sdk
29    /// `embedding-model-v4.ts`). Provider impls inherit the default.
30    fn specification_version(&self) -> &'static str {
31        "v4"
32    }
33
34    /// Maximum inputs the provider accepts per call.
35    ///
36    /// `None` means "no documented limit"; callers should still batch
37    /// conservatively. Defaults to `None`.
38    async fn max_embeddings_per_call(&self) -> Option<u32> {
39        None
40    }
41
42    /// Whether the model can handle multiple embed calls in parallel.
43    async fn supports_parallel_calls(&self) -> bool {
44        true
45    }
46
47    /// Embed a batch of inputs.
48    ///
49    /// # Errors
50    ///
51    /// Returns a [`crate::ProviderError`] when the upstream call fails or
52    /// the response is malformed.
53    async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult>;
54}
55
56/// Options for one [`EmbeddingModel::do_embed`] call.
57#[derive(Debug, Clone, Default, Serialize, Deserialize)]
58pub struct EmbedOptions {
59    /// Inputs to embed.
60    pub values: Vec<String>,
61    /// Extra HTTP headers (HTTP providers only).
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub headers: Option<Headers>,
64    /// Provider-specific options.
65    #[serde(
66        default,
67        rename = "providerOptions",
68        skip_serializing_if = "Option::is_none"
69    )]
70    pub provider_options: Option<ProviderOptions>,
71}
72
73/// One embedding vector.
74pub type Embedding = Vec<f32>;
75
76/// Result of [`EmbeddingModel::do_embed`].
77#[derive(Debug, Clone)]
78pub struct EmbedResult {
79    /// Embeddings in input order.
80    pub embeddings: Vec<Embedding>,
81    /// Token usage if reported.
82    pub usage: Option<EmbeddingUsage>,
83    /// Provider-specific metadata.
84    pub provider_metadata: Option<ProviderMetadata>,
85    /// Request info (telemetry).
86    pub request: Option<RequestInfo>,
87    /// Response info (telemetry).
88    pub response: Option<ResponseInfo>,
89}
90
91/// Token usage for an embedding call.
92#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
93pub struct EmbeddingUsage {
94    /// Tokens consumed.
95    pub tokens: Option<u64>,
96}