1use crate::core::{GenericProvider, HttpClient, Protocol};
6use crate::error::LlmConnectorError;
7use crate::types::{ChatRequest, ChatResponse, Role, Tool, ToolChoice, Choice, Message as TypeMessage};
8
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12fn extract_zhipu_reasoning_content(content: &str) -> (Option<String>, String) {
24 if content.contains("###Thinking") && content.contains("###Response") {
26 let parts: Vec<&str> = content.split("###Response").collect();
28 if parts.len() >= 2 {
29 let thinking = parts[0]
30 .replace("###Thinking", "")
31 .trim()
32 .to_string();
33 let response = parts[1..].join("###Response").trim().to_string();
34
35 if !thinking.is_empty() {
36 return (Some(thinking), response);
37 }
38 }
39 }
40
41 (None, content.to_string())
43}
44
45#[cfg(feature = "streaming")]
47#[derive(Debug, Clone, PartialEq)]
48enum ZhipuStreamPhase {
49 Initial,
51 InThinking,
53 InResponse,
55}
56
57#[cfg(feature = "streaming")]
59struct ZhipuStreamState {
60 buffer: String,
62 phase: ZhipuStreamPhase,
64}
65
66#[cfg(feature = "streaming")]
67impl ZhipuStreamState {
68 fn new() -> Self {
69 Self {
70 buffer: String::new(),
71 phase: ZhipuStreamPhase::Initial,
72 }
73 }
74
75 fn process(&mut self, delta_content: &str) -> (Option<String>, Option<String>) {
80 self.buffer.push_str(delta_content);
81
82 match self.phase {
83 ZhipuStreamPhase::Initial => {
84 if self.buffer.contains("###Thinking") {
86 self.buffer = self.buffer.replace("###Thinking", "").trim_start().to_string();
88 self.phase = ZhipuStreamPhase::InThinking;
89
90 if self.buffer.contains("###Response") {
92 return self.handle_response_marker();
93 }
94
95 let reasoning = self.buffer.clone();
97 self.buffer.clear();
98 (Some(reasoning), None)
99 } else {
100 let content = self.buffer.clone();
102 self.buffer.clear();
103 (None, Some(content))
104 }
105 }
106 ZhipuStreamPhase::InThinking => {
107 if self.buffer.contains("###Response") {
109 self.handle_response_marker()
110 } else {
111 let reasoning = self.buffer.clone();
113 self.buffer.clear();
114 (Some(reasoning), None)
115 }
116 }
117 ZhipuStreamPhase::InResponse => {
118 let content = self.buffer.clone();
120 self.buffer.clear();
121 (None, Some(content))
122 }
123 }
124 }
125
126 fn handle_response_marker(&mut self) -> (Option<String>, Option<String>) {
128 let parts: Vec<&str> = self.buffer.split("###Response").collect();
129 if parts.len() >= 2 {
130 let thinking = parts[0].trim();
132 let reasoning = if !thinking.is_empty() {
133 Some(thinking.to_string())
134 } else {
135 None
136 };
137
138 let answer = parts[1..].join("###Response").trim_start().to_string();
140 self.buffer = String::new();
141 self.phase = ZhipuStreamPhase::InResponse;
142
143 let content = if !answer.is_empty() {
144 Some(answer)
145 } else {
146 None
147 };
148
149 (reasoning, content)
150 } else {
151 (None, None)
153 }
154 }
155}
156
157#[derive(Clone, Debug)]
166pub struct ZhipuProtocol {
167 api_key: String,
168 use_openai_format: bool,
169}
170
171impl ZhipuProtocol {
172 pub fn new(api_key: &str) -> Self {
174 Self {
175 api_key: api_key.to_string(),
176 use_openai_format: false,
177 }
178 }
179
180 pub fn new_openai_compatible(api_key: &str) -> Self {
182 Self {
183 api_key: api_key.to_string(),
184 use_openai_format: true,
185 }
186 }
187
188 pub fn api_key(&self) -> &str {
190 &self.api_key
191 }
192
193 pub fn is_openai_compatible(&self) -> bool {
195 self.use_openai_format
196 }
197}
198
199#[async_trait::async_trait]
200impl Protocol for ZhipuProtocol {
201 type Request = ZhipuRequest;
202 type Response = ZhipuResponse;
203
204 fn name(&self) -> &str {
205 "zhipu"
206 }
207
208 fn chat_endpoint(&self, base_url: &str) -> String {
209 format!("{}/api/paas/v4/chat/completions", base_url)
210 }
211
212 fn auth_headers(&self) -> Vec<(String, String)> {
213 vec![
214 (
215 "Authorization".to_string(),
216 format!("Bearer {}", self.api_key),
217 ),
218 ]
221 }
222
223 fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError> {
224 let messages: Vec<ZhipuMessage> = request
226 .messages
227 .iter()
228 .map(|msg| ZhipuMessage {
229 role: match msg.role {
230 Role::System => "system".to_string(),
231 Role::User => "user".to_string(),
232 Role::Assistant => "assistant".to_string(),
233 Role::Tool => "tool".to_string(),
234 },
235 content: msg.content_as_text(),
237 tool_calls: msg.tool_calls.as_ref().map(|calls| {
238 calls.iter().map(|c| serde_json::to_value(c).unwrap_or_default()).collect()
239 }),
240 tool_call_id: msg.tool_call_id.clone(),
241 name: msg.name.clone(),
242 })
243 .collect();
244
245 Ok(ZhipuRequest {
246 model: request.model.clone(),
247 messages,
248 max_tokens: request.max_tokens,
249 temperature: request.temperature,
250 top_p: request.top_p,
251 stream: request.stream,
252 tools: request.tools.clone(),
253 tool_choice: request.tool_choice.clone(),
254 })
255 }
256
257 fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError> {
258 let parsed: ZhipuResponse = serde_json::from_str(response).map_err(|e| {
259 LlmConnectorError::InvalidRequest(format!("Failed to parse response: {}", e))
260 })?;
261
262 if let Some(choices) = parsed.choices {
263 if let Some(first_choice) = choices.first() {
264 let (reasoning_content, final_content) =
267 extract_zhipu_reasoning_content(&first_choice.message.content);
268
269 let type_message = TypeMessage {
270 role: match first_choice.message.role.as_str() {
271 "system" => Role::System,
272 "user" => Role::User,
273 "assistant" => Role::Assistant,
274 "tool" => Role::Tool,
275 _ => Role::Assistant,
276 },
277 content: vec![crate::types::MessageBlock::text(&final_content)],
278 tool_calls: first_choice.message.tool_calls.as_ref().map(|calls| {
279 calls.iter().filter_map(|v| {
280 serde_json::from_value(v.clone()).ok()
281 }).collect()
282 }),
283 ..Default::default()
284 };
285
286 let choice = Choice {
287 index: first_choice.index.unwrap_or(0),
288 message: type_message,
289 finish_reason: first_choice.finish_reason.clone(),
290 logprobs: None,
291 };
292
293 return Ok(ChatResponse {
294 id: parsed.id.unwrap_or_else(|| "unknown".to_string()),
295 object: "chat.completion".to_string(),
296 created: parsed.created.unwrap_or(0),
297 model: parsed.model.unwrap_or_else(|| "unknown".to_string()),
298 content: final_content,
299 reasoning_content,
300 choices: vec![choice],
301 usage: parsed.usage.and_then(|v| serde_json::from_value(v).ok()),
302 system_fingerprint: None,
303 });
304 }
305 }
306
307 Err(LlmConnectorError::InvalidRequest(
308 "Empty or invalid response".to_string(),
309 ))
310 }
311
312 fn map_error(&self, status: u16, body: &str) -> LlmConnectorError {
313 LlmConnectorError::from_status_code(status, format!("Zhipu API error: {}", body))
314 }
315
316 #[cfg(feature = "streaming")]
321 async fn parse_stream_response(
322 &self,
323 response: reqwest::Response,
324 ) -> Result<crate::types::ChatStream, LlmConnectorError> {
325 use crate::types::StreamingResponse;
326 use futures_util::StreamExt;
327
328 let stream = response.bytes_stream();
329
330 let events_stream = stream
331 .scan(String::new(), |buffer, chunk_result| {
332 let mut out: Vec<Result<String, LlmConnectorError>> = Vec::new();
333 match chunk_result {
334 Ok(chunk) => {
335 let chunk_str = String::from_utf8_lossy(&chunk).replace("\r\n", "\n");
336 buffer.push_str(&chunk_str);
337
338 while let Some(newline_idx) = buffer.find('\n') {
340 let line: String = buffer.drain(..newline_idx + 1).collect();
341 let trimmed = line.trim();
342
343 if trimmed.is_empty() {
345 continue;
346 }
347
348 if let Some(payload) = trimmed
350 .strip_prefix("data: ")
351 .or_else(|| trimmed.strip_prefix("data:"))
352 {
353 let payload = payload.trim();
354
355 if payload == "[DONE]" {
357 continue;
358 }
359
360 if payload.is_empty() {
362 continue;
363 }
364
365 out.push(Ok(payload.to_string()));
366 }
367 }
368 }
369 Err(e) => {
370 out.push(Err(LlmConnectorError::NetworkError(e.to_string())));
371 }
372 }
373 std::future::ready(Some(out))
374 })
375 .flat_map(futures_util::stream::iter);
376
377 let response_stream = events_stream.scan(
380 ZhipuStreamState::new(),
381 |state, result| {
382 let processed = result.and_then(|json_str| {
383 let mut response = serde_json::from_str::<StreamingResponse>(&json_str).map_err(|e| {
384 LlmConnectorError::ParseError(format!(
385 "Failed to parse Zhipu streaming response: {}. JSON: {}",
386 e, json_str
387 ))
388 })?;
389
390 if let Some(first_choice) = response.choices.first_mut() {
392 if let Some(ref delta_content) = first_choice.delta.content {
393 let (reasoning_delta, content_delta) = state.process(delta_content);
395
396 if let Some(reasoning) = reasoning_delta {
398 first_choice.delta.reasoning_content = Some(reasoning);
399 }
400
401 if let Some(content) = content_delta {
402 first_choice.delta.content = Some(content.clone());
403 response.content = content;
405 } else {
406 first_choice.delta.content = None;
408 response.content = String::new();
409 }
410 }
411 }
412
413 Ok(response)
414 });
415
416 std::future::ready(Some(processed))
417 }
418 );
419
420 Ok(Box::pin(response_stream))
421 }
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct ZhipuRequest {
427 pub model: String,
428 pub messages: Vec<ZhipuMessage>,
429 #[serde(skip_serializing_if = "Option::is_none")]
430 pub max_tokens: Option<u32>,
431 #[serde(skip_serializing_if = "Option::is_none")]
432 pub temperature: Option<f32>,
433 #[serde(skip_serializing_if = "Option::is_none")]
434 pub top_p: Option<f32>,
435 #[serde(skip_serializing_if = "Option::is_none")]
436 pub stream: Option<bool>,
437 #[serde(skip_serializing_if = "Option::is_none")]
438 pub tools: Option<Vec<Tool>>,
439 #[serde(skip_serializing_if = "Option::is_none")]
440 pub tool_choice: Option<ToolChoice>,
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct ZhipuMessage {
445 pub role: String,
446 #[serde(default)]
447 pub content: String,
448 #[serde(skip_serializing_if = "Option::is_none")]
449 pub tool_calls: Option<Vec<serde_json::Value>>,
450 #[serde(skip_serializing_if = "Option::is_none")]
451 pub tool_call_id: Option<String>,
452 #[serde(skip_serializing_if = "Option::is_none")]
453 pub name: Option<String>,
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize)]
457pub struct ZhipuResponse {
458 pub id: Option<String>,
459 pub created: Option<u64>,
460 pub model: Option<String>,
461 pub choices: Option<Vec<ZhipuChoice>>,
462 pub usage: Option<serde_json::Value>,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct ZhipuChoice {
467 pub index: Option<u32>,
468 pub message: ZhipuMessage,
469 pub finish_reason: Option<String>,
470}
471
472pub type ZhipuProvider = GenericProvider<ZhipuProtocol>;
478
479pub fn zhipu(api_key: &str) -> Result<ZhipuProvider, LlmConnectorError> {
494 zhipu_with_config(api_key, false, None, None, None)
495}
496
497pub fn zhipu_openai_compatible(api_key: &str) -> Result<ZhipuProvider, LlmConnectorError> {
512 zhipu_with_config(api_key, true, None, None, None)
513}
514
515pub fn zhipu_with_config(
537 api_key: &str,
538 openai_compatible: bool,
539 base_url: Option<&str>,
540 timeout_secs: Option<u64>,
541 proxy: Option<&str>,
542) -> Result<ZhipuProvider, LlmConnectorError> {
543 let protocol = if openai_compatible {
545 ZhipuProtocol::new_openai_compatible(api_key)
546 } else {
547 ZhipuProtocol::new(api_key)
548 };
549
550 let client = HttpClient::with_config(
552 base_url.unwrap_or("https://open.bigmodel.cn"),
553 timeout_secs,
554 proxy,
555 )?;
556
557 let auth_headers: HashMap<String, String> = protocol.auth_headers().into_iter().collect();
559 let client = client.with_headers(auth_headers);
560
561 Ok(GenericProvider::new(protocol, client))
563}
564
565pub fn zhipu_with_timeout(
579 api_key: &str,
580 timeout_secs: u64,
581) -> Result<ZhipuProvider, LlmConnectorError> {
582 zhipu_with_config(api_key, true, None, Some(timeout_secs), None)
583}
584
585pub fn zhipu_enterprise(
601 api_key: &str,
602 enterprise_endpoint: &str,
603) -> Result<ZhipuProvider, LlmConnectorError> {
604 zhipu_with_config(api_key, true, Some(enterprise_endpoint), None, None)
605}
606
607pub fn validate_zhipu_key(api_key: &str) -> bool {
623 !api_key.is_empty() && api_key.len() > 10
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_zhipu_provider_creation() {
632 let provider = zhipu("test-key");
633 assert!(provider.is_ok());
634
635 let provider = provider.unwrap();
636 assert_eq!(provider.protocol().name(), "zhipu");
637 }
638
639 #[test]
640 fn test_zhipu_openai_compatible() {
641 let provider = zhipu_openai_compatible("test-key");
642 assert!(provider.is_ok());
643
644 let provider = provider.unwrap();
645 assert_eq!(provider.protocol().name(), "zhipu");
646 assert!(provider.protocol().is_openai_compatible());
647 }
648
649 #[test]
650 fn test_zhipu_with_config() {
651 let provider = zhipu_with_config(
652 "test-key",
653 true,
654 Some("https://custom.bigmodel.cn"),
655 Some(60),
656 None,
657 );
658 assert!(provider.is_ok());
659
660 let provider = provider.unwrap();
661 assert_eq!(provider.client().base_url(), "https://custom.bigmodel.cn");
662 assert!(provider.protocol().is_openai_compatible());
663 }
664
665 #[test]
666 fn test_zhipu_with_timeout() {
667 let provider = zhipu_with_timeout("test-key", 120);
668 assert!(provider.is_ok());
669 }
670
671 #[test]
672 fn test_zhipu_enterprise() {
673 let provider = zhipu_enterprise("test-key", "https://enterprise.bigmodel.cn");
674 assert!(provider.is_ok());
675
676 let provider = provider.unwrap();
677 assert_eq!(
678 provider.client().base_url(),
679 "https://enterprise.bigmodel.cn"
680 );
681 }
682
683 #[test]
684 fn test_validate_zhipu_key() {
685 assert!(validate_zhipu_key("valid-test-key"));
686 assert!(validate_zhipu_key("another-valid-key-12345"));
687 assert!(!validate_zhipu_key("short"));
688 assert!(!validate_zhipu_key(""));
689 }
690
691 #[test]
692 fn test_extract_zhipu_reasoning_content() {
693 let content_with_thinking = "###Thinking\n这是推理过程\n分析步骤1\n分析步骤2\n###Response\n这是最终答案";
695 let (reasoning, answer) = extract_zhipu_reasoning_content(content_with_thinking);
696 assert!(reasoning.is_some());
697 assert_eq!(reasoning.unwrap(), "这是推理过程\n分析步骤1\n分析步骤2");
698 assert_eq!(answer, "这是最终答案");
699
700 let content_without_thinking = "这只是一个普通的回答";
702 let (reasoning, answer) = extract_zhipu_reasoning_content(content_without_thinking);
703 assert!(reasoning.is_none());
704 assert_eq!(answer, "这只是一个普通的回答");
705
706 let content_only_thinking = "###Thinking\n这是推理过程";
708 let (reasoning, answer) = extract_zhipu_reasoning_content(content_only_thinking);
709 assert!(reasoning.is_none());
710 assert_eq!(answer, "###Thinking\n这是推理过程");
711
712 let content_empty_thinking = "###Thinking\n\n###Response\n答案";
714 let (reasoning, answer) = extract_zhipu_reasoning_content(content_empty_thinking);
715 assert!(reasoning.is_none());
716 assert_eq!(answer, "###Thinking\n\n###Response\n答案");
717 }
718
719 #[cfg(feature = "streaming")]
720 #[test]
721 fn test_zhipu_stream_state() {
722 let mut state = ZhipuStreamState::new();
724
725 let (reasoning, content) = state.process("###Thinking\n开始");
727 assert_eq!(reasoning, Some("开始".to_string()));
728 assert_eq!(content, None);
729
730 let (reasoning, content) = state.process("推理");
732 assert_eq!(reasoning, Some("推理".to_string()));
733 assert_eq!(content, None);
734
735 let (reasoning, content) = state.process("过程\n###Response\n答案");
737 assert_eq!(reasoning, Some("过程".to_string()));
738 assert_eq!(content, Some("答案".to_string()));
739
740 let (reasoning, content) = state.process("继续");
742 assert_eq!(reasoning, None);
743 assert_eq!(content, Some("继续".to_string()));
744 }
745
746 #[cfg(feature = "streaming")]
747 #[test]
748 fn test_zhipu_stream_state_non_reasoning() {
749 let mut state = ZhipuStreamState::new();
751
752 let (reasoning, content) = state.process("这是");
754 assert_eq!(reasoning, None);
755 assert_eq!(content, Some("这是".to_string()));
756
757 let (reasoning, content) = state.process("普通回答");
759 assert_eq!(reasoning, None);
760 assert_eq!(content, Some("普通回答".to_string()));
761 }
762
763 #[cfg(feature = "streaming")]
764 #[test]
765 fn test_zhipu_stream_state_complete_in_one_chunk() {
766 let mut state = ZhipuStreamState::new();
768
769 let (reasoning, content) = state.process("###Thinking\n推理过程\n###Response\n答案");
770 assert_eq!(reasoning, Some("推理过程".to_string()));
771 assert_eq!(content, Some("答案".to_string()));
772 }
773}