1use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use reqwest::Client;
15use serde::Deserialize;
16use serde_json::{Value, json};
17
18pub struct ZaiProvider {
19 client: Client,
20 api_key: String,
21 base_url: String,
22}
23
24impl std::fmt::Debug for ZaiProvider {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("ZaiProvider")
27 .field("base_url", &self.base_url)
28 .field("api_key", &"<REDACTED>")
29 .finish()
30 }
31}
32
33impl ZaiProvider {
34 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
35 tracing::debug!(
36 provider = "zai",
37 base_url = %base_url,
38 api_key_len = api_key.len(),
39 "Creating Z.AI provider with custom base URL"
40 );
41 Ok(Self {
42 client: Client::new(),
43 api_key,
44 base_url,
45 })
46 }
47
48 fn convert_messages(messages: &[Message], include_reasoning_content: bool) -> Vec<Value> {
49 messages
50 .iter()
51 .map(|msg| {
52 let role = match msg.role {
53 Role::System => "system",
54 Role::User => "user",
55 Role::Assistant => "assistant",
56 Role::Tool => "tool",
57 };
58
59 match msg.role {
60 Role::Tool => {
61 if let Some(ContentPart::ToolResult {
62 tool_call_id,
63 content,
64 }) = msg.content.first()
65 {
66 json!({
67 "role": "tool",
68 "tool_call_id": tool_call_id,
69 "content": content
70 })
71 } else {
72 json!({"role": role, "content": ""})
73 }
74 }
75 Role::Assistant => {
76 let text: String = msg
77 .content
78 .iter()
79 .filter_map(|p| match p {
80 ContentPart::Text { text } => Some(text.clone()),
81 _ => None,
82 })
83 .collect::<Vec<_>>()
84 .join("");
85
86 let tool_calls: Vec<Value> = msg
87 .content
88 .iter()
89 .filter_map(|p| match p {
90 ContentPart::ToolCall {
91 id,
92 name,
93 arguments,
94 } => {
95 let args_string = serde_json::from_str::<Value>(arguments)
98 .map(|parsed| {
99 serde_json::to_string(&parsed)
100 .unwrap_or_else(|_| "{}".to_string())
101 })
102 .unwrap_or_else(|_| {
103 json!({"input": arguments}).to_string()
104 });
105 Some(json!({
106 "id": id,
107 "type": "function",
108 "function": {
109 "name": name,
110 "arguments": args_string
111 }
112 }))
113 }
114 _ => None,
115 })
116 .collect();
117
118 let mut msg_json = json!({
119 "role": "assistant",
120 "content": if text.is_empty() { "".to_string() } else { text },
121 });
122 if include_reasoning_content {
123 let reasoning: String = msg
124 .content
125 .iter()
126 .filter_map(|p| match p {
127 ContentPart::Thinking { text } => Some(text.clone()),
128 _ => None,
129 })
130 .collect::<Vec<_>>()
131 .join("");
132 if !reasoning.is_empty() {
133 msg_json["reasoning_content"] = json!(reasoning);
134 }
135 }
136 if !tool_calls.is_empty() {
137 msg_json["tool_calls"] = json!(tool_calls);
138 }
139 msg_json
140 }
141 _ => {
142 let text: String = msg
143 .content
144 .iter()
145 .filter_map(|p| match p {
146 ContentPart::Text { text } => Some(text.clone()),
147 _ => None,
148 })
149 .collect::<Vec<_>>()
150 .join("\n");
151
152 json!({"role": role, "content": text})
153 }
154 }
155 })
156 .collect()
157 }
158
159 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
160 tools
161 .iter()
162 .map(|t| {
163 json!({
164 "type": "function",
165 "function": {
166 "name": t.name,
167 "description": t.description,
168 "parameters": t.parameters
169 }
170 })
171 })
172 .collect()
173 }
174
175 fn model_supports_tool_stream(model: &str) -> bool {
176 model.contains("glm-5") || model.contains("glm-4.7") || model.contains("glm-4.6")
177 }
178}
179
180#[derive(Debug, Deserialize)]
181struct ZaiResponse {
182 choices: Vec<ZaiChoice>,
183 #[serde(default)]
184 usage: Option<ZaiUsage>,
185}
186
187#[derive(Debug, Deserialize)]
188struct ZaiChoice {
189 message: ZaiMessage,
190 #[serde(default)]
191 finish_reason: Option<String>,
192}
193
194#[derive(Debug, Deserialize)]
195struct ZaiMessage {
196 #[serde(default)]
197 content: Option<String>,
198 #[serde(default)]
199 tool_calls: Option<Vec<ZaiToolCall>>,
200 #[serde(default)]
201 reasoning_content: Option<String>,
202}
203
204#[derive(Debug, Deserialize)]
205struct ZaiToolCall {
206 id: String,
207 function: ZaiFunction,
208}
209
210#[derive(Debug, Deserialize)]
211struct ZaiFunction {
212 name: String,
213 arguments: Value,
214}
215
216#[derive(Debug, Deserialize)]
217struct ZaiUsage {
218 #[serde(default)]
219 prompt_tokens: usize,
220 #[serde(default)]
221 completion_tokens: usize,
222 #[serde(default)]
223 total_tokens: usize,
224 #[serde(default)]
225 prompt_tokens_details: Option<ZaiPromptTokensDetails>,
226}
227
228#[derive(Debug, Deserialize)]
229struct ZaiPromptTokensDetails {
230 #[serde(default)]
231 cached_tokens: usize,
232}
233
234#[derive(Debug, Deserialize)]
235struct ZaiError {
236 error: ZaiErrorDetail,
237}
238
239#[derive(Debug, Deserialize)]
240struct ZaiErrorDetail {
241 message: String,
242 #[serde(default, rename = "type")]
243 error_type: Option<String>,
244}
245
246#[derive(Debug, Deserialize)]
248struct ZaiStreamResponse {
249 choices: Vec<ZaiStreamChoice>,
250}
251
252#[derive(Debug, Deserialize)]
253struct ZaiStreamChoice {
254 delta: ZaiStreamDelta,
255 #[serde(default)]
256 finish_reason: Option<String>,
257}
258
259#[derive(Debug, Deserialize)]
260struct ZaiStreamDelta {
261 #[serde(default)]
262 content: Option<String>,
263 #[serde(default)]
264 reasoning_content: Option<String>,
265 #[serde(default)]
266 tool_calls: Option<Vec<ZaiStreamToolCall>>,
267}
268
269#[derive(Debug, Deserialize)]
270struct ZaiStreamToolCall {
271 #[serde(default)]
272 id: Option<String>,
273 function: Option<ZaiStreamFunction>,
274}
275
276#[derive(Debug, Deserialize)]
277struct ZaiStreamFunction {
278 #[serde(default)]
279 name: Option<String>,
280 #[serde(default)]
281 arguments: Option<Value>,
282}
283
284#[async_trait]
285impl Provider for ZaiProvider {
286 fn name(&self) -> &str {
287 "zai"
288 }
289
290 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
291 Ok(vec![
292 ModelInfo {
293 id: "glm-5".to_string(),
294 name: "GLM-5".to_string(),
295 provider: "zai".to_string(),
296 context_window: 200_000,
297 max_output_tokens: Some(128_000),
298 supports_vision: false,
299 supports_tools: true,
300 supports_streaming: true,
301 input_cost_per_million: None,
302 output_cost_per_million: None,
303 },
304 ModelInfo {
305 id: "glm-4.7".to_string(),
306 name: "GLM-4.7".to_string(),
307 provider: "zai".to_string(),
308 context_window: 128_000,
309 max_output_tokens: Some(128_000),
310 supports_vision: false,
311 supports_tools: true,
312 supports_streaming: true,
313 input_cost_per_million: None,
314 output_cost_per_million: None,
315 },
316 ModelInfo {
317 id: "glm-4.7-flash".to_string(),
318 name: "GLM-4.7 Flash".to_string(),
319 provider: "zai".to_string(),
320 context_window: 128_000,
321 max_output_tokens: Some(128_000),
322 supports_vision: false,
323 supports_tools: true,
324 supports_streaming: true,
325 input_cost_per_million: None,
326 output_cost_per_million: None,
327 },
328 ModelInfo {
329 id: "glm-4.6".to_string(),
330 name: "GLM-4.6".to_string(),
331 provider: "zai".to_string(),
332 context_window: 128_000,
333 max_output_tokens: Some(128_000),
334 supports_vision: false,
335 supports_tools: true,
336 supports_streaming: true,
337 input_cost_per_million: None,
338 output_cost_per_million: None,
339 },
340 ModelInfo {
341 id: "glm-4.5".to_string(),
342 name: "GLM-4.5".to_string(),
343 provider: "zai".to_string(),
344 context_window: 128_000,
345 max_output_tokens: Some(96_000),
346 supports_vision: false,
347 supports_tools: true,
348 supports_streaming: true,
349 input_cost_per_million: None,
350 output_cost_per_million: None,
351 },
352 ])
353 }
354
355 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
356 let messages = Self::convert_messages(&request.messages, false);
360 let tools = Self::convert_tools(&request.tools);
361
362 let temperature = request.temperature.unwrap_or(1.0);
364
365 let mut body = json!({
366 "model": request.model,
367 "messages": messages,
368 "temperature": temperature,
369 });
370
371 body["thinking"] = json!({
374 "type": "enabled"
375 });
376
377 if !tools.is_empty() {
378 body["tools"] = json!(tools);
379 }
380 if let Some(max) = request.max_tokens {
381 body["max_tokens"] = json!(max);
382 }
383
384 tracing::debug!(model = %request.model, "Z.AI request");
385
386 let response = self
387 .client
388 .post(format!("{}/chat/completions", self.base_url))
389 .header("Authorization", format!("Bearer {}", self.api_key))
390 .header("Content-Type", "application/json")
391 .json(&body)
392 .send()
393 .await
394 .context("Failed to send request to Z.AI")?;
395
396 let status = response.status();
397 let text = response
398 .text()
399 .await
400 .context("Failed to read Z.AI response")?;
401
402 if !status.is_success() {
403 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
404 anyhow::bail!(
405 "Z.AI API error: {} ({:?})",
406 err.error.message,
407 err.error.error_type
408 );
409 }
410 anyhow::bail!("Z.AI API error: {} {}", status, text);
411 }
412
413 let response: ZaiResponse = serde_json::from_str(&text).context(format!(
414 "Failed to parse Z.AI response: {}",
415 &text[..text.len().min(200)]
416 ))?;
417
418 let choice = response
419 .choices
420 .first()
421 .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
422
423 if let Some(ref reasoning) = choice.message.reasoning_content {
425 if !reasoning.is_empty() {
426 tracing::info!(
427 reasoning_len = reasoning.len(),
428 "Z.AI reasoning content received"
429 );
430 }
431 }
432
433 let mut content = Vec::new();
434 let mut has_tool_calls = false;
435
436 if let Some(ref reasoning) = choice.message.reasoning_content {
438 if !reasoning.is_empty() {
439 content.push(ContentPart::Thinking {
440 text: reasoning.clone(),
441 });
442 }
443 }
444
445 if let Some(text) = &choice.message.content {
446 if !text.is_empty() {
447 content.push(ContentPart::Text { text: text.clone() });
448 }
449 }
450
451 if let Some(tool_calls) = &choice.message.tool_calls {
452 has_tool_calls = !tool_calls.is_empty();
453 for tc in tool_calls {
454 let arguments = match &tc.function.arguments {
456 Value::String(s) => s.clone(),
457 other => serde_json::to_string(other).unwrap_or_default(),
458 };
459 content.push(ContentPart::ToolCall {
460 id: tc.id.clone(),
461 name: tc.function.name.clone(),
462 arguments,
463 });
464 }
465 }
466
467 let finish_reason = if has_tool_calls {
468 FinishReason::ToolCalls
469 } else {
470 match choice.finish_reason.as_deref() {
471 Some("stop") => FinishReason::Stop,
472 Some("length") => FinishReason::Length,
473 Some("tool_calls") => FinishReason::ToolCalls,
474 Some("sensitive") => FinishReason::ContentFilter,
475 _ => FinishReason::Stop,
476 }
477 };
478
479 Ok(CompletionResponse {
480 message: Message {
481 role: Role::Assistant,
482 content,
483 },
484 usage: Usage {
485 prompt_tokens: response
486 .usage
487 .as_ref()
488 .map(|u| u.prompt_tokens)
489 .unwrap_or(0),
490 completion_tokens: response
491 .usage
492 .as_ref()
493 .map(|u| u.completion_tokens)
494 .unwrap_or(0),
495 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
496 cache_read_tokens: response
497 .usage
498 .as_ref()
499 .and_then(|u| u.prompt_tokens_details.as_ref())
500 .map(|d| d.cached_tokens)
501 .filter(|&t| t > 0),
502 cache_write_tokens: None,
503 },
504 finish_reason,
505 })
506 }
507
508 async fn complete_stream(
509 &self,
510 request: CompletionRequest,
511 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
512 let messages = Self::convert_messages(&request.messages, false);
516 let tools = Self::convert_tools(&request.tools);
517
518 let temperature = request.temperature.unwrap_or(1.0);
519
520 let mut body = json!({
521 "model": request.model,
522 "messages": messages,
523 "temperature": temperature,
524 "stream": true,
525 });
526
527 body["thinking"] = json!({
528 "type": "enabled"
529 });
530
531 if !tools.is_empty() {
532 body["tools"] = json!(tools);
533 if Self::model_supports_tool_stream(&request.model) {
534 body["tool_stream"] = json!(true);
536 }
537 }
538 if let Some(max) = request.max_tokens {
539 body["max_tokens"] = json!(max);
540 }
541
542 tracing::debug!(model = %request.model, "Z.AI streaming request");
543
544 let response = self
545 .client
546 .post(format!("{}/chat/completions", self.base_url))
547 .header("Authorization", format!("Bearer {}", self.api_key))
548 .header("Content-Type", "application/json")
549 .json(&body)
550 .send()
551 .await
552 .context("Failed to send streaming request to Z.AI")?;
553
554 if !response.status().is_success() {
555 let status = response.status();
556 let text = response.text().await.unwrap_or_default();
557 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
558 anyhow::bail!(
559 "Z.AI API error: {} ({:?})",
560 err.error.message,
561 err.error.error_type
562 );
563 }
564 anyhow::bail!("Z.AI streaming error: {} {}", status, text);
565 }
566
567 let stream = response.bytes_stream();
568 let mut buffer = String::new();
569
570 Ok(stream
571 .flat_map(move |chunk_result| {
572 let mut chunks: Vec<StreamChunk> = Vec::new();
573 match chunk_result {
574 Ok(bytes) => {
575 let text = String::from_utf8_lossy(&bytes);
576 buffer.push_str(&text);
577
578 let mut text_buf = String::new();
579 while let Some(line_end) = buffer.find('\n') {
580 let line = buffer[..line_end].trim().to_string();
581 buffer = buffer[line_end + 1..].to_string();
582
583 if line == "data: [DONE]" {
584 if !text_buf.is_empty() {
585 chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
586 }
587 chunks.push(StreamChunk::Done { usage: None });
588 continue;
589 }
590 if let Some(data) = line.strip_prefix("data: ") {
591 if let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
592 {
593 if let Some(choice) = parsed.choices.first() {
594 if let Some(ref reasoning) = choice.delta.reasoning_content
596 {
597 if !reasoning.is_empty() {
598 text_buf.push_str(reasoning);
599 }
600 }
601 if let Some(ref content) = choice.delta.content {
602 text_buf.push_str(content);
603 }
604 if let Some(ref tool_calls) = choice.delta.tool_calls {
606 if !text_buf.is_empty() {
607 chunks.push(StreamChunk::Text(std::mem::take(
608 &mut text_buf,
609 )));
610 }
611 for tc in tool_calls {
612 if let Some(ref func) = tc.function {
613 if let Some(ref name) = func.name {
614 chunks.push(StreamChunk::ToolCallStart {
616 id: tc.id.clone().unwrap_or_default(),
617 name: name.clone(),
618 });
619 }
620 if let Some(ref args) = func.arguments {
621 let delta = match args {
622 Value::String(s) => s.clone(),
623 other => serde_json::to_string(other)
624 .unwrap_or_default(),
625 };
626 if !delta.is_empty() {
627 chunks.push(
628 StreamChunk::ToolCallDelta {
629 id: tc
630 .id
631 .clone()
632 .unwrap_or_default(),
633 arguments_delta: delta,
634 },
635 );
636 }
637 }
638 }
639 }
640 }
641 if let Some(ref reason) = choice.finish_reason {
643 if !text_buf.is_empty() {
644 chunks.push(StreamChunk::Text(std::mem::take(
645 &mut text_buf,
646 )));
647 }
648 if reason == "tool_calls" {
649 if let Some(ref tcs) = choice.delta.tool_calls {
651 if let Some(tc) = tcs.last() {
652 chunks.push(StreamChunk::ToolCallEnd {
653 id: tc.id.clone().unwrap_or_default(),
654 });
655 }
656 }
657 }
658 }
659 }
660 }
661 }
662 }
663 if !text_buf.is_empty() {
664 chunks.push(StreamChunk::Text(text_buf));
665 }
666 }
667 Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
668 }
669 futures::stream::iter(chunks)
670 })
671 .boxed())
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678
679 #[test]
680 fn convert_messages_serializes_tool_arguments_as_json_string() {
681 let messages = vec![Message {
682 role: Role::Assistant,
683 content: vec![ContentPart::ToolCall {
684 id: "call_1".to_string(),
685 name: "get_weather".to_string(),
686 arguments: "{\"city\":\"Beijing\"}".to_string(),
687 }],
688 }];
689
690 let converted = ZaiProvider::convert_messages(&messages, true);
691 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
692 .as_str()
693 .expect("arguments must be a string");
694
695 assert_eq!(args, "{\"city\":\"Beijing\"}");
696 }
697
698 #[test]
699 fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
700 let messages = vec![Message {
701 role: Role::Assistant,
702 content: vec![ContentPart::ToolCall {
703 id: "call_1".to_string(),
704 name: "get_weather".to_string(),
705 arguments: "city=Beijing".to_string(),
706 }],
707 }];
708
709 let converted = ZaiProvider::convert_messages(&messages, true);
710 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
711 .as_str()
712 .expect("arguments must be a string");
713 let parsed: Value = serde_json::from_str(args).expect("arguments must contain valid JSON");
714
715 assert_eq!(parsed, json!({"input": "city=Beijing"}));
716 }
717}