use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub top_p: Option<f32>,
#[serde(default)]
pub n: Option<u32>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub stop: Option<Vec<String>>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub presence_penalty: Option<f32>,
#[serde(default)]
pub frequency_penalty: Option<f32>,
#[serde(default)]
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChatChunkChoice>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatChunkChoice {
pub index: u32,
pub delta: ChatDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct ChatDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub top_p: Option<f32>,
#[serde(default)]
pub n: Option<u32>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub stop: Option<Vec<String>>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub logprobs: Option<u32>,
#[serde(default)]
pub echo: Option<bool>,
#[serde(default)]
pub suffix: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
pub finish_reason: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
}
#[derive(Debug, Clone, Serialize)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<f32>,
pub top_logprobs: Vec<std::collections::HashMap<String, f32>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingRequest {
pub model: String,
pub input: EmbeddingInput,
#[serde(default)]
pub encoding_format: Option<String>,
#[serde(default)]
pub dimensions: Option<u32>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
Single(String),
Multiple(Vec<String>),
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingData {
pub object: String,
pub index: u32,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelObject>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelObject {
pub id: String,
pub object: String,
pub created: i64,
pub owned_by: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl Usage {
pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_request_deserialization() {
let json = r#"{
"model": "gpt-4",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
],
"temperature": 0.7,
"max_tokens": 100
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.model, "gpt-4");
assert_eq!(req.messages.len(), 2);
assert_eq!(req.temperature, Some(0.7));
assert_eq!(req.max_tokens, Some(100));
}
#[test]
fn test_chat_response_serialization() {
let response = ChatCompletionResponse {
id: "chatcmpl-123".to_string(),
object: "chat.completion".to_string(),
created: 1677652288,
model: "gpt-4".to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: "Hello!".to_string(),
name: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage::new(10, 5),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("chatcmpl-123"));
assert!(json.contains("Hello!"));
}
#[test]
fn test_embedding_input_variants() {
let json = r#"{"model": "text-embedding-3-small", "input": "Hello"}"#;
let req: EmbeddingRequest = serde_json::from_str(json).unwrap();
matches!(req.input, EmbeddingInput::Single(_));
let json = r#"{"model": "text-embedding-3-small", "input": ["Hello", "World"]}"#;
let req: EmbeddingRequest = serde_json::from_str(json).unwrap();
matches!(req.input, EmbeddingInput::Multiple(_));
}
#[test]
fn test_usage() {
let usage = Usage::new(100, 50);
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
}