1use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use reqwest::Client;
15use serde::Deserialize;
16use serde_json::{Value, json};
17
18pub struct ZaiProvider {
19 client: Client,
20 api_key: String,
21 base_url: String,
22}
23
24impl std::fmt::Debug for ZaiProvider {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("ZaiProvider")
27 .field("base_url", &self.base_url)
28 .field("api_key", &"<REDACTED>")
29 .finish()
30 }
31}
32
33impl ZaiProvider {
34 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
35 tracing::debug!(
36 provider = "zai",
37 base_url = %base_url,
38 api_key_len = api_key.len(),
39 "Creating Z.AI provider with custom base URL"
40 );
41 Ok(Self {
42 client: Client::new(),
43 api_key,
44 base_url,
45 })
46 }
47
48 fn convert_messages(messages: &[Message]) -> Vec<Value> {
49 messages
50 .iter()
51 .map(|msg| {
52 let role = match msg.role {
53 Role::System => "system",
54 Role::User => "user",
55 Role::Assistant => "assistant",
56 Role::Tool => "tool",
57 };
58
59 match msg.role {
60 Role::Tool => {
61 if let Some(ContentPart::ToolResult {
62 tool_call_id,
63 content,
64 }) = msg.content.first()
65 {
66 json!({
67 "role": "tool",
68 "tool_call_id": tool_call_id,
69 "content": content
70 })
71 } else {
72 json!({"role": role, "content": ""})
73 }
74 }
75 Role::Assistant => {
76 let text: String = msg
77 .content
78 .iter()
79 .filter_map(|p| match p {
80 ContentPart::Text { text } => Some(text.clone()),
81 _ => None,
82 })
83 .collect::<Vec<_>>()
84 .join("");
85
86 let reasoning: String = msg
88 .content
89 .iter()
90 .filter_map(|p| match p {
91 ContentPart::Thinking { text } => Some(text.clone()),
92 _ => None,
93 })
94 .collect::<Vec<_>>()
95 .join("");
96
97 let tool_calls: Vec<Value> = msg
98 .content
99 .iter()
100 .filter_map(|p| match p {
101 ContentPart::ToolCall {
102 id,
103 name,
104 arguments,
105 } => {
106 let args_string = serde_json::from_str::<Value>(arguments)
109 .map(|parsed| {
110 serde_json::to_string(&parsed)
111 .unwrap_or_else(|_| "{}".to_string())
112 })
113 .unwrap_or_else(|_| {
114 json!({"input": arguments}).to_string()
115 });
116 Some(json!({
117 "id": id,
118 "type": "function",
119 "function": {
120 "name": name,
121 "arguments": args_string
122 }
123 }))
124 }
125 _ => None,
126 })
127 .collect();
128
129 let mut msg_json = json!({
130 "role": "assistant",
131 "content": if text.is_empty() { "".to_string() } else { text },
132 });
133 if !reasoning.is_empty() {
135 msg_json["reasoning_content"] = json!(reasoning);
136 }
137 if !tool_calls.is_empty() {
138 msg_json["tool_calls"] = json!(tool_calls);
139 }
140 msg_json
141 }
142 _ => {
143 let text: String = msg
144 .content
145 .iter()
146 .filter_map(|p| match p {
147 ContentPart::Text { text } => Some(text.clone()),
148 _ => None,
149 })
150 .collect::<Vec<_>>()
151 .join("\n");
152
153 json!({"role": role, "content": text})
154 }
155 }
156 })
157 .collect()
158 }
159
160 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
161 tools
162 .iter()
163 .map(|t| {
164 json!({
165 "type": "function",
166 "function": {
167 "name": t.name,
168 "description": t.description,
169 "parameters": t.parameters
170 }
171 })
172 })
173 .collect()
174 }
175}
176
177#[derive(Debug, Deserialize)]
178struct ZaiResponse {
179 choices: Vec<ZaiChoice>,
180 #[serde(default)]
181 usage: Option<ZaiUsage>,
182}
183
184#[derive(Debug, Deserialize)]
185struct ZaiChoice {
186 message: ZaiMessage,
187 #[serde(default)]
188 finish_reason: Option<String>,
189}
190
191#[derive(Debug, Deserialize)]
192struct ZaiMessage {
193 #[serde(default)]
194 content: Option<String>,
195 #[serde(default)]
196 tool_calls: Option<Vec<ZaiToolCall>>,
197 #[serde(default)]
198 reasoning_content: Option<String>,
199}
200
201#[derive(Debug, Deserialize)]
202struct ZaiToolCall {
203 id: String,
204 function: ZaiFunction,
205}
206
207#[derive(Debug, Deserialize)]
208struct ZaiFunction {
209 name: String,
210 arguments: Value,
211}
212
213#[derive(Debug, Deserialize)]
214struct ZaiUsage {
215 #[serde(default)]
216 prompt_tokens: usize,
217 #[serde(default)]
218 completion_tokens: usize,
219 #[serde(default)]
220 total_tokens: usize,
221 #[serde(default)]
222 prompt_tokens_details: Option<ZaiPromptTokensDetails>,
223}
224
225#[derive(Debug, Deserialize)]
226struct ZaiPromptTokensDetails {
227 #[serde(default)]
228 cached_tokens: usize,
229}
230
231#[derive(Debug, Deserialize)]
232struct ZaiError {
233 error: ZaiErrorDetail,
234}
235
236#[derive(Debug, Deserialize)]
237struct ZaiErrorDetail {
238 message: String,
239 #[serde(default, rename = "type")]
240 error_type: Option<String>,
241}
242
243#[derive(Debug, Deserialize)]
245struct ZaiStreamResponse {
246 choices: Vec<ZaiStreamChoice>,
247}
248
249#[derive(Debug, Deserialize)]
250struct ZaiStreamChoice {
251 delta: ZaiStreamDelta,
252 #[serde(default)]
253 finish_reason: Option<String>,
254}
255
256#[derive(Debug, Deserialize)]
257struct ZaiStreamDelta {
258 #[serde(default)]
259 content: Option<String>,
260 #[serde(default)]
261 reasoning_content: Option<String>,
262 #[serde(default)]
263 tool_calls: Option<Vec<ZaiStreamToolCall>>,
264}
265
266#[derive(Debug, Deserialize)]
267struct ZaiStreamToolCall {
268 #[serde(default)]
269 id: Option<String>,
270 function: Option<ZaiStreamFunction>,
271}
272
273#[derive(Debug, Deserialize)]
274struct ZaiStreamFunction {
275 #[serde(default)]
276 name: Option<String>,
277 #[serde(default)]
278 arguments: Option<Value>,
279}
280
281#[async_trait]
282impl Provider for ZaiProvider {
283 fn name(&self) -> &str {
284 "zai"
285 }
286
287 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
288 Ok(vec![
289 ModelInfo {
290 id: "glm-5".to_string(),
291 name: "GLM-5".to_string(),
292 provider: "zai".to_string(),
293 context_window: 200_000,
294 max_output_tokens: Some(128_000),
295 supports_vision: false,
296 supports_tools: true,
297 supports_streaming: true,
298 input_cost_per_million: None,
299 output_cost_per_million: None,
300 },
301 ModelInfo {
302 id: "glm-4.7".to_string(),
303 name: "GLM-4.7".to_string(),
304 provider: "zai".to_string(),
305 context_window: 128_000,
306 max_output_tokens: Some(128_000),
307 supports_vision: false,
308 supports_tools: true,
309 supports_streaming: true,
310 input_cost_per_million: None,
311 output_cost_per_million: None,
312 },
313 ModelInfo {
314 id: "glm-4.7-flash".to_string(),
315 name: "GLM-4.7 Flash".to_string(),
316 provider: "zai".to_string(),
317 context_window: 128_000,
318 max_output_tokens: Some(128_000),
319 supports_vision: false,
320 supports_tools: true,
321 supports_streaming: true,
322 input_cost_per_million: None,
323 output_cost_per_million: None,
324 },
325 ModelInfo {
326 id: "glm-4.6".to_string(),
327 name: "GLM-4.6".to_string(),
328 provider: "zai".to_string(),
329 context_window: 128_000,
330 max_output_tokens: Some(128_000),
331 supports_vision: false,
332 supports_tools: true,
333 supports_streaming: true,
334 input_cost_per_million: None,
335 output_cost_per_million: None,
336 },
337 ModelInfo {
338 id: "glm-4.5".to_string(),
339 name: "GLM-4.5".to_string(),
340 provider: "zai".to_string(),
341 context_window: 128_000,
342 max_output_tokens: Some(96_000),
343 supports_vision: false,
344 supports_tools: true,
345 supports_streaming: true,
346 input_cost_per_million: None,
347 output_cost_per_million: None,
348 },
349 ])
350 }
351
352 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
353 let messages = Self::convert_messages(&request.messages);
354 let tools = Self::convert_tools(&request.tools);
355
356 let temperature = request.temperature.unwrap_or(1.0);
358
359 let mut body = json!({
360 "model": request.model,
361 "messages": messages,
362 "temperature": temperature,
363 });
364
365 body["thinking"] = json!({
369 "type": "enabled",
370 "clear_thinking": false
371 });
372
373 if !tools.is_empty() {
374 body["tools"] = json!(tools);
375 }
376 if let Some(max) = request.max_tokens {
377 body["max_tokens"] = json!(max);
378 }
379
380 tracing::debug!(model = %request.model, "Z.AI request");
381
382 let response = self
383 .client
384 .post(format!("{}/chat/completions", self.base_url))
385 .header("Authorization", format!("Bearer {}", self.api_key))
386 .header("Content-Type", "application/json")
387 .json(&body)
388 .send()
389 .await
390 .context("Failed to send request to Z.AI")?;
391
392 let status = response.status();
393 let text = response
394 .text()
395 .await
396 .context("Failed to read Z.AI response")?;
397
398 if !status.is_success() {
399 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
400 anyhow::bail!(
401 "Z.AI API error: {} ({:?})",
402 err.error.message,
403 err.error.error_type
404 );
405 }
406 anyhow::bail!("Z.AI API error: {} {}", status, text);
407 }
408
409 let response: ZaiResponse = serde_json::from_str(&text).context(format!(
410 "Failed to parse Z.AI response: {}",
411 &text[..text.len().min(200)]
412 ))?;
413
414 let choice = response
415 .choices
416 .first()
417 .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
418
419 if let Some(ref reasoning) = choice.message.reasoning_content {
421 if !reasoning.is_empty() {
422 tracing::info!(
423 reasoning_len = reasoning.len(),
424 "Z.AI reasoning content received"
425 );
426 }
427 }
428
429 let mut content = Vec::new();
430 let mut has_tool_calls = false;
431
432 if let Some(ref reasoning) = choice.message.reasoning_content {
434 if !reasoning.is_empty() {
435 content.push(ContentPart::Thinking {
436 text: reasoning.clone(),
437 });
438 }
439 }
440
441 if let Some(text) = &choice.message.content {
442 if !text.is_empty() {
443 content.push(ContentPart::Text { text: text.clone() });
444 }
445 }
446
447 if let Some(tool_calls) = &choice.message.tool_calls {
448 has_tool_calls = !tool_calls.is_empty();
449 for tc in tool_calls {
450 let arguments = match &tc.function.arguments {
452 Value::String(s) => s.clone(),
453 other => serde_json::to_string(other).unwrap_or_default(),
454 };
455 content.push(ContentPart::ToolCall {
456 id: tc.id.clone(),
457 name: tc.function.name.clone(),
458 arguments,
459 });
460 }
461 }
462
463 let finish_reason = if has_tool_calls {
464 FinishReason::ToolCalls
465 } else {
466 match choice.finish_reason.as_deref() {
467 Some("stop") => FinishReason::Stop,
468 Some("length") => FinishReason::Length,
469 Some("tool_calls") => FinishReason::ToolCalls,
470 Some("sensitive") => FinishReason::ContentFilter,
471 _ => FinishReason::Stop,
472 }
473 };
474
475 Ok(CompletionResponse {
476 message: Message {
477 role: Role::Assistant,
478 content,
479 },
480 usage: Usage {
481 prompt_tokens: response
482 .usage
483 .as_ref()
484 .map(|u| u.prompt_tokens)
485 .unwrap_or(0),
486 completion_tokens: response
487 .usage
488 .as_ref()
489 .map(|u| u.completion_tokens)
490 .unwrap_or(0),
491 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
492 cache_read_tokens: response
493 .usage
494 .as_ref()
495 .and_then(|u| u.prompt_tokens_details.as_ref())
496 .map(|d| d.cached_tokens)
497 .filter(|&t| t > 0),
498 cache_write_tokens: None,
499 },
500 finish_reason,
501 })
502 }
503
504 async fn complete_stream(
505 &self,
506 request: CompletionRequest,
507 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
508 let messages = Self::convert_messages(&request.messages);
509 let tools = Self::convert_tools(&request.tools);
510
511 let temperature = request.temperature.unwrap_or(1.0);
512
513 let mut body = json!({
514 "model": request.model,
515 "messages": messages,
516 "temperature": temperature,
517 "stream": true,
518 });
519
520 body["thinking"] = json!({
521 "type": "enabled",
522 "clear_thinking": false
523 });
524
525 if !tools.is_empty() {
526 body["tools"] = json!(tools);
527 body["tool_stream"] = json!(true);
529 }
530 if let Some(max) = request.max_tokens {
531 body["max_tokens"] = json!(max);
532 }
533
534 tracing::debug!(model = %request.model, "Z.AI streaming request");
535
536 let response = self
537 .client
538 .post(format!("{}/chat/completions", self.base_url))
539 .header("Authorization", format!("Bearer {}", self.api_key))
540 .header("Content-Type", "application/json")
541 .json(&body)
542 .send()
543 .await
544 .context("Failed to send streaming request to Z.AI")?;
545
546 if !response.status().is_success() {
547 let status = response.status();
548 let text = response.text().await.unwrap_or_default();
549 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
550 anyhow::bail!(
551 "Z.AI API error: {} ({:?})",
552 err.error.message,
553 err.error.error_type
554 );
555 }
556 anyhow::bail!("Z.AI streaming error: {} {}", status, text);
557 }
558
559 let stream = response.bytes_stream();
560 let mut buffer = String::new();
561
562 Ok(stream
563 .flat_map(move |chunk_result| {
564 let mut chunks: Vec<StreamChunk> = Vec::new();
565 match chunk_result {
566 Ok(bytes) => {
567 let text = String::from_utf8_lossy(&bytes);
568 buffer.push_str(&text);
569
570 let mut text_buf = String::new();
571 while let Some(line_end) = buffer.find('\n') {
572 let line = buffer[..line_end].trim().to_string();
573 buffer = buffer[line_end + 1..].to_string();
574
575 if line == "data: [DONE]" {
576 if !text_buf.is_empty() {
577 chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
578 }
579 chunks.push(StreamChunk::Done { usage: None });
580 continue;
581 }
582 if let Some(data) = line.strip_prefix("data: ") {
583 if let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
584 {
585 if let Some(choice) = parsed.choices.first() {
586 if let Some(ref reasoning) = choice.delta.reasoning_content
588 {
589 if !reasoning.is_empty() {
590 text_buf.push_str(reasoning);
591 }
592 }
593 if let Some(ref content) = choice.delta.content {
594 text_buf.push_str(content);
595 }
596 if let Some(ref tool_calls) = choice.delta.tool_calls {
598 if !text_buf.is_empty() {
599 chunks.push(StreamChunk::Text(std::mem::take(
600 &mut text_buf,
601 )));
602 }
603 for tc in tool_calls {
604 if let Some(ref func) = tc.function {
605 if let Some(ref name) = func.name {
606 chunks.push(StreamChunk::ToolCallStart {
608 id: tc.id.clone().unwrap_or_default(),
609 name: name.clone(),
610 });
611 }
612 if let Some(ref args) = func.arguments {
613 let delta = match args {
614 Value::String(s) => s.clone(),
615 other => serde_json::to_string(other)
616 .unwrap_or_default(),
617 };
618 if !delta.is_empty() {
619 chunks.push(
620 StreamChunk::ToolCallDelta {
621 id: tc
622 .id
623 .clone()
624 .unwrap_or_default(),
625 arguments_delta: delta,
626 },
627 );
628 }
629 }
630 }
631 }
632 }
633 if let Some(ref reason) = choice.finish_reason {
635 if !text_buf.is_empty() {
636 chunks.push(StreamChunk::Text(std::mem::take(
637 &mut text_buf,
638 )));
639 }
640 if reason == "tool_calls" {
641 if let Some(ref tcs) = choice.delta.tool_calls {
643 if let Some(tc) = tcs.last() {
644 chunks.push(StreamChunk::ToolCallEnd {
645 id: tc.id.clone().unwrap_or_default(),
646 });
647 }
648 }
649 }
650 }
651 }
652 }
653 }
654 }
655 if !text_buf.is_empty() {
656 chunks.push(StreamChunk::Text(text_buf));
657 }
658 }
659 Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
660 }
661 futures::stream::iter(chunks)
662 })
663 .boxed())
664 }
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 fn convert_messages_serializes_tool_arguments_as_json_string() {
673 let messages = vec![Message {
674 role: Role::Assistant,
675 content: vec![ContentPart::ToolCall {
676 id: "call_1".to_string(),
677 name: "get_weather".to_string(),
678 arguments: "{\"city\":\"Beijing\"}".to_string(),
679 }],
680 }];
681
682 let converted = ZaiProvider::convert_messages(&messages);
683 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
684 .as_str()
685 .expect("arguments must be a string");
686
687 assert_eq!(args, "{\"city\":\"Beijing\"}");
688 }
689
690 #[test]
691 fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
692 let messages = vec![Message {
693 role: Role::Assistant,
694 content: vec![ContentPart::ToolCall {
695 id: "call_1".to_string(),
696 name: "get_weather".to_string(),
697 arguments: "city=Beijing".to_string(),
698 }],
699 }];
700
701 let converted = ZaiProvider::convert_messages(&messages);
702 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
703 .as_str()
704 .expect("arguments must be a string");
705 let parsed: Value = serde_json::from_str(args).expect("arguments must contain valid JSON");
706
707 assert_eq!(parsed, json!({"input": "city=Beijing"}));
708 }
709}