llmsdk_provider/embedding_model.rs
1//! Embedding model trait and supporting types.
2//!
3//! Mirrors `@ai-sdk/provider/src/embedding-model/v4/*`. Generic over the
4//! embedding input type because some providers accept binary inputs
5//! (image embeddings) in addition to text.
6// Rust guideline compliant 2026-02-21
7
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10
11use crate::error::Result;
12use crate::shared::{Headers, ProviderMetadata, ProviderOptions, RequestInfo, ResponseInfo};
13
14/// Contract every text-embedding model implements.
15///
16/// Mirrors `EmbeddingModelV4`. We pin the input type to `String` for now —
17/// audio / image embeddings will introduce a parallel trait when needed.
18#[async_trait]
19pub trait EmbeddingModel: Send + Sync + std::fmt::Debug {
20 /// Provider id, e.g. `"openai"`.
21 fn provider(&self) -> &str;
22
23 /// Provider-specific model id, e.g. `"text-embedding-3-small"`.
24 fn model_id(&self) -> &str;
25
26 /// Specification version (currently `"v4"`).
27 ///
28 /// Mirrors `EmbeddingModelV4.specificationVersion` (ai-sdk
29 /// `embedding-model-v4.ts`). Provider impls inherit the default.
30 fn specification_version(&self) -> &'static str {
31 "v4"
32 }
33
34 /// Maximum inputs the provider accepts per call.
35 ///
36 /// `None` means "no documented limit"; callers should still batch
37 /// conservatively. Defaults to `None`.
38 async fn max_embeddings_per_call(&self) -> Option<u32> {
39 None
40 }
41
42 /// Whether the model can handle multiple embed calls in parallel.
43 async fn supports_parallel_calls(&self) -> bool {
44 true
45 }
46
47 /// Embed a batch of inputs.
48 ///
49 /// # Errors
50 ///
51 /// Returns a [`crate::ProviderError`] when the upstream call fails or
52 /// the response is malformed.
53 async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult>;
54}
55
56/// Options for one [`EmbeddingModel::do_embed`] call.
57#[derive(Debug, Clone, Default, Serialize, Deserialize)]
58pub struct EmbedOptions {
59 /// Inputs to embed.
60 pub values: Vec<String>,
61 /// Extra HTTP headers (HTTP providers only).
62 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub headers: Option<Headers>,
64 /// Provider-specific options.
65 #[serde(
66 default,
67 rename = "providerOptions",
68 skip_serializing_if = "Option::is_none"
69 )]
70 pub provider_options: Option<ProviderOptions>,
71}
72
73/// One embedding vector.
74pub type Embedding = Vec<f32>;
75
76/// Result of [`EmbeddingModel::do_embed`].
77#[derive(Debug, Clone)]
78pub struct EmbedResult {
79 /// Embeddings in input order.
80 pub embeddings: Vec<Embedding>,
81 /// Token usage if reported.
82 pub usage: Option<EmbeddingUsage>,
83 /// Provider-specific metadata.
84 pub provider_metadata: Option<ProviderMetadata>,
85 /// Request info (telemetry).
86 pub request: Option<RequestInfo>,
87 /// Response info (telemetry).
88 pub response: Option<ResponseInfo>,
89}
90
91/// Token usage for an embedding call.
92#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
93pub struct EmbeddingUsage {
94 /// Tokens consumed.
95 pub tokens: Option<u64>,
96}