use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: Some(content.into()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: Some(content.into()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: MessageRole::Assistant,
content: Some(content.into()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<std::collections::HashMap<String, f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChatChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatChunkChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChunkChoice {
pub index: u32,
pub delta: ChatDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<MessageRole>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub tool_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_message_constructors() {
let system = ChatMessage::system("You are helpful.");
assert_eq!(system.role, MessageRole::System);
assert_eq!(system.content, Some("You are helpful.".to_string()));
let user = ChatMessage::user("Hello!");
assert_eq!(user.role, MessageRole::User);
let assistant = ChatMessage::assistant("Hi there!");
assert_eq!(assistant.role, MessageRole::Assistant);
}
#[test]
fn test_request_serialization() {
let request = ChatCompletionRequest {
model: "gpt-4o-mini".to_string(),
messages: vec![
ChatMessage::system("Be helpful"),
ChatMessage::user("Hello"),
],
temperature: Some(0.7),
top_p: None,
n: None,
stream: false,
stop: None,
max_tokens: Some(100),
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
tools: None,
tool_choice: None,
response_format: None,
seed: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("gpt-4o-mini"));
assert!(json.contains("system"));
assert!(json.contains("user"));
}
#[test]
fn test_response_deserialization() {
let json = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "Hello!"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}
}"#;
let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.id, "chatcmpl-123");
assert_eq!(response.choices[0].message.role, MessageRole::Assistant);
assert_eq!(response.usage.unwrap().total_tokens, 21);
}
}