1use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use once_cell::sync::Lazy;
15use regex::Regex;
16use reqwest::Client;
17use serde::Deserialize;
18use serde_json::{Value, json};
19use std::collections::HashMap;
20
21pub const DEFAULT_BASE_URL: &str = "https://api.z.ai/api/paas/v4";
22const CODING_BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4";
23const PONY_ALPHA_2_MODEL: &str = "pony-alpha-2";
24
25pub struct ZaiProvider {
26 client: Client,
27 api_key: String,
28 base_url: String,
29}
30
31fn should_disable_thinking_for_request(request: &CompletionRequest) -> bool {
32 if request
33 .messages
34 .iter()
35 .any(contains_explicit_short_answer_instruction)
36 {
37 return true;
38 }
39
40 request.max_tokens.is_some_and(|max| max <= 256)
41 && request.messages.iter().any(|message| {
42 message.content.iter().any(|part| match part {
43 ContentPart::Text { text } => text.chars().count() <= 512,
44 _ => false,
45 })
46 })
47}
48
49fn contains_explicit_short_answer_instruction(message: &Message) -> bool {
50 message.content.iter().any(|part| match part {
51 ContentPart::Text { text } => {
52 let lower = text.to_ascii_lowercase();
53 lower.contains("reply with exactly")
54 || lower.contains("say only")
55 || lower.contains("respond with exactly")
56 || lower.contains("return exactly")
57 }
58 _ => false,
59 })
60}
61
62fn thinking_config_for_request(request: &CompletionRequest) -> Value {
63 if should_disable_thinking_for_request(request) {
64 json!({"type": "disabled"})
65 } else {
66 json!({
67 "type": "enabled",
68 "clear_thinking": true
69 })
70 }
71}
72
73#[derive(Debug, Default)]
74struct ZaiStreamToolState {
75 stream_id: String,
76 name: Option<String>,
77 started: bool,
78 finished: bool,
79}
80
81impl std::fmt::Debug for ZaiProvider {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("ZaiProvider")
84 .field("base_url", &self.base_url)
85 .field("api_key", &"<REDACTED>")
86 .finish()
87 }
88}
89
90impl ZaiProvider {
91 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
92 tracing::debug!(
93 provider = "zai",
94 base_url = %base_url,
95 api_key_len = api_key.len(),
96 "Creating Z.AI provider with custom base URL"
97 );
98 Ok(Self {
99 client: crate::provider::shared_http::shared_client().clone(),
100 api_key,
101 base_url,
102 })
103 }
104
105 fn request_base_url(&self, model: &str) -> &str {
106 if model.eq_ignore_ascii_case(PONY_ALPHA_2_MODEL) {
107 CODING_BASE_URL
108 } else {
109 &self.base_url
110 }
111 }
112
113 async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
117 let discovery_url = if self.base_url.contains("/coding/") {
120 self.base_url.replace("/coding/", "/")
121 } else {
122 self.base_url.clone()
123 };
124 let url = format!("{discovery_url}/models");
125 let response = match self
126 .client
127 .get(&url)
128 .header("Authorization", format!("Bearer {}", self.api_key))
129 .send()
130 .await
131 {
132 Ok(r) => r,
133 Err(e) => {
134 tracing::debug!(
135 url = %url,
136 error = %e,
137 "Z.AI /models discovery request failed"
138 );
139 return Vec::new();
140 }
141 };
142
143 if !response.status().is_success() {
144 tracing::debug!(
145 url = %url,
146 status = %response.status(),
147 "Z.AI /models endpoint returned non-success"
148 );
149 return Vec::new();
150 }
151
152 let payload: Value = match response.json().await {
153 Ok(p) => p,
154 Err(e) => {
155 tracing::debug!(
156 url = %url,
157 error = %e,
158 "Failed to parse Z.AI /models response"
159 );
160 return Vec::new();
161 }
162 };
163
164 let models = payload
165 .get("data")
166 .and_then(Value::as_array)
167 .into_iter()
168 .flatten()
169 .filter_map(|entry| {
170 let id = match entry {
171 Value::String(s) => s.trim().to_string(),
172 Value::Object(_) => entry.get("id").and_then(Value::as_str)?.trim().to_string(),
173 _ => return None,
174 };
175 if id.is_empty() {
176 return None;
177 }
178 let name = entry
179 .get("name")
180 .and_then(Value::as_str)
181 .map(str::trim)
182 .filter(|n| !n.is_empty())
183 .unwrap_or(&id)
184 .to_string();
185 Some(ModelInfo {
186 id,
187 name,
188 provider: "zai".to_string(),
189 context_window: 200_000,
190 max_output_tokens: Some(128_000),
191 supports_vision: false,
192 supports_tools: true,
193 supports_streaming: true,
194 input_cost_per_million: None,
195 output_cost_per_million: None,
196 })
197 })
198 .collect::<Vec<_>>();
199
200 if models.is_empty() {
201 tracing::debug!(url = %url, "Z.AI /models returned no model ids");
202 } else {
203 tracing::info!(count = models.len(), "Z.AI /models discovery succeeded");
204 }
205 models
206 }
207
208 fn normalize_tool_arguments(arguments: &str) -> String {
209 if let Ok(parsed) = serde_json::from_str::<Value>(arguments) {
212 if parsed.is_object() {
213 return serde_json::to_string(&parsed).unwrap_or_else(|_| "{}".to_string());
214 }
215 return json!({"input": parsed}).to_string();
216 }
217
218 if let Some(salvaged) = Self::salvage_json_object(arguments) {
219 return serde_json::to_string(&salvaged).unwrap_or_else(|_| "{}".to_string());
220 }
221
222 json!({"input": arguments}).to_string()
223 }
224
225 fn salvage_json_object(arguments: &str) -> Option<Value> {
226 let trimmed = arguments.trim();
227 if !trimmed.starts_with('{') {
228 return None;
229 }
230
231 static RE_SIMPLE_PAIR: Lazy<Regex> = Lazy::new(|| {
232 Regex::new(
235 r#"(?s)\"(?P<k>[^\"\\]*(?:\\.[^\"\\]*)*)\"\s*:\s*(?P<v>\"(?:\\.|[^\"])*\"|true|false|null|-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)"#,
236 )
237 .expect("invalid regex")
238 });
239
240 let mut map = serde_json::Map::new();
241 for caps in RE_SIMPLE_PAIR.captures_iter(trimmed) {
242 let key = caps.name("k")?.as_str();
243 let val_str = caps.name("v")?.as_str();
244 if let Ok(val) = serde_json::from_str::<Value>(val_str) {
245 map.insert(key.to_string(), val);
246 }
247 }
248
249 if map.is_empty() {
250 None
251 } else {
252 Some(Value::Object(map))
253 }
254 }
255
256 fn convert_messages(messages: &[Message], include_reasoning_content: bool) -> Vec<Value> {
257 messages
258 .iter()
259 .map(|msg| {
260 let role = match msg.role {
261 Role::System => "system",
262 Role::User => "user",
263 Role::Assistant => "assistant",
264 Role::Tool => "tool",
265 };
266
267 match msg.role {
268 Role::Tool => {
269 if let Some(ContentPart::ToolResult {
270 tool_call_id,
271 content,
272 }) = msg.content.first()
273 {
274 json!({
275 "role": "tool",
276 "tool_call_id": tool_call_id,
277 "content": content
278 })
279 } else {
280 json!({"role": role, "content": ""})
281 }
282 }
283 Role::Assistant => {
284 let text: String = msg
285 .content
286 .iter()
287 .filter_map(|p| match p {
288 ContentPart::Text { text } => Some(text.clone()),
289 _ => None,
290 })
291 .collect::<Vec<_>>()
292 .join("");
293
294 let tool_calls: Vec<Value> = msg
295 .content
296 .iter()
297 .filter_map(|p| match p {
298 ContentPart::ToolCall {
299 id,
300 name,
301 arguments,
302 ..
303 } => {
304 let args_string = Self::normalize_tool_arguments(arguments);
305 Some(json!({
306 "id": id,
307 "type": "function",
308 "function": {
309 "name": name,
310 "arguments": args_string
311 }
312 }))
313 }
314 _ => None,
315 })
316 .collect();
317
318 let mut msg_json = if tool_calls.is_empty() {
319 json!({
320 "role": "assistant",
321 "content": text,
322 })
323 } else {
324 json!({
327 "role": "assistant",
328 "content": if text.is_empty() { Value::Null } else { json!(text) },
329 "tool_calls": tool_calls,
330 })
331 };
332 if include_reasoning_content {
333 let reasoning: String = msg
334 .content
335 .iter()
336 .filter_map(|p| match p {
337 ContentPart::Thinking { text } => Some(text.clone()),
338 _ => None,
339 })
340 .collect::<Vec<_>>()
341 .join("");
342 if !reasoning.is_empty() {
343 msg_json["reasoning_content"] = json!(reasoning);
344 }
345 }
346 msg_json
347 }
348 _ => {
349 let text: String = msg
350 .content
351 .iter()
352 .filter_map(|p| match p {
353 ContentPart::Text { text } => Some(text.clone()),
354 _ => None,
355 })
356 .collect::<Vec<_>>()
357 .join("\n");
358
359 json!({"role": role, "content": text})
360 }
361 }
362 })
363 .collect()
364 }
365
366 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
367 tools
368 .iter()
369 .map(|t| {
370 json!({
371 "type": "function",
372 "function": {
373 "name": t.name,
374 "description": t.description,
375 "parameters": t.parameters
376 }
377 })
378 })
379 .collect()
380 }
381
382 fn model_supports_tool_stream(model: &str) -> bool {
383 model.contains("glm-5")
384 || model.contains("glm-4.7")
385 || model.contains("glm-4.6")
386 || model.eq_ignore_ascii_case(PONY_ALPHA_2_MODEL)
387 }
388
389 fn preview_text(text: &str, max_chars: usize) -> &str {
390 if max_chars == 0 {
391 return "";
392 }
393 if let Some((idx, _)) = text.char_indices().nth(max_chars) {
394 &text[..idx]
395 } else {
396 text
397 }
398 }
399
400 fn stream_tool_arguments_fragment(arguments: &Value) -> String {
401 match arguments {
402 Value::Null => String::new(),
403 Value::String(s) => s.clone(),
404 other => serde_json::to_string(other).unwrap_or_default(),
405 }
406 }
407
408 fn append_stream_tool_call_chunks(
409 chunks: &mut Vec<StreamChunk>,
410 tool_calls: &[ZaiStreamToolCall],
411 tool_states: &mut HashMap<usize, ZaiStreamToolState>,
412 next_fallback_index: &mut usize,
413 last_seen_index: &mut Option<usize>,
414 ) {
415 for tc in tool_calls {
416 let index = tc
417 .index
418 .or_else(|| {
419 tc.id.as_ref().and_then(|id| {
420 tool_states
421 .iter()
422 .find_map(|(idx, state)| (state.stream_id == *id).then_some(*idx))
423 })
424 })
425 .or(*last_seen_index)
426 .unwrap_or_else(|| {
427 let idx = *next_fallback_index;
428 *next_fallback_index += 1;
429 idx
430 });
431 *last_seen_index = Some(index);
432
433 let state = tool_states
434 .entry(index)
435 .or_insert_with(|| ZaiStreamToolState {
436 stream_id: tc.id.clone().unwrap_or_else(|| format!("zai-tool-{index}")),
437 ..Default::default()
438 });
439
440 if let Some(id) = &tc.id
441 && !state.started
442 && state.stream_id.starts_with("zai-tool-")
443 {
444 state.stream_id = id.clone();
445 }
446
447 if let Some(func) = &tc.function {
448 if let Some(name) = &func.name
449 && !name.is_empty()
450 {
451 state.name = Some(name.clone());
452 }
453
454 if !state.started
455 && let Some(name) = &state.name
456 {
457 chunks.push(StreamChunk::ToolCallStart {
458 id: state.stream_id.clone(),
459 name: name.clone(),
460 });
461 state.started = true;
462 }
463
464 if let Some(arguments) = &func.arguments {
465 let delta = Self::stream_tool_arguments_fragment(arguments);
466 if !delta.is_empty() {
467 if !state.started {
468 chunks.push(StreamChunk::ToolCallStart {
469 id: state.stream_id.clone(),
470 name: state.name.clone().unwrap_or_else(|| "tool".to_string()),
471 });
472 state.started = true;
473 }
474 chunks.push(StreamChunk::ToolCallDelta {
475 id: state.stream_id.clone(),
476 arguments_delta: delta,
477 });
478 }
479 }
480 }
481 }
482 }
483
484 fn finish_stream_tool_call_chunks(
485 chunks: &mut Vec<StreamChunk>,
486 tool_states: &mut HashMap<usize, ZaiStreamToolState>,
487 ) {
488 let mut ordered_indexes: Vec<_> = tool_states.keys().copied().collect();
489 ordered_indexes.sort_unstable();
490
491 for index in ordered_indexes {
492 if let Some(state) = tool_states.get_mut(&index)
493 && state.started
494 && !state.finished
495 {
496 chunks.push(StreamChunk::ToolCallEnd {
497 id: state.stream_id.clone(),
498 });
499 state.finished = true;
500 }
501 }
502 }
503}
504
505#[derive(Debug, Deserialize)]
506struct ZaiResponse {
507 choices: Vec<ZaiChoice>,
508 #[serde(default)]
509 usage: Option<ZaiUsage>,
510}
511
512#[derive(Debug, Deserialize)]
513struct ZaiChoice {
514 message: ZaiMessage,
515 #[serde(default)]
516 finish_reason: Option<String>,
517}
518
519#[derive(Debug, Deserialize)]
520struct ZaiMessage {
521 #[serde(default)]
522 content: Option<String>,
523 #[serde(default)]
524 tool_calls: Option<Vec<ZaiToolCall>>,
525 #[serde(default)]
526 reasoning_content: Option<String>,
527}
528
529#[derive(Debug, Deserialize)]
530struct ZaiToolCall {
531 id: String,
532 function: ZaiFunction,
533}
534
535#[derive(Debug, Deserialize)]
536struct ZaiFunction {
537 name: String,
538 arguments: Value,
539}
540
541#[derive(Debug, Deserialize)]
542struct ZaiUsage {
543 #[serde(default)]
544 prompt_tokens: usize,
545 #[serde(default)]
546 completion_tokens: usize,
547 #[serde(default)]
548 total_tokens: usize,
549 #[serde(default)]
550 prompt_tokens_details: Option<ZaiPromptTokensDetails>,
551}
552
553#[derive(Debug, Deserialize)]
554struct ZaiPromptTokensDetails {
555 #[serde(default)]
556 cached_tokens: usize,
557}
558
559#[derive(Debug, Deserialize)]
560struct ZaiError {
561 error: ZaiErrorDetail,
562}
563
564#[derive(Debug, Deserialize)]
565struct ZaiErrorDetail {
566 message: String,
567 #[serde(default, rename = "type")]
568 error_type: Option<String>,
569}
570
571#[derive(Debug, Deserialize)]
573struct ZaiStreamResponse {
574 choices: Vec<ZaiStreamChoice>,
575}
576
577#[derive(Debug, Deserialize)]
578struct ZaiStreamChoice {
579 delta: ZaiStreamDelta,
580 #[serde(default)]
581 finish_reason: Option<String>,
582}
583
584#[derive(Debug, Deserialize)]
585struct ZaiStreamDelta {
586 #[serde(default)]
587 content: Option<String>,
588 #[serde(default)]
589 reasoning_content: Option<String>,
590 #[serde(default)]
591 tool_calls: Option<Vec<ZaiStreamToolCall>>,
592}
593
594#[derive(Debug, Deserialize)]
595struct ZaiStreamToolCall {
596 #[serde(default)]
597 index: Option<usize>,
598 #[serde(default)]
599 id: Option<String>,
600 function: Option<ZaiStreamFunction>,
601}
602
603#[derive(Debug, Deserialize)]
604struct ZaiStreamFunction {
605 #[serde(default)]
606 name: Option<String>,
607 #[serde(default)]
608 arguments: Option<Value>,
609}
610
611#[async_trait]
612impl Provider for ZaiProvider {
613 fn name(&self) -> &str {
614 "zai"
615 }
616
617 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
618 let discovered = self.discover_models_from_api().await;
622 if !discovered.is_empty() {
623 let mut models = discovered;
626 if !models.iter().any(|m| m.id == PONY_ALPHA_2_MODEL) {
627 models.push(ModelInfo {
628 id: PONY_ALPHA_2_MODEL.to_string(),
629 name: "Pony Alpha 2".to_string(),
630 provider: "zai".to_string(),
631 context_window: 128_000,
632 max_output_tokens: Some(16_384),
633 supports_vision: false,
634 supports_tools: true,
635 supports_streaming: true,
636 input_cost_per_million: None,
637 output_cost_per_million: None,
638 });
639 }
640 if !models.iter().any(|m| m.id == "glm-4.7-flash") {
641 models.push(ModelInfo {
642 id: "glm-4.7-flash".to_string(),
643 name: "GLM-4.7 Flash".to_string(),
644 provider: "zai".to_string(),
645 context_window: 128_000,
646 max_output_tokens: Some(128_000),
647 supports_vision: false,
648 supports_tools: true,
649 supports_streaming: true,
650 input_cost_per_million: None,
651 output_cost_per_million: None,
652 });
653 }
654 return Ok(models);
655 }
656
657 Ok(vec![
660 ModelInfo {
661 id: "glm-5.1".to_string(),
662 name: "GLM-5.1".to_string(),
663 provider: "zai".to_string(),
664 context_window: 200_000,
665 max_output_tokens: Some(128_000),
666 supports_vision: false,
667 supports_tools: true,
668 supports_streaming: true,
669 input_cost_per_million: None,
670 output_cost_per_million: None,
671 },
672 ModelInfo {
673 id: "glm-5".to_string(),
674 name: "GLM-5".to_string(),
675 provider: "zai".to_string(),
676 context_window: 200_000,
677 max_output_tokens: Some(128_000),
678 supports_vision: false,
679 supports_tools: true,
680 supports_streaming: true,
681 input_cost_per_million: None,
682 output_cost_per_million: None,
683 },
684 ModelInfo {
685 id: "glm-4.7".to_string(),
686 name: "GLM-4.7".to_string(),
687 provider: "zai".to_string(),
688 context_window: 128_000,
689 max_output_tokens: Some(128_000),
690 supports_vision: false,
691 supports_tools: true,
692 supports_streaming: true,
693 input_cost_per_million: None,
694 output_cost_per_million: None,
695 },
696 ModelInfo {
697 id: "glm-4.7-flash".to_string(),
698 name: "GLM-4.7 Flash".to_string(),
699 provider: "zai".to_string(),
700 context_window: 128_000,
701 max_output_tokens: Some(128_000),
702 supports_vision: false,
703 supports_tools: true,
704 supports_streaming: true,
705 input_cost_per_million: None,
706 output_cost_per_million: None,
707 },
708 ModelInfo {
709 id: "glm-4.6".to_string(),
710 name: "GLM-4.6".to_string(),
711 provider: "zai".to_string(),
712 context_window: 128_000,
713 max_output_tokens: Some(128_000),
714 supports_vision: false,
715 supports_tools: true,
716 supports_streaming: true,
717 input_cost_per_million: None,
718 output_cost_per_million: None,
719 },
720 ModelInfo {
721 id: "glm-4.5".to_string(),
722 name: "GLM-4.5".to_string(),
723 provider: "zai".to_string(),
724 context_window: 128_000,
725 max_output_tokens: Some(96_000),
726 supports_vision: false,
727 supports_tools: true,
728 supports_streaming: true,
729 input_cost_per_million: None,
730 output_cost_per_million: None,
731 },
732 ModelInfo {
733 id: "glm-5-turbo".to_string(),
734 name: "GLM-5 Turbo".to_string(),
735 provider: "zai".to_string(),
736 context_window: 200_000,
737 max_output_tokens: Some(128_000),
738 supports_vision: false,
739 supports_tools: true,
740 supports_streaming: true,
741 input_cost_per_million: Some(0.96),
742 output_cost_per_million: Some(3.20),
743 },
744 ModelInfo {
745 id: PONY_ALPHA_2_MODEL.to_string(),
746 name: "Pony Alpha 2".to_string(),
747 provider: "zai".to_string(),
748 context_window: 128_000,
749 max_output_tokens: Some(16_384),
750 supports_vision: false,
751 supports_tools: true,
752 supports_streaming: true,
753 input_cost_per_million: None,
754 output_cost_per_million: None,
755 },
756 ])
757 }
758
759 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
760 let messages = Self::convert_messages(&request.messages, false);
764 let tools = Self::convert_tools(&request.tools);
765
766 let temperature = request.temperature.unwrap_or(1.0);
768
769 let mut body = json!({
770 "model": request.model,
771 "messages": messages,
772 "temperature": temperature,
773 });
774
775 body["thinking"] = thinking_config_for_request(&request);
781
782 if !tools.is_empty() {
783 body["tools"] = json!(tools);
784 }
785 if let Some(max) = request.max_tokens {
786 body["max_tokens"] = json!(max);
787 }
788
789 tracing::debug!(model = %request.model, "Z.AI request");
790 tracing::trace!(body = %serde_json::to_string(&body).unwrap_or_default(), "Z.AI request body");
791 let request_base_url = self.request_base_url(&request.model);
792
793 let (text, status) = super::retry::send_with_retry(|| async {
794 let resp = self
795 .client
796 .post(format!("{}/chat/completions", request_base_url))
797 .header("Authorization", format!("Bearer {}", self.api_key))
798 .header("Content-Type", "application/json")
799 .json(&body)
800 .send()
801 .await
802 .context("Failed to send request to Z.AI")?;
803 let status = resp.status();
804 let text = resp.text().await.context("Failed to read Z.AI response")?;
805 Ok((text, status))
806 })
807 .await?;
808
809 if !status.is_success() {
810 tracing::debug!(status = %status, body = %text, "Z.AI error response");
811 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
812 anyhow::bail!(
813 "Z.AI API error: {} ({:?})",
814 err.error.message,
815 err.error.error_type
816 );
817 }
818 anyhow::bail!("Z.AI API error: {status} {text}");
819 }
820
821 let response: ZaiResponse = serde_json::from_str(&text).context(format!(
822 "Failed to parse Z.AI response: {}",
823 Self::preview_text(&text, 200)
824 ))?;
825
826 let choice = response
827 .choices
828 .first()
829 .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
830
831 if let Some(ref reasoning) = choice.message.reasoning_content
833 && !reasoning.is_empty()
834 {
835 tracing::info!(
836 reasoning_len = reasoning.len(),
837 "Z.AI reasoning content received"
838 );
839 }
840
841 let mut content = Vec::new();
842 let mut has_tool_calls = false;
843
844 if let Some(ref reasoning) = choice.message.reasoning_content
846 && !reasoning.is_empty()
847 {
848 content.push(ContentPart::Thinking {
849 text: reasoning.clone(),
850 });
851 }
852
853 if let Some(text) = &choice.message.content
854 && !text.is_empty()
855 {
856 content.push(ContentPart::Text { text: text.clone() });
857 }
858
859 if let Some(tool_calls) = &choice.message.tool_calls {
860 has_tool_calls = !tool_calls.is_empty();
861 for tc in tool_calls {
862 let arguments = match &tc.function.arguments {
864 Value::String(s) => s.clone(),
865 other => serde_json::to_string(other).unwrap_or_default(),
866 };
867 content.push(ContentPart::ToolCall {
868 id: tc.id.clone(),
869 name: tc.function.name.clone(),
870 arguments,
871 thought_signature: None,
872 });
873 }
874 }
875
876 let finish_reason = if has_tool_calls {
877 FinishReason::ToolCalls
878 } else {
879 match choice.finish_reason.as_deref() {
880 Some("stop") => FinishReason::Stop,
881 Some("length") => FinishReason::Length,
882 Some("tool_calls") => FinishReason::ToolCalls,
883 Some("sensitive") => FinishReason::ContentFilter,
884 _ => FinishReason::Stop,
885 }
886 };
887
888 Ok(CompletionResponse {
889 message: Message {
890 role: Role::Assistant,
891 content,
892 },
893 usage: Usage {
894 prompt_tokens: {
895 let u = response.usage.as_ref();
899 let total = u.map(|u| u.prompt_tokens).unwrap_or(0);
900 let cached = u
901 .and_then(|u| u.prompt_tokens_details.as_ref())
902 .map(|d| d.cached_tokens)
903 .unwrap_or(0);
904 total.saturating_sub(cached)
905 },
906 completion_tokens: response
907 .usage
908 .as_ref()
909 .map(|u| u.completion_tokens)
910 .unwrap_or(0),
911 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
912 cache_read_tokens: response
913 .usage
914 .as_ref()
915 .and_then(|u| u.prompt_tokens_details.as_ref())
916 .map(|d| d.cached_tokens)
917 .filter(|&t| t > 0),
918 cache_write_tokens: None,
919 },
920 finish_reason,
921 })
922 }
923
924 async fn complete_stream(
925 &self,
926 request: CompletionRequest,
927 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
928 let messages = Self::convert_messages(&request.messages, false);
932 let tools = Self::convert_tools(&request.tools);
933
934 let temperature = request.temperature.unwrap_or(1.0);
935
936 let mut body = json!({
937 "model": request.model,
938 "messages": messages,
939 "temperature": temperature,
940 "stream": true,
941 });
942
943 body["thinking"] = thinking_config_for_request(&request);
944
945 if !tools.is_empty() {
946 body["tools"] = json!(tools);
947 if Self::model_supports_tool_stream(&request.model) {
948 body["tool_stream"] = json!(true);
950 }
951 }
952 if let Some(max) = request.max_tokens {
953 body["max_tokens"] = json!(max);
954 }
955
956 tracing::debug!(model = %request.model, "Z.AI streaming request");
957 let request_base_url = self.request_base_url(&request.model);
958
959 let response = super::retry::send_response_with_retry(|| async {
960 self.client
961 .post(format!("{}/chat/completions", request_base_url))
962 .header("Authorization", format!("Bearer {}", self.api_key))
963 .header("Content-Type", "application/json")
964 .json(&body)
965 .send()
966 .await
967 .context("Failed to send streaming request to Z.AI")
968 })
969 .await?;
970
971 let stream = response.bytes_stream();
972 let mut buffer = String::new();
973 let mut tool_states = HashMap::<usize, ZaiStreamToolState>::new();
974 let mut next_fallback_tool_index = 0usize;
975 let mut last_seen_tool_index = None;
976
977 Ok(stream
978 .flat_map(move |chunk_result| {
979 let mut chunks: Vec<StreamChunk> = Vec::new();
980 match chunk_result {
981 Ok(bytes) => {
982 let text = String::from_utf8_lossy(&bytes);
983 buffer.push_str(&text);
984
985 let mut text_buf = String::new();
986 while let Some(line_end) = buffer.find('\n') {
987 let line = buffer[..line_end].trim().to_string();
988 buffer = buffer[line_end + 1..].to_string();
989
990 if line == "data: [DONE]" {
991 if !text_buf.is_empty() {
992 chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
993 }
994 chunks.push(StreamChunk::Done { usage: None });
995 continue;
996 }
997 if let Some(data) = line.strip_prefix("data: ")
998 && let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
999 && let Some(choice) = parsed.choices.first()
1000 {
1001 if let Some(ref reasoning) = choice.delta.reasoning_content
1003 && !reasoning.is_empty()
1004 {
1005 text_buf.push_str(reasoning);
1006 }
1007 if let Some(ref content) = choice.delta.content {
1008 text_buf.push_str(content);
1009 }
1010 if let Some(ref tool_calls) = choice.delta.tool_calls {
1012 if !text_buf.is_empty() {
1013 chunks
1014 .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
1015 }
1016 Self::append_stream_tool_call_chunks(
1017 &mut chunks,
1018 tool_calls,
1019 &mut tool_states,
1020 &mut next_fallback_tool_index,
1021 &mut last_seen_tool_index,
1022 );
1023 }
1024 if let Some(ref reason) = choice.finish_reason {
1026 if !text_buf.is_empty() {
1027 chunks
1028 .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
1029 }
1030 if reason == "tool_calls" {
1031 Self::finish_stream_tool_call_chunks(
1032 &mut chunks,
1033 &mut tool_states,
1034 );
1035 }
1036 }
1037 }
1038 }
1039 if !text_buf.is_empty() {
1040 chunks.push(StreamChunk::Text(text_buf));
1041 }
1042 }
1043 Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
1044 }
1045 futures::stream::iter(chunks)
1046 })
1047 .boxed())
1048 }
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053 use super::*;
1054 use crate::provider::Provider;
1055
1056 #[tokio::test]
1057 async fn list_models_includes_pony_alpha_2() {
1058 let provider =
1059 ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1060 .expect("provider should construct");
1061 let models = provider.list_models().await.expect("models should list");
1062
1063 assert!(models.iter().any(|model| model.id == PONY_ALPHA_2_MODEL));
1064 }
1065
1066 #[tokio::test]
1067 async fn list_models_includes_glm_5_turbo() {
1068 let provider =
1069 ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1070 .expect("provider should construct");
1071 let models = provider.list_models().await.expect("models should list");
1072
1073 let turbo = models
1074 .iter()
1075 .find(|m| m.id == "glm-5-turbo")
1076 .expect("glm-5-turbo should be in model list");
1077 assert_eq!(turbo.context_window, 200_000);
1078 assert_eq!(turbo.max_output_tokens, Some(128_000));
1079 assert!(turbo.supports_tools);
1080 assert!(turbo.supports_streaming);
1081 assert_eq!(turbo.input_cost_per_million, Some(0.96));
1082 assert_eq!(turbo.output_cost_per_million, Some(3.20));
1083 }
1084
1085 #[tokio::test]
1086 async fn list_models_includes_glm_5_1() {
1087 let provider =
1088 ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1089 .expect("provider should construct");
1090 let models = provider.list_models().await.expect("models should list");
1091
1092 let glm51 = models
1093 .iter()
1094 .find(|m| m.id == "glm-5.1")
1095 .expect("glm-5.1 should be in model list");
1096 assert_eq!(glm51.context_window, 200_000);
1097 assert_eq!(glm51.max_output_tokens, Some(128_000));
1098 assert!(glm51.supports_tools);
1099 assert!(glm51.supports_streaming);
1100 }
1101
1102 #[test]
1103 fn model_supports_tool_stream_matches_glm_5_1() {
1104 assert!(ZaiProvider::model_supports_tool_stream("glm-5.1"));
1105 assert!(ZaiProvider::model_supports_tool_stream("glm-5"));
1106 assert!(ZaiProvider::model_supports_tool_stream("glm-5-turbo"));
1107 assert!(!ZaiProvider::model_supports_tool_stream("glm-4.5"));
1108 }
1109
1110 #[test]
1111 fn pony_alpha_2_routes_to_coding_endpoint() {
1112 let provider =
1113 ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1114 .expect("provider should construct");
1115
1116 assert_eq!(
1117 provider.request_base_url(PONY_ALPHA_2_MODEL),
1118 CODING_BASE_URL
1119 );
1120 assert_eq!(provider.request_base_url("glm-5"), DEFAULT_BASE_URL);
1121 }
1122
1123 #[test]
1124 fn convert_messages_serializes_tool_arguments_as_json_string() {
1125 let messages = vec![Message {
1126 role: Role::Assistant,
1127 content: vec![ContentPart::ToolCall {
1128 id: "call_1".to_string(),
1129 name: "get_weather".to_string(),
1130 arguments: "{\"city\":\"Beijing\".. }".to_string(),
1131 thought_signature: None,
1132 }],
1133 }];
1134
1135 let converted = ZaiProvider::convert_messages(&messages, true);
1136 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1137 .as_str()
1138 .expect("arguments must be a string");
1139 let parsed: Value =
1140 serde_json::from_str(args).expect("arguments string must contain valid JSON");
1141
1142 assert_eq!(parsed, json!({"city":"Beijing"}));
1143 }
1144
1145 #[test]
1146 fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
1147 let messages = vec![Message {
1148 role: Role::Assistant,
1149 content: vec![ContentPart::ToolCall {
1150 id: "call_1".to_string(),
1151 name: "get_weather".to_string(),
1152 arguments: "city=Beijing".to_string(),
1153 thought_signature: None,
1154 }],
1155 }];
1156
1157 let converted = ZaiProvider::convert_messages(&messages, true);
1158 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1159 .as_str()
1160 .expect("arguments must be a string");
1161 let parsed: Value =
1162 serde_json::from_str(args).expect("arguments string must contain valid JSON");
1163
1164 assert_eq!(parsed, json!({"input":"city=Beijing"}));
1165 }
1166
1167 #[test]
1168 fn convert_messages_wraps_scalar_tool_arguments_as_json_string() {
1169 let messages = vec![Message {
1170 role: Role::Assistant,
1171 content: vec![ContentPart::ToolCall {
1172 id: "call_1".to_string(),
1173 name: "get_weather".to_string(),
1174 arguments: "\"Beijing\"".to_string(),
1175 thought_signature: None,
1176 }],
1177 }];
1178
1179 let converted = ZaiProvider::convert_messages(&messages, true);
1180 let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1181 .as_str()
1182 .expect("arguments must be a string");
1183 let parsed: Value =
1184 serde_json::from_str(args).expect("arguments string must contain valid JSON");
1185
1186 assert_eq!(parsed, json!({"input":"Beijing"}));
1187 }
1188
1189 #[test]
1190 fn convert_messages_uses_null_content_for_empty_assistant_tool_calls() {
1191 let messages = vec![Message {
1192 role: Role::Assistant,
1193 content: vec![ContentPart::ToolCall {
1194 id: "call_1".to_string(),
1195 name: "get_weather".to_string(),
1196 arguments: r#"{"city":"Beijing"}"#.to_string(),
1197 thought_signature: None,
1198 }],
1199 }];
1200
1201 let converted = ZaiProvider::convert_messages(&messages, false);
1202
1203 assert!(converted[0].get("tool_calls").is_some());
1204 assert_eq!(converted[0]["content"], Value::Null);
1205 }
1206
1207 #[test]
1208 fn convert_messages_keeps_text_content_for_assistant_tool_calls_with_text() {
1209 let messages = vec![Message {
1210 role: Role::Assistant,
1211 content: vec![
1212 ContentPart::Text {
1213 text: "Using a tool.".to_string(),
1214 },
1215 ContentPart::ToolCall {
1216 id: "call_1".to_string(),
1217 name: "get_weather".to_string(),
1218 arguments: r#"{"city":"Beijing"}"#.to_string(),
1219 thought_signature: None,
1220 },
1221 ],
1222 }];
1223
1224 let converted = ZaiProvider::convert_messages(&messages, false);
1225
1226 assert!(converted[0].get("tool_calls").is_some());
1227 assert_eq!(converted[0]["content"], json!("Using a tool."));
1228 }
1229
1230 #[test]
1231 fn short_exact_zai_requests_disable_thinking() {
1232 let request = CompletionRequest {
1233 messages: vec![Message {
1234 role: Role::User,
1235 content: vec![ContentPart::Text {
1236 text: "Reply with exactly: api-model-call-ok".to_string(),
1237 }],
1238 }],
1239 tools: vec![],
1240 model: "glm-5.1".to_string(),
1241 temperature: None,
1242 top_p: None,
1243 max_tokens: Some(20),
1244 stop: vec![],
1245 };
1246
1247 assert!(should_disable_thinking_for_request(&request));
1248 assert_eq!(
1249 thinking_config_for_request(&request),
1250 json!({"type": "disabled"})
1251 );
1252 }
1253
1254 #[test]
1255 fn long_low_budget_zai_requests_keep_thinking_enabled() {
1256 let request = CompletionRequest {
1257 messages: vec![Message {
1258 role: Role::User,
1259 content: vec![ContentPart::Text {
1260 text: format!(
1261 "Summarize the following material. {}",
1262 "context paragraph. ".repeat(40)
1263 ),
1264 }],
1265 }],
1266 tools: vec![],
1267 model: "glm-5.1".to_string(),
1268 temperature: None,
1269 top_p: None,
1270 max_tokens: Some(128),
1271 stop: vec![],
1272 };
1273
1274 assert!(!should_disable_thinking_for_request(&request));
1275 assert_eq!(thinking_config_for_request(&request)["type"], "enabled");
1276 }
1277
1278 #[test]
1279 fn normal_zai_requests_keep_clear_thinking_enabled() {
1280 let request = CompletionRequest {
1281 messages: vec![Message {
1282 role: Role::User,
1283 content: vec![ContentPart::Text {
1284 text: "Explain the tradeoffs.".to_string(),
1285 }],
1286 }],
1287 tools: vec![],
1288 model: "glm-5.1".to_string(),
1289 temperature: None,
1290 top_p: None,
1291 max_tokens: Some(1024),
1292 stop: vec![],
1293 };
1294
1295 assert!(!should_disable_thinking_for_request(&request));
1296 assert_eq!(thinking_config_for_request(&request)["type"], "enabled");
1297 assert_eq!(
1298 thinking_config_for_request(&request)["clear_thinking"],
1299 true
1300 );
1301 }
1302
1303 #[test]
1304 fn stream_tool_chunks_keep_same_call_id_when_followup_delta_omits_id() {
1305 let mut chunks = Vec::new();
1306 let mut tool_states = HashMap::new();
1307 let mut next_fallback_tool_index = 0usize;
1308 let mut last_seen_tool_index = None;
1309
1310 ZaiProvider::append_stream_tool_call_chunks(
1311 &mut chunks,
1312 &[ZaiStreamToolCall {
1313 index: Some(0),
1314 id: Some("call_1".to_string()),
1315 function: Some(ZaiStreamFunction {
1316 name: Some("bash".to_string()),
1317 arguments: Some(Value::String("{\"".to_string())),
1318 }),
1319 }],
1320 &mut tool_states,
1321 &mut next_fallback_tool_index,
1322 &mut last_seen_tool_index,
1323 );
1324
1325 ZaiProvider::append_stream_tool_call_chunks(
1326 &mut chunks,
1327 &[ZaiStreamToolCall {
1328 index: Some(0),
1329 id: None,
1330 function: Some(ZaiStreamFunction {
1331 name: None,
1332 arguments: Some(Value::String("command\":\"pwd\"}".to_string())),
1333 }),
1334 }],
1335 &mut tool_states,
1336 &mut next_fallback_tool_index,
1337 &mut last_seen_tool_index,
1338 );
1339
1340 ZaiProvider::finish_stream_tool_call_chunks(&mut chunks, &mut tool_states);
1341
1342 assert_eq!(chunks.len(), 4);
1343 assert!(matches!(
1344 &chunks[0],
1345 StreamChunk::ToolCallStart { id, name }
1346 if id == "call_1" && name == "bash"
1347 ));
1348 assert!(matches!(
1349 &chunks[1],
1350 StreamChunk::ToolCallDelta { id, arguments_delta }
1351 if id == "call_1" && arguments_delta == "{\""
1352 ));
1353 assert!(matches!(
1354 &chunks[2],
1355 StreamChunk::ToolCallDelta { id, arguments_delta }
1356 if id == "call_1" && arguments_delta == "command\":\"pwd\"}"
1357 ));
1358 assert!(matches!(
1359 &chunks[3],
1360 StreamChunk::ToolCallEnd { id } if id == "call_1"
1361 ));
1362 }
1363
1364 #[test]
1365 fn preview_text_truncates_on_char_boundary() {
1366 let text = "a😀b";
1367 assert_eq!(ZaiProvider::preview_text(text, 2), "a😀");
1368 }
1369}