use serde::{Deserialize, Serialize};
use crate::client::Client;
use crate::error::Result;
use crate::logoi::output::Usage;
pub struct Embeddings<'a> {
client: &'a Client,
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingRequest {
pub model: String,
pub input: EmbeddingInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
String(String),
Strings(Vec<String>),
TokenIds(Vec<i64>),
TokenIdGroups(Vec<Vec<i64>>),
}
impl From<String> for EmbeddingInput {
fn from(s: String) -> Self {
EmbeddingInput::String(s)
}
}
impl From<&str> for EmbeddingInput {
fn from(s: &str) -> Self {
EmbeddingInput::String(s.to_string())
}
}
impl From<Vec<String>> for EmbeddingInput {
fn from(v: Vec<String>) -> Self {
EmbeddingInput::Strings(v)
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<Embedding>,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Embedding {
pub object: String,
pub embedding: Vec<f32>,
pub index: u32,
}
pub struct EmbeddingRequestBuilder {
model: String,
input: Option<EmbeddingInput>,
encoding_format: Option<String>,
dimensions: Option<u32>,
user: Option<String>,
}
impl EmbeddingRequestBuilder {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
input: None,
encoding_format: None,
dimensions: None,
user: None,
}
}
pub fn input(mut self, input: impl Into<EmbeddingInput>) -> Self {
self.input = Some(input.into());
self
}
pub fn encoding_format(mut self, f: impl Into<String>) -> Self {
self.encoding_format = Some(f.into());
self
}
pub fn dimensions(mut self, d: u32) -> Self {
self.dimensions = Some(d);
self
}
pub fn user(mut self, u: impl Into<String>) -> Self {
self.user = Some(u.into());
self
}
pub fn build(self) -> EmbeddingRequest {
EmbeddingRequest {
model: self.model,
input: self.input.unwrap_or(EmbeddingInput::String(String::new())),
encoding_format: self.encoding_format,
dimensions: self.dimensions,
user: self.user,
}
}
}
impl<'a> Embeddings<'a> {
pub(crate) fn new(client: &'a Client) -> Self {
Self { client }
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "debug", skip_all, fields(endpoint = "embeddings", model = %req.model))
)]
pub async fn create(&self, req: EmbeddingRequest) -> Result<EmbeddingResponse> {
super::post_json(self.client, "/embeddings", &req).await
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "debug", skip_all, fields(endpoint = "embeddings.one"))
)]
pub async fn create_one(
&self,
text: impl Into<String>,
model: impl Into<String>,
) -> Result<Vec<f32>> {
let req = EmbeddingRequestBuilder::new(model)
.input(text.into())
.encoding_format("float")
.build();
let resp = self.create(req).await?;
resp.data
.into_iter()
.next()
.map(|e| e.embedding)
.ok_or_else(|| crate::error::OpenAiError::Stream("empty embeddings response".into()))
}
}