use serde::Serialize;
use validator::*;
use super::{tools::*, traits::*};
#[derive(Debug, Clone, Validate, Serialize)]
pub struct ChatBody<N, M>
where
N: ModelName,
(N, M): Bounded,
{
pub model: N,
pub messages: Vec<M>,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<ThinkingType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub do_sample: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 1.0))]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 1.0))]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1, max = 98304))]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tools>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(length(min = 6, max = 128))]
pub user_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(length(min = 1, max = 1))]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
}
impl<N, M> ChatBody<N, M>
where
N: ModelName,
(N, M): Bounded,
{
pub fn new(model: N, messages: M) -> Self {
Self {
model,
messages: vec![messages],
request_id: None,
thinking: None,
do_sample: None,
stream: None,
tool_stream: None,
temperature: None,
top_p: None,
max_tokens: None,
tools: None,
user_id: None,
stop: None,
response_format: None,
}
}
pub fn add_messages(mut self, messages: M) -> Self {
self.messages.push(messages);
self
}
pub fn add_message(mut self, message: M) -> Self {
self.messages.push(message);
self
}
pub fn extend_messages(mut self, messages: impl IntoIterator<Item = M>) -> Self {
self.messages.extend(messages);
self
}
pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
self.request_id = Some(request_id.into());
self
}
pub fn with_do_sample(mut self, do_sample: bool) -> Self {
self.do_sample = Some(do_sample);
self
}
pub fn with_stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
#[deprecated(note = "with_tools is deprecated; use add_tool/add_tools instead")]
pub fn with_tools(mut self, tools: impl Into<Vec<Tools>>) -> Self {
self.tools = Some(tools.into());
self
}
pub fn add_tools(mut self, tools: Tools) -> Self {
self.tools.get_or_insert(Vec::new()).push(tools);
self
}
pub fn extend_tools(mut self, tools: Vec<Tools>) -> Self {
self.tools.get_or_insert(Vec::new()).extend(tools);
self
}
pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
pub fn with_stop(mut self, stop: String) -> Self {
self.stop.get_or_insert_with(Vec::new).push(stop);
self
}
}
impl<N, M> ChatBody<N, M>
where
N: ModelName + ThinkEnable,
(N, M): Bounded,
{
pub fn with_thinking(mut self, thinking: ThinkingType) -> Self {
self.thinking = Some(thinking);
self
}
}
impl<N, M> ChatBody<N, M>
where
N: ModelName + ToolStreamEnable,
(N, M): Bounded,
{
pub fn with_tool_stream(mut self, tool_stream: bool) -> Self {
self.tool_stream = Some(tool_stream);
if tool_stream {
self.stream = Some(true);
}
self
}
}
impl From<Tools> for Vec<Tools> {
fn from(tool: Tools) -> Self {
vec![tool]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{chat_message_types::TextMessage, chat_models::GLM4_6};
#[test]
fn test_with_tool_stream_sets_both_fields() {
let body: ChatBody<GLM4_6, TextMessage> =
ChatBody::new(GLM4_6 {}, TextMessage::user("test"));
let body = body.with_tool_stream(true);
assert_eq!(body.tool_stream, Some(true));
assert_eq!(body.stream, Some(true));
}
#[test]
fn test_with_tool_stream_false_does_not_force_stream() {
let body: ChatBody<GLM4_6, TextMessage> =
ChatBody::new(GLM4_6 {}, TextMessage::user("test"));
let body = body.with_tool_stream(false);
assert_eq!(body.tool_stream, Some(false));
assert_ne!(body.stream, Some(true));
}
#[test]
fn test_add_tools_accumulates() {
let body: ChatBody<GLM4_6, TextMessage> =
ChatBody::new(GLM4_6 {}, TextMessage::user("test"));
let tool = crate::model::tools::Function::new(
"test_fn",
"A test function",
serde_json::json!({"type": "object"}),
);
let body = body.add_tools(crate::model::tools::Tools::Function { function: tool });
assert!(body.tools.is_some());
assert_eq!(body.tools.unwrap().len(), 1);
}
#[test]
fn test_extend_messages() {
let body: ChatBody<GLM4_6, TextMessage> =
ChatBody::new(GLM4_6 {}, TextMessage::user("first"));
let body = body.extend_messages(vec![
TextMessage::assistant("second"),
TextMessage::user("third"),
]);
assert_eq!(body.messages.len(), 3);
match &body.messages[0] {
TextMessage::User { content } => assert_eq!(content, "first"),
_ => panic!("Expected User message"),
}
}
#[test]
fn test_add_message() {
let body: ChatBody<GLM4_6, TextMessage> =
ChatBody::new(GLM4_6 {}, TextMessage::user("first"));
let body = body.add_message(TextMessage::assistant("second"));
assert_eq!(body.messages.len(), 2);
}
}