use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::provider::{ModelName, ProviderId, ToolCall, ToolChoice, ToolSpec};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatRequest {
pub model: ModelName,
pub messages: Vec<Message>,
pub tools: Vec<ToolSpec>,
pub tool_choice: ToolChoice,
pub response_format: Option<ResponseFormat>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub max_output_tokens: Option<u32>,
pub stop: Vec<String>,
pub metadata: Value,
}
impl ChatRequest {
#[must_use]
pub fn new(model: ModelName) -> Self {
Self {
model,
messages: Vec::new(),
tools: Vec::new(),
tool_choice: ToolChoice::default(),
response_format: None,
temperature: None,
top_p: None,
max_output_tokens: None,
stop: Vec::new(),
metadata: Value::Null,
}
}
#[must_use]
pub fn with_message(mut self, message: Message) -> Self {
self.messages.push(message);
self
}
#[must_use]
pub fn with_user_text(self, text: impl Into<String>) -> Self {
self.with_message(Message::user_text(text))
}
#[must_use]
pub fn with_tool(mut self, tool: ToolSpec) -> Self {
self.tools.push(tool);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "role")]
#[non_exhaustive]
pub enum Message {
System {
content: Vec<ContentPart>,
},
User {
content: Vec<ContentPart>,
},
Assistant {
content: Vec<ContentPart>,
tool_calls: Vec<ToolCall>,
},
Tool {
tool_call_id: String,
name: String,
content: Vec<ContentPart>,
},
}
impl Message {
#[must_use]
pub fn system_text(text: impl Into<String>) -> Self {
Self::System {
content: vec![ContentPart::text(text)],
}
}
#[must_use]
pub fn user_text(text: impl Into<String>) -> Self {
Self::User {
content: vec![ContentPart::text(text)],
}
}
#[must_use]
pub fn assistant_text(text: impl Into<String>) -> Self {
Self::Assistant {
content: vec![ContentPart::text(text)],
tool_calls: Vec::new(),
}
}
#[must_use]
pub fn tool_text(
tool_call_id: impl Into<String>,
name: impl Into<String>,
text: impl Into<String>,
) -> Self {
Self::Tool {
tool_call_id: tool_call_id.into(),
name: name.into(),
content: vec![ContentPart::text(text)],
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
#[non_exhaustive]
pub enum ContentPart {
Text {
text: String,
},
Json {
value: Value,
},
ImageUrl {
url: String,
mime_type: Option<String>,
},
}
impl ContentPart {
#[must_use]
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
#[must_use]
pub fn json(value: Value) -> Self {
Self::Json { value }
}
#[must_use]
pub fn image_url(url: impl Into<String>, mime_type: Option<String>) -> Self {
Self::ImageUrl {
url: url.into(),
mime_type,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
#[non_exhaustive]
pub enum ResponseFormat {
Text,
JsonObject,
JsonSchema {
name: String,
schema: Value,
strict: bool,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatResponse {
pub provider: ProviderId,
pub model: ModelName,
pub message: Message,
pub finish_reason: FinishReason,
pub usage: Option<TokenUsage>,
pub raw: Option<Value>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum FinishReason {
Stop,
ToolCalls,
Length,
ContentFilter,
Error,
Unknown(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
}
impl TokenUsage {
#[must_use]
pub const fn new(input_tokens: u64, output_tokens: u64) -> Self {
Self {
input_tokens,
output_tokens,
total_tokens: input_tokens + output_tokens,
}
}
}