Skip to main content

openrouter_rs/api/
embeddings.rs

1use derive_builder::Builder;
2use reqwest::Client as HttpClient;
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    api::models,
7    error::OpenRouterError,
8    transport::{request as transport_request, response as transport_response},
9    types::{ApiResponse, ProviderPreferences},
10};
11
12/// Supported embedding encoding formats.
13#[derive(Serialize, Deserialize, Debug, Clone)]
14#[non_exhaustive]
15#[serde(rename_all = "lowercase")]
16pub enum EmbeddingEncodingFormat {
17    Float,
18    Base64,
19}
20
21/// Image URL content for multimodal embedding input.
22#[derive(Serialize, Deserialize, Debug, Clone)]
23#[non_exhaustive]
24pub struct EmbeddingImageUrl {
25    pub url: String,
26}
27
28impl EmbeddingImageUrl {
29    pub fn new(url: impl Into<String>) -> Self {
30        Self { url: url.into() }
31    }
32}
33
34/// Base64 or data-URL backed multimodal media for embedding content parts.
35#[derive(Serialize, Deserialize, Debug, Clone)]
36#[non_exhaustive]
37pub struct EmbeddingMultimodalMedia {
38    pub data: String,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub format: Option<String>,
41}
42
43impl EmbeddingMultimodalMedia {
44    pub fn new(data: impl Into<String>, format: Option<impl Into<String>>) -> Self {
45        Self {
46            data: data.into(),
47            format: format.map(Into::into),
48        }
49    }
50}
51
52/// One multimodal content part for embedding input.
53#[derive(Serialize, Deserialize, Debug, Clone)]
54#[non_exhaustive]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum EmbeddingContentPart {
57    Text {
58        text: String,
59    },
60    ImageUrl {
61        image_url: EmbeddingImageUrl,
62    },
63    InputAudio {
64        input_audio: EmbeddingMultimodalMedia,
65    },
66    InputVideo {
67        input_video: EmbeddingMultimodalMedia,
68    },
69    InputFile {
70        input_file: EmbeddingMultimodalMedia,
71    },
72}
73
74impl EmbeddingContentPart {
75    pub fn text(text: impl Into<String>) -> Self {
76        Self::Text { text: text.into() }
77    }
78
79    pub fn image_url(url: impl Into<String>) -> Self {
80        Self::ImageUrl {
81            image_url: EmbeddingImageUrl::new(url),
82        }
83    }
84
85    pub fn input_audio(data: impl Into<String>, format: Option<impl Into<String>>) -> Self {
86        Self::InputAudio {
87            input_audio: EmbeddingMultimodalMedia::new(data, format),
88        }
89    }
90
91    pub fn input_video(data: impl Into<String>, format: Option<impl Into<String>>) -> Self {
92        Self::InputVideo {
93            input_video: EmbeddingMultimodalMedia::new(data, format),
94        }
95    }
96
97    pub fn input_file(data: impl Into<String>, format: Option<impl Into<String>>) -> Self {
98        Self::InputFile {
99            input_file: EmbeddingMultimodalMedia::new(data, format),
100        }
101    }
102}
103
104/// One multimodal embedding input item.
105#[derive(Serialize, Deserialize, Debug, Clone)]
106#[non_exhaustive]
107pub struct EmbeddingMultimodalInput {
108    pub content: Vec<EmbeddingContentPart>,
109}
110
111impl EmbeddingMultimodalInput {
112    pub fn new(content: Vec<EmbeddingContentPart>) -> Self {
113        Self { content }
114    }
115}
116
117/// Embedding request input variants.
118#[derive(Serialize, Deserialize, Debug, Clone)]
119#[non_exhaustive]
120#[serde(untagged)]
121pub enum EmbeddingInput {
122    Text(String),
123    TextArray(Vec<String>),
124    TokenArray(Vec<f64>),
125    TokenArrayBatch(Vec<Vec<f64>>),
126    MultimodalArray(Vec<EmbeddingMultimodalInput>),
127}
128
129impl From<String> for EmbeddingInput {
130    fn from(value: String) -> Self {
131        Self::Text(value)
132    }
133}
134
135impl From<&str> for EmbeddingInput {
136    fn from(value: &str) -> Self {
137        Self::Text(value.to_string())
138    }
139}
140
141impl From<Vec<String>> for EmbeddingInput {
142    fn from(value: Vec<String>) -> Self {
143        Self::TextArray(value)
144    }
145}
146
147impl From<Vec<f64>> for EmbeddingInput {
148    fn from(value: Vec<f64>) -> Self {
149        Self::TokenArray(value)
150    }
151}
152
153impl From<Vec<Vec<f64>>> for EmbeddingInput {
154    fn from(value: Vec<Vec<f64>>) -> Self {
155        Self::TokenArrayBatch(value)
156    }
157}
158
159impl From<Vec<EmbeddingMultimodalInput>> for EmbeddingInput {
160    fn from(value: Vec<EmbeddingMultimodalInput>) -> Self {
161        Self::MultimodalArray(value)
162    }
163}
164
165/// Request body for `POST /embeddings`.
166#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
167#[builder(build_fn(error = "OpenRouterError"))]
168#[non_exhaustive]
169pub struct EmbeddingRequest {
170    #[builder(setter(into))]
171    pub input: EmbeddingInput,
172
173    #[builder(setter(into))]
174    pub model: String,
175
176    #[builder(setter(strip_option), default)]
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub encoding_format: Option<EmbeddingEncodingFormat>,
179
180    #[builder(setter(strip_option), default)]
181    #[serde(skip_serializing_if = "Option::is_none")]
182    pub dimensions: Option<u32>,
183
184    #[builder(setter(into, strip_option), default)]
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub user: Option<String>,
187
188    #[builder(setter(strip_option), default)]
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub provider: Option<ProviderPreferences>,
191
192    #[builder(setter(into, strip_option), default)]
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub input_type: Option<String>,
195}
196
197impl EmbeddingRequest {
198    pub fn builder() -> EmbeddingRequestBuilder {
199        EmbeddingRequestBuilder::default()
200    }
201
202    pub fn new(model: impl Into<String>, input: impl Into<EmbeddingInput>) -> Self {
203        Self::builder()
204            .model(model.into())
205            .input(input.into())
206            .build()
207            .expect("Failed to build EmbeddingRequest")
208    }
209}
210
211/// One embedding vector payload.
212#[derive(Serialize, Deserialize, Debug, Clone)]
213#[non_exhaustive]
214#[serde(untagged)]
215pub enum EmbeddingVector {
216    Float(Vec<f64>),
217    Base64(String),
218}
219
220/// One embedding item.
221#[derive(Serialize, Deserialize, Debug, Clone)]
222#[non_exhaustive]
223pub struct EmbeddingData {
224    pub object: String,
225    pub embedding: EmbeddingVector,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub index: Option<u32>,
228}
229
230/// Token breakdown details for embedding requests.
231#[derive(Serialize, Deserialize, Debug, Clone)]
232#[non_exhaustive]
233pub struct EmbeddingPromptTokensDetails {
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub audio_tokens: Option<u32>,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub image_tokens: Option<u32>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub text_tokens: Option<u32>,
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub video_tokens: Option<u32>,
242}
243
244/// Provider-level cost breakdown for embedding requests.
245#[derive(Serialize, Deserialize, Debug, Clone)]
246#[non_exhaustive]
247pub struct EmbeddingCostDetails {
248    pub upstream_inference_completions_cost: f64,
249    pub upstream_inference_prompt_cost: f64,
250    #[serde(default, skip_serializing_if = "Option::is_none")]
251    pub upstream_inference_cost: Option<f64>,
252}
253
254/// Token/cost usage for embedding request.
255#[derive(Serialize, Deserialize, Debug, Clone)]
256#[non_exhaustive]
257pub struct EmbeddingUsage {
258    pub prompt_tokens: u32,
259    pub total_tokens: u32,
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub prompt_tokens_details: Option<EmbeddingPromptTokensDetails>,
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub cost: Option<f64>,
264    #[serde(default, skip_serializing_if = "Option::is_none")]
265    pub cost_details: Option<EmbeddingCostDetails>,
266}
267
268/// Response body for `POST /embeddings`.
269#[derive(Serialize, Deserialize, Debug, Clone)]
270#[non_exhaustive]
271pub struct EmbeddingResponse {
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub id: Option<String>,
274    pub object: String,
275    pub data: Vec<EmbeddingData>,
276    pub model: String,
277    #[serde(skip_serializing_if = "Option::is_none")]
278    pub usage: Option<EmbeddingUsage>,
279}
280
281/// Submit an embedding request.
282pub async fn create_embedding(
283    base_url: &str,
284    api_key: &str,
285    x_title: &Option<String>,
286    http_referer: &Option<String>,
287    app_categories: &Option<Vec<String>>,
288    request: &EmbeddingRequest,
289) -> Result<EmbeddingResponse, OpenRouterError> {
290    let http_client = crate::transport::new_client()?;
291    create_embedding_with_client(
292        &http_client,
293        base_url,
294        api_key,
295        x_title,
296        http_referer,
297        app_categories,
298        request,
299    )
300    .await
301}
302
303pub(crate) async fn create_embedding_with_client(
304    http_client: &HttpClient,
305    base_url: &str,
306    api_key: &str,
307    x_title: &Option<String>,
308    http_referer: &Option<String>,
309    app_categories: &Option<Vec<String>>,
310    request: &EmbeddingRequest,
311) -> Result<EmbeddingResponse, OpenRouterError> {
312    let url = format!("{base_url}/embeddings");
313
314    let response = transport_request::with_client_request_headers(
315        transport_request::post(http_client, &url),
316        api_key,
317        x_title,
318        http_referer,
319        app_categories,
320    )?
321    .json(request)
322    .send()
323    .await?;
324
325    if response.status().is_success() {
326        let embedding_response: EmbeddingResponse =
327            transport_response::parse_json_response(response, "embedding").await?;
328        Ok(embedding_response)
329    } else {
330        transport_response::handle_error(response).await?;
331        unreachable!()
332    }
333}
334
335/// List all embedding models.
336pub async fn list_embedding_models(
337    base_url: &str,
338    api_key: &str,
339) -> Result<Vec<models::Model>, OpenRouterError> {
340    let http_client = crate::transport::new_client()?;
341    list_embedding_models_with_client(&http_client, base_url, api_key).await
342}
343
344pub(crate) async fn list_embedding_models_with_client(
345    http_client: &HttpClient,
346    base_url: &str,
347    api_key: &str,
348) -> Result<Vec<models::Model>, OpenRouterError> {
349    let url = format!("{base_url}/embeddings/models");
350
351    let response =
352        transport_request::with_bearer_auth(transport_request::get(http_client, &url), api_key)
353            .send()
354            .await?;
355
356    if response.status().is_success() {
357        let models_response: ApiResponse<Vec<models::Model>> =
358            transport_response::parse_json_response(response, "embedding models").await?;
359        Ok(models_response.data)
360    } else {
361        transport_response::handle_error(response).await?;
362        unreachable!()
363    }
364}