use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use crate::{api_resources::TokenUsage, Client, Result};
#[skip_serializing_none]
#[derive(Builder, Debug, Default, Deserialize, Serialize)]
#[builder(default, setter(into, strip_option))]
pub struct EmbeddingParam {
model: String,
input: String,
user: Option<String>,
}
impl EmbeddingParamBuilder {
pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
Self {
model: Some(model.into()),
input: Some(input.into()),
..Self::default()
}
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct Embedding {
pub object: String,
pub data: Vec<EmbeddingData>,
pub mode: String,
pub usage: Option<TokenUsage>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Embeddings,
pub index: u64,
}
type Embeddings = Vec<f32>;
pub async fn create(client: &Client, param: &EmbeddingParam) -> Result<Embedding> {
client.create_embeddings(param).await
}
impl Client {
async fn create_embeddings(&self, param: &EmbeddingParam) -> Result<Embedding> {
self.post::<EmbeddingParam, Embedding>("embeddings", Some(param))
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_embedding() {
let param: EmbeddingParam = serde_json::from_str(
r#"
{
"model": "text-embedding-ada-002",
"input": "The food was delicious and the waiter..."
}
"#,
)
.unwrap();
let resp: Embedding = serde_json::from_str(
r#"
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [
0.0023064255,
-0.009327292,
-0.0028842222
],
"index": 0
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
"#,
)
.unwrap();
assert_eq!(param.model, "text-embedding-ada-002");
assert_eq!(param.user, None);
assert_eq!(resp.data.len(), 1);
assert_eq!(resp.data[0].embedding.len(), 3);
}
}