use crate::{
client::OpenRouterClient,
error::{OpenRouterError, Result},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
String(String),
StringArray(Vec<String>),
NumberArray(Vec<f64>),
NestedArray(Vec<Vec<f64>>),
ContentArray(Vec<EmbeddingContentItem>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingContentItem {
pub content: Vec<EmbeddingContentPart>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum EmbeddingContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: EmbeddingImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingImageUrl {
pub url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub input: EmbeddingInput,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EmbeddingEncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EmbeddingEncodingFormat {
Float,
Base64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingData {
FloatArray(Vec<f64>),
Base64String(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub object: String,
pub data: Vec<EmbeddingDataItem>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingDataItem {
pub object: String,
pub embedding: EmbeddingData,
pub index: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: f64,
pub total_tokens: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingModel {
pub id: String,
pub canonical_slug: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub hugging_face_id: Option<String>,
pub name: String,
pub created: f64,
pub description: String,
pub pricing: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_length: Option<f64>,
pub architecture: Value,
pub top_provider: Value,
pub per_request_limits: Value,
pub supported_parameters: Vec<String>,
pub default_parameters: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub expiration_date: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingModelsResponse {
pub data: Vec<EmbeddingModel>,
}
pub struct EmbeddingBuilder {
request: EmbeddingRequest,
}
impl EmbeddingBuilder {
pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
Self {
request: EmbeddingRequest {
input: EmbeddingInput::String(input.into()),
model: model.into(),
encoding_format: None,
dimensions: None,
user: None,
provider: None,
input_type: None,
},
}
}
pub fn new_with_array(model: impl Into<String>, inputs: Vec<String>) -> Self {
Self {
request: EmbeddingRequest {
input: EmbeddingInput::StringArray(inputs),
model: model.into(),
encoding_format: None,
dimensions: None,
user: None,
provider: None,
input_type: None,
},
}
}
pub fn encoding_format(mut self, format: EmbeddingEncodingFormat) -> Self {
self.request.encoding_format = Some(format);
self
}
pub fn dimensions(mut self, dims: u32) -> Self {
self.request.dimensions = Some(dims);
self
}
pub fn build(self) -> EmbeddingRequest {
self.request
}
}
impl OpenRouterClient {
pub async fn create_embedding(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
let url = format!("{}/embeddings", self.base_url);
let headers = self.build_headers()?;
let response = self
.client
.post(&url)
.headers(headers)
.json(&request)
.send()
.await
.map_err(OpenRouterError::HttpError)?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(OpenRouterError::ApiError {
code: status.as_u16(),
message: error_text,
});
}
let result = response
.json::<EmbeddingResponse>()
.await
.map_err(OpenRouterError::HttpError)?;
Ok(result)
}
pub async fn list_embedding_models(&self) -> Result<EmbeddingModelsResponse> {
let url = format!("{}/embeddings/models", self.base_url);
let headers = self.build_headers()?;
let response = self
.client
.get(&url)
.headers(headers)
.send()
.await
.map_err(OpenRouterError::HttpError)?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(OpenRouterError::ApiError {
code: status.as_u16(),
message: error_text,
});
}
let result = response
.json::<EmbeddingModelsResponse>()
.await
.map_err(OpenRouterError::HttpError)?;
Ok(result)
}
}