use super::shared::{FinishReason, ReasoningEffort, StopToken, Usage, WebSearchContextSize};
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::Display;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub choices: Vec<ChatCompletionChoice>,
#[serde(default = "crate::shared::default_created")]
pub created: u32,
#[serde(default = "default_model")]
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub object: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[cfg(feature = "stream")]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionChunkResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub choices: Vec<ChatCompletionChunkChoice>,
#[serde(default = "crate::shared::default_created")]
pub created: u32,
#[serde(default = "default_model")]
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub object: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
fn default_model() -> String {
"".to_string()
}
#[derive(Serialize, Deserialize, Debug, Default, Builder, Clone, PartialEq)]
#[builder(name = "ChatCompletionParametersBuilder")]
#[builder(setter(into, strip_option), default)]
pub struct ChatCompletionParameters {
pub messages: Vec<ChatMessage>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<ReasoningEffort>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<Modality>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prediction: Option<PredictedOutput>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audio: Option<AudioParameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ChatCompletionResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StopToken>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<ChatCompletionStreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatCompletionTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ChatCompletionToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_identifier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub web_search_options: Option<WebSearchOptions>,
#[serde(flatten)]
#[serde(skip_serializing_if = "Option::is_none")]
pub extra_body: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub query_params: Option<HashMap<String, String>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionStreamOptions {
pub include_usage: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub continuous_usage_stats: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionToolChoiceFunction {
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<ChatCompletionToolType>,
pub function: ChatCompletionToolChoiceFunctionName,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionToolChoiceFunctionName {
pub name: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionFunction {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChatCompletionResponseFormat {
Text,
JsonObject,
JsonSchema { json_schema: JsonSchema },
}
#[derive(Serialize, Deserialize, Debug, Default, Builder, Clone, PartialEq)]
#[builder(name = "JsonSchemaBuilder")]
#[builder(setter(into, strip_option), default)]
pub struct JsonSchema {
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
schema: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
strict: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionTool {
pub r#type: ChatCompletionToolType,
pub function: ChatCompletionFunction,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "role", rename_all = "snake_case")]
pub enum ChatMessage {
Developer {
content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
System {
content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<ChatMessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
refusal: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
audio: Option<AudioDataIdParameter>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
},
Tool {
content: ChatMessageContent,
tool_call_id: String,
},
}
impl ChatMessage {
pub fn message(&self) -> Option<&ChatMessageContent> {
match self {
ChatMessage::Developer { content, .. }
| ChatMessage::System { content, .. }
| ChatMessage::User { content, .. }
| ChatMessage::Assistant {
content: Some(content),
..
} => Some(content),
ChatMessage::Assistant { content: None, .. } => None,
ChatMessage::Tool { .. } => None,
}
}
pub fn text(&self) -> Option<&str> {
match self {
ChatMessage::Developer { content, .. }
| ChatMessage::System { content, .. }
| ChatMessage::User { content, .. }
| ChatMessage::Tool { content, .. }
| ChatMessage::Assistant {
content: Some(content),
..
} => {
if let ChatMessageContent::Text(text) = content {
Some(text)
} else {
None
}
}
ChatMessage::Assistant { content: None, .. } => None,
}
}
pub fn name(&self) -> Option<&str> {
match self {
ChatMessage::Developer { name, .. }
| ChatMessage::System { name, .. }
| ChatMessage::User { name, .. }
| ChatMessage::Assistant { name, .. } => name.as_deref(),
ChatMessage::Tool { .. } => None,
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "role", rename_all = "snake_case")]
pub enum DeltaChatMessage {
Developer {
content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
System {
content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
content: ChatMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<ChatMessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
refusal: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<DeltaToolCall>>,
},
Tool {
content: String,
tool_call_id: String,
},
#[serde(untagged)]
Untagged {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<ChatMessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
refusal: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<DeltaToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
},
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: Function,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct DeltaToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub index: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
pub function: DeltaFunction,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Function {
pub name: String,
pub arguments: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct DeltaFunction {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionChoice {
pub index: u32,
pub message: ChatMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProps>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AudioDataIdParameter {
pub id: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AudioParameters {
pub voice: Voice,
pub format: AudioFormat,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct LogProps {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Vec<LogPropsContent>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<Vec<LogPropsContent>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct LogPropsContent {
#[serde(flatten)]
pub token_info: LogProbsContentInfo,
pub top_logprobs: Vec<LogProbsContentInfo>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct LogProbsContentInfo {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
#[cfg(feature = "stream")]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatCompletionChunkChoice {
pub index: Option<u32>,
pub delta: DeltaChatMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProps>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ImageUrlType {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<ImageUrlDetail>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct PredictedOutput {
pub r#type: PredictedOutputType,
pub content: PredictedOutputContent,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum PredictedOutputContent {
String(String),
Array(Vec<PredictedOutputArrayPart>),
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct PredictedOutputArrayPart {
pub r#type: String,
pub text: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum PredictedOutputType {
Content,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Modality {
Text,
Audio,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ImageUrlDetail {
Auto,
High,
Low,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum ChatMessageContent {
Text(String),
ContentPart(Vec<ChatMessageContentPart>),
None,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum ChatMessageContentPart {
Text(ChatMessageTextContentPart),
Image(ChatMessageImageContentPart),
Audio(ChatMessageAudioContentPart),
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatMessageTextContentPart {
pub r#type: String,
pub text: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatMessageImageContentPart {
pub r#type: String,
pub image_url: ImageUrlType,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatMessageAudioContentPart {
pub r#type: String,
pub input_audio: InputAudioData,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ChatMessageImageUrl {
pub url: String,
pub detail: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct InputAudioData {
pub data: String,
pub format: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct WebSearchOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub search_context_size: Option<WebSearchContextSize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_location: Option<ApproximateUserLocation>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ApproximateUserLocation {
pub r#type: UserLocationType,
pub approximate: WebSearchUserLocation,
}
impl Display for ChatMessageContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChatMessageContent::Text(text) => write!(f, "{text}"),
ChatMessageContent::ContentPart(tcp) => {
for part in tcp {
write!(f, "{part:?}")?;
}
Ok(())
}
ChatMessageContent::None => write!(f, ""),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChatCompletionToolType {
Function,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChatCompletionToolChoice {
None,
Auto,
Required,
#[serde(untagged)]
ChatCompletionToolChoiceFunction(ChatCompletionToolChoiceFunction),
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct WebSearchUserLocation {
pub city: Option<String>,
pub country: Option<String>,
pub region: Option<String>,
pub timezone: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum UserLocationType {
Approximate,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Voice {
Alloy,
Ash,
Ballad,
Coral,
Echo,
Sage,
Shimmer,
Verse,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum AudioFormat {
Wav,
Mp3,
Flac,
Opus,
Pcm16,
}
impl Default for ChatMessageContent {
fn default() -> Self {
ChatMessageContent::Text("".to_string())
}
}
impl DeltaFunction {
pub fn merge(&mut self, other: &Self) {
if self.name.is_none() && other.name.is_some() {
self.name.clone_from(&other.name);
}
if let Some(arguments) = &other.arguments {
if let Some(self_arguments) = &mut self.arguments {
self_arguments.push_str(arguments);
} else {
self.arguments = Some(arguments.clone());
}
}
}
pub fn is_empty(&self) -> bool {
self.name.is_none() && self.arguments.is_none()
}
}
#[cfg(test)]
mod tests {
use crate::chat::{
ChatCompletionResponseFormat, ChatCompletionToolChoice, ChatCompletionToolChoiceFunction,
ChatCompletionToolChoiceFunctionName, ChatCompletionToolType, ChatMessage,
ChatMessageContent, ChatMessageContentPart, ChatMessageTextContentPart, JsonSchemaBuilder,
};
use serde_json;
#[test]
fn test_chat_completion_response_format_serialization_deserialization() {
let json_schema = JsonSchemaBuilder::default()
.description("This is a test schema".to_string())
.name("test_schema".to_string())
.schema(Some(serde_json::json!({"type": "object"})))
.strict(true)
.build()
.unwrap();
let response_format = ChatCompletionResponseFormat::JsonSchema { json_schema };
let serialized = serde_json::to_string(&response_format).unwrap();
assert_eq!(serialized, "{\"type\":\"json_schema\",\"json_schema\":{\"description\":\"This is a test schema\",\"name\":\"test_schema\",\"schema\":{\"type\":\"object\"},\"strict\":true}}");
let deserialized: ChatCompletionResponseFormat = serde_json::from_str(&serialized).unwrap();
match deserialized {
ChatCompletionResponseFormat::JsonSchema { json_schema } => {
assert_eq!(
json_schema.description,
Some("This is a test schema".to_string())
);
assert_eq!(json_schema.name, "test_schema".to_string());
assert_eq!(
json_schema.schema,
Some(serde_json::json!({"type": "object"}))
);
assert_eq!(json_schema.strict, Some(true));
}
_ => panic!("Deserialized format should be JsonSchema"),
}
}
#[test]
fn test_chat_completion_tool_choice_required_serialization_deserialization() {
let tool_choice = ChatCompletionToolChoice::Required;
let serialized = serde_json::to_string(&tool_choice).unwrap();
assert_eq!(serialized, "\"required\"");
let deserialized: ChatCompletionToolChoice =
serde_json::from_str(serialized.as_str()).unwrap();
assert_eq!(deserialized, tool_choice)
}
#[test]
fn test_chat_completion_tool_choice_named_function_serialization_deserialization() {
let tool_choice = ChatCompletionToolChoice::ChatCompletionToolChoiceFunction(
ChatCompletionToolChoiceFunction {
r#type: Some(ChatCompletionToolType::Function),
function: ChatCompletionToolChoiceFunctionName {
name: "get_current_weather".to_string(),
},
},
);
let serialized = serde_json::to_string(&tool_choice).unwrap();
assert_eq!(
serialized,
"{\"type\":\"function\",\"function\":{\"name\":\"get_current_weather\"}}"
);
let deserialized: ChatCompletionToolChoice =
serde_json::from_str(serialized.as_str()).unwrap();
assert_eq!(deserialized, tool_choice)
}
#[test]
fn test_chat_message_tool_content_string_serialization_deserialization() {
let tool_message = ChatMessage::Tool {
content: ChatMessageContent::Text("tool_result".to_string()),
tool_call_id: "tool_call_id".to_string(),
};
let serialized = serde_json::to_string(&tool_message).unwrap();
assert_eq!(
serialized,
"{\"role\":\"tool\",\"content\":\"tool_result\",\"tool_call_id\":\"tool_call_id\"}"
);
let deserialized: ChatMessage = serde_json::from_str(serialized.as_str()).unwrap();
assert_eq!(deserialized, tool_message)
}
#[test]
fn test_chat_message_tool_content_array_serialization_deserialization() {
let content_array = vec![ChatMessageContentPart::Text(ChatMessageTextContentPart {
r#type: "text".to_string(),
text: "tool_result".to_string(),
})];
let tool_message = ChatMessage::Tool {
content: ChatMessageContent::ContentPart(content_array),
tool_call_id: "tool_call_id".to_string(),
};
let serialized = serde_json::to_string(&tool_message).unwrap();
assert_eq!(
serialized,
"{\"role\":\"tool\",\"content\":[{\"type\":\"text\",\"text\":\"tool_result\"}],\"tool_call_id\":\"tool_call_id\"}"
);
let deserialized: ChatMessage = serde_json::from_str(serialized.as_str()).unwrap();
assert_eq!(deserialized, tool_message)
}
}