use serde::{Deserialize, Serialize};
use crate::types::shared::{Tool, ToolCall, ToolChoice, Usage};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ReasoningEffort {
Low,
Medium,
High,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
impl From<String> for MessageContent {
fn from(value: String) -> Self {
MessageContent::Text(value)
}
}
impl From<&str> for MessageContent {
fn from(value: &str) -> Self {
MessageContent::Text(value.to_owned())
}
}
impl From<Vec<ContentPart>> for MessageContent {
fn from(value: Vec<ContentPart>) -> Self {
MessageContent::Parts(value)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
},
ImageUrl {
image_url: ImageUrl,
},
}
impl ContentPart {
pub fn text(text: impl Into<String>) -> Self {
ContentPart::Text { text: text.into() }
}
pub fn image_url(url: impl Into<String>) -> Self {
ContentPart::ImageUrl {
image_url: ImageUrl {
url: url.into(),
detail: None,
},
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<MessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl ChatMessage {
pub fn system(content: impl Into<MessageContent>) -> Self {
ChatMessage {
role: Role::System,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}
pub fn user(content: impl Into<MessageContent>) -> Self {
ChatMessage {
role: Role::User,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}
pub fn assistant(content: impl Into<MessageContent>) -> Self {
ChatMessage {
role: Role::Assistant,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
}
}
pub fn tool(tool_call_id: impl Into<String>, content: impl Into<MessageContent>) -> Self {
ChatMessage {
role: Role::Tool,
content: Some(content.into()),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<ReasoningEffort>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StopSequence>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum StopSequence {
One(String),
Many(Vec<String>),
}
impl ChatCompletionRequest {
pub fn builder() -> ChatCompletionRequestBuilder {
ChatCompletionRequestBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct ChatCompletionRequestBuilder {
model: Option<String>,
messages: Vec<ChatMessage>,
stream: Option<bool>,
max_tokens: Option<u32>,
temperature: Option<f32>,
top_p: Option<f32>,
reasoning_effort: Option<ReasoningEffort>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
response_format: Option<serde_json::Value>,
n: Option<u32>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
logit_bias: Option<serde_json::Value>,
user: Option<String>,
seed: Option<i64>,
stop: Option<StopSequence>,
}
impl ChatCompletionRequestBuilder {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn message(mut self, message: ChatMessage) -> Self {
self.messages.push(message);
self
}
pub fn messages(mut self, messages: impl IntoIterator<Item = ChatMessage>) -> Self {
self.messages.extend(messages);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
self.reasoning_effort = Some(effort);
self
}
pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
self.tools = Some(tools.into_iter().collect());
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn response_format(mut self, response_format: serde_json::Value) -> Self {
self.response_format = Some(response_format);
self
}
pub fn n(mut self, n: u32) -> Self {
self.n = Some(n);
self
}
pub fn presence_penalty(mut self, penalty: f32) -> Self {
self.presence_penalty = Some(penalty);
self
}
pub fn frequency_penalty(mut self, penalty: f32) -> Self {
self.frequency_penalty = Some(penalty);
self
}
pub fn logit_bias(mut self, logit_bias: serde_json::Value) -> Self {
self.logit_bias = Some(logit_bias);
self
}
pub fn user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
pub fn seed(mut self, seed: i64) -> Self {
self.seed = Some(seed);
self
}
pub fn stop(mut self, stop: impl Into<String>) -> Self {
self.stop = Some(StopSequence::One(stop.into()));
self
}
pub fn stop_sequences(mut self, stop: impl IntoIterator<Item = String>) -> Self {
self.stop = Some(StopSequence::Many(stop.into_iter().collect()));
self
}
pub fn build(self) -> Result<ChatCompletionRequest, BuildError> {
let model = self.model.ok_or(BuildError::MissingModel)?;
Ok(ChatCompletionRequest {
model,
messages: self.messages,
stream: self.stream,
max_tokens: self.max_tokens,
temperature: self.temperature,
top_p: self.top_p,
reasoning_effort: self.reasoning_effort,
tools: self.tools,
tool_choice: self.tool_choice,
response_format: self.response_format,
n: self.n,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
logit_bias: self.logit_bias,
user: self.user,
seed: self.seed,
stop: self.stop,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[non_exhaustive]
pub enum BuildError {
#[error("`model` is required")]
MissingModel,
#[error("`prompt` is required")]
MissingPrompt,
#[error("`input` is required")]
MissingInput,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: ResponseMessage,
pub finish_reason: FinishReason,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ResponseMessage {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: Delta,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<Role>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub kind: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_requires_model() {
let err = ChatCompletionRequest::builder()
.message(ChatMessage::user("hi"))
.build()
.unwrap_err();
assert_eq!(err, BuildError::MissingModel);
}
#[test]
fn builder_sets_fields() {
let req = ChatCompletionRequest::builder()
.model("m")
.message(ChatMessage::system("be brief"))
.message(ChatMessage::user("hi"))
.temperature(0.5)
.stream(true)
.build()
.unwrap();
assert_eq!(req.model, "m");
assert_eq!(req.messages.len(), 2);
assert_eq!(req.temperature, Some(0.5));
assert_eq!(req.stream, Some(true));
}
#[test]
fn omits_none_fields_in_json() {
let req = ChatCompletionRequest::builder()
.model("m")
.message(ChatMessage::user("hi"))
.build()
.unwrap();
let json = serde_json::to_value(&req).unwrap();
assert!(json.get("temperature").is_none());
assert!(json.get("stream").is_none());
assert_eq!(json["model"], "m");
}
#[test]
fn user_message_serializes_role_lowercase() {
let msg = ChatMessage::user("hi");
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "user");
assert_eq!(json["content"], "hi");
}
#[test]
fn multimodal_content_roundtrips() {
let msg = ChatMessage::user(vec![
ContentPart::text("look:"),
ContentPart::image_url("https://example.com/a.png"),
]);
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["content"][0]["type"], "text");
assert_eq!(json["content"][1]["type"], "image_url");
assert_eq!(
json["content"][1]["image_url"]["url"],
"https://example.com/a.png"
);
}
#[test]
fn tool_message_carries_call_id() {
let msg = ChatMessage::tool("call_1", "42");
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "tool");
assert_eq!(json["tool_call_id"], "call_1");
}
}