use serde::Serialize;
use crate::error::BaochuanError;
use super::message::ChatMessage;
use super::tools::{Tool, ToolChoice};
#[derive(Debug, Clone, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<String>>,
#[serde(rename = "audio", skip_serializing_if = "Option::is_none")]
pub audio_output: Option<AudioOutputConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize)]
pub struct AudioOutputConfig {
pub voice: String,
pub format: String,
}
#[derive(Debug, Default)]
pub struct ChatRequestBuilder {
model: Option<String>,
messages: Vec<ChatMessage>,
stream: bool,
max_tokens: Option<u32>,
temperature: Option<f32>,
top_p: Option<f32>,
modalities: Option<Vec<String>>,
audio_output: Option<AudioOutputConfig>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
}
impl ChatRequestBuilder {
pub fn new(model: impl Into<String>) -> Self {
Self { model: Some(model.into()), ..Default::default() }
}
pub fn message(mut self, message: ChatMessage) -> Self {
self.messages.push(message);
self
}
pub fn messages(mut self, messages: Vec<ChatMessage>) -> Self {
self.messages = messages;
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.stream = stream;
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn modalities(mut self, modalities: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.modalities = Some(modalities.into_iter().map(Into::into).collect());
self
}
pub fn audio_output(mut self, voice: impl Into<String>, format: impl Into<String>) -> Self {
self.audio_output = Some(AudioOutputConfig { voice: voice.into(), format: format.into() });
self
}
pub fn tool(mut self, tool: Tool) -> Self {
self.tools.get_or_insert_with(Vec::new).push(tool);
self
}
pub fn tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
self.tool_choice = Some(choice);
self
}
pub fn build(self) -> Result<ChatRequest, BaochuanError> {
let model = self.model.ok_or_else(|| {
BaochuanError::InvalidRequest("model must be specified".to_string())
})?;
if self.messages.is_empty() {
return Err(BaochuanError::InvalidRequest(
"at least one message is required".to_string(),
));
}
Ok(ChatRequest {
model,
messages: self.messages,
stream: self.stream,
max_tokens: self.max_tokens,
temperature: self.temperature,
top_p: self.top_p,
modalities: self.modalities,
audio_output: self.audio_output,
tools: self.tools,
tool_choice: self.tool_choice,
})
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TtsRequest {
pub model: String,
pub input: String,
pub voice: String,
#[serde(rename = "response_format", skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed: Option<f32>,
}
#[derive(Debug, Default)]
pub struct TtsRequestBuilder {
model: Option<String>,
input: Option<String>,
voice: Option<String>,
format: Option<String>,
speed: Option<f32>,
}
impl TtsRequestBuilder {
pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
Self {
model: Some(model.into()),
input: Some(input.into()),
..Default::default()
}
}
pub fn voice(mut self, voice: impl Into<String>) -> Self {
self.voice = Some(voice.into());
self
}
pub fn format(mut self, format: impl Into<String>) -> Self {
self.format = Some(format.into());
self
}
pub fn speed(mut self, speed: f32) -> Self {
self.speed = Some(speed);
self
}
pub fn build(self) -> Result<TtsRequest, BaochuanError> {
Ok(TtsRequest {
model: self.model.ok_or_else(|| BaochuanError::InvalidRequest("model required".to_string()))?,
input: self.input.ok_or_else(|| BaochuanError::InvalidRequest("input required".to_string()))?,
voice: self.voice.ok_or_else(|| BaochuanError::InvalidRequest("voice required".to_string()))?,
format: self.format,
speed: self.speed,
})
}
}