1use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use once_cell::sync::Lazy;
15use regex::Regex;
16use reqwest::Client;
17use serde::Deserialize;
18use serde_json::{Value, json};
19use std::collections::HashMap;
20
21pub const DEFAULT_BASE_URL: &str = "https://api.z.ai/api/paas/v4";
22const CODING_BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4";
23const PONY_ALPHA_2_MODEL: &str = "pony-alpha-2";
24
25pub struct ZaiProvider {
26 client: Client,
27 api_key: String,
28 base_url: String,
29}
30
31#[derive(Debug, Default)]
32struct ZaiStreamToolState {
33 stream_id: String,
34 name: Option<String>,
35 started: bool,
36 finished: bool,
37}
38
39impl std::fmt::Debug for ZaiProvider {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("ZaiProvider")
42 .field("base_url", &self.base_url)
43 .field("api_key", &"<REDACTED>")
44 .finish()
45 }
46}
47
48impl ZaiProvider {
49 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
50 tracing::debug!(
51 provider = "zai",
52 base_url = %base_url,
53 api_key_len = api_key.len(),
54 "Creating Z.AI provider with custom base URL"
55 );
56 Ok(Self {
57 client: crate::provider::shared_http::shared_client().clone(),
58 api_key,
59 base_url,
60 })
61 }
62
63 fn request_base_url(&self, model: &str) -> &str {
64 if model.eq_ignore_ascii_case(PONY_ALPHA_2_MODEL) {
65 CODING_BASE_URL
66 } else {
67 &self.base_url
68 }
69 }
70
71 async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
75 let discovery_url = if self.base_url.contains("/coding/") {
78 self.base_url.replace("/coding/", "/")
79 } else {
80 self.base_url.clone()
81 };
82 let url = format!("{discovery_url}/models");
83 let response = match self
84 .client
85 .get(&url)
86 .header("Authorization", format!("Bearer {}", self.api_key))
87 .send()
88 .await
89 {
90 Ok(r) => r,
91 Err(e) => {
92 tracing::debug!(
93 url = %url,
94 error = %e,
95 "Z.AI /models discovery request failed"
96 );
97 return Vec::new();
98 }
99 };
100
101 if !response.status().is_success() {
102 tracing::debug!(
103 url = %url,
104 status = %response.status(),
105 "Z.AI /models endpoint returned non-success"
106 );
107 return Vec::new();
108 }
109
110 let payload: Value = match response.json().await {
111 Ok(p) => p,
112 Err(e) => {
113 tracing::debug!(
114 url = %url,
115 error = %e,
116 "Failed to parse Z.AI /models response"
117 );
118 return Vec::new();
119 }
120 };
121
122 let models = payload
123 .get("data")
124 .and_then(Value::as_array)
125 .into_iter()
126 .flatten()
127 .filter_map(|entry| {
128 let id = match entry {
129 Value::String(s) => s.trim().to_string(),
130 Value::Object(_) => entry.get("id").and_then(Value::as_str)?.trim().to_string(),
131 _ => return None,
132 };
133 if id.is_empty() {
134 return None;
135 }
136 let name = entry
137 .get("name")
138 .and_then(Value::as_str)
139 .map(str::trim)
140 .filter(|n| !n.is_empty())
141 .unwrap_or(&id)
142 .to_string();
143 Some(ModelInfo {
144 id,
145 name,
146 provider: "zai".to_string(),
147 context_window: 200_000,
148 max_output_tokens: Some(128_000),
149 supports_vision: false,
150 supports_tools: true,
151 supports_streaming: true,
152 input_cost_per_million: None,
153 output_cost_per_million: None,
154 })
155 })
156 .collect::<Vec<_>>();
157
158 if models.is_empty() {
159 tracing::debug!(url = %url, "Z.AI /models returned no model ids");
160 } else {
161 tracing::info!(count = models.len(), "Z.AI /models discovery succeeded");
162 }
163 models
164 }
165
166 fn normalize_tool_arguments(arguments: &str) -> String {
167 if let Ok(parsed) = serde_json::from_str::<Value>(arguments) {
170 if parsed.is_object() {
171 return serde_json::to_string(&parsed).unwrap_or_else(|_| "{}".to_string());
172 }
173 return json!({"input": parsed}).to_string();
174 }
175
176 if let Some(salvaged) = Self::salvage_json_object(arguments) {
177 return serde_json::to_string(&salvaged).unwrap_or_else(|_| "{}".to_string());
178 }
179
180 json!({"input": arguments}).to_string()
181 }
182
183 fn salvage_json_object(arguments: &str) -> Option<Value> {
184 let trimmed = arguments.trim();
185 if !trimmed.starts_with('{') {
186 return None;
187 }
188
189 static RE_SIMPLE_PAIR: Lazy<Regex> = Lazy::new(|| {
190 Regex::new(
193 r#"(?s)\"(?P<k>[^\"\\]*(?:\\.[^\"\\]*)*)\"\s*:\s*(?P<v>\"(?:\\.|[^\"])*\"|true|false|null|-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)"#,
194 )
195 .expect("invalid regex")
196 });
197
198 let mut map = serde_json::Map::new();
199 for caps in RE_SIMPLE_PAIR.captures_iter(trimmed) {
200 let key = caps.name("k")?.as_str();
201 let val_str = caps.name("v")?.as_str();
202 if let Ok(val) = serde_json::from_str::<Value>(val_str) {
203 map.insert(key.to_string(), val);
204 }
205 }
206
207 if map.is_empty() {
208 None
209 } else {
210 Some(Value::Object(map))
211 }
212 }
213
214 fn convert_messages(messages: &[Message], include_reasoning_content: bool) -> Vec<Value> {
215 messages
216 .iter()
217 .map(|msg| {
218 let role = match msg.role {
219 Role::System => "system",
220 Role::User => "user",
221 Role::Assistant => "assistant",
222 Role::Tool => "tool",
223 };
224
225 match msg.role {
226 Role::Tool => {
227 if let Some(ContentPart::ToolResult {
228 tool_call_id,
229 content,
230 }) = msg.content.first()
231 {
232 json!({
233 "role": "tool",
234 "tool_call_id": tool_call_id,
235 "content": content
236 })
237 } else {
238 json!({"role": role, "content": ""})
239 }
240 }
241 Role::Assistant => {
242 let text: String = msg
243 .content
244 .iter()
245 .filter_map(|p| match p {
246 ContentPart::Text { text } => Some(text.clone()),
247 _ => None,
248 })
249 .collect::<Vec<_>>()
250 .join("");
251
252 let tool_calls: Vec<Value> = msg
253 .content
254 .iter()
255 .filter_map(|p| match p {
256 ContentPart::ToolCall {
257 id,
258 name,
259 arguments,
260 ..
261 } => {
262 let args_string = Self::normalize_tool_arguments(arguments);
263 Some(json!({
264 "id": id,
265 "type": "function",
266 "function": {
267 "name": name,
268 "arguments": args_string
269 }
270 }))
271 }
272 _ => None,
273 })
274 .collect();
275
276 let mut msg_json = json!({
277 "role": "assistant",
278 "content": text,
279 });
280 if include_reasoning_content {
281 let reasoning: String = msg
282 .content
283 .iter()
284 .filter_map(|p| match p {
285 ContentPart::Thinking { text } => Some(text.clone()),
286 _ => None,
287 })
288 .collect::<Vec<_>>()
289 .join("");
290 if !reasoning.is_empty() {
291 msg_json["reasoning_content"] = json!(reasoning);
292 }
293 }
294 if !tool_calls.is_empty() {
295 msg_json["tool_calls"] = json!(tool_calls);
296 }
297 msg_json
298 }
299 _ => {
300 let text: String = msg
301 .content
302 .iter()
303 .filter_map(|p| match p {
304 ContentPart::Text { text } => Some(text.clone()),
305 _ => None,
306 })
307 .collect::<Vec<_>>()
308 .join("\n");
309
310 json!({"role": role, "content": text})
311 }
312 }
313 })
314 .collect()
315 }
316
317 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
318 tools
319 .iter()
320 .map(|t| {
321 json!({
322 "type": "function",
323 "function": {
324 "name": t.name,
325 "description": t.description,
326 "parameters": t.parameters
327 }
328 })
329 })
330 .collect()
331 }
332
333 fn model_supports_tool_stream(model: &str) -> bool {
334 model.contains("glm-5")
335 || model.contains("glm-4.7")
336 || model.contains("glm-4.6")
337 || model.eq_ignore_ascii_case(PONY_ALPHA_2_MODEL)
338 }
339
340 fn preview_text(text: &str, max_chars: usize) -> &str {
341 if max_chars == 0 {
342 return "";
343 }
344 if let Some((idx, _)) = text.char_indices().nth(max_chars) {
345 &text[..idx]
346 } else {
347 text
348 }
349 }
350
351 fn stream_tool_arguments_fragment(arguments: &Value) -> String {
352 match arguments {
353 Value::Null => String::new(),
354 Value::String(s) => s.clone(),
355 other => serde_json::to_string(other).unwrap_or_default(),
356 }
357 }
358
359 fn append_stream_tool_call_chunks(
360 chunks: &mut Vec<StreamChunk>,
361 tool_calls: &[ZaiStreamToolCall],
362 tool_states: &mut HashMap<usize, ZaiStreamToolState>,
363 next_fallback_index: &mut usize,
364 last_seen_index: &mut Option<usize>,
365 ) {
366 for tc in tool_calls {
367 let index = tc
368 .index
369 .or_else(|| {
370 tc.id.as_ref().and_then(|id| {
371 tool_states
372 .iter()
373 .find_map(|(idx, state)| (state.stream_id == *id).then_some(*idx))
374 })
375 })
376 .or(*last_seen_index)
377 .unwrap_or_else(|| {
378 let idx = *next_fallback_index;
379 *next_fallback_index += 1;
380 idx
381 });
382 *last_seen_index = Some(index);
383
384 let state = tool_states
385 .entry(index)
386 .or_insert_with(|| ZaiStreamToolState {
387 stream_id: tc.id.clone().unwrap_or_else(|| format!("zai-tool-{index}")),
388 ..Default::default()
389 });
390
391 if let Some(id) = &tc.id
392 && !state.started
393 && state.stream_id.starts_with("zai-tool-")
394 {
395 state.stream_id = id.clone();
396 }
397
398 if let Some(func) = &tc.function {
399 if let Some(name) = &func.name
400 && !name.is_empty()
401 {
402 state.name = Some(name.clone());
403 }
404
405 if !state.started
406 && let Some(name) = &state.name
407 {
408 chunks.push(StreamChunk::ToolCallStart {
409 id: state.stream_id.clone(),
410 name: name.clone(),
411 });
412 state.started = true;
413 }
414
415 if let Some(arguments) = &func.arguments {
416 let delta = Self::stream_tool_arguments_fragment(arguments);
417 if !delta.is_empty() {
418 if !state.started {
419 chunks.push(StreamChunk::ToolCallStart {
420 id: state.stream_id.clone(),
421 name: state.name.clone().unwrap_or_else(|| "tool".to_string()),
422 });
423 state.started = true;
424 }
425 chunks.push(StreamChunk::ToolCallDelta {
426 id: state.stream_id.clone(),
427 arguments_delta: delta,
428 });
429 }
430 }
431 }
432 }
433 }
434
435 fn finish_stream_tool_call_chunks(
436 chunks: &mut Vec<StreamChunk>,
437 tool_states: &mut HashMap<usize, ZaiStreamToolState>,
438 ) {
439 let mut ordered_indexes: Vec<_> = tool_states.keys().copied().collect();
440 ordered_indexes.sort_unstable();
441
442 for index in ordered_indexes {
443 if let Some(state) = tool_states.get_mut(&index)
444 && state.started
445 && !state.finished
446 {
447 chunks.push(StreamChunk::ToolCallEnd {
448 id: state.stream_id.clone(),
449 });
450 state.finished = true;
451 }
452 }
453 }
454}
455
456#[derive(Debug, Deserialize)]
457struct ZaiResponse {
458 choices: Vec<ZaiChoice>,
459 #[serde(default)]
460 usage: Option<ZaiUsage>,
461}
462
463#[derive(Debug, Deserialize)]
464struct ZaiChoice {
465 message: ZaiMessage,
466 #[serde(default)]
467 finish_reason: Option<String>,
468}
469
470#[derive(Debug, Deserialize)]
471struct ZaiMessage {
472 #[serde(default)]
473 content: Option<String>,
474 #[serde(default)]
475 tool_calls: Option<Vec<ZaiToolCall>>,
476 #[serde(default)]
477 reasoning_content: Option<String>,
478}
479
480#[derive(Debug, Deserialize)]
481struct ZaiToolCall {
482 id: String,
483 function: ZaiFunction,
484}
485
486#[derive(Debug, Deserialize)]
487struct ZaiFunction {
488 name: String,
489 arguments: Value,
490}
491
492#[derive(Debug, Deserialize)]
493struct ZaiUsage {
494 #[serde(default)]
495 prompt_tokens: usize,
496 #[serde(default)]
497 completion_tokens: usize,
498 #[serde(default)]
499 total_tokens: usize,
500 #[serde(default)]
501 prompt_tokens_details: Option<ZaiPromptTokensDetails>,
502}
503
504#[derive(Debug, Deserialize)]
505struct ZaiPromptTokensDetails {
506 #[serde(default)]
507 cached_tokens: usize,
508}
509
510#[derive(Debug, Deserialize)]
511struct ZaiError {
512 error: ZaiErrorDetail,
513}
514
515#[derive(Debug, Deserialize)]
516struct ZaiErrorDetail {
517 message: String,
518 #[serde(default, rename = "type")]
519 error_type: Option<String>,
520}
521
522#[derive(Debug, Deserialize)]
524struct ZaiStreamResponse {
525 choices: Vec<ZaiStreamChoice>,
526}
527
528#[derive(Debug, Deserialize)]
529struct ZaiStreamChoice {
530 delta: ZaiStreamDelta,
531 #[serde(default)]
532 finish_reason: Option<String>,
533}
534
535#[derive(Debug, Deserialize)]
536struct ZaiStreamDelta {
537 #[serde(default)]
538 content: Option<String>,
539 #[serde(default)]
540 reasoning_content: Option<String>,
541 #[serde(default)]
542 tool_calls: Option<Vec<ZaiStreamToolCall>>,
543}
544
545#[derive(Debug, Deserialize)]
546struct ZaiStreamToolCall {
547 #[serde(default)]
548 index: Option<usize>,
549 #[serde(default)]
550 id: Option<String>,
551 function: Option<ZaiStreamFunction>,
552}
553
554#[derive(Debug, Deserialize)]
555struct ZaiStreamFunction {
556 #[serde(default)]
557 name: Option<String>,
558 #[serde(default)]
559 arguments: Option<Value>,
560}
561
562#[async_trait]
563impl Provider for ZaiProvider {
564 fn name(&self) -> &str {
565 "zai"
566 }
567
568 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
569 let discovered = self.discover_models_from_api().await;
573 if !discovered.is_empty() {
574 let mut models = discovered;
577 if !models.iter().any(|m| m.id == PONY_ALPHA_2_MODEL) {
578 models.push(ModelInfo {
579 id: PONY_ALPHA_2_MODEL.to_string(),
580 name: "Pony Alpha 2".to_string(),
581 provider: "zai".to_string(),
582 context_window: 128_000,
583 max_output_tokens: Some(16_384),
584 supports_vision: false,
585 supports_tools: true,
586 supports_streaming: true,
587 input_cost_per_million: None,
588 output_cost_per_million: None,
589 });
590 }
591 if !models.iter().any(|m| m.id == "glm-4.7-flash") {
592 models.push(ModelInfo {
593 id: "glm-4.7-flash".to_string(),
594 name: "GLM-4.7 Flash".to_string(),
595 provider: "zai".to_string(),
596 context_window: 128_000,
597 max_output_tokens: Some(128_000),
598 supports_vision: false,
599 supports_tools: true,
600 supports_streaming: true,
601 input_cost_per_million: None,
602 output_cost_per_million: None,
603 });
604 }
605 return Ok(models);
606 }
607
608 Ok(vec![
611 ModelInfo {
612 id: "glm-5.1".to_string(),
613 name: "GLM-5.1".to_string(),
614 provider: "zai".to_string(),
615 context_window: 200_000,
616 max_output_tokens: Some(128_000),
617 supports_vision: false,
618 supports_tools: true,
619 supports_streaming: true,
620 input_cost_per_million: None,
621 output_cost_per_million: None,
622 },
623 ModelInfo {
624 id: "glm-5".to_string(),
625 name: "GLM-5".to_string(),
626 provider: "zai".to_string(),
627 context_window: 200_000,
628 max_output_tokens: Some(128_000),
629 supports_vision: false,
630 supports_tools: true,
631 supports_streaming: true,
632 input_cost_per_million: None,
633 output_cost_per_million: None,
634 },
635 ModelInfo {
636 id: "glm-4.7".to_string(),
637 name: "GLM-4.7".to_string(),
638 provider: "zai".to_string(),
639 context_window: 128_000,
640 max_output_tokens: Some(128_000),
641 supports_vision: false,
642 supports_tools: true,
643 supports_streaming: true,
644 input_cost_per_million: None,
645 output_cost_per_million: None,
646 },
647 ModelInfo {
648 id: "glm-4.7-flash".to_string(),
649 name: "GLM-4.7 Flash".to_string(),
650 provider: "zai".to_string(),
651 context_window: 128_000,
652 max_output_tokens: Some(128_000),
653 supports_vision: false,
654 supports_tools: true,
655 supports_streaming: true,
656 input_cost_per_million: None,
657 output_cost_per_million: None,
658 },
659 ModelInfo {
660 id: "glm-4.6".to_string(),
661 name: "GLM-4.6".to_string(),
662 provider: "zai".to_string(),
663 context_window: 128_000,
664 max_output_tokens: Some(128_000),
665 supports_vision: false,
666 supports_tools: true,
667 supports_streaming: true,
668 input_cost_per_million: None,
669 output_cost_per_million: None,
670 },
671 ModelInfo {
672 id: "glm-4.5".to_string(),
673 name: "GLM-4.5".to_string(),
674 provider: "zai".to_string(),
675 context_window: 128_000,
676 max_output_tokens: Some(96_000),
677 supports_vision: false,
678 supports_tools: true,
679 supports_streaming: true,
680 input_cost_per_million: None,
681 output_cost_per_million: None,
682 },
683 ModelInfo {
684 id: "glm-5-turbo".to_string(),
685 name: "GLM-5 Turbo".to_string(),
686 provider: "zai".to_string(),
687 context_window: 200_000,
688 max_output_tokens: Some(128_000),
689 supports_vision: false,
690 supports_tools: true,
691 supports_streaming: true,
692 input_cost_per_million: Some(0.96),
693 output_cost_per_million: Some(3.20),
694 },
695 ModelInfo {
696 id: PONY_ALPHA_2_MODEL.to_string(),
697 name: "Pony Alpha 2".to_string(),
698 provider: "zai".to_string(),
699 context_window: 128_000,
700 max_output_tokens: Some(16_384),
701 supports_vision: false,
702 supports_tools: true,
703 supports_streaming: true,
704 input_cost_per_million: None,
705 output_cost_per_million: None,
706 },
707 ])
708 }
709
710 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
711 let messages = Self::convert_messages(&request.messages, false);
715 let tools = Self::convert_tools(&request.tools);
716
717 let temperature = request.temperature.unwrap_or(1.0);
719
720 let mut body = json!({
721 "model": request.model,
722 "messages": messages,
723 "temperature": temperature,
724 });
725
726 body["thinking"] = json!({
732 "type": "enabled",
733 "clear_thinking": true
734 });
735
736 if !tools.is_empty() {
737 body["tools"] = json!(tools);
738 }
739 if let Some(max) = request.max_tokens {
740 body["max_tokens"] = json!(max);
741 }
742
743 tracing::debug!(model = %request.model, "Z.AI request");
744 tracing::trace!(body = %serde_json::to_string(&body).unwrap_or_default(), "Z.AI request body");
745 let request_base_url = self.request_base_url(&request.model);
746
747 let (text, status) = super::retry::send_with_retry(|| async {
748 let resp = self
749 .client
750 .post(format!("{}/chat/completions", request_base_url))
751 .header("Authorization", format!("Bearer {}", self.api_key))
752 .header("Content-Type", "application/json")
753 .json(&body)
754 .send()
755 .await
756 .context("Failed to send request to Z.AI")?;
757 let status = resp.status();
758 let text = resp.text().await.context("Failed to read Z.AI response")?;
759 Ok((text, status))
760 })
761 .await?;
762
763 if !status.is_success() {
764 tracing::debug!(status = %status, body = %text, "Z.AI error response");
765 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
766 anyhow::bail!(
767 "Z.AI API error: {} ({:?})",
768 err.error.message,
769 err.error.error_type
770 );
771 }
772 anyhow::bail!("Z.AI API error: {status} {text}");
773 }
774
775 let response: ZaiResponse = serde_json::from_str(&text).context(format!(
776 "Failed to parse Z.AI response: {}",
777 Self::preview_text(&text, 200)
778 ))?;
779
780 let choice = response
781 .choices
782 .first()
783 .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
784
785 if let Some(ref reasoning) = choice.message.reasoning_content
787 && !reasoning.is_empty()
788 {
789 tracing::info!(
790 reasoning_len = reasoning.len(),
791 "Z.AI reasoning content received"
792 );
793 }
794
795 let mut content = Vec::new();
796 let mut has_tool_calls = false;
797
798 if let Some(ref reasoning) = choice.message.reasoning_content
800 && !reasoning.is_empty()
801 {
802 content.push(ContentPart::Thinking {
803 text: reasoning.clone(),
804 });
805 }
806
807 if let Some(text) = &choice.message.content
808 && !text.is_empty()
809 {
810 content.push(ContentPart::Text { text: text.clone() });
811 }
812
813 if let Some(tool_calls) = &choice.message.tool_calls {
814 has_tool_calls = !tool_calls.is_empty();
815 for tc in tool_calls {
816 let arguments = match &tc.function.arguments {
818 Value::String(s) => s.clone(),
819 other => serde_json::to_string(other).unwrap_or_default(),
820 };
821 content.push(ContentPart::ToolCall {
822 id: tc.id.clone(),
823 name: tc.function.name.clone(),
824 arguments,
825 thought_signature: None,
826 });
827 }
828 }
829
830 let finish_reason = if has_tool_calls {
831 FinishReason::ToolCalls
832 } else {
833 match choice.finish_reason.as_deref() {
834 Some("stop") => FinishReason::Stop,
835 Some("length") => FinishReason::Length,
836 Some("tool_calls") => FinishReason::ToolCalls,
837 Some("sensitive") => FinishReason::ContentFilter,
838 _ => FinishReason::Stop,
839 }
840 };
841
842 Ok(CompletionResponse {
843 message: Message {
844 role: Role::Assistant,
845 content,
846 },
847 usage: Usage {
848 prompt_tokens: {
849 let u = response.usage.as_ref();
853 let total = u.map(|u| u.prompt_tokens).unwrap_or(0);
854 let cached = u
855 .and_then(|u| u.prompt_tokens_details.as_ref())
856 .map(|d| d.cached_tokens)
857 .unwrap_or(0);
858 total.saturating_sub(cached)
859 },
860 completion_tokens: response
861 .usage
862 .as_ref()
863 .map(|u| u.completion_tokens)
864 .unwrap_or(0),
865 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
866 cache_read_tokens: response
867 .usage
868 .as_ref()
869 .and_then(|u| u.prompt_tokens_details.as_ref())
870 .map(|d| d.cached_tokens)
871 .filter(|&t| t > 0),
872 cache_write_tokens: None,
873 },
874 finish_reason,
875 })
876 }
877
878 async fn complete_stream(
879 &self,
880 request: CompletionRequest,
881 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
882 let messages = Self::convert_messages(&request.messages, false);
886 let tools = Self::convert_tools(&request.tools);
887
888 let temperature = request.temperature.unwrap_or(1.0);
889
890 let mut body = json!({
891 "model": request.model,
892 "messages": messages,
893 "temperature": temperature,
894 "stream": true,
895 });
896
897 body["thinking"] = json!({
898 "type": "enabled",
899 "clear_thinking": true
900 });
901
902 if !tools.is_empty() {
903 body["tools"] = json!(tools);
904 if Self::model_supports_tool_stream(&request.model) {
905 body["tool_stream"] = json!(true);
907 }
908 }
909 if let Some(max) = request.max_tokens {
910 body["max_tokens"] = json!(max);
911 }
912
913 tracing::debug!(model = %request.model, "Z.AI streaming request");
914 let request_base_url = self.request_base_url(&request.model);
915
916 let response = super::retry::send_response_with_retry(|| async {
917 self.client
918 .post(format!("{}/chat/completions", request_base_url))
919 .header("Authorization", format!("Bearer {}", self.api_key))
920 .header("Content-Type", "application/json")
921 .json(&body)
922 .send()
923 .await
924 .context("Failed to send streaming request to Z.AI")
925 })
926 .await?;
927
928 let stream = response.bytes_stream();
929 let mut buffer = String::new();
930 let mut tool_states = HashMap::<usize, ZaiStreamToolState>::new();
931 let mut next_fallback_tool_index = 0usize;
932 let mut last_seen_tool_index = None;
933
934 Ok(stream
935 .flat_map(move |chunk_result| {
936 let mut chunks: Vec<StreamChunk> = Vec::new();
937 match chunk_result {
938 Ok(bytes) => {
939 let text = String::from_utf8_lossy(&bytes);
940 buffer.push_str(&text);
941
942 let mut text_buf = String::new();
943 while let Some(line_end) = buffer.find('\n') {
944 let line = buffer[..line_end].trim().to_string();
945 buffer = buffer[line_end + 1..].to_string();
946
947 if line == "data: [DONE]" {
948 if !text_buf.is_empty() {
949 chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
950 }
951 chunks.push(StreamChunk::Done { usage: None });
952 continue;
953 }
954 if let Some(data) = line.strip_prefix("data: ")
955 && let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
956 && let Some(choice) = parsed.choices.first()
957 {
958 if let Some(ref reasoning) = choice.delta.reasoning_content
960 && !reasoning.is_empty()
961 {
962 text_buf.push_str(reasoning);
963 }
964 if let Some(ref content) = choice.delta.content {
965 text_buf.push_str(content);
966 }
967 if let Some(ref tool_calls) = choice.delta.tool_calls {
969 if !text_buf.is_empty() {
970 chunks
971 .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
972 }
973 Self::append_stream_tool_call_chunks(
974 &mut chunks,
975 tool_calls,
976 &mut tool_states,
977 &mut next_fallback_tool_index,
978 &mut last_seen_tool_index,
979 );
980 }
981 if let Some(ref reason) = choice.finish_reason {
983 if !text_buf.is_empty() {
984 chunks
985 .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
986 }
987 if reason == "tool_calls" {
988 Self::finish_stream_tool_call_chunks(
989 &mut chunks,
990 &mut tool_states,
991 );
992 }
993 }
994 }
995 }
996 if !text_buf.is_empty() {
997 chunks.push(StreamChunk::Text(text_buf));
998 }
999 }
1000 Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
1001 }
1002 futures::stream::iter(chunks)
1003 })
1004 .boxed())
1005 }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010 use super::*;
1011 use crate::provider::Provider;
1012
1013 #[tokio::test]
1014 async fn list_models_includes_pony_alpha_2() {
1015 let provider =
1016 ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1017 .expect("provider should construct");
1018 let models = provider.list_models().await.expect("models should list");
1019
1020 assert!(models.iter().any(|model| model.id == PONY_ALPHA_2_MODEL));
1021 }
1022
1023 #[tokio::test]
1024 async fn list_models_includes_glm_5_turbo() {
1025 let provider =
1026 ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1027 .expect("provider should construct");
1028 let models = provider.list_models().await.expect("models should list");
1029
1030 let turbo = models
1031 .iter()
1032 .find(|m| m.id == "glm-5-turbo")
1033 .expect("glm-5-turbo should be in model list");
1034 assert_eq!(turbo.context_window, 200_000);
1035 assert_eq!(turbo.max_output_tokens, Some(128_000));
1036 assert!(turbo.supports_tools);
1037 assert!(turbo.supports_streaming);
1038 assert_eq!(turbo.input_cost_per_million, Some(0.96));
1039 assert_eq!(turbo.output_cost_per_million, Some(3.20));
1040 }
1041
1042 #[tokio::test]
1043 async fn list_models_includes_glm_5_1() {
1044 let provider =
1045 ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1046 .expect("provider should construct");
1047 let models = provider.list_models().await.expect("models should list");
1048
1049 let glm51 = models
1050 .iter()
1051 .find(|m| m.id == "glm-5.1")
1052 .expect("glm-5.1 should be in model list");
1053 assert_eq!(glm51.context_window, 200_000);
1054 assert_eq!(glm51.max_output_tokens, Some(128_000));
1055 assert!(glm51.supports_tools);
1056 assert!(glm51.supports_streaming);
1057 }
1058
1059 #[test]
1060 fn model_supports_tool_stream_matches_glm_5_1() {
1061 assert!(ZaiProvider::model_supports_tool_stream("glm-5.1"));
1062 assert!(ZaiProvider::model_supports_tool_stream("glm-5"));
1063 assert!(ZaiProvider::model_supports_tool_stream("glm-5-turbo"));
1064 assert!(!ZaiProvider::model_supports_tool_stream("glm-4.5"));
1065 }
1066
1067 #[test]
1068 fn pony_alpha_2_routes_to_coding_endpoint() {
1069 let provider =
1070 ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1071 .expect("provider should construct");
1072
1073 assert_eq!(
1074 provider.request_base_url(PONY_ALPHA_2_MODEL),
1075 CODING_BASE_URL
1076 );
1077 assert_eq!(provider.request_base_url("glm-5"), DEFAULT_BASE_URL);
1078 }
1079
1080 #[test]
1081 fn convert_messages_serializes_tool_arguments_as_json_string() {
1082 let messages = vec![Message {
1083 role: Role::Assistant,
1084 content: vec![ContentPart::ToolCall {
1085 id: "call_1".to_string(),
1086 name: "get_weather".to_string(),
1087 arguments: "{\"city\":\"Beijing\".. }".to_string(),
1088 thought_signature: None,
1089 }],
1090 }];
1091
1092 let converted = ZaiProvider::convert_messages(&messages, true);
1093 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1094 .as_str()
1095 .expect("arguments must be a string");
1096 let parsed: Value =
1097 serde_json::from_str(args).expect("arguments string must contain valid JSON");
1098
1099 assert_eq!(parsed, json!({"city":"Beijing"}));
1100 }
1101
1102 #[test]
1103 fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
1104 let messages = vec![Message {
1105 role: Role::Assistant,
1106 content: vec![ContentPart::ToolCall {
1107 id: "call_1".to_string(),
1108 name: "get_weather".to_string(),
1109 arguments: "city=Beijing".to_string(),
1110 thought_signature: None,
1111 }],
1112 }];
1113
1114 let converted = ZaiProvider::convert_messages(&messages, true);
1115 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1116 .as_str()
1117 .expect("arguments must be a string");
1118 let parsed: Value =
1119 serde_json::from_str(args).expect("arguments string must contain valid JSON");
1120
1121 assert_eq!(parsed, json!({"input":"city=Beijing"}));
1122 }
1123
1124 #[test]
1125 fn convert_messages_wraps_scalar_tool_arguments_as_json_string() {
1126 let messages = vec![Message {
1127 role: Role::Assistant,
1128 content: vec![ContentPart::ToolCall {
1129 id: "call_1".to_string(),
1130 name: "get_weather".to_string(),
1131 arguments: "\"Beijing\"".to_string(),
1132 thought_signature: None,
1133 }],
1134 }];
1135
1136 let converted = ZaiProvider::convert_messages(&messages, true);
1137 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1138 .as_str()
1139 .expect("arguments must be a string");
1140 let parsed: Value =
1141 serde_json::from_str(args).expect("arguments string must contain valid JSON");
1142
1143 assert_eq!(parsed, json!({"input":"Beijing"}));
1144 }
1145
1146 #[test]
1147 fn stream_tool_chunks_keep_same_call_id_when_followup_delta_omits_id() {
1148 let mut chunks = Vec::new();
1149 let mut tool_states = HashMap::new();
1150 let mut next_fallback_tool_index = 0usize;
1151 let mut last_seen_tool_index = None;
1152
1153 ZaiProvider::append_stream_tool_call_chunks(
1154 &mut chunks,
1155 &[ZaiStreamToolCall {
1156 index: Some(0),
1157 id: Some("call_1".to_string()),
1158 function: Some(ZaiStreamFunction {
1159 name: Some("bash".to_string()),
1160 arguments: Some(Value::String("{\"".to_string())),
1161 }),
1162 }],
1163 &mut tool_states,
1164 &mut next_fallback_tool_index,
1165 &mut last_seen_tool_index,
1166 );
1167
1168 ZaiProvider::append_stream_tool_call_chunks(
1169 &mut chunks,
1170 &[ZaiStreamToolCall {
1171 index: Some(0),
1172 id: None,
1173 function: Some(ZaiStreamFunction {
1174 name: None,
1175 arguments: Some(Value::String("command\":\"pwd\"}".to_string())),
1176 }),
1177 }],
1178 &mut tool_states,
1179 &mut next_fallback_tool_index,
1180 &mut last_seen_tool_index,
1181 );
1182
1183 ZaiProvider::finish_stream_tool_call_chunks(&mut chunks, &mut tool_states);
1184
1185 assert_eq!(chunks.len(), 4);
1186 assert!(matches!(
1187 &chunks[0],
1188 StreamChunk::ToolCallStart { id, name }
1189 if id == "call_1" && name == "bash"
1190 ));
1191 assert!(matches!(
1192 &chunks[1],
1193 StreamChunk::ToolCallDelta { id, arguments_delta }
1194 if id == "call_1" && arguments_delta == "{\""
1195 ));
1196 assert!(matches!(
1197 &chunks[2],
1198 StreamChunk::ToolCallDelta { id, arguments_delta }
1199 if id == "call_1" && arguments_delta == "command\":\"pwd\"}"
1200 ));
1201 assert!(matches!(
1202 &chunks[3],
1203 StreamChunk::ToolCallEnd { id } if id == "call_1"
1204 ));
1205 }
1206
1207 #[test]
1208 fn preview_text_truncates_on_char_boundary() {
1209 let text = "a😀b";
1210 assert_eq!(ZaiProvider::preview_text(text, 2), "a😀");
1211 }
1212}