use crate::IntoRequest;
use derive_builder::Builder;
use reqwest::{Client, RequestBuilder};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Builder)]
#[builder(pattern = "mutable")]
pub struct EmbeddingRequest {
input: EmbeddingInput,
#[builder(default)]
model: EmbeddingModel,
#[builder(default, setter(strip_option))]
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<EmbeddingEncodingFormat>,
#[builder(default, setter(strip_option, into))]
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
String(String),
StringArray(Vec<String>),
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum EmbeddingModel {
#[default]
#[serde(rename = "text-embedding-ada-002")]
TextEmbeddingAda002,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum EmbeddingEncodingFormat {
#[default]
Float,
Base64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingData {
pub index: usize,
pub embedding: Vec<f32>,
pub object: String,
}
impl IntoRequest for EmbeddingRequest {
fn into_request(self, base_url: &str, client: Client) -> RequestBuilder {
let url = format!("{}/embeddings", base_url);
client.post(url).json(&self)
}
}
impl EmbeddingRequest {
pub fn new(input: impl Into<EmbeddingInput>) -> Self {
EmbeddingRequestBuilder::default()
.input(input.into())
.build()
.unwrap()
}
pub fn new_array(input: Vec<String>) -> Self {
EmbeddingRequestBuilder::default()
.input(input.into())
.build()
.unwrap()
}
}
impl From<String> for EmbeddingInput {
fn from(s: String) -> Self {
Self::String(s)
}
}
impl From<Vec<String>> for EmbeddingInput {
fn from(s: Vec<String>) -> Self {
Self::StringArray(s)
}
}
impl From<&[String]> for EmbeddingInput {
fn from(s: &[String]) -> Self {
Self::StringArray(s.to_vec())
}
}
impl From<&str> for EmbeddingInput {
fn from(s: &str) -> Self {
Self::String(s.to_owned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SDK;
use anyhow::Result;
#[tokio::test]
async fn string_embedding_should_work() -> Result<()> {
let req = EmbeddingRequest::new("The quick brown fox jumped over the lazy dog.");
let res = SDK.embedding(req).await?;
assert_eq!(res.data.len(), 1);
assert_eq!(res.object, "list");
assert_eq!(res.model, "text-embedding-ada-002-v2");
let data = &res.data[0];
assert_eq!(data.embedding.len(), 1536);
assert_eq!(data.index, 0);
assert_eq!(data.object, "embedding");
Ok(())
}
#[tokio::test]
async fn array_string_embedding_should_work() -> Result<()> {
let req = EmbeddingRequest::new_array(vec![
"The quick brown fox jumped over the lazy dog.".into(),
"我是谁?宇宙有没有尽头?".into(),
]);
let res = SDK.embedding(req).await?;
assert_eq!(res.data.len(), 2);
assert_eq!(res.object, "list");
assert_eq!(res.model, "text-embedding-ada-002-v2");
let data = &res.data[1];
assert_eq!(data.embedding.len(), 1536);
assert_eq!(data.index, 1);
assert_eq!(data.object, "embedding");
Ok(())
}
}