use crate::core::{GenericProvider, HttpClient, Protocol};
use crate::error::LlmConnectorError;
use crate::types::{
ChatRequest, ChatResponse, Choice, ImageSource, Message as TypeMessage, MessageBlock, Role,
Tool, ToolChoice,
};
use serde_json::{Value, json};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
fn extract_zhipu_reasoning_content(content: &str) -> (Option<String>, String) {
if content.contains("###Thinking") && content.contains("###Response") {
let parts: Vec<&str> = content.split("###Response").collect();
if parts.len() >= 2 {
let thinking = parts[0].replace("###Thinking", "").trim().to_string();
let response = parts[1..].join("###Response").trim().to_string();
if !thinking.is_empty() {
return (Some(thinking), response);
}
}
}
(None, content.to_string())
}
#[cfg(feature = "streaming")]
#[derive(Debug, Clone, PartialEq)]
enum ZhipuStreamPhase {
Initial,
InThinking,
InResponse,
}
#[cfg(feature = "streaming")]
struct ZhipuStreamState {
buffer: String,
phase: ZhipuStreamPhase,
}
#[cfg(feature = "streaming")]
impl ZhipuStreamState {
fn new() -> Self {
Self {
buffer: String::new(),
phase: ZhipuStreamPhase::Initial,
}
}
fn process(&mut self, delta_content: &str) -> (Option<String>, Option<String>) {
self.buffer.push_str(delta_content);
match self.phase {
ZhipuStreamPhase::Initial => {
if self.buffer.contains("###Thinking") {
self.buffer = self
.buffer
.replace("###Thinking", "")
.trim_start()
.to_string();
self.phase = ZhipuStreamPhase::InThinking;
if self.buffer.contains("###Response") {
return self.handle_response_marker();
}
let reasoning = self.buffer.clone();
self.buffer.clear();
(Some(reasoning), None)
} else {
let content = self.buffer.clone();
self.buffer.clear();
(None, Some(content))
}
}
ZhipuStreamPhase::InThinking => {
if self.buffer.contains("###Response") {
self.handle_response_marker()
} else {
let reasoning = self.buffer.clone();
self.buffer.clear();
(Some(reasoning), None)
}
}
ZhipuStreamPhase::InResponse => {
let content = self.buffer.clone();
self.buffer.clear();
(None, Some(content))
}
}
}
fn handle_response_marker(&mut self) -> (Option<String>, Option<String>) {
let parts: Vec<&str> = self.buffer.split("###Response").collect();
if parts.len() >= 2 {
let thinking = parts[0].trim();
let reasoning = if !thinking.is_empty() {
Some(thinking.to_string())
} else {
None
};
let answer = parts[1..].join("###Response").trim_start().to_string();
self.buffer = String::new();
self.phase = ZhipuStreamPhase::InResponse;
let content = if !answer.is_empty() {
Some(answer)
} else {
None
};
(reasoning, content)
} else {
(None, None)
}
}
}
#[derive(Clone, Debug)]
pub struct ZhipuProtocol {
api_key: String,
use_openai_format: bool,
}
impl ZhipuProtocol {
pub fn new(api_key: &str) -> Self {
Self {
api_key: api_key.to_string(),
use_openai_format: false,
}
}
pub fn new_openai_compatible(api_key: &str) -> Self {
Self {
api_key: api_key.to_string(),
use_openai_format: true,
}
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn is_openai_compatible(&self) -> bool {
self.use_openai_format
}
}
#[async_trait::async_trait]
impl Protocol for ZhipuProtocol {
type Request = ZhipuRequest;
type Response = ZhipuResponse;
fn name(&self) -> &str {
"zhipu"
}
fn chat_endpoint(&self, base_url: &str) -> String {
format!("{}/api/paas/v4/chat/completions", base_url)
}
fn auth_headers(&self) -> Vec<(String, String)> {
vec![
(
"Authorization".to_string(),
format!("Bearer {}", self.api_key),
),
]
}
fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError> {
let messages: Vec<ZhipuMessage> = request
.messages
.iter()
.map(|msg| {
let role = match msg.role {
Role::System => "system".to_string(),
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::Tool => "tool".to_string(),
};
let has_image = msg.content.iter().any(|block| block.is_image());
let content = if has_image {
let blocks: Vec<Value> = msg.content.iter().map(|block| {
match block {
MessageBlock::Text { text } => json!({
"type": "text",
"text": text
}),
MessageBlock::Image { source } => json!({
"type": "image_url",
"image_url": {
"url": match source {
ImageSource::Base64 { media_type, data } => format!("data:{};base64,{}", media_type, data),
ImageSource::Url { url } => url.clone(),
}
}
}),
MessageBlock::ImageUrl { image_url } => json!({
"type": "image_url",
"image_url": { "url": image_url.url }
}),
}
}).collect();
json!(blocks)
} else {
json!(msg.content_as_text())
};
ZhipuMessage {
role,
content,
tool_calls: msg.tool_calls.as_ref().map(|calls| {
calls.iter().map(|c| serde_json::to_value(c).unwrap_or_default()).collect()
}),
tool_call_id: msg.tool_call_id.clone(),
name: msg.name.clone(),
}
})
.collect();
Ok(ZhipuRequest {
model: request.model.clone(),
messages,
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
stream: request.stream,
tools: request.tools.clone(),
tool_choice: request.tool_choice.clone(),
})
}
fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError> {
let parsed: ZhipuResponse = serde_json::from_str(response).map_err(|e| {
LlmConnectorError::InvalidRequest(format!("Failed to parse response: {}", e))
})?;
if let Some(choices) = parsed.choices
&& let Some(first_choice) = choices.first()
{
let content_str = match &first_choice.message.content {
Value::String(s) => s.clone(),
other => other.to_string(),
};
let (reasoning_content, final_content) = extract_zhipu_reasoning_content(&content_str);
let type_message = TypeMessage {
role: match first_choice.message.role.as_str() {
"system" => Role::System,
"user" => Role::User,
"assistant" => Role::Assistant,
"tool" => Role::Tool,
_ => Role::Assistant,
},
content: vec![crate::types::MessageBlock::text(&final_content)],
tool_calls: first_choice.message.tool_calls.as_ref().map(|calls| {
calls
.iter()
.filter_map(|v| serde_json::from_value(v.clone()).ok())
.collect()
}),
..Default::default()
};
let choice = Choice {
index: first_choice.index.unwrap_or(0),
message: type_message,
finish_reason: first_choice.finish_reason.clone(),
logprobs: None,
};
return Ok(ChatResponse {
id: parsed.id.unwrap_or_else(|| "unknown".to_string()),
object: "chat.completion".to_string(),
created: parsed.created.unwrap_or(0),
model: parsed.model.unwrap_or_else(|| "unknown".to_string()),
content: final_content,
reasoning_content,
choices: vec![choice],
usage: parsed.usage.and_then(|v| serde_json::from_value(v).ok()),
system_fingerprint: None,
});
}
Err(LlmConnectorError::InvalidRequest(
"Empty or invalid response".to_string(),
))
}
fn map_error(&self, status: u16, body: &str) -> LlmConnectorError {
let body_lower = body.to_lowercase();
if body_lower.contains("context_length_exceeded")
|| body_lower.contains("maximum context length")
|| body_lower.contains("token limit")
{
return LlmConnectorError::ContextLengthExceeded(format!("Zhipu: {}", body));
}
LlmConnectorError::from_status_code(status, format!("Zhipu API error: {}", body))
}
#[cfg(feature = "streaming")]
async fn parse_stream_response(
&self,
response: reqwest::Response,
) -> Result<crate::types::ChatStream, LlmConnectorError> {
use crate::types::StreamingResponse;
use futures_util::StreamExt;
let stream = response.bytes_stream();
let events_stream = stream
.scan(String::new(), |buffer, chunk_result| {
let mut out: Vec<Result<String, LlmConnectorError>> = Vec::new();
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk).replace("\r\n", "\n");
buffer.push_str(&chunk_str);
while let Some(newline_idx) = buffer.find('\n') {
let line: String = buffer.drain(..newline_idx + 1).collect();
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if let Some(payload) = trimmed
.strip_prefix("data: ")
.or_else(|| trimmed.strip_prefix("data:"))
{
let payload = payload.trim();
if payload == "[DONE]" {
continue;
}
if payload.is_empty() {
continue;
}
out.push(Ok(payload.to_string()));
}
}
}
Err(e) => {
out.push(Err(LlmConnectorError::NetworkError(e.to_string())));
}
}
std::future::ready(Some(out))
})
.flat_map(futures_util::stream::iter);
let response_stream = events_stream.scan(ZhipuStreamState::new(), |state, result| {
let processed = result.and_then(|json_str| {
let mut response =
serde_json::from_str::<StreamingResponse>(&json_str).map_err(|e| {
LlmConnectorError::ParseError(format!(
"Failed to parse Zhipu streaming response: {}. JSON: {}",
e, json_str
))
})?;
if let Some(first_choice) = response.choices.first_mut()
&& let Some(ref delta_content) = first_choice.delta.content
{
let (reasoning_delta, content_delta) = state.process(delta_content);
if let Some(reasoning) = reasoning_delta {
first_choice.delta.reasoning_content = Some(reasoning);
}
if let Some(content) = content_delta {
first_choice.delta.content = Some(content.clone());
response.content = content;
} else {
first_choice.delta.content = None;
response.content = String::new();
}
}
Ok(response)
});
std::future::ready(Some(processed))
});
Ok(Box::pin(response_stream))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZhipuRequest {
pub model: String,
pub messages: Vec<ZhipuMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZhipuMessage {
pub role: String,
#[serde(default)]
pub content: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZhipuResponse {
pub id: Option<String>,
pub created: Option<u64>,
pub model: Option<String>,
pub choices: Option<Vec<ZhipuChoice>>,
pub usage: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZhipuChoice {
pub index: Option<u32>,
pub message: ZhipuMessage,
pub finish_reason: Option<String>,
}
pub type ZhipuProvider = GenericProvider<ZhipuProtocol>;
pub fn zhipu(api_key: &str) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_with_config(api_key, false, None, None, None)
}
pub fn zhipu_openai_compatible(api_key: &str) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_with_config(api_key, true, None, None, None)
}
pub fn zhipu_with_config(
api_key: &str,
openai_compatible: bool,
base_url: Option<&str>,
timeout_secs: Option<u64>,
proxy: Option<&str>,
) -> Result<ZhipuProvider, LlmConnectorError> {
let protocol = if openai_compatible {
ZhipuProtocol::new_openai_compatible(api_key)
} else {
ZhipuProtocol::new(api_key)
};
let client = HttpClient::with_config(
base_url.unwrap_or("https://open.bigmodel.cn"),
timeout_secs,
proxy,
)?;
let auth_headers: HashMap<String, String> = protocol.auth_headers().into_iter().collect();
let client = client.with_headers(auth_headers);
Ok(GenericProvider::new(protocol, client))
}
pub fn zhipu_with_timeout(
api_key: &str,
timeout_secs: u64,
) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_with_config(api_key, true, None, Some(timeout_secs), None)
}
pub fn zhipu_enterprise(
api_key: &str,
enterprise_endpoint: &str,
) -> Result<ZhipuProvider, LlmConnectorError> {
zhipu_with_config(api_key, true, Some(enterprise_endpoint), None, None)
}
pub fn validate_zhipu_key(api_key: &str) -> bool {
!api_key.is_empty() && api_key.len() > 10
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zhipu_provider_creation() {
let provider = zhipu("test-key");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.protocol().name(), "zhipu");
}
#[test]
fn test_zhipu_openai_compatible() {
let provider = zhipu_openai_compatible("test-key");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.protocol().name(), "zhipu");
assert!(provider.protocol().is_openai_compatible());
}
#[test]
fn test_zhipu_with_config() {
let provider = zhipu_with_config(
"test-key",
true,
Some("https://custom.bigmodel.cn"),
Some(60),
None,
);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.client().base_url(), "https://custom.bigmodel.cn");
assert!(provider.protocol().is_openai_compatible());
}
#[test]
fn test_zhipu_with_timeout() {
let provider = zhipu_with_timeout("test-key", 120);
assert!(provider.is_ok());
}
#[test]
fn test_zhipu_enterprise() {
let provider = zhipu_enterprise("test-key", "https://enterprise.bigmodel.cn");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(
provider.client().base_url(),
"https://enterprise.bigmodel.cn"
);
}
#[test]
fn test_validate_zhipu_key() {
assert!(validate_zhipu_key("valid-test-key"));
assert!(validate_zhipu_key("another-valid-key-12345"));
assert!(!validate_zhipu_key("short"));
assert!(!validate_zhipu_key(""));
}
#[test]
fn test_extract_zhipu_reasoning_content() {
let content_with_thinking = "###Thinking\nthis_is_reasoning_process\nanalysis_step_1\nanalysis_step_2\n###Response\nthis_is_final_answer";
let (reasoning, answer) = extract_zhipu_reasoning_content(content_with_thinking);
assert!(reasoning.is_some());
assert_eq!(
reasoning.unwrap(),
"this_is_reasoning_process\nanalysis_step_1\nanalysis_step_2"
);
assert_eq!(answer, "this_is_final_answer");
let content_without_thinking = "this_is_a_normal_answer";
let (reasoning, answer) = extract_zhipu_reasoning_content(content_without_thinking);
assert!(reasoning.is_none());
assert_eq!(answer, "this_is_a_normal_answer");
let content_only_thinking = "###Thinking\nthis_is_reasoning_process";
let (reasoning, answer) = extract_zhipu_reasoning_content(content_only_thinking);
assert!(reasoning.is_none());
assert_eq!(answer, "###Thinking\nthis_is_reasoning_process");
let content_empty_thinking = "###Thinking\n\n###Response\nanswer";
let (reasoning, answer) = extract_zhipu_reasoning_content(content_empty_thinking);
assert!(reasoning.is_none());
assert_eq!(answer, "###Thinking\n\n###Response\nanswer");
}
#[test]
fn test_zhipu_build_request_text_only() {
use crate::types::Message;
let protocol = ZhipuProtocol::new("test-key");
let request = ChatRequest {
model: "glm-4".to_string(),
messages: vec![Message::user("Hello")],
..Default::default()
};
let zhipu_req = protocol.build_request(&request).unwrap();
assert_eq!(zhipu_req.messages[0].content, json!("Hello"));
}
#[test]
fn test_zhipu_build_request_with_image_url() {
use crate::types::Message;
let protocol = ZhipuProtocol::new("test-key");
let mut msg = Message::user("Describe this image");
msg.content
.push(MessageBlock::image_url("https://example.com/cat.jpg"));
let request = ChatRequest {
model: "glm-4v".to_string(),
messages: vec![msg],
..Default::default()
};
let zhipu_req = protocol.build_request(&request).unwrap();
let content = &zhipu_req.messages[0].content;
assert!(content.is_array());
let arr = content.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["type"], "text");
assert_eq!(arr[0]["text"], "Describe this image");
assert_eq!(arr[1]["type"], "image_url");
assert_eq!(arr[1]["image_url"]["url"], "https://example.com/cat.jpg");
}
#[test]
fn test_zhipu_build_request_with_base64_image() {
use crate::types::Message;
let protocol = ZhipuProtocol::new("test-key");
let mut msg = Message::user("What is this?");
msg.content
.push(MessageBlock::image_base64("image/jpeg", "abc123"));
let request = ChatRequest {
model: "glm-4v".to_string(),
messages: vec![msg],
..Default::default()
};
let zhipu_req = protocol.build_request(&request).unwrap();
let content = &zhipu_req.messages[0].content;
assert!(content.is_array());
let arr = content.as_array().unwrap();
assert_eq!(arr[1]["type"], "image_url");
assert_eq!(arr[1]["image_url"]["url"], "data:image/jpeg;base64,abc123");
}
#[cfg(feature = "streaming")]
#[test]
fn test_zhipu_stream_state() {
let mut state = ZhipuStreamState::new();
let (reasoning, content) = state.process("###Thinking\nstart");
assert_eq!(reasoning, Some("start".to_string()));
assert_eq!(content, None);
let (reasoning, content) = state.process("reasoning");
assert_eq!(reasoning, Some("reasoning".to_string()));
assert_eq!(content, None);
let (reasoning, content) = state.process("process\n###Response\nanswer");
assert_eq!(reasoning, Some("process".to_string()));
assert_eq!(content, Some("answer".to_string()));
let (reasoning, content) = state.process("continue");
assert_eq!(reasoning, None);
assert_eq!(content, Some("continue".to_string()));
}
#[cfg(feature = "streaming")]
#[test]
fn test_zhipu_stream_state_non_reasoning() {
let mut state = ZhipuStreamState::new();
let (reasoning, content) = state.process("this_is");
assert_eq!(reasoning, None);
assert_eq!(content, Some("this_is".to_string()));
let (reasoning, content) = state.process("normal_answer");
assert_eq!(reasoning, None);
assert_eq!(content, Some("normal_answer".to_string()));
}
#[cfg(feature = "streaming")]
#[test]
fn test_zhipu_stream_state_complete_in_one_chunk() {
let mut state = ZhipuStreamState::new();
let (reasoning, content) =
state.process("###Thinking\nReasoning process\n###Response\nanswer");
assert_eq!(reasoning, Some("Reasoning process".to_string()));
assert_eq!(content, Some("answer".to_string()));
}
}