use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Serialize)]
pub(crate) struct Request<'a> {
pub model: &'a str,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool<'a>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat<'a>>,
}
#[derive(Debug, Serialize)]
pub(crate) struct Message {
pub role: &'static str,
pub content: Option<MessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub(crate) enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
pub(crate) enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Serialize)]
pub(crate) struct ImageUrl {
pub url: String,
}
#[derive(Debug, Serialize)]
pub(crate) struct ToolCallRequest {
pub id: String,
#[serde(rename = "type")]
pub call_type: &'static str,
pub function: FunctionCallRequest,
}
#[derive(Debug, Serialize)]
pub(crate) struct FunctionCallRequest {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize)]
pub(crate) struct Tool<'a> {
#[serde(rename = "type")]
pub tool_type: &'static str,
pub function: FunctionDef<'a>,
}
#[derive(Debug, Serialize)]
pub(crate) struct FunctionDef<'a> {
pub name: &'a str,
pub description: &'a str,
pub parameters: &'a Value,
}
#[derive(Debug, Serialize)]
pub(crate) struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Serialize)]
pub(crate) struct ResponseFormat<'a> {
#[serde(rename = "type")]
pub format_type: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<JsonSchemaFormat<'a>>,
}
#[derive(Debug, Serialize)]
pub(crate) struct JsonSchemaFormat<'a> {
pub name: &'static str,
pub schema: &'a Value,
pub strict: bool,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Response {
pub choices: Vec<Choice>,
pub model: String,
pub usage: Option<ResponseUsage>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct Choice {
pub message: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ResponseMessage {
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ToolCallResponse {
pub id: String,
pub function: FunctionCallResponse,
}
#[derive(Debug, Deserialize)]
pub(crate) struct FunctionCallResponse {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ResponseUsage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
#[serde(default)]
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct CompletionTokensDetails {
#[serde(default)]
pub reasoning_tokens: Option<u64>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ErrorDetail {
pub message: String,
}
#[derive(Debug, Deserialize)]
pub(crate) struct StreamChunk {
pub choices: Vec<StreamChoice>,
pub usage: Option<ResponseUsage>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct StreamChoice {
pub delta: StreamDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct StreamDelta {
pub content: Option<String>,
pub tool_calls: Option<Vec<StreamToolCall>>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct StreamToolCall {
pub index: u32,
pub id: Option<String>,
pub function: Option<StreamFunctionCall>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct StreamFunctionCall {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_serialization_minimal() {
let req = Request {
model: "gpt-4o",
messages: vec![Message {
role: "user",
content: Some(MessageContent::Text("Hello".into())),
tool_calls: None,
tool_call_id: None,
}],
temperature: None,
max_completion_tokens: None,
stream: None,
stream_options: None,
tools: None,
tool_choice: None,
response_format: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["model"], "gpt-4o");
assert!(json.get("temperature").is_none());
assert!(json.get("tools").is_none());
assert!(json.get("stream").is_none());
}
#[test]
fn test_request_with_tools() {
let schema = serde_json::json!({
"type": "object",
"properties": { "city": { "type": "string" } },
"required": ["city"]
});
let req = Request {
model: "gpt-4o",
messages: vec![],
temperature: Some(0.7),
max_completion_tokens: Some(1024),
stream: Some(true),
stream_options: Some(StreamOptions {
include_usage: true,
}),
tools: Some(vec![Tool {
tool_type: "function",
function: FunctionDef {
name: "get_weather",
description: "Get weather for a city",
parameters: &schema,
},
}]),
tool_choice: Some(serde_json::json!("auto")),
response_format: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["tools"][0]["type"], "function");
assert_eq!(json["tools"][0]["function"]["name"], "get_weather");
assert_eq!(json["stream_options"]["include_usage"], true);
}
#[test]
fn test_response_deserialization() {
let json = serde_json::json!({
"id": "chatcmpl-123",
"choices": [{
"message": {
"role": "assistant",
"content": "Hello!"
},
"finish_reason": "stop"
}],
"model": "gpt-4o-2024-08-06",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
});
let resp: Response = serde_json::from_value(json).unwrap();
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!"));
assert_eq!(resp.choices[0].finish_reason.as_deref(), Some("stop"));
assert_eq!(resp.usage.as_ref().unwrap().prompt_tokens, 10);
}
#[test]
fn test_response_with_tool_calls() {
let json = serde_json::json!({
"id": "chatcmpl-456",
"choices": [{
"message": {
"role": "assistant",
"content": null,
"tool_calls": [{
"id": "call_abc",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Tokyo\"}"
}
}]
},
"finish_reason": "tool_calls"
}],
"model": "gpt-4o",
"usage": {
"prompt_tokens": 50,
"completion_tokens": 20,
"total_tokens": 70
}
});
let resp: Response = serde_json::from_value(json).unwrap();
let tc = &resp.choices[0].message.tool_calls.as_ref().unwrap()[0];
assert_eq!(tc.id, "call_abc");
assert_eq!(tc.function.name, "get_weather");
}
#[test]
fn test_response_with_reasoning_tokens() {
let json = serde_json::json!({
"id": "chatcmpl-789",
"choices": [{
"message": { "role": "assistant", "content": "42" },
"finish_reason": "stop"
}],
"model": "o1-mini",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 100,
"total_tokens": 110,
"completion_tokens_details": {
"reasoning_tokens": 80
}
}
});
let resp: Response = serde_json::from_value(json).unwrap();
let usage = resp.usage.unwrap();
assert_eq!(
usage.completion_tokens_details.unwrap().reasoning_tokens,
Some(80)
);
}
#[test]
fn test_error_response_deserialization() {
let json = serde_json::json!({
"error": {
"message": "Invalid API key",
"type": "invalid_api_key",
"code": "invalid_api_key"
}
});
let err: ErrorResponse = serde_json::from_value(json).unwrap();
assert_eq!(err.error.message, "Invalid API key");
}
#[test]
fn test_stream_chunk_deserialization() {
let json = serde_json::json!({
"id": "chatcmpl-123",
"choices": [{
"delta": { "content": "Hello" },
"finish_reason": null
}]
});
let chunk: StreamChunk = serde_json::from_value(json).unwrap();
assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
assert!(chunk.choices[0].finish_reason.is_none());
}
#[test]
fn test_stream_tool_call_deserialization() {
let json = serde_json::json!({
"id": "chatcmpl-456",
"choices": [{
"delta": {
"tool_calls": [{
"index": 0,
"id": "call_abc",
"type": "function",
"function": {
"name": "get_weather",
"arguments": ""
}
}]
},
"finish_reason": null
}]
});
let chunk: StreamChunk = serde_json::from_value(json).unwrap();
let tc = &chunk.choices[0].delta.tool_calls.as_ref().unwrap()[0];
assert_eq!(tc.index, 0);
assert_eq!(tc.id.as_deref(), Some("call_abc"));
}
#[test]
fn test_message_content_text_serialization() {
let msg = Message {
role: "user",
content: Some(MessageContent::Text("Hello".into())),
tool_calls: None,
tool_call_id: None,
};
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["content"], "Hello");
}
#[test]
fn test_message_content_parts_serialization() {
let msg = Message {
role: "user",
content: Some(MessageContent::Parts(vec![
ContentPart::Text {
text: "What's in this image?".into(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "data:image/png;base64,abc123".into(),
},
},
])),
tool_calls: None,
tool_call_id: None,
};
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["content"][0]["type"], "text");
assert_eq!(json["content"][1]["type"], "image_url");
}
#[test]
fn test_response_format_json_schema() {
let schema = serde_json::json!({"type": "object"});
let rf = ResponseFormat {
format_type: "json_schema",
json_schema: Some(JsonSchemaFormat {
name: "output",
schema: &schema,
strict: true,
}),
};
let json = serde_json::to_value(&rf).unwrap();
assert_eq!(json["type"], "json_schema");
assert_eq!(json["json_schema"]["strict"], true);
}
}