1use derive_builder::Builder;
2use reqwest::Client as HttpClient;
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 api::models,
7 error::OpenRouterError,
8 transport::{request as transport_request, response as transport_response},
9 types::{ApiResponse, ProviderPreferences},
10};
11
12#[derive(Serialize, Deserialize, Debug, Clone)]
14#[non_exhaustive]
15#[serde(rename_all = "lowercase")]
16pub enum EmbeddingEncodingFormat {
17 Float,
18 Base64,
19}
20
21#[derive(Serialize, Deserialize, Debug, Clone)]
23#[non_exhaustive]
24pub struct EmbeddingImageUrl {
25 pub url: String,
26}
27
28impl EmbeddingImageUrl {
29 pub fn new(url: impl Into<String>) -> Self {
30 Self { url: url.into() }
31 }
32}
33
34#[derive(Serialize, Deserialize, Debug, Clone)]
36#[non_exhaustive]
37pub struct EmbeddingMultimodalMedia {
38 pub data: String,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub format: Option<String>,
41}
42
43impl EmbeddingMultimodalMedia {
44 pub fn new(data: impl Into<String>, format: Option<impl Into<String>>) -> Self {
45 Self {
46 data: data.into(),
47 format: format.map(Into::into),
48 }
49 }
50}
51
52#[derive(Serialize, Deserialize, Debug, Clone)]
54#[non_exhaustive]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum EmbeddingContentPart {
57 Text {
58 text: String,
59 },
60 ImageUrl {
61 image_url: EmbeddingImageUrl,
62 },
63 InputAudio {
64 input_audio: EmbeddingMultimodalMedia,
65 },
66 InputVideo {
67 input_video: EmbeddingMultimodalMedia,
68 },
69 InputFile {
70 input_file: EmbeddingMultimodalMedia,
71 },
72}
73
74impl EmbeddingContentPart {
75 pub fn text(text: impl Into<String>) -> Self {
76 Self::Text { text: text.into() }
77 }
78
79 pub fn image_url(url: impl Into<String>) -> Self {
80 Self::ImageUrl {
81 image_url: EmbeddingImageUrl::new(url),
82 }
83 }
84
85 pub fn input_audio(data: impl Into<String>, format: Option<impl Into<String>>) -> Self {
86 Self::InputAudio {
87 input_audio: EmbeddingMultimodalMedia::new(data, format),
88 }
89 }
90
91 pub fn input_video(data: impl Into<String>, format: Option<impl Into<String>>) -> Self {
92 Self::InputVideo {
93 input_video: EmbeddingMultimodalMedia::new(data, format),
94 }
95 }
96
97 pub fn input_file(data: impl Into<String>, format: Option<impl Into<String>>) -> Self {
98 Self::InputFile {
99 input_file: EmbeddingMultimodalMedia::new(data, format),
100 }
101 }
102}
103
104#[derive(Serialize, Deserialize, Debug, Clone)]
106#[non_exhaustive]
107pub struct EmbeddingMultimodalInput {
108 pub content: Vec<EmbeddingContentPart>,
109}
110
111impl EmbeddingMultimodalInput {
112 pub fn new(content: Vec<EmbeddingContentPart>) -> Self {
113 Self { content }
114 }
115}
116
117#[derive(Serialize, Deserialize, Debug, Clone)]
119#[non_exhaustive]
120#[serde(untagged)]
121pub enum EmbeddingInput {
122 Text(String),
123 TextArray(Vec<String>),
124 TokenArray(Vec<f64>),
125 TokenArrayBatch(Vec<Vec<f64>>),
126 MultimodalArray(Vec<EmbeddingMultimodalInput>),
127}
128
129impl From<String> for EmbeddingInput {
130 fn from(value: String) -> Self {
131 Self::Text(value)
132 }
133}
134
135impl From<&str> for EmbeddingInput {
136 fn from(value: &str) -> Self {
137 Self::Text(value.to_string())
138 }
139}
140
141impl From<Vec<String>> for EmbeddingInput {
142 fn from(value: Vec<String>) -> Self {
143 Self::TextArray(value)
144 }
145}
146
147impl From<Vec<f64>> for EmbeddingInput {
148 fn from(value: Vec<f64>) -> Self {
149 Self::TokenArray(value)
150 }
151}
152
153impl From<Vec<Vec<f64>>> for EmbeddingInput {
154 fn from(value: Vec<Vec<f64>>) -> Self {
155 Self::TokenArrayBatch(value)
156 }
157}
158
159impl From<Vec<EmbeddingMultimodalInput>> for EmbeddingInput {
160 fn from(value: Vec<EmbeddingMultimodalInput>) -> Self {
161 Self::MultimodalArray(value)
162 }
163}
164
165#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
167#[builder(build_fn(error = "OpenRouterError"))]
168#[non_exhaustive]
169pub struct EmbeddingRequest {
170 #[builder(setter(into))]
171 pub input: EmbeddingInput,
172
173 #[builder(setter(into))]
174 pub model: String,
175
176 #[builder(setter(strip_option), default)]
177 #[serde(skip_serializing_if = "Option::is_none")]
178 pub encoding_format: Option<EmbeddingEncodingFormat>,
179
180 #[builder(setter(strip_option), default)]
181 #[serde(skip_serializing_if = "Option::is_none")]
182 pub dimensions: Option<u32>,
183
184 #[builder(setter(into, strip_option), default)]
185 #[serde(skip_serializing_if = "Option::is_none")]
186 pub user: Option<String>,
187
188 #[builder(setter(strip_option), default)]
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub provider: Option<ProviderPreferences>,
191
192 #[builder(setter(into, strip_option), default)]
193 #[serde(skip_serializing_if = "Option::is_none")]
194 pub input_type: Option<String>,
195}
196
197impl EmbeddingRequest {
198 pub fn builder() -> EmbeddingRequestBuilder {
199 EmbeddingRequestBuilder::default()
200 }
201
202 pub fn new(model: impl Into<String>, input: impl Into<EmbeddingInput>) -> Self {
203 Self::builder()
204 .model(model.into())
205 .input(input.into())
206 .build()
207 .expect("Failed to build EmbeddingRequest")
208 }
209}
210
211#[derive(Serialize, Deserialize, Debug, Clone)]
213#[non_exhaustive]
214#[serde(untagged)]
215pub enum EmbeddingVector {
216 Float(Vec<f64>),
217 Base64(String),
218}
219
220#[derive(Serialize, Deserialize, Debug, Clone)]
222#[non_exhaustive]
223pub struct EmbeddingData {
224 pub object: String,
225 pub embedding: EmbeddingVector,
226 #[serde(skip_serializing_if = "Option::is_none")]
227 pub index: Option<u32>,
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
232#[non_exhaustive]
233pub struct EmbeddingPromptTokensDetails {
234 #[serde(skip_serializing_if = "Option::is_none")]
235 pub audio_tokens: Option<u32>,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 pub image_tokens: Option<u32>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 pub text_tokens: Option<u32>,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 pub video_tokens: Option<u32>,
242}
243
244#[derive(Serialize, Deserialize, Debug, Clone)]
246#[non_exhaustive]
247pub struct EmbeddingCostDetails {
248 pub upstream_inference_completions_cost: f64,
249 pub upstream_inference_prompt_cost: f64,
250 #[serde(default, skip_serializing_if = "Option::is_none")]
251 pub upstream_inference_cost: Option<f64>,
252}
253
254#[derive(Serialize, Deserialize, Debug, Clone)]
256#[non_exhaustive]
257pub struct EmbeddingUsage {
258 pub prompt_tokens: u32,
259 pub total_tokens: u32,
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub prompt_tokens_details: Option<EmbeddingPromptTokensDetails>,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 pub cost: Option<f64>,
264 #[serde(default, skip_serializing_if = "Option::is_none")]
265 pub cost_details: Option<EmbeddingCostDetails>,
266}
267
268#[derive(Serialize, Deserialize, Debug, Clone)]
270#[non_exhaustive]
271pub struct EmbeddingResponse {
272 #[serde(skip_serializing_if = "Option::is_none")]
273 pub id: Option<String>,
274 pub object: String,
275 pub data: Vec<EmbeddingData>,
276 pub model: String,
277 #[serde(skip_serializing_if = "Option::is_none")]
278 pub usage: Option<EmbeddingUsage>,
279}
280
281pub async fn create_embedding(
283 base_url: &str,
284 api_key: &str,
285 x_title: &Option<String>,
286 http_referer: &Option<String>,
287 app_categories: &Option<Vec<String>>,
288 request: &EmbeddingRequest,
289) -> Result<EmbeddingResponse, OpenRouterError> {
290 let http_client = crate::transport::new_client()?;
291 create_embedding_with_client(
292 &http_client,
293 base_url,
294 api_key,
295 x_title,
296 http_referer,
297 app_categories,
298 request,
299 )
300 .await
301}
302
303pub(crate) async fn create_embedding_with_client(
304 http_client: &HttpClient,
305 base_url: &str,
306 api_key: &str,
307 x_title: &Option<String>,
308 http_referer: &Option<String>,
309 app_categories: &Option<Vec<String>>,
310 request: &EmbeddingRequest,
311) -> Result<EmbeddingResponse, OpenRouterError> {
312 let url = format!("{base_url}/embeddings");
313
314 let response = transport_request::with_client_request_headers(
315 transport_request::post(http_client, &url),
316 api_key,
317 x_title,
318 http_referer,
319 app_categories,
320 )?
321 .json(request)
322 .send()
323 .await?;
324
325 if response.status().is_success() {
326 let embedding_response: EmbeddingResponse =
327 transport_response::parse_json_response(response, "embedding").await?;
328 Ok(embedding_response)
329 } else {
330 transport_response::handle_error(response).await?;
331 unreachable!()
332 }
333}
334
335pub async fn list_embedding_models(
337 base_url: &str,
338 api_key: &str,
339) -> Result<Vec<models::Model>, OpenRouterError> {
340 let http_client = crate::transport::new_client()?;
341 list_embedding_models_with_client(&http_client, base_url, api_key).await
342}
343
344pub(crate) async fn list_embedding_models_with_client(
345 http_client: &HttpClient,
346 base_url: &str,
347 api_key: &str,
348) -> Result<Vec<models::Model>, OpenRouterError> {
349 let url = format!("{base_url}/embeddings/models");
350
351 let response =
352 transport_request::with_bearer_auth(transport_request::get(http_client, &url), api_key)
353 .send()
354 .await?;
355
356 if response.status().is_success() {
357 let models_response: ApiResponse<Vec<models::Model>> =
358 transport_response::parse_json_response(response, "embedding models").await?;
359 Ok(models_response.data)
360 } else {
361 transport_response::handle_error(response).await?;
362 unreachable!()
363 }
364}