use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, PartialEq)]
pub struct EmbeddingCreateParams {
pub model: String,
pub input: EmbeddingInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
impl EmbeddingCreateParams {
pub fn new(model: impl Into<String>, input: impl Into<EmbeddingInput>) -> Self {
Self {
model: model.into(),
input: input.into(),
dimensions: None,
encoding_format: None,
user: None,
extra: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(untagged)]
pub enum EmbeddingInput {
Text(String),
Texts(Vec<String>),
Tokens(Vec<u32>),
TokenBatches(Vec<Vec<u32>>),
}
impl From<&str> for EmbeddingInput {
fn from(value: &str) -> Self {
Self::Text(value.to_string())
}
}
impl From<String> for EmbeddingInput {
fn from(value: String) -> Self {
Self::Text(value)
}
}
impl From<Vec<String>> for EmbeddingInput {
fn from(value: Vec<String>) -> Self {
Self::Texts(value)
}
}
impl From<Vec<u32>> for EmbeddingInput {
fn from(value: Vec<u32>) -> Self {
Self::Tokens(value)
}
}
impl From<Vec<Vec<u32>>> for EmbeddingInput {
fn from(value: Vec<Vec<u32>>) -> Self {
Self::TokenBatches(value)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum EncodingFormat {
Float,
Base64,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct CreateEmbeddingResponse {
#[serde(default)]
pub object: Option<String>,
pub data: Vec<Embedding>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub usage: Option<EmbeddingUsage>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct Embedding {
#[serde(default)]
pub object: Option<String>,
pub index: u32,
pub embedding: EmbeddingVector,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum EmbeddingVector {
Float(Vec<f64>),
Base64(String),
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}