use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use crate::structured::ResponseFormat;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub top_p: Option<f32>,
#[serde(default)]
pub n: Option<u32>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub stop: Option<Vec<String>>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub presence_penalty: Option<f32>,
#[serde(default)]
pub frequency_penalty: Option<f32>,
#[serde(default)]
pub user: Option<String>,
#[serde(default)]
pub logprobs: Option<bool>,
#[serde(default)]
pub top_logprobs: Option<u32>,
#[serde(default)]
pub tools: Option<Vec<Tool>>,
#[serde(default)]
pub tool_choice: Option<ToolChoice>,
#[serde(default)]
pub parallel_tool_calls: Option<bool>,
#[serde(default)]
pub response_format: Option<ResponseFormat>,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ChatMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
#[serde(untagged)]
pub enum ToolChoice {
String(String),
Tool(ToolChoiceFunction),
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ToolChoiceFunction {
#[serde(rename = "type")]
pub choice_type: String,
pub function: ToolChoiceFunctionName,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ToolChoiceFunctionName {
pub name: String,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ChatLogProbs {
pub content: Vec<TokenLogProb>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct TokenLogProb {
pub token: String,
pub logprob: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub bytes: Option<Vec<u8>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<Vec<TopLogProb>>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChatChunkChoice>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ChatChunkChoice {
pub index: u32,
pub delta: ChatDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Default, ToSchema)]
pub struct ChatDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
}
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub top_p: Option<f32>,
#[serde(default)]
pub n: Option<u32>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub stop: Option<Vec<String>>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub logprobs: Option<u32>,
#[serde(default)]
pub echo: Option<bool>,
#[serde(default)]
pub suffix: Option<String>,
#[serde(default)]
pub presence_penalty: Option<f32>,
#[serde(default)]
pub frequency_penalty: Option<f32>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
pub finish_reason: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<f32>,
pub top_logprobs: Vec<std::collections::HashMap<String, f32>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct EmbeddingRequest {
pub model: String,
pub input: EmbeddingInput,
#[serde(default)]
pub encoding_format: Option<String>,
#[serde(default)]
pub dimensions: Option<u32>,
}
#[derive(Debug, Clone, Deserialize, ToSchema)]
#[serde(untagged)]
pub enum EmbeddingInput {
Single(String),
Multiple(Vec<String>),
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct EmbeddingData {
pub object: String,
pub index: u32,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelObject>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct ModelObject {
pub id: String,
pub object: String,
pub created: i64,
pub owned_by: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
impl Usage {
pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_request_deserialization() {
let json = r#"{
"model": "gpt-4",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
],
"temperature": 0.7,
"max_tokens": 100
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.model, "gpt-4");
assert_eq!(req.messages.len(), 2);
assert_eq!(req.temperature, Some(0.7));
assert_eq!(req.max_tokens, Some(100));
}
#[test]
fn test_chat_response_serialization() {
let response = ChatCompletionResponse {
id: "inf-chat-123".to_string(),
object: "chat.completion".to_string(),
created: 1677652288,
model: "gpt-4".to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: "Hello!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
},
finish_reason: "stop".to_string(),
logprobs: None,
}],
usage: Usage::new(10, 5),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("inf-chat-123"));
assert!(json.contains("Hello!"));
assert!(!json.contains("logprobs"));
assert!(!json.contains("tool_calls"));
}
#[test]
fn test_embedding_input_variants() {
let json = r#"{"model": "text-embedding-3-small", "input": "Hello"}"#;
let req: EmbeddingRequest = serde_json::from_str(json).unwrap();
matches!(req.input, EmbeddingInput::Single(_));
let json = r#"{"model": "text-embedding-3-small", "input": ["Hello", "World"]}"#;
let req: EmbeddingRequest = serde_json::from_str(json).unwrap();
matches!(req.input, EmbeddingInput::Multiple(_));
}
#[test]
fn test_usage() {
let usage = Usage::new(100, 50);
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
#[test]
fn test_chat_request_with_logprobs() {
let json = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello!"}],
"logprobs": true,
"top_logprobs": 5
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.logprobs, Some(true));
assert_eq!(req.top_logprobs, Some(5));
}
#[test]
fn test_chat_request_logprobs_defaults() {
let json = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello!"}]
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.logprobs, None);
assert_eq!(req.top_logprobs, None);
}
#[test]
fn test_logprobs_serialization() {
let logprobs = ChatLogProbs {
content: vec![TokenLogProb {
token: "Hello".to_string(),
logprob: -0.5,
bytes: Some(vec![72, 101, 108, 108, 111]),
top_logprobs: Some(vec![
TopLogProb {
token: "Hello".to_string(),
logprob: -0.5,
bytes: None,
},
TopLogProb {
token: "Hi".to_string(),
logprob: -1.2,
bytes: None,
},
]),
}],
};
let json = serde_json::to_string(&logprobs).unwrap();
assert!(json.contains("\"token\":\"Hello\""));
assert!(json.contains("\"logprob\":-0.5"));
assert!(json.contains("\"top_logprobs\""));
}
#[test]
fn test_chat_response_with_logprobs() {
let response = ChatCompletionResponse {
id: "inf-chat-456".to_string(),
object: "chat.completion".to_string(),
created: 1677652288,
model: "gpt-4".to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
},
finish_reason: "stop".to_string(),
logprobs: Some(ChatLogProbs {
content: vec![TokenLogProb {
token: "Hi".to_string(),
logprob: -0.3,
bytes: None,
top_logprobs: None,
}],
}),
}],
usage: Usage::new(5, 2),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("\"logprobs\""));
assert!(json.contains("\"token\":\"Hi\""));
assert!(json.contains("\"logprob\":-0.3"));
}
#[test]
fn test_token_logprob_minimal() {
let token_logprob = TokenLogProb {
token: "test".to_string(),
logprob: -1.0,
bytes: None,
top_logprobs: None,
};
let json = serde_json::to_string(&token_logprob).unwrap();
assert!(json.contains("\"token\":\"test\""));
assert!(json.contains("\"logprob\":-1"));
assert!(!json.contains("\"bytes\""));
assert!(!json.contains("\"top_logprobs\""));
}
#[test]
fn test_tool_serialization() {
let tool = Tool {
tool_type: "function".to_string(),
function: FunctionDefinition {
name: "get_weather".to_string(),
description: Some("Get the current weather".to_string()),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
})),
strict: None,
},
};
let json = serde_json::to_string(&tool).unwrap();
assert!(json.contains("\"type\":\"function\""));
assert!(json.contains("\"name\":\"get_weather\""));
assert!(json.contains("\"description\":\"Get the current weather\""));
assert!(json.contains("\"location\""));
}
#[test]
fn test_tool_deserialization() {
let json = r#"{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
}
}
}
}"#;
let tool: Tool = serde_json::from_str(json).unwrap();
assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "get_weather");
assert_eq!(
tool.function.description,
Some("Get current weather".to_string())
);
}
#[test]
fn test_tool_call_serialization() {
let tool_call = ToolCall {
id: "call_abc123".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location": "San Francisco"}"#.to_string(),
},
};
let json = serde_json::to_string(&tool_call).unwrap();
assert!(json.contains("\"id\":\"call_abc123\""));
assert!(json.contains("\"type\":\"function\""));
assert!(json.contains("\"name\":\"get_weather\""));
assert!(json.contains("San Francisco"));
}
#[test]
fn test_tool_call_deserialization() {
let json = r#"{
"id": "call_xyz789",
"type": "function",
"function": {
"name": "calculate",
"arguments": "{\"expression\": \"2+2\"}"
}
}"#;
let tool_call: ToolCall = serde_json::from_str(json).unwrap();
assert_eq!(tool_call.id, "call_xyz789");
assert_eq!(tool_call.call_type, "function");
assert_eq!(tool_call.function.name, "calculate");
}
#[test]
fn test_chat_request_with_tools() {
let json = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "What's the weather?"}],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
}
}
}
}
],
"tool_choice": "auto"
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert!(req.tools.is_some());
let tools = req.tools.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].function.name, "get_weather");
}
#[test]
fn test_tool_choice_string_variant() {
let choice = ToolChoice::String("auto".to_string());
let json = serde_json::to_string(&choice).unwrap();
assert_eq!(json, "\"auto\"");
let parsed: ToolChoice = serde_json::from_str("\"none\"").unwrap();
match parsed {
ToolChoice::String(s) => assert_eq!(s, "none"),
_ => panic!("Expected String variant"),
}
}
#[test]
fn test_tool_choice_function_variant() {
let choice = ToolChoice::Tool(ToolChoiceFunction {
choice_type: "function".to_string(),
function: ToolChoiceFunctionName {
name: "get_weather".to_string(),
},
});
let json = serde_json::to_string(&choice).unwrap();
assert!(json.contains("\"type\":\"function\""));
assert!(json.contains("\"name\":\"get_weather\""));
}
#[test]
fn test_message_with_tool_calls() {
let message = ChatMessage {
role: "assistant".to_string(),
content: "".to_string(),
name: None,
tool_calls: Some(vec![ToolCall {
id: "call_123".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location": "NYC"}"#.to_string(),
},
}]),
tool_call_id: None,
};
let json = serde_json::to_string(&message).unwrap();
assert!(json.contains("\"tool_calls\""));
assert!(json.contains("\"call_123\""));
}
#[test]
fn test_tool_response_message() {
let message = ChatMessage {
role: "tool".to_string(),
content: "72°F, sunny".to_string(),
name: None,
tool_calls: None,
tool_call_id: Some("call_123".to_string()),
};
let json = serde_json::to_string(&message).unwrap();
assert!(json.contains("\"role\":\"tool\""));
assert!(json.contains("\"tool_call_id\":\"call_123\""));
assert!(json.contains("72°F"));
}
#[test]
fn test_function_definition_minimal() {
let func = FunctionDefinition {
name: "simple_func".to_string(),
description: None,
parameters: None,
strict: None,
};
let json = serde_json::to_string(&func).unwrap();
assert!(json.contains("\"name\":\"simple_func\""));
assert!(!json.contains("\"description\""));
assert!(!json.contains("\"parameters\""));
assert!(!json.contains("\"strict\""));
}
#[test]
fn test_function_definition_with_strict() {
let func = FunctionDefinition {
name: "strict_func".to_string(),
description: Some("A strict function".to_string()),
parameters: Some(serde_json::json!({"type": "object"})),
strict: Some(true),
};
let json = serde_json::to_string(&func).unwrap();
assert!(json.contains("\"strict\":true"));
}
#[test]
fn test_chat_request_with_response_format_text() {
let json = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello!"}],
"response_format": {"type": "text"}
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert!(req.response_format.is_some());
let format = req.response_format.unwrap();
assert!(matches!(format, ResponseFormat::Text));
}
#[test]
fn test_chat_request_with_response_format_json_object() {
let json = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Give me JSON"}],
"response_format": {"type": "json_object"}
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert!(req.response_format.is_some());
let format = req.response_format.unwrap();
assert!(matches!(format, ResponseFormat::JsonObject));
assert!(format.requires_json());
}
#[test]
fn test_chat_request_with_response_format_json_schema() {
let json = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Give me a person object"}],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "person",
"description": "A person object",
"schema": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
},
"strict": true
}
}
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert!(req.response_format.is_some());
let format = req.response_format.unwrap();
assert!(format.requires_json());
assert!(format.is_strict());
let schema = format.schema();
assert!(schema.is_some());
let schema = schema.unwrap();
assert_eq!(schema.name, "person");
assert_eq!(schema.description, Some("A person object".to_string()));
}
#[test]
fn test_chat_request_without_response_format() {
let json = r#"{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello!"}]
}"#;
let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert!(req.response_format.is_none());
}
}