use std::marker::PhantomData;
use serde::Serialize;
use validator::Validate;
use super::super::{chat_base_request::*, tools::*, traits::*};
use crate::client::http::HttpClient;
pub struct ChatCompletion<N, M, S = StreamOff>
where
N: ModelName + Chat,
(N, M): Bounded,
ChatBody<N, M>: Serialize,
S: StreamState,
{
pub key: String,
pub url: String,
body: ChatBody<N, M>,
_stream: PhantomData<S>,
}
impl<N, M> ChatCompletion<N, M, StreamOff>
where
N: ModelName + Chat,
(N, M): Bounded,
ChatBody<N, M>: Serialize,
{
pub fn new(model: N, messages: M, key: String) -> ChatCompletion<N, M, StreamOff> {
let body = ChatBody::new(model, messages);
ChatCompletion {
body,
key,
url: "https://open.bigmodel.cn/api/paas/v4/chat/completions".to_string(),
_stream: PhantomData,
}
}
pub fn body_mut(&mut self) -> &mut ChatBody<N, M> {
&mut self.body
}
pub fn add_messages(mut self, messages: M) -> Self {
self.body = self.body.add_messages(messages);
self
}
pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
self.body = self.body.with_request_id(request_id);
self
}
pub fn with_do_sample(mut self, do_sample: bool) -> Self {
self.body = self.body.with_do_sample(do_sample);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.body = self.body.with_temperature(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.body = self.body.with_top_p(top_p);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.body = self.body.with_max_tokens(max_tokens);
self
}
pub fn add_tool(mut self, tool: Tools) -> Self {
self.body = self.body.add_tools(tool);
self
}
pub fn add_tools(mut self, tools: Vec<Tools>) -> Self {
self.body = self.body.extend_tools(tools);
self
}
pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
self.body = self.body.with_user_id(user_id);
self
}
pub fn with_stop(mut self, stop: String) -> Self {
self.body = self.body.with_stop(stop);
self
}
pub fn with_url(mut self, url: impl Into<String>) -> Self {
self.url = url.into();
self
}
pub fn with_coding_plan(mut self) -> Self {
self.url = "https://open.bigmodel.cn/api/coding/paas/v4/chat/completions".to_string();
self
}
pub fn with_thinking(mut self, thinking: ThinkingType) -> Self
where
N: ThinkEnable,
{
self.body = self.body.with_thinking(thinking);
self
}
pub fn enable_stream(mut self) -> ChatCompletion<N, M, StreamOn> {
self.body.stream = Some(true);
ChatCompletion {
key: self.key,
url: self.url,
body: self.body,
_stream: PhantomData,
}
}
pub fn validate(&self) -> crate::ZaiResult<()> {
self.body
.validate()
.map_err(crate::client::error::ZaiError::from)?;
if matches!(self.body.stream, Some(true)) {
return Err(crate::client::error::ZaiError::ApiError {
code: 1200,
message: "stream=true detected; use enable_stream() and streaming APIs instead"
.to_string(),
});
}
Ok(())
}
pub async fn send(
&self,
) -> crate::ZaiResult<crate::model::chat_base_response::ChatCompletionResponse>
where
N: serde::Serialize,
M: serde::Serialize,
{
self.validate()?;
let resp: reqwest::Response = self.post().await?;
let parsed = resp
.json::<crate::model::chat_base_response::ChatCompletionResponse>()
.await?;
Ok(parsed)
}
}
impl<N, M> ChatCompletion<N, M, StreamOn>
where
N: ModelName + Chat,
(N, M): Bounded,
ChatBody<N, M>: Serialize,
{
pub fn with_tool_stream(mut self, tool_stream: bool) -> Self
where
N: ToolStreamEnable,
{
self.body = self.body.with_tool_stream(tool_stream);
self
}
pub fn disable_stream(mut self) -> ChatCompletion<N, M, StreamOff> {
self.body.stream = Some(false);
self.body.tool_stream = None;
ChatCompletion {
key: self.key,
url: self.url,
body: self.body,
_stream: PhantomData,
}
}
}
impl<N, M, S> HttpClient for ChatCompletion<N, M, S>
where
N: ModelName + Serialize + Chat,
M: Serialize,
(N, M): Bounded,
S: StreamState,
{
type Body = ChatBody<N, M>;
type ApiUrl = String;
type ApiKey = String;
fn api_url(&self) -> &Self::ApiUrl {
&self.url
}
fn api_key(&self) -> &Self::ApiKey {
&self.key
}
fn body(&self) -> &Self::Body {
&self.body
}
}
impl<N, M> crate::model::traits::SseStreamable for ChatCompletion<N, M, StreamOn>
where
N: ModelName + Serialize + Chat,
M: Serialize,
(N, M): Bounded,
{
}
impl<N, M> crate::model::stream_ext::StreamChatLikeExt for ChatCompletion<N, M, StreamOn>
where
N: ModelName + Serialize + Chat,
M: Serialize,
(N, M): Bounded,
{
}