use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use crate::error::OpenAIError;
#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
String(String),
StringArray(Vec<String>),
IntegerArray(Vec<u32>),
ArrayOfIntegerArray(Vec<Vec<u32>>),
}
#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EncodingFormat {
#[default]
Float,
Base64,
}
#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)]
#[builder(name = "CreateEmbeddingRequestArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct CreateEmbeddingRequest {
pub model: String,
pub input: EmbeddingInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct Embedding {
pub index: u32,
pub object: String,
pub embedding: Vec<f32>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct Base64EmbeddingVector(pub String);
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct Base64Embedding {
pub index: u32,
pub object: String,
pub embedding: Base64EmbeddingVector,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct CreateEmbeddingResponse {
pub object: String,
pub model: String,
pub data: Vec<Embedding>,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct CreateBase64EmbeddingResponse {
pub object: String,
pub model: String,
pub data: Vec<Base64Embedding>,
pub usage: EmbeddingUsage,
}