use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StopSequence>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(flatten)]
pub extra: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<MessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(flatten)]
pub extra: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum StopSequence {
Single(String),
Multiple(Vec<String>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Mode(String), Specific {
#[serde(rename = "type")]
tool_type: String,
function: ToolChoiceFunction,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: ChatMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: ChunkDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ChunkToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkToolCall {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub call_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<ChunkFunctionCall>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkFunctionCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deserialize_simple_request() {
let json = r#"{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello"}
]
}"#;
let request: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.model, "gpt-4");
assert_eq!(request.messages.len(), 1);
}
#[test]
fn test_deserialize_multimodal_content() {
let json = r#"{
"model": "gpt-4-vision",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
]
}
]
}"#;
let request: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.messages.len(), 1);
if let Some(MessageContent::Parts(parts)) = &request.messages[0].content {
assert_eq!(parts.len(), 2);
} else {
panic!("Expected parts content");
}
}
#[test]
fn test_deserialize_with_tools() {
let json = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "What's the weather?"}],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {"type": "object", "properties": {}}
}
}
]
}"#;
let request: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert!(request.tools.is_some());
assert_eq!(request.tools.unwrap().len(), 1);
}
#[test]
fn test_serialize_response() {
let response = ChatCompletionResponse {
id: "chatcmpl-123".to_string(),
object: "chat.completion".to_string(),
created: 1234567890,
model: "gpt-4".to_string(),
choices: vec![Choice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: Some(MessageContent::Text("Hello!".to_string())),
name: None,
tool_calls: None,
tool_call_id: None,
extra: None,
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Some(Usage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
prompt_tokens_details: None,
completion_tokens_details: None,
}),
system_fingerprint: None,
service_tier: None,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("chatcmpl-123"));
}
}