use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
Single(String),
Batch(Vec<String>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EncodingFormat {
Float,
Base64,
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingRequest {
pub model: String,
pub input: EmbeddingInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider: Option<crate::models::provider_preferences::ProviderPreferences>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingData {
pub embedding: Vec<f64>,
pub index: usize,
#[serde(default)]
pub object: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Option<EmbeddingUsage>,
}
impl EmbeddingResponse {
pub fn first_embedding(&self) -> Option<&Vec<f64>> {
self.data.first().map(|d| &d.embedding)
}
pub fn embeddings(&self) -> Vec<&Vec<f64>> {
self.data.iter().map(|d| &d.embedding).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_request_single_serialization() {
let req = EmbeddingRequest {
model: "openai/text-embedding-3-small".to_string(),
input: EmbeddingInput::Single("Hello world".to_string()),
encoding_format: None,
provider: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["model"], "openai/text-embedding-3-small");
assert_eq!(json["input"], "Hello world");
}
#[test]
fn test_embedding_request_batch_serialization() {
let req = EmbeddingRequest {
model: "openai/text-embedding-3-small".to_string(),
input: EmbeddingInput::Batch(vec!["Hello".to_string(), "World".to_string()]),
encoding_format: Some(EncodingFormat::Float),
provider: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["input"].as_array().unwrap().len(), 2);
assert_eq!(json["encoding_format"], "float");
}
#[test]
fn test_embedding_response_deserialization() {
let json = r#"{
"object": "list",
"data": [
{
"embedding": [0.1, 0.2, 0.3, 0.4],
"index": 0,
"object": "embedding"
}
],
"model": "openai/text-embedding-3-small",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}"#;
let response: EmbeddingResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.object, "list");
assert_eq!(response.data.len(), 1);
assert_eq!(response.data[0].embedding, vec![0.1, 0.2, 0.3, 0.4]);
assert_eq!(response.data[0].index, 0);
assert_eq!(response.model, "openai/text-embedding-3-small");
assert_eq!(response.usage.as_ref().unwrap().prompt_tokens, 5);
}
#[test]
fn test_embedding_response_batch() {
let json = r#"{
"object": "list",
"data": [
{"embedding": [0.1, 0.2], "index": 0, "object": "embedding"},
{"embedding": [0.3, 0.4], "index": 1, "object": "embedding"}
],
"model": "openai/text-embedding-3-small",
"usage": {"prompt_tokens": 10, "total_tokens": 10}
}"#;
let response: EmbeddingResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.data.len(), 2);
assert_eq!(response.first_embedding().unwrap(), &vec![0.1, 0.2]);
assert_eq!(response.embeddings().len(), 2);
}
#[test]
fn test_embedding_input_single_deserialization() {
let json = r#""Hello world""#;
let input: EmbeddingInput = serde_json::from_str(json).unwrap();
match input {
EmbeddingInput::Single(s) => assert_eq!(s, "Hello world"),
_ => panic!("Expected Single variant"),
}
}
#[test]
fn test_embedding_input_batch_deserialization() {
let json = r#"["Hello", "World"]"#;
let input: EmbeddingInput = serde_json::from_str(json).unwrap();
match input {
EmbeddingInput::Batch(v) => assert_eq!(v, vec!["Hello", "World"]),
_ => panic!("Expected Batch variant"),
}
}
}