1use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use reqwest::Client;
15use serde::Deserialize;
16use serde_json::{Value, json};
17
18pub struct ZaiProvider {
19 client: Client,
20 api_key: String,
21 base_url: String,
22}
23
24impl std::fmt::Debug for ZaiProvider {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("ZaiProvider")
27 .field("base_url", &self.base_url)
28 .field("api_key", &"<REDACTED>")
29 .finish()
30 }
31}
32
33impl ZaiProvider {
34 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
35 tracing::debug!(
36 provider = "zai",
37 base_url = %base_url,
38 api_key_len = api_key.len(),
39 "Creating Z.AI provider with custom base URL"
40 );
41 Ok(Self {
42 client: Client::new(),
43 api_key,
44 base_url,
45 })
46 }
47
48 fn convert_messages(messages: &[Message], include_reasoning_content: bool) -> Vec<Value> {
49 messages
50 .iter()
51 .map(|msg| {
52 let role = match msg.role {
53 Role::System => "system",
54 Role::User => "user",
55 Role::Assistant => "assistant",
56 Role::Tool => "tool",
57 };
58
59 match msg.role {
60 Role::Tool => {
61 if let Some(ContentPart::ToolResult {
62 tool_call_id,
63 content,
64 }) = msg.content.first()
65 {
66 json!({
67 "role": "tool",
68 "tool_call_id": tool_call_id,
69 "content": content
70 })
71 } else {
72 json!({"role": role, "content": ""})
73 }
74 }
75 Role::Assistant => {
76 let text: String = msg
77 .content
78 .iter()
79 .filter_map(|p| match p {
80 ContentPart::Text { text } => Some(text.clone()),
81 _ => None,
82 })
83 .collect::<Vec<_>>()
84 .join("");
85
86 let tool_calls: Vec<Value> = msg
87 .content
88 .iter()
89 .filter_map(|p| match p {
90 ContentPart::ToolCall {
91 id,
92 name,
93 arguments,
94 ..
95 } => {
96 let args_string = serde_json::from_str::<Value>(arguments)
99 .map(|parsed| {
100 serde_json::to_string(&parsed)
101 .unwrap_or_else(|_| "{}".to_string())
102 })
103 .unwrap_or_else(|_| {
104 json!({"input": arguments}).to_string()
105 });
106 Some(json!({
107 "id": id,
108 "type": "function",
109 "function": {
110 "name": name,
111 "arguments": args_string
112 }
113 }))
114 }
115 _ => None,
116 })
117 .collect();
118
119 let mut msg_json = json!({
120 "role": "assistant",
121 "content": if text.is_empty() { "".to_string() } else { text },
122 });
123 if include_reasoning_content {
124 let reasoning: String = msg
125 .content
126 .iter()
127 .filter_map(|p| match p {
128 ContentPart::Thinking { text } => Some(text.clone()),
129 _ => None,
130 })
131 .collect::<Vec<_>>()
132 .join("");
133 if !reasoning.is_empty() {
134 msg_json["reasoning_content"] = json!(reasoning);
135 }
136 }
137 if !tool_calls.is_empty() {
138 msg_json["tool_calls"] = json!(tool_calls);
139 }
140 msg_json
141 }
142 _ => {
143 let text: String = msg
144 .content
145 .iter()
146 .filter_map(|p| match p {
147 ContentPart::Text { text } => Some(text.clone()),
148 _ => None,
149 })
150 .collect::<Vec<_>>()
151 .join("\n");
152
153 json!({"role": role, "content": text})
154 }
155 }
156 })
157 .collect()
158 }
159
160 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
161 tools
162 .iter()
163 .map(|t| {
164 json!({
165 "type": "function",
166 "function": {
167 "name": t.name,
168 "description": t.description,
169 "parameters": t.parameters
170 }
171 })
172 })
173 .collect()
174 }
175
176 fn model_supports_tool_stream(model: &str) -> bool {
177 model.contains("glm-5") || model.contains("glm-4.7") || model.contains("glm-4.6")
178 }
179
180 fn preview_text(text: &str, max_chars: usize) -> &str {
181 if max_chars == 0 {
182 return "";
183 }
184 if let Some((idx, _)) = text.char_indices().nth(max_chars) {
185 &text[..idx]
186 } else {
187 text
188 }
189 }
190}
191
192#[derive(Debug, Deserialize)]
193struct ZaiResponse {
194 choices: Vec<ZaiChoice>,
195 #[serde(default)]
196 usage: Option<ZaiUsage>,
197}
198
199#[derive(Debug, Deserialize)]
200struct ZaiChoice {
201 message: ZaiMessage,
202 #[serde(default)]
203 finish_reason: Option<String>,
204}
205
206#[derive(Debug, Deserialize)]
207struct ZaiMessage {
208 #[serde(default)]
209 content: Option<String>,
210 #[serde(default)]
211 tool_calls: Option<Vec<ZaiToolCall>>,
212 #[serde(default)]
213 reasoning_content: Option<String>,
214}
215
216#[derive(Debug, Deserialize)]
217struct ZaiToolCall {
218 id: String,
219 function: ZaiFunction,
220}
221
222#[derive(Debug, Deserialize)]
223struct ZaiFunction {
224 name: String,
225 arguments: Value,
226}
227
228#[derive(Debug, Deserialize)]
229struct ZaiUsage {
230 #[serde(default)]
231 prompt_tokens: usize,
232 #[serde(default)]
233 completion_tokens: usize,
234 #[serde(default)]
235 total_tokens: usize,
236 #[serde(default)]
237 prompt_tokens_details: Option<ZaiPromptTokensDetails>,
238}
239
240#[derive(Debug, Deserialize)]
241struct ZaiPromptTokensDetails {
242 #[serde(default)]
243 cached_tokens: usize,
244}
245
246#[derive(Debug, Deserialize)]
247struct ZaiError {
248 error: ZaiErrorDetail,
249}
250
251#[derive(Debug, Deserialize)]
252struct ZaiErrorDetail {
253 message: String,
254 #[serde(default, rename = "type")]
255 error_type: Option<String>,
256}
257
258#[derive(Debug, Deserialize)]
260struct ZaiStreamResponse {
261 choices: Vec<ZaiStreamChoice>,
262}
263
264#[derive(Debug, Deserialize)]
265struct ZaiStreamChoice {
266 delta: ZaiStreamDelta,
267 #[serde(default)]
268 finish_reason: Option<String>,
269}
270
271#[derive(Debug, Deserialize)]
272struct ZaiStreamDelta {
273 #[serde(default)]
274 content: Option<String>,
275 #[serde(default)]
276 reasoning_content: Option<String>,
277 #[serde(default)]
278 tool_calls: Option<Vec<ZaiStreamToolCall>>,
279}
280
281#[derive(Debug, Deserialize)]
282struct ZaiStreamToolCall {
283 #[serde(default)]
284 id: Option<String>,
285 function: Option<ZaiStreamFunction>,
286}
287
288#[derive(Debug, Deserialize)]
289struct ZaiStreamFunction {
290 #[serde(default)]
291 name: Option<String>,
292 #[serde(default)]
293 arguments: Option<Value>,
294}
295
296#[async_trait]
297impl Provider for ZaiProvider {
298 fn name(&self) -> &str {
299 "zai"
300 }
301
302 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
303 Ok(vec![
304 ModelInfo {
305 id: "glm-5".to_string(),
306 name: "GLM-5".to_string(),
307 provider: "zai".to_string(),
308 context_window: 200_000,
309 max_output_tokens: Some(128_000),
310 supports_vision: false,
311 supports_tools: true,
312 supports_streaming: true,
313 input_cost_per_million: None,
314 output_cost_per_million: None,
315 },
316 ModelInfo {
317 id: "glm-4.7".to_string(),
318 name: "GLM-4.7".to_string(),
319 provider: "zai".to_string(),
320 context_window: 128_000,
321 max_output_tokens: Some(128_000),
322 supports_vision: false,
323 supports_tools: true,
324 supports_streaming: true,
325 input_cost_per_million: None,
326 output_cost_per_million: None,
327 },
328 ModelInfo {
329 id: "glm-4.7-flash".to_string(),
330 name: "GLM-4.7 Flash".to_string(),
331 provider: "zai".to_string(),
332 context_window: 128_000,
333 max_output_tokens: Some(128_000),
334 supports_vision: false,
335 supports_tools: true,
336 supports_streaming: true,
337 input_cost_per_million: None,
338 output_cost_per_million: None,
339 },
340 ModelInfo {
341 id: "glm-4.6".to_string(),
342 name: "GLM-4.6".to_string(),
343 provider: "zai".to_string(),
344 context_window: 128_000,
345 max_output_tokens: Some(128_000),
346 supports_vision: false,
347 supports_tools: true,
348 supports_streaming: true,
349 input_cost_per_million: None,
350 output_cost_per_million: None,
351 },
352 ModelInfo {
353 id: "glm-4.5".to_string(),
354 name: "GLM-4.5".to_string(),
355 provider: "zai".to_string(),
356 context_window: 128_000,
357 max_output_tokens: Some(96_000),
358 supports_vision: false,
359 supports_tools: true,
360 supports_streaming: true,
361 input_cost_per_million: None,
362 output_cost_per_million: None,
363 },
364 ])
365 }
366
367 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
368 let messages = Self::convert_messages(&request.messages, false);
372 let tools = Self::convert_tools(&request.tools);
373
374 let temperature = request.temperature.unwrap_or(1.0);
376
377 let mut body = json!({
378 "model": request.model,
379 "messages": messages,
380 "temperature": temperature,
381 });
382
383 body["thinking"] = json!({
386 "type": "enabled"
387 });
388
389 if !tools.is_empty() {
390 body["tools"] = json!(tools);
391 }
392 if let Some(max) = request.max_tokens {
393 body["max_tokens"] = json!(max);
394 }
395
396 tracing::debug!(model = %request.model, "Z.AI request");
397
398 let response = self
399 .client
400 .post(format!("{}/chat/completions", self.base_url))
401 .header("Authorization", format!("Bearer {}", self.api_key))
402 .header("Content-Type", "application/json")
403 .json(&body)
404 .send()
405 .await
406 .context("Failed to send request to Z.AI")?;
407
408 let status = response.status();
409 let text = response
410 .text()
411 .await
412 .context("Failed to read Z.AI response")?;
413
414 if !status.is_success() {
415 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
416 anyhow::bail!(
417 "Z.AI API error: {} ({:?})",
418 err.error.message,
419 err.error.error_type
420 );
421 }
422 anyhow::bail!("Z.AI API error: {} {}", status, text);
423 }
424
425 let response: ZaiResponse = serde_json::from_str(&text).context(format!(
426 "Failed to parse Z.AI response: {}",
427 Self::preview_text(&text, 200)
428 ))?;
429
430 let choice = response
431 .choices
432 .first()
433 .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
434
435 if let Some(ref reasoning) = choice.message.reasoning_content {
437 if !reasoning.is_empty() {
438 tracing::info!(
439 reasoning_len = reasoning.len(),
440 "Z.AI reasoning content received"
441 );
442 }
443 }
444
445 let mut content = Vec::new();
446 let mut has_tool_calls = false;
447
448 if let Some(ref reasoning) = choice.message.reasoning_content {
450 if !reasoning.is_empty() {
451 content.push(ContentPart::Thinking {
452 text: reasoning.clone(),
453 });
454 }
455 }
456
457 if let Some(text) = &choice.message.content {
458 if !text.is_empty() {
459 content.push(ContentPart::Text { text: text.clone() });
460 }
461 }
462
463 if let Some(tool_calls) = &choice.message.tool_calls {
464 has_tool_calls = !tool_calls.is_empty();
465 for tc in tool_calls {
466 let arguments = match &tc.function.arguments {
468 Value::String(s) => s.clone(),
469 other => serde_json::to_string(other).unwrap_or_default(),
470 };
471 content.push(ContentPart::ToolCall {
472 id: tc.id.clone(),
473 name: tc.function.name.clone(),
474 arguments,
475 thought_signature: None,
476 });
477 }
478 }
479
480 let finish_reason = if has_tool_calls {
481 FinishReason::ToolCalls
482 } else {
483 match choice.finish_reason.as_deref() {
484 Some("stop") => FinishReason::Stop,
485 Some("length") => FinishReason::Length,
486 Some("tool_calls") => FinishReason::ToolCalls,
487 Some("sensitive") => FinishReason::ContentFilter,
488 _ => FinishReason::Stop,
489 }
490 };
491
492 Ok(CompletionResponse {
493 message: Message {
494 role: Role::Assistant,
495 content,
496 },
497 usage: Usage {
498 prompt_tokens: response
499 .usage
500 .as_ref()
501 .map(|u| u.prompt_tokens)
502 .unwrap_or(0),
503 completion_tokens: response
504 .usage
505 .as_ref()
506 .map(|u| u.completion_tokens)
507 .unwrap_or(0),
508 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
509 cache_read_tokens: response
510 .usage
511 .as_ref()
512 .and_then(|u| u.prompt_tokens_details.as_ref())
513 .map(|d| d.cached_tokens)
514 .filter(|&t| t > 0),
515 cache_write_tokens: None,
516 },
517 finish_reason,
518 })
519 }
520
521 async fn complete_stream(
522 &self,
523 request: CompletionRequest,
524 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
525 let messages = Self::convert_messages(&request.messages, false);
529 let tools = Self::convert_tools(&request.tools);
530
531 let temperature = request.temperature.unwrap_or(1.0);
532
533 let mut body = json!({
534 "model": request.model,
535 "messages": messages,
536 "temperature": temperature,
537 "stream": true,
538 });
539
540 body["thinking"] = json!({
541 "type": "enabled"
542 });
543
544 if !tools.is_empty() {
545 body["tools"] = json!(tools);
546 if Self::model_supports_tool_stream(&request.model) {
547 body["tool_stream"] = json!(true);
549 }
550 }
551 if let Some(max) = request.max_tokens {
552 body["max_tokens"] = json!(max);
553 }
554
555 tracing::debug!(model = %request.model, "Z.AI streaming request");
556
557 let response = self
558 .client
559 .post(format!("{}/chat/completions", self.base_url))
560 .header("Authorization", format!("Bearer {}", self.api_key))
561 .header("Content-Type", "application/json")
562 .json(&body)
563 .send()
564 .await
565 .context("Failed to send streaming request to Z.AI")?;
566
567 if !response.status().is_success() {
568 let status = response.status();
569 let text = response.text().await.unwrap_or_default();
570 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
571 anyhow::bail!(
572 "Z.AI API error: {} ({:?})",
573 err.error.message,
574 err.error.error_type
575 );
576 }
577 anyhow::bail!("Z.AI streaming error: {} {}", status, text);
578 }
579
580 let stream = response.bytes_stream();
581 let mut buffer = String::new();
582
583 Ok(stream
584 .flat_map(move |chunk_result| {
585 let mut chunks: Vec<StreamChunk> = Vec::new();
586 match chunk_result {
587 Ok(bytes) => {
588 let text = String::from_utf8_lossy(&bytes);
589 buffer.push_str(&text);
590
591 let mut text_buf = String::new();
592 while let Some(line_end) = buffer.find('\n') {
593 let line = buffer[..line_end].trim().to_string();
594 buffer = buffer[line_end + 1..].to_string();
595
596 if line == "data: [DONE]" {
597 if !text_buf.is_empty() {
598 chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
599 }
600 chunks.push(StreamChunk::Done { usage: None });
601 continue;
602 }
603 if let Some(data) = line.strip_prefix("data: ") {
604 if let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
605 {
606 if let Some(choice) = parsed.choices.first() {
607 if let Some(ref reasoning) = choice.delta.reasoning_content
609 {
610 if !reasoning.is_empty() {
611 text_buf.push_str(reasoning);
612 }
613 }
614 if let Some(ref content) = choice.delta.content {
615 text_buf.push_str(content);
616 }
617 if let Some(ref tool_calls) = choice.delta.tool_calls {
619 if !text_buf.is_empty() {
620 chunks.push(StreamChunk::Text(std::mem::take(
621 &mut text_buf,
622 )));
623 }
624 for tc in tool_calls {
625 if let Some(ref func) = tc.function {
626 if let Some(ref name) = func.name {
627 chunks.push(StreamChunk::ToolCallStart {
629 id: tc.id.clone().unwrap_or_default(),
630 name: name.clone(),
631 });
632 }
633 if let Some(ref args) = func.arguments {
634 let delta = match args {
635 Value::String(s) => s.clone(),
636 other => serde_json::to_string(other)
637 .unwrap_or_default(),
638 };
639 if !delta.is_empty() {
640 chunks.push(
641 StreamChunk::ToolCallDelta {
642 id: tc
643 .id
644 .clone()
645 .unwrap_or_default(),
646 arguments_delta: delta,
647 },
648 );
649 }
650 }
651 }
652 }
653 }
654 if let Some(ref reason) = choice.finish_reason {
656 if !text_buf.is_empty() {
657 chunks.push(StreamChunk::Text(std::mem::take(
658 &mut text_buf,
659 )));
660 }
661 if reason == "tool_calls" {
662 if let Some(ref tcs) = choice.delta.tool_calls {
664 if let Some(tc) = tcs.last() {
665 chunks.push(StreamChunk::ToolCallEnd {
666 id: tc.id.clone().unwrap_or_default(),
667 });
668 }
669 }
670 }
671 }
672 }
673 }
674 }
675 }
676 if !text_buf.is_empty() {
677 chunks.push(StreamChunk::Text(text_buf));
678 }
679 }
680 Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
681 }
682 futures::stream::iter(chunks)
683 })
684 .boxed())
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691
692 #[test]
693 fn convert_messages_serializes_tool_arguments_as_json_string() {
694 let messages = vec![Message {
695 role: Role::Assistant,
696 content: vec![ContentPart::ToolCall {
697 id: "call_1".to_string(),
698 name: "get_weather".to_string(),
699 arguments: "{\"city\":\"Beijing\".. }".to_string(),
700 thought_signature: None,
701 }],
702 }];
703
704 let converted = ZaiProvider::convert_messages(&messages, true);
705 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
706 .as_str()
707 .expect("arguments must be a string");
708
709 assert_eq!(args, "{\"city\":\"Beijing\"}");
710 }
711
712 #[test]
713 fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
714 let messages = vec![Message {
715 role: Role::Assistant,
716 content: vec![ContentPart::ToolCall {
717 id: "call_1".to_string(),
718 name: "get_weather".to_string(),
719 arguments: "city=Beijing".to_string(),
720 thought_signature: None,
721 }],
722 }];
723
724 let converted = ZaiProvider::convert_messages(&messages, true);
725 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
726 .as_str()
727 .expect("arguments must be a string");
728 let parsed: Value = serde_json::from_str(args).expect("arguments must contain valid JSON");
729
730 assert_eq!(parsed, json!({"input": "city=Beijing"}));
731 }
732
733 #[test]
734 fn preview_text_truncates_on_char_boundary() {
735 let text = "a😀b";
736 assert_eq!(ZaiProvider::preview_text(text, 2), "a😀");
737 }
738}