use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingRequest {
pub input: EmbeddingInput,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
}
impl EmbeddingRequest {
pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
Self {
input: EmbeddingInput::Single(input.into()),
model: model.into(),
encoding_format: None,
dimensions: None,
user: None,
metadata: None,
}
}
pub fn batch(model: impl Into<String>, inputs: Vec<String>) -> Self {
Self {
input: EmbeddingInput::Batch(inputs),
model: model.into(),
encoding_format: None,
dimensions: None,
user: None,
metadata: None,
}
}
pub fn encoding_format(mut self, format: impl Into<String>) -> Self {
self.encoding_format = Some(format.into());
self
}
pub fn dimensions(mut self, dims: u32) -> Self {
self.dimensions = Some(dims);
self
}
pub fn user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
pub fn metadata(
mut self,
metadata: std::collections::HashMap<String, serde_json::Value>,
) -> Self {
self.metadata = Some(metadata);
self
}
}
#[derive(Debug, Clone)]
pub enum EmbeddingInput {
Single(String),
Batch(Vec<String>),
}
impl Serialize for EmbeddingInput {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
EmbeddingInput::Single(s) => serializer.serialize_str(s),
EmbeddingInput::Batch(v) => v.serialize(serializer),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: EmbeddingUsage,
}
impl EmbeddingResponse {
pub fn embedding(&self) -> Option<&[f32]> {
self.data.first().map(|d| d.embedding.as_slice())
}
pub fn embeddings(&self) -> Vec<&[f32]> {
self.data.iter().map(|d| d.embedding.as_slice()).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f32>,
pub index: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_request_single() {
let req = EmbeddingRequest::new("text-embedding-ada-002", "Hello, world!");
assert_eq!(req.model, "text-embedding-ada-002");
match req.input {
EmbeddingInput::Single(s) => assert_eq!(s, "Hello, world!"),
_ => panic!("Expected single input"),
}
}
#[test]
fn test_embedding_request_batch() {
let req = EmbeddingRequest::batch(
"text-embedding-ada-002",
vec!["Hello".to_string(), "World".to_string()],
);
match req.input {
EmbeddingInput::Batch(v) => assert_eq!(v.len(), 2),
_ => panic!("Expected batch input"),
}
}
#[test]
fn test_embedding_request_builder() {
let req = EmbeddingRequest::new("model", "text")
.dimensions(1536)
.encoding_format("float")
.user("user-123");
assert_eq!(req.dimensions, Some(1536));
assert_eq!(req.encoding_format, Some("float".to_string()));
assert_eq!(req.user, Some("user-123".to_string()));
}
}