use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
pub top_p: Option<f32>,
pub n: Option<u32>,
pub stream: Option<bool>,
pub stream_options: Option<StreamOptions>,
pub stop: Option<Vec<String>>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub logit_bias: Option<HashMap<String, f32>>,
pub user: Option<String>,
pub functions: Option<Vec<Function>>,
pub function_call: Option<FunctionCall>,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
pub response_format: Option<ResponseFormat>,
pub seed: Option<u32>,
pub logprobs: Option<bool>,
pub top_logprobs: Option<u32>,
pub modalities: Option<Vec<String>>,
pub audio: Option<AudioParams>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: Option<MessageContent>,
pub name: Option<String>,
pub function_call: Option<FunctionCall>,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_id: Option<String>,
pub audio: Option<AudioContent>,
}
#[derive(Debug, Clone, PartialEq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Function,
Tool,
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageRole::System => write!(f, "system"),
MessageRole::User => write!(f, "user"),
MessageRole::Assistant => write!(f, "assistant"),
MessageRole::Function => write!(f, "function"),
MessageRole::Tool => write!(f, "tool"),
}
}
}
#[derive(Debug, Clone, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Hash, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text {
text: String,
},
#[serde(rename = "image_url")]
ImageUrl {
image_url: ImageUrl,
},
#[serde(rename = "audio")]
Audio {
audio: AudioContent,
},
}
#[derive(Debug, Clone, Hash, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
pub detail: Option<String>,
}
#[derive(Debug, Clone, Hash, Serialize, Deserialize)]
pub struct AudioContent {
pub data: String,
pub format: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Function {
pub name: String,
pub description: Option<String>,
pub parameters: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
None(String), Auto(String), Required(String), Specific(ToolChoiceFunction),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolChoiceFunctionSpec,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunctionSpec {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
pub json_schema: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioParams {
pub voice: String,
pub format: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub logprobs: Option<Logprobs>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Logprobs {
pub content: Option<Vec<ContentLogprob>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentLogprob {
pub token: String,
pub logprob: f64,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Option<Vec<TopLogprob>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopLogprob {
pub token: String,
pub logprob: f64,
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub prompt_tokens_details: Option<PromptTokensDetails>,
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTokensDetails {
pub cached_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ChatMessageDelta,
pub logprobs: Option<Logprobs>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessageDelta {
pub role: Option<MessageRole>,
pub content: Option<String>,
pub function_call: Option<FunctionCallDelta>,
pub tool_calls: Option<Vec<ToolCallDelta>>,
pub audio: Option<AudioDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
pub id: Option<String>,
#[serde(rename = "type")]
pub tool_type: Option<String>,
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioDelta {
pub data: Option<String>,
pub transcript: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub n: Option<u32>,
pub stream: Option<bool>,
pub stop: Option<Vec<String>>,
pub presence_penalty: Option<f64>,
pub frequency_penalty: Option<f64>,
pub logit_bias: Option<HashMap<String, f64>>,
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
pub logprobs: Option<serde_json::Value>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub model: String,
pub input: serde_json::Value,
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingObject>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingObject {
pub object: String,
pub embedding: Vec<f64>,
pub index: u32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageGenerationRequest {
pub prompt: String,
pub model: Option<String>,
pub n: Option<u32>,
pub size: Option<String>,
pub response_format: Option<String>,
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageGenerationResponse {
pub created: u64,
pub data: Vec<ImageObject>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageObject {
pub url: Option<String>,
pub b64_json: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Model {
pub id: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelListResponse {
pub object: String,
pub data: Vec<Model>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
pub logprobs: Option<serde_json::Value>,
}
impl Default for ChatCompletionRequest {
fn default() -> Self {
Self {
model: "gpt-3.5-turbo".to_string(),
messages: vec![],
temperature: None,
max_tokens: None,
max_completion_tokens: None,
top_p: None,
n: None,
stream: None,
stream_options: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
functions: None,
function_call: None,
tools: None,
tool_choice: None,
response_format: None,
seed: None,
logprobs: None,
top_logprobs: None,
modalities: None,
audio: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_completion_request_serialization() {
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello".to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
}],
temperature: Some(0.7),
..Default::default()
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("gpt-4"));
assert!(json.contains("Hello"));
}
#[test]
fn test_message_content_variants() {
let text_content = MessageContent::Text("Hello".to_string());
let json = serde_json::to_string(&text_content).unwrap();
assert_eq!(json, "\"Hello\"");
let parts_content = MessageContent::Parts(vec![ContentPart::Text {
text: "Hello".to_string(),
}]);
let json = serde_json::to_string(&parts_content).unwrap();
assert!(json.contains("text"));
assert!(json.contains("Hello"));
}
}