use crate::common::Usage;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
pub input: InputText,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
#[test]
fn test_embedding_serialize_embedding_request() {
let embedding_request = EmbeddingRequest {
model: Some("text-embedding-ada-002".to_string()),
input: "Hello, world!".into(),
encoding_format: None,
user: None,
};
let serialized = serde_json::to_string(&embedding_request).unwrap();
assert_eq!(
serialized,
r#"{"model":"text-embedding-ada-002","input":"Hello, world!"}"#
);
let embedding_request = EmbeddingRequest {
model: Some("text-embedding-ada-002".to_string()),
input: vec!["Hello, world!", "This is a test string"].into(),
encoding_format: None,
user: None,
};
let serialized = serde_json::to_string(&embedding_request).unwrap();
assert_eq!(
serialized,
r#"{"model":"text-embedding-ada-002","input":["Hello, world!","This is a test string"]}"#
);
}
#[test]
fn test_embedding_deserialize_embedding_request() {
let serialized = r#"{"model":"text-embedding-ada-002","input":"Hello, world!"}"#;
let embedding_request: EmbeddingRequest = serde_json::from_str(serialized).unwrap();
assert_eq!(
embedding_request.model,
Some("text-embedding-ada-002".to_string())
);
assert_eq!(embedding_request.input, InputText::from("Hello, world!"));
assert_eq!(embedding_request.encoding_format, None);
assert_eq!(embedding_request.user, None);
let serialized =
r#"{"model":"text-embedding-ada-002","input":["Hello, world!","This is a test string"]}"#;
let embedding_request: EmbeddingRequest = serde_json::from_str(serialized).unwrap();
assert_eq!(
embedding_request.model,
Some("text-embedding-ada-002".to_string())
);
assert_eq!(
embedding_request.input,
InputText::from(vec!["Hello, world!", "This is a test string"])
);
assert_eq!(embedding_request.encoding_format, None);
assert_eq!(embedding_request.user, None);
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(untagged)]
pub enum InputText {
String(String),
ArrayOfStrings(Vec<String>),
ArrayOfTokens(Vec<i64>),
ArrayOfTokenArrays(Vec<Vec<i64>>),
}
impl From<&str> for InputText {
fn from(s: &str) -> Self {
InputText::String(s.to_string())
}
}
impl From<&String> for InputText {
fn from(s: &String) -> Self {
InputText::String(s.to_string())
}
}
impl From<String> for InputText {
fn from(s: String) -> Self {
InputText::String(s)
}
}
impl From<&[String]> for InputText {
fn from(s: &[String]) -> Self {
InputText::ArrayOfStrings(s.to_vec())
}
}
impl From<Vec<&str>> for InputText {
fn from(s: Vec<&str>) -> Self {
InputText::ArrayOfStrings(s.iter().map(|s| s.to_string()).collect())
}
}
impl From<Vec<String>> for InputText {
fn from(s: Vec<String>) -> Self {
InputText::ArrayOfStrings(s)
}
}
impl From<&[i64]> for InputText {
fn from(s: &[i64]) -> Self {
InputText::ArrayOfTokens(s.to_vec())
}
}
impl From<Vec<i64>> for InputText {
fn from(s: Vec<i64>) -> Self {
InputText::ArrayOfTokens(s)
}
}
impl From<Vec<Vec<i64>>> for InputText {
fn from(s: Vec<Vec<i64>>) -> Self {
InputText::ArrayOfTokenArrays(s)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingsResponse {
pub object: String,
pub data: Vec<EmbeddingObject>,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EmbeddingObject {
pub index: u64,
pub object: String,
pub embedding: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunksRequest {
pub id: String,
pub filename: String,
pub chunk_capacity: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunksResponse {
pub id: String,
pub filename: String,
pub chunks: Vec<String>,
}