use core::fmt;
use reqwest::{Client as ReqwestClient, Method, RequestBuilder, Url};
use serde::{Deserialize, Serialize};
use crate::error::Error;
pub struct EmbeddingClient {
base_url: Url,
http_client: ReqwestClient,
}
impl EmbeddingClient {
pub fn new(base_url: Url, http_client: ReqwestClient) -> Self {
Self {
base_url,
http_client,
}
}
pub async fn create_embedding(
&self,
payload: CreateEmbedding,
) -> Result<CreateEmbeddingResponse, Error> {
let embedding = self
.request(Method::POST, "embeddings")?
.json(&payload)
.send()
.await?
.json::<CreateEmbeddingResponse>()
.await?;
Ok(embedding)
}
fn request(&self, method: Method, path: &str) -> Result<RequestBuilder, Error> {
let url = self
.base_url
.join(path)
.map_err(|err| Error::UrlParse(err.to_string()))?;
Ok(self.http_client.request(method, url))
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Embedding {
pub index: u64,
pub embedding: Vec<f64>,
pub object: EmbeddingKind,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingKind {
Embedding,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateEmbedding {
pub input: String,
pub model: EmbeddingModel,
pub encoding_format: EncodingFormat,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl CreateEmbedding {
pub fn new(input: impl Into<String>, model: EmbeddingModel) -> Self {
Self {
input: input.into(),
model,
encoding_format: EncodingFormat::Float,
dimensions: None,
user: None,
}
}
pub fn with_input(mut self, input: impl Into<String>) -> Self {
self.input = input.into();
self
}
pub fn with_model(mut self, model: EmbeddingModel) -> Self {
self.model = model;
self
}
pub fn with_encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
self.encoding_format = encoding_format;
self
}
pub fn with_dimensions(mut self, dimensions: u64) -> Self {
self.dimensions = Some(dimensions);
self
}
pub fn with_user(mut self, user: String) -> Self {
self.user = Some(user);
self
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum EmbeddingModel {
#[serde(rename = "text-embedding-3-large")]
TextEmbedding3Large,
#[serde(rename = "text-embedding-3-small")]
TextEmbedding3Small,
#[serde(rename = "text-embedding-ada-002")]
TextEmbeddingAda002,
}
impl fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
Self::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
Self::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EncodingFormat {
Float,
Base64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CreateEmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: EmbeddingModel,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingData {
pub index: u64,
pub object: EmbeddingKind,
pub embedding: Vec<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}