use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
}
impl ChatRequest {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
}
pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
pub fn add_message(mut self, message: Message) -> Self {
self.messages.push(message);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn with_presence_penalty(mut self, penalty: f32) -> Self {
self.presence_penalty = Some(penalty);
self
}
pub fn with_frequency_penalty(mut self, penalty: f32) -> Self {
self.frequency_penalty = Some(penalty);
self
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn with_response_format(mut self, format: ResponseFormat) -> Self {
self.response_format = Some(format);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl Default for Message {
fn default() -> Self {
Self {
role: Role::User,
content: String::new(),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
}
impl Message {
pub fn new(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(Role::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(Role::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(Role::Assistant, content)
}
pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: content.into(),
tool_call_id: Some(tool_call_id.into()),
name: None,
tool_calls: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
self.tool_calls = Some(tool_calls);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Mode(String),
Function {
#[serde(rename = "type")]
tool_type: String,
function: FunctionChoice,
},
}
impl ToolChoice {
pub fn none() -> Self {
Self::Mode("none".to_string())
}
pub fn auto() -> Self {
Self::Mode("auto".to_string())
}
pub fn required() -> Self {
Self::Mode("required".to_string())
}
pub fn function(name: impl Into<String>) -> Self {
Self::Function {
tool_type: "function".to_string(),
function: FunctionChoice {
name: name.into(),
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionChoice {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
}