Skip to main content

codetether_agent/provider/
zai.rs

1//! Z.AI provider implementation (direct API)
2//!
3//! GLM-5, GLM-4.7, and other Z.AI foundation models via api.z.ai.
4//! Z.AI (formerly ZhipuAI) offers OpenAI-compatible chat completions with
5//! reasoning/thinking support via the `reasoning_content` field.
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use once_cell::sync::Lazy;
15use regex::Regex;
16use reqwest::Client;
17use serde::Deserialize;
18use serde_json::{Value, json};
19use std::collections::HashMap;
20
21pub const DEFAULT_BASE_URL: &str = "https://api.z.ai/api/paas/v4";
22const CODING_BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4";
23const PONY_ALPHA_2_MODEL: &str = "pony-alpha-2";
24
25pub struct ZaiProvider {
26    client: Client,
27    api_key: String,
28    base_url: String,
29}
30
31fn should_disable_thinking_for_request(request: &CompletionRequest) -> bool {
32    if request
33        .messages
34        .iter()
35        .any(contains_explicit_short_answer_instruction)
36    {
37        return true;
38    }
39
40    request.max_tokens.is_some_and(|max| max <= 256)
41        && request.messages.iter().any(|message| {
42            message.content.iter().any(|part| match part {
43                ContentPart::Text { text } => text.chars().count() <= 512,
44                _ => false,
45            })
46        })
47}
48
49fn contains_explicit_short_answer_instruction(message: &Message) -> bool {
50    message.content.iter().any(|part| match part {
51        ContentPart::Text { text } => {
52            let lower = text.to_ascii_lowercase();
53            lower.contains("reply with exactly")
54                || lower.contains("say only")
55                || lower.contains("respond with exactly")
56                || lower.contains("return exactly")
57        }
58        _ => false,
59    })
60}
61
62fn thinking_config_for_request(request: &CompletionRequest) -> Value {
63    if should_disable_thinking_for_request(request) {
64        json!({"type": "disabled"})
65    } else {
66        json!({
67            "type": "enabled",
68            "clear_thinking": true
69        })
70    }
71}
72
73#[derive(Debug, Default)]
74struct ZaiStreamToolState {
75    stream_id: String,
76    name: Option<String>,
77    started: bool,
78    finished: bool,
79}
80
81impl std::fmt::Debug for ZaiProvider {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("ZaiProvider")
84            .field("base_url", &self.base_url)
85            .field("api_key", &"<REDACTED>")
86            .finish()
87    }
88}
89
90impl ZaiProvider {
91    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
92        tracing::debug!(
93            provider = "zai",
94            base_url = %base_url,
95            api_key_len = api_key.len(),
96            "Creating Z.AI provider with custom base URL"
97        );
98        Ok(Self {
99            client: crate::provider::shared_http::shared_client().clone(),
100            api_key,
101            base_url,
102        })
103    }
104
105    fn request_base_url(&self, model: &str) -> &str {
106        if model.eq_ignore_ascii_case(PONY_ALPHA_2_MODEL) {
107            CODING_BASE_URL
108        } else {
109            &self.base_url
110        }
111    }
112
113    /// Fetch available models from the Z.AI /models endpoint.
114    /// Returns an empty vec on any failure (network, parse, auth) so callers
115    /// can fall back to the hardcoded catalog.
116    async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
117        // Always hit the standard API endpoint for model discovery,
118        // even if base_url points at the coding endpoint.
119        let discovery_url = if self.base_url.contains("/coding/") {
120            self.base_url.replace("/coding/", "/")
121        } else {
122            self.base_url.clone()
123        };
124        let url = format!("{discovery_url}/models");
125        let response = match self
126            .client
127            .get(&url)
128            .header("Authorization", format!("Bearer {}", self.api_key))
129            .send()
130            .await
131        {
132            Ok(r) => r,
133            Err(e) => {
134                tracing::debug!(
135                    url = %url,
136                    error = %e,
137                    "Z.AI /models discovery request failed"
138                );
139                return Vec::new();
140            }
141        };
142
143        if !response.status().is_success() {
144            tracing::debug!(
145                url = %url,
146                status = %response.status(),
147                "Z.AI /models endpoint returned non-success"
148            );
149            return Vec::new();
150        }
151
152        let payload: Value = match response.json().await {
153            Ok(p) => p,
154            Err(e) => {
155                tracing::debug!(
156                    url = %url,
157                    error = %e,
158                    "Failed to parse Z.AI /models response"
159                );
160                return Vec::new();
161            }
162        };
163
164        let models = payload
165            .get("data")
166            .and_then(Value::as_array)
167            .into_iter()
168            .flatten()
169            .filter_map(|entry| {
170                let id = match entry {
171                    Value::String(s) => s.trim().to_string(),
172                    Value::Object(_) => entry.get("id").and_then(Value::as_str)?.trim().to_string(),
173                    _ => return None,
174                };
175                if id.is_empty() {
176                    return None;
177                }
178                let name = entry
179                    .get("name")
180                    .and_then(Value::as_str)
181                    .map(str::trim)
182                    .filter(|n| !n.is_empty())
183                    .unwrap_or(&id)
184                    .to_string();
185                Some(ModelInfo {
186                    id,
187                    name,
188                    provider: "zai".to_string(),
189                    context_window: 200_000,
190                    max_output_tokens: Some(128_000),
191                    supports_vision: false,
192                    supports_tools: true,
193                    supports_streaming: true,
194                    input_cost_per_million: None,
195                    output_cost_per_million: None,
196                })
197            })
198            .collect::<Vec<_>>();
199
200        if models.is_empty() {
201            tracing::debug!(url = %url, "Z.AI /models returned no model ids");
202        } else {
203            tracing::info!(count = models.len(), "Z.AI /models discovery succeeded");
204        }
205        models
206    }
207
208    fn normalize_tool_arguments(arguments: &str) -> String {
209        // The live Z.AI endpoint rejects object-typed historical arguments in
210        // assistant.tool_calls and accepts OpenAI-style JSON strings instead.
211        if let Ok(parsed) = serde_json::from_str::<Value>(arguments) {
212            if parsed.is_object() {
213                return serde_json::to_string(&parsed).unwrap_or_else(|_| "{}".to_string());
214            }
215            return json!({"input": parsed}).to_string();
216        }
217
218        if let Some(salvaged) = Self::salvage_json_object(arguments) {
219            return serde_json::to_string(&salvaged).unwrap_or_else(|_| "{}".to_string());
220        }
221
222        json!({"input": arguments}).to_string()
223    }
224
225    fn salvage_json_object(arguments: &str) -> Option<Value> {
226        let trimmed = arguments.trim();
227        if !trimmed.starts_with('{') {
228            return None;
229        }
230
231        static RE_SIMPLE_PAIR: Lazy<Regex> = Lazy::new(|| {
232            // Matches simple JSON key/value pairs where the value is a primitive
233            // or a quoted string. This is intentionally conservative.
234            Regex::new(
235                r#"(?s)\"(?P<k>[^\"\\]*(?:\\.[^\"\\]*)*)\"\s*:\s*(?P<v>\"(?:\\.|[^\"])*\"|true|false|null|-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)"#,
236            )
237            .expect("invalid regex")
238        });
239
240        let mut map = serde_json::Map::new();
241        for caps in RE_SIMPLE_PAIR.captures_iter(trimmed) {
242            let key = caps.name("k")?.as_str();
243            let val_str = caps.name("v")?.as_str();
244            if let Ok(val) = serde_json::from_str::<Value>(val_str) {
245                map.insert(key.to_string(), val);
246            }
247        }
248
249        if map.is_empty() {
250            None
251        } else {
252            Some(Value::Object(map))
253        }
254    }
255
256    fn convert_messages(messages: &[Message], include_reasoning_content: bool) -> Vec<Value> {
257        messages
258            .iter()
259            .map(|msg| {
260                let role = match msg.role {
261                    Role::System => "system",
262                    Role::User => "user",
263                    Role::Assistant => "assistant",
264                    Role::Tool => "tool",
265                };
266
267                match msg.role {
268                    Role::Tool => {
269                        if let Some(ContentPart::ToolResult {
270                            tool_call_id,
271                            content,
272                        }) = msg.content.first()
273                        {
274                            json!({
275                                "role": "tool",
276                                "tool_call_id": tool_call_id,
277                                "content": content
278                            })
279                        } else {
280                            json!({"role": role, "content": ""})
281                        }
282                    }
283                    Role::Assistant => {
284                        let text: String = msg
285                            .content
286                            .iter()
287                            .filter_map(|p| match p {
288                                ContentPart::Text { text } => Some(text.clone()),
289                                _ => None,
290                            })
291                            .collect::<Vec<_>>()
292                            .join("");
293
294                        let tool_calls: Vec<Value> = msg
295                            .content
296                            .iter()
297                            .filter_map(|p| match p {
298                                ContentPart::ToolCall {
299                                    id,
300                                    name,
301                                    arguments,
302                                    ..
303                                } => {
304                                    let args_string = Self::normalize_tool_arguments(arguments);
305                                    Some(json!({
306                                        "id": id,
307                                        "type": "function",
308                                        "function": {
309                                            "name": name,
310                                            "arguments": args_string
311                                        }
312                                    }))
313                                }
314                                _ => None,
315                            })
316                            .collect();
317
318                        let mut msg_json = if tool_calls.is_empty() {
319                            json!({
320                                "role": "assistant",
321                                "content": text,
322                            })
323                        } else {
324                            // Z.AI rejects content: "" when tool_calls is present;
325                            // use null instead (same pattern as openai_codex.rs).
326                            json!({
327                                "role": "assistant",
328                                "content": if text.is_empty() { Value::Null } else { json!(text) },
329                                "tool_calls": tool_calls,
330                            })
331                        };
332                        if include_reasoning_content {
333                            let reasoning: String = msg
334                                .content
335                                .iter()
336                                .filter_map(|p| match p {
337                                    ContentPart::Thinking { text } => Some(text.clone()),
338                                    _ => None,
339                                })
340                                .collect::<Vec<_>>()
341                                .join("");
342                            if !reasoning.is_empty() {
343                                msg_json["reasoning_content"] = json!(reasoning);
344                            }
345                        }
346                        msg_json
347                    }
348                    _ => {
349                        let text: String = msg
350                            .content
351                            .iter()
352                            .filter_map(|p| match p {
353                                ContentPart::Text { text } => Some(text.clone()),
354                                _ => None,
355                            })
356                            .collect::<Vec<_>>()
357                            .join("\n");
358
359                        json!({"role": role, "content": text})
360                    }
361                }
362            })
363            .collect()
364    }
365
366    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
367        tools
368            .iter()
369            .map(|t| {
370                json!({
371                    "type": "function",
372                    "function": {
373                        "name": t.name,
374                        "description": t.description,
375                        "parameters": t.parameters
376                    }
377                })
378            })
379            .collect()
380    }
381
382    fn model_supports_tool_stream(model: &str) -> bool {
383        model.contains("glm-5")
384            || model.contains("glm-4.7")
385            || model.contains("glm-4.6")
386            || model.eq_ignore_ascii_case(PONY_ALPHA_2_MODEL)
387    }
388
389    fn preview_text(text: &str, max_chars: usize) -> &str {
390        if max_chars == 0 {
391            return "";
392        }
393        if let Some((idx, _)) = text.char_indices().nth(max_chars) {
394            &text[..idx]
395        } else {
396            text
397        }
398    }
399
400    fn stream_tool_arguments_fragment(arguments: &Value) -> String {
401        match arguments {
402            Value::Null => String::new(),
403            Value::String(s) => s.clone(),
404            other => serde_json::to_string(other).unwrap_or_default(),
405        }
406    }
407
408    fn append_stream_tool_call_chunks(
409        chunks: &mut Vec<StreamChunk>,
410        tool_calls: &[ZaiStreamToolCall],
411        tool_states: &mut HashMap<usize, ZaiStreamToolState>,
412        next_fallback_index: &mut usize,
413        last_seen_index: &mut Option<usize>,
414    ) {
415        for tc in tool_calls {
416            let index = tc
417                .index
418                .or_else(|| {
419                    tc.id.as_ref().and_then(|id| {
420                        tool_states
421                            .iter()
422                            .find_map(|(idx, state)| (state.stream_id == *id).then_some(*idx))
423                    })
424                })
425                .or(*last_seen_index)
426                .unwrap_or_else(|| {
427                    let idx = *next_fallback_index;
428                    *next_fallback_index += 1;
429                    idx
430                });
431            *last_seen_index = Some(index);
432
433            let state = tool_states
434                .entry(index)
435                .or_insert_with(|| ZaiStreamToolState {
436                    stream_id: tc.id.clone().unwrap_or_else(|| format!("zai-tool-{index}")),
437                    ..Default::default()
438                });
439
440            if let Some(id) = &tc.id
441                && !state.started
442                && state.stream_id.starts_with("zai-tool-")
443            {
444                state.stream_id = id.clone();
445            }
446
447            if let Some(func) = &tc.function {
448                if let Some(name) = &func.name
449                    && !name.is_empty()
450                {
451                    state.name = Some(name.clone());
452                }
453
454                if !state.started
455                    && let Some(name) = &state.name
456                {
457                    chunks.push(StreamChunk::ToolCallStart {
458                        id: state.stream_id.clone(),
459                        name: name.clone(),
460                    });
461                    state.started = true;
462                }
463
464                if let Some(arguments) = &func.arguments {
465                    let delta = Self::stream_tool_arguments_fragment(arguments);
466                    if !delta.is_empty() {
467                        if !state.started {
468                            chunks.push(StreamChunk::ToolCallStart {
469                                id: state.stream_id.clone(),
470                                name: state.name.clone().unwrap_or_else(|| "tool".to_string()),
471                            });
472                            state.started = true;
473                        }
474                        chunks.push(StreamChunk::ToolCallDelta {
475                            id: state.stream_id.clone(),
476                            arguments_delta: delta,
477                        });
478                    }
479                }
480            }
481        }
482    }
483
484    fn finish_stream_tool_call_chunks(
485        chunks: &mut Vec<StreamChunk>,
486        tool_states: &mut HashMap<usize, ZaiStreamToolState>,
487    ) {
488        let mut ordered_indexes: Vec<_> = tool_states.keys().copied().collect();
489        ordered_indexes.sort_unstable();
490
491        for index in ordered_indexes {
492            if let Some(state) = tool_states.get_mut(&index)
493                && state.started
494                && !state.finished
495            {
496                chunks.push(StreamChunk::ToolCallEnd {
497                    id: state.stream_id.clone(),
498                });
499                state.finished = true;
500            }
501        }
502    }
503}
504
505#[derive(Debug, Deserialize)]
506struct ZaiResponse {
507    choices: Vec<ZaiChoice>,
508    #[serde(default)]
509    usage: Option<ZaiUsage>,
510}
511
512#[derive(Debug, Deserialize)]
513struct ZaiChoice {
514    message: ZaiMessage,
515    #[serde(default)]
516    finish_reason: Option<String>,
517}
518
519#[derive(Debug, Deserialize)]
520struct ZaiMessage {
521    #[serde(default)]
522    content: Option<String>,
523    #[serde(default)]
524    tool_calls: Option<Vec<ZaiToolCall>>,
525    #[serde(default)]
526    reasoning_content: Option<String>,
527}
528
529#[derive(Debug, Deserialize)]
530struct ZaiToolCall {
531    id: String,
532    function: ZaiFunction,
533}
534
535#[derive(Debug, Deserialize)]
536struct ZaiFunction {
537    name: String,
538    arguments: Value,
539}
540
541#[derive(Debug, Deserialize)]
542struct ZaiUsage {
543    #[serde(default)]
544    prompt_tokens: usize,
545    #[serde(default)]
546    completion_tokens: usize,
547    #[serde(default)]
548    total_tokens: usize,
549    #[serde(default)]
550    prompt_tokens_details: Option<ZaiPromptTokensDetails>,
551}
552
553#[derive(Debug, Deserialize)]
554struct ZaiPromptTokensDetails {
555    #[serde(default)]
556    cached_tokens: usize,
557}
558
559#[derive(Debug, Deserialize)]
560struct ZaiError {
561    error: ZaiErrorDetail,
562}
563
564#[derive(Debug, Deserialize)]
565struct ZaiErrorDetail {
566    message: String,
567    #[serde(default, rename = "type")]
568    error_type: Option<String>,
569}
570
571// SSE stream types
572#[derive(Debug, Deserialize)]
573struct ZaiStreamResponse {
574    choices: Vec<ZaiStreamChoice>,
575}
576
577#[derive(Debug, Deserialize)]
578struct ZaiStreamChoice {
579    delta: ZaiStreamDelta,
580    #[serde(default)]
581    finish_reason: Option<String>,
582}
583
584#[derive(Debug, Deserialize)]
585struct ZaiStreamDelta {
586    #[serde(default)]
587    content: Option<String>,
588    #[serde(default)]
589    reasoning_content: Option<String>,
590    #[serde(default)]
591    tool_calls: Option<Vec<ZaiStreamToolCall>>,
592}
593
594#[derive(Debug, Deserialize)]
595struct ZaiStreamToolCall {
596    #[serde(default)]
597    index: Option<usize>,
598    #[serde(default)]
599    id: Option<String>,
600    function: Option<ZaiStreamFunction>,
601}
602
603#[derive(Debug, Deserialize)]
604struct ZaiStreamFunction {
605    #[serde(default)]
606    name: Option<String>,
607    #[serde(default)]
608    arguments: Option<Value>,
609}
610
611#[async_trait]
612impl Provider for ZaiProvider {
613    fn name(&self) -> &str {
614        "zai"
615    }
616
617    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
618        // Attempt dynamic model discovery from the Z.AI /models endpoint.
619        // When the API is reachable, this returns the authoritative model
620        // catalog including newly-released models without a code change.
621        let discovered = self.discover_models_from_api().await;
622        if !discovered.is_empty() {
623            // Merge in special models that exist outside the /models API
624            // (coding endpoint models, etc.)
625            let mut models = discovered;
626            if !models.iter().any(|m| m.id == PONY_ALPHA_2_MODEL) {
627                models.push(ModelInfo {
628                    id: PONY_ALPHA_2_MODEL.to_string(),
629                    name: "Pony Alpha 2".to_string(),
630                    provider: "zai".to_string(),
631                    context_window: 128_000,
632                    max_output_tokens: Some(16_384),
633                    supports_vision: false,
634                    supports_tools: true,
635                    supports_streaming: true,
636                    input_cost_per_million: None,
637                    output_cost_per_million: None,
638                });
639            }
640            if !models.iter().any(|m| m.id == "glm-4.7-flash") {
641                models.push(ModelInfo {
642                    id: "glm-4.7-flash".to_string(),
643                    name: "GLM-4.7 Flash".to_string(),
644                    provider: "zai".to_string(),
645                    context_window: 128_000,
646                    max_output_tokens: Some(128_000),
647                    supports_vision: false,
648                    supports_tools: true,
649                    supports_streaming: true,
650                    input_cost_per_million: None,
651                    output_cost_per_million: None,
652                });
653            }
654            return Ok(models);
655        }
656
657        // Static catalog used when the /models endpoint is unavailable
658        // (e.g. network partition, auth failure, or non-standard deployments).
659        Ok(vec![
660            ModelInfo {
661                id: "glm-5.1".to_string(),
662                name: "GLM-5.1".to_string(),
663                provider: "zai".to_string(),
664                context_window: 200_000,
665                max_output_tokens: Some(128_000),
666                supports_vision: false,
667                supports_tools: true,
668                supports_streaming: true,
669                input_cost_per_million: None,
670                output_cost_per_million: None,
671            },
672            ModelInfo {
673                id: "glm-5".to_string(),
674                name: "GLM-5".to_string(),
675                provider: "zai".to_string(),
676                context_window: 200_000,
677                max_output_tokens: Some(128_000),
678                supports_vision: false,
679                supports_tools: true,
680                supports_streaming: true,
681                input_cost_per_million: None,
682                output_cost_per_million: None,
683            },
684            ModelInfo {
685                id: "glm-4.7".to_string(),
686                name: "GLM-4.7".to_string(),
687                provider: "zai".to_string(),
688                context_window: 128_000,
689                max_output_tokens: Some(128_000),
690                supports_vision: false,
691                supports_tools: true,
692                supports_streaming: true,
693                input_cost_per_million: None,
694                output_cost_per_million: None,
695            },
696            ModelInfo {
697                id: "glm-4.7-flash".to_string(),
698                name: "GLM-4.7 Flash".to_string(),
699                provider: "zai".to_string(),
700                context_window: 128_000,
701                max_output_tokens: Some(128_000),
702                supports_vision: false,
703                supports_tools: true,
704                supports_streaming: true,
705                input_cost_per_million: None,
706                output_cost_per_million: None,
707            },
708            ModelInfo {
709                id: "glm-4.6".to_string(),
710                name: "GLM-4.6".to_string(),
711                provider: "zai".to_string(),
712                context_window: 128_000,
713                max_output_tokens: Some(128_000),
714                supports_vision: false,
715                supports_tools: true,
716                supports_streaming: true,
717                input_cost_per_million: None,
718                output_cost_per_million: None,
719            },
720            ModelInfo {
721                id: "glm-4.5".to_string(),
722                name: "GLM-4.5".to_string(),
723                provider: "zai".to_string(),
724                context_window: 128_000,
725                max_output_tokens: Some(96_000),
726                supports_vision: false,
727                supports_tools: true,
728                supports_streaming: true,
729                input_cost_per_million: None,
730                output_cost_per_million: None,
731            },
732            ModelInfo {
733                id: "glm-5-turbo".to_string(),
734                name: "GLM-5 Turbo".to_string(),
735                provider: "zai".to_string(),
736                context_window: 200_000,
737                max_output_tokens: Some(128_000),
738                supports_vision: false,
739                supports_tools: true,
740                supports_streaming: true,
741                input_cost_per_million: Some(0.96),
742                output_cost_per_million: Some(3.20),
743            },
744            ModelInfo {
745                id: PONY_ALPHA_2_MODEL.to_string(),
746                name: "Pony Alpha 2".to_string(),
747                provider: "zai".to_string(),
748                context_window: 128_000,
749                max_output_tokens: Some(16_384),
750                supports_vision: false,
751                supports_tools: true,
752                supports_streaming: true,
753                input_cost_per_million: None,
754                output_cost_per_million: None,
755            },
756        ])
757    }
758
759    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
760        // Compatibility-first mode: omit historical reasoning_content from
761        // request messages to avoid strict parameter validation errors on some
762        // endpoint variants.
763        let messages = Self::convert_messages(&request.messages, false);
764        let tools = Self::convert_tools(&request.tools);
765
766        // GLM-5 and GLM-4.7 default to temperature 1.0
767        let temperature = request.temperature.unwrap_or(1.0);
768
769        let mut body = json!({
770            "model": request.model,
771            "messages": messages,
772            "temperature": temperature,
773        });
774
775        // Always enable thinking with clear_thinking: true.
776        // The Coding endpoint (api.z.ai/api/coding/paas/v4) defaults
777        // clear_thinking to false (Preserved Thinking), which requires
778        // reasoning_content in historical assistant messages. Since we strip
779        // reasoning_content, we must explicitly set clear_thinking: true.
780        body["thinking"] = thinking_config_for_request(&request);
781
782        if !tools.is_empty() {
783            body["tools"] = json!(tools);
784        }
785        if let Some(max) = request.max_tokens {
786            body["max_tokens"] = json!(max);
787        }
788
789        tracing::debug!(model = %request.model, "Z.AI request");
790        tracing::trace!(body = %serde_json::to_string(&body).unwrap_or_default(), "Z.AI request body");
791        let request_base_url = self.request_base_url(&request.model);
792
793        let (text, status) = super::retry::send_with_retry(|| async {
794            let resp = self
795                .client
796                .post(format!("{}/chat/completions", request_base_url))
797                .header("Authorization", format!("Bearer {}", self.api_key))
798                .header("Content-Type", "application/json")
799                .json(&body)
800                .send()
801                .await
802                .context("Failed to send request to Z.AI")?;
803            let status = resp.status();
804            let text = resp.text().await.context("Failed to read Z.AI response")?;
805            Ok((text, status))
806        })
807        .await?;
808
809        if !status.is_success() {
810            tracing::debug!(status = %status, body = %text, "Z.AI error response");
811            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
812                anyhow::bail!(
813                    "Z.AI API error: {} ({:?})",
814                    err.error.message,
815                    err.error.error_type
816                );
817            }
818            anyhow::bail!("Z.AI API error: {status} {text}");
819        }
820
821        let response: ZaiResponse = serde_json::from_str(&text).context(format!(
822            "Failed to parse Z.AI response: {}",
823            Self::preview_text(&text, 200)
824        ))?;
825
826        let choice = response
827            .choices
828            .first()
829            .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
830
831        // Log thinking/reasoning content if present
832        if let Some(ref reasoning) = choice.message.reasoning_content
833            && !reasoning.is_empty()
834        {
835            tracing::info!(
836                reasoning_len = reasoning.len(),
837                "Z.AI reasoning content received"
838            );
839        }
840
841        let mut content = Vec::new();
842        let mut has_tool_calls = false;
843
844        // Emit thinking content as a Thinking part
845        if let Some(ref reasoning) = choice.message.reasoning_content
846            && !reasoning.is_empty()
847        {
848            content.push(ContentPart::Thinking {
849                text: reasoning.clone(),
850            });
851        }
852
853        if let Some(text) = &choice.message.content
854            && !text.is_empty()
855        {
856            content.push(ContentPart::Text { text: text.clone() });
857        }
858
859        if let Some(tool_calls) = &choice.message.tool_calls {
860            has_tool_calls = !tool_calls.is_empty();
861            for tc in tool_calls {
862                // Z.AI returns arguments as an object; serialize to string for our ContentPart
863                let arguments = match &tc.function.arguments {
864                    Value::String(s) => s.clone(),
865                    other => serde_json::to_string(other).unwrap_or_default(),
866                };
867                content.push(ContentPart::ToolCall {
868                    id: tc.id.clone(),
869                    name: tc.function.name.clone(),
870                    arguments,
871                    thought_signature: None,
872                });
873            }
874        }
875
876        let finish_reason = if has_tool_calls {
877            FinishReason::ToolCalls
878        } else {
879            match choice.finish_reason.as_deref() {
880                Some("stop") => FinishReason::Stop,
881                Some("length") => FinishReason::Length,
882                Some("tool_calls") => FinishReason::ToolCalls,
883                Some("sensitive") => FinishReason::ContentFilter,
884                _ => FinishReason::Stop,
885            }
886        };
887
888        Ok(CompletionResponse {
889            message: Message {
890                role: Role::Assistant,
891                content,
892            },
893            usage: Usage {
894                prompt_tokens: {
895                    // Subtract cached input so the cost estimator does
896                    // not double-count: cache_read_tokens is billed
897                    // separately at the discounted rate.
898                    let u = response.usage.as_ref();
899                    let total = u.map(|u| u.prompt_tokens).unwrap_or(0);
900                    let cached = u
901                        .and_then(|u| u.prompt_tokens_details.as_ref())
902                        .map(|d| d.cached_tokens)
903                        .unwrap_or(0);
904                    total.saturating_sub(cached)
905                },
906                completion_tokens: response
907                    .usage
908                    .as_ref()
909                    .map(|u| u.completion_tokens)
910                    .unwrap_or(0),
911                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
912                cache_read_tokens: response
913                    .usage
914                    .as_ref()
915                    .and_then(|u| u.prompt_tokens_details.as_ref())
916                    .map(|d| d.cached_tokens)
917                    .filter(|&t| t > 0),
918                cache_write_tokens: None,
919            },
920            finish_reason,
921        })
922    }
923
924    async fn complete_stream(
925        &self,
926        request: CompletionRequest,
927    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
928        // Compatibility-first mode: omit historical reasoning_content from
929        // request messages to avoid strict parameter validation errors on some
930        // endpoint variants.
931        let messages = Self::convert_messages(&request.messages, false);
932        let tools = Self::convert_tools(&request.tools);
933
934        let temperature = request.temperature.unwrap_or(1.0);
935
936        let mut body = json!({
937            "model": request.model,
938            "messages": messages,
939            "temperature": temperature,
940            "stream": true,
941        });
942
943        body["thinking"] = thinking_config_for_request(&request);
944
945        if !tools.is_empty() {
946            body["tools"] = json!(tools);
947            if Self::model_supports_tool_stream(&request.model) {
948                // Enable streaming tool calls only on known-compatible models.
949                body["tool_stream"] = json!(true);
950            }
951        }
952        if let Some(max) = request.max_tokens {
953            body["max_tokens"] = json!(max);
954        }
955
956        tracing::debug!(model = %request.model, "Z.AI streaming request");
957        let request_base_url = self.request_base_url(&request.model);
958
959        let response = super::retry::send_response_with_retry(|| async {
960            self.client
961                .post(format!("{}/chat/completions", request_base_url))
962                .header("Authorization", format!("Bearer {}", self.api_key))
963                .header("Content-Type", "application/json")
964                .json(&body)
965                .send()
966                .await
967                .context("Failed to send streaming request to Z.AI")
968        })
969        .await?;
970
971        let stream = response.bytes_stream();
972        let mut buffer = String::new();
973        let mut tool_states = HashMap::<usize, ZaiStreamToolState>::new();
974        let mut next_fallback_tool_index = 0usize;
975        let mut last_seen_tool_index = None;
976
977        Ok(stream
978            .flat_map(move |chunk_result| {
979                let mut chunks: Vec<StreamChunk> = Vec::new();
980                match chunk_result {
981                    Ok(bytes) => {
982                        let text = String::from_utf8_lossy(&bytes);
983                        buffer.push_str(&text);
984
985                        let mut text_buf = String::new();
986                        while let Some(line_end) = buffer.find('\n') {
987                            let line = buffer[..line_end].trim().to_string();
988                            buffer = buffer[line_end + 1..].to_string();
989
990                            if line == "data: [DONE]" {
991                                if !text_buf.is_empty() {
992                                    chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
993                                }
994                                chunks.push(StreamChunk::Done { usage: None });
995                                continue;
996                            }
997                            if let Some(data) = line.strip_prefix("data: ")
998                                && let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
999                                && let Some(choice) = parsed.choices.first()
1000                            {
1001                                // Reasoning content streamed as text (prefixed for TUI rendering)
1002                                if let Some(ref reasoning) = choice.delta.reasoning_content
1003                                    && !reasoning.is_empty()
1004                                {
1005                                    text_buf.push_str(reasoning);
1006                                }
1007                                if let Some(ref content) = choice.delta.content {
1008                                    text_buf.push_str(content);
1009                                }
1010                                // Streaming tool calls
1011                                if let Some(ref tool_calls) = choice.delta.tool_calls {
1012                                    if !text_buf.is_empty() {
1013                                        chunks
1014                                            .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
1015                                    }
1016                                    Self::append_stream_tool_call_chunks(
1017                                        &mut chunks,
1018                                        tool_calls,
1019                                        &mut tool_states,
1020                                        &mut next_fallback_tool_index,
1021                                        &mut last_seen_tool_index,
1022                                    );
1023                                }
1024                                // finish_reason signals end of a tool call or completion
1025                                if let Some(ref reason) = choice.finish_reason {
1026                                    if !text_buf.is_empty() {
1027                                        chunks
1028                                            .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
1029                                    }
1030                                    if reason == "tool_calls" {
1031                                        Self::finish_stream_tool_call_chunks(
1032                                            &mut chunks,
1033                                            &mut tool_states,
1034                                        );
1035                                    }
1036                                }
1037                            }
1038                        }
1039                        if !text_buf.is_empty() {
1040                            chunks.push(StreamChunk::Text(text_buf));
1041                        }
1042                    }
1043                    Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
1044                }
1045                futures::stream::iter(chunks)
1046            })
1047            .boxed())
1048    }
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053    use super::*;
1054    use crate::provider::Provider;
1055
1056    #[tokio::test]
1057    async fn list_models_includes_pony_alpha_2() {
1058        let provider =
1059            ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1060                .expect("provider should construct");
1061        let models = provider.list_models().await.expect("models should list");
1062
1063        assert!(models.iter().any(|model| model.id == PONY_ALPHA_2_MODEL));
1064    }
1065
1066    #[tokio::test]
1067    async fn list_models_includes_glm_5_turbo() {
1068        let provider =
1069            ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1070                .expect("provider should construct");
1071        let models = provider.list_models().await.expect("models should list");
1072
1073        let turbo = models
1074            .iter()
1075            .find(|m| m.id == "glm-5-turbo")
1076            .expect("glm-5-turbo should be in model list");
1077        assert_eq!(turbo.context_window, 200_000);
1078        assert_eq!(turbo.max_output_tokens, Some(128_000));
1079        assert!(turbo.supports_tools);
1080        assert!(turbo.supports_streaming);
1081        assert_eq!(turbo.input_cost_per_million, Some(0.96));
1082        assert_eq!(turbo.output_cost_per_million, Some(3.20));
1083    }
1084
1085    #[tokio::test]
1086    async fn list_models_includes_glm_5_1() {
1087        let provider =
1088            ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1089                .expect("provider should construct");
1090        let models = provider.list_models().await.expect("models should list");
1091
1092        let glm51 = models
1093            .iter()
1094            .find(|m| m.id == "glm-5.1")
1095            .expect("glm-5.1 should be in model list");
1096        assert_eq!(glm51.context_window, 200_000);
1097        assert_eq!(glm51.max_output_tokens, Some(128_000));
1098        assert!(glm51.supports_tools);
1099        assert!(glm51.supports_streaming);
1100    }
1101
1102    #[test]
1103    fn model_supports_tool_stream_matches_glm_5_1() {
1104        assert!(ZaiProvider::model_supports_tool_stream("glm-5.1"));
1105        assert!(ZaiProvider::model_supports_tool_stream("glm-5"));
1106        assert!(ZaiProvider::model_supports_tool_stream("glm-5-turbo"));
1107        assert!(!ZaiProvider::model_supports_tool_stream("glm-4.5"));
1108    }
1109
1110    #[test]
1111    fn pony_alpha_2_routes_to_coding_endpoint() {
1112        let provider =
1113            ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1114                .expect("provider should construct");
1115
1116        assert_eq!(
1117            provider.request_base_url(PONY_ALPHA_2_MODEL),
1118            CODING_BASE_URL
1119        );
1120        assert_eq!(provider.request_base_url("glm-5"), DEFAULT_BASE_URL);
1121    }
1122
1123    #[test]
1124    fn convert_messages_serializes_tool_arguments_as_json_string() {
1125        let messages = vec![Message {
1126            role: Role::Assistant,
1127            content: vec![ContentPart::ToolCall {
1128                id: "call_1".to_string(),
1129                name: "get_weather".to_string(),
1130                arguments: "{\"city\":\"Beijing\".. }".to_string(),
1131                thought_signature: None,
1132            }],
1133        }];
1134
1135        let converted = ZaiProvider::convert_messages(&messages, true);
1136        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1137            .as_str()
1138            .expect("arguments must be a string");
1139        let parsed: Value =
1140            serde_json::from_str(args).expect("arguments string must contain valid JSON");
1141
1142        assert_eq!(parsed, json!({"city":"Beijing"}));
1143    }
1144
1145    #[test]
1146    fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
1147        let messages = vec![Message {
1148            role: Role::Assistant,
1149            content: vec![ContentPart::ToolCall {
1150                id: "call_1".to_string(),
1151                name: "get_weather".to_string(),
1152                arguments: "city=Beijing".to_string(),
1153                thought_signature: None,
1154            }],
1155        }];
1156
1157        let converted = ZaiProvider::convert_messages(&messages, true);
1158        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1159            .as_str()
1160            .expect("arguments must be a string");
1161        let parsed: Value =
1162            serde_json::from_str(args).expect("arguments string must contain valid JSON");
1163
1164        assert_eq!(parsed, json!({"input":"city=Beijing"}));
1165    }
1166
1167    #[test]
1168    fn convert_messages_wraps_scalar_tool_arguments_as_json_string() {
1169        let messages = vec![Message {
1170            role: Role::Assistant,
1171            content: vec![ContentPart::ToolCall {
1172                id: "call_1".to_string(),
1173                name: "get_weather".to_string(),
1174                arguments: "\"Beijing\"".to_string(),
1175                thought_signature: None,
1176            }],
1177        }];
1178
1179        let converted = ZaiProvider::convert_messages(&messages, true);
1180        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1181            .as_str()
1182            .expect("arguments must be a string");
1183        let parsed: Value =
1184            serde_json::from_str(args).expect("arguments string must contain valid JSON");
1185
1186        assert_eq!(parsed, json!({"input":"Beijing"}));
1187    }
1188
1189    #[test]
1190    fn convert_messages_uses_null_content_for_empty_assistant_tool_calls() {
1191        let messages = vec![Message {
1192            role: Role::Assistant,
1193            content: vec![ContentPart::ToolCall {
1194                id: "call_1".to_string(),
1195                name: "get_weather".to_string(),
1196                arguments: r#"{"city":"Beijing"}"#.to_string(),
1197                thought_signature: None,
1198            }],
1199        }];
1200
1201        let converted = ZaiProvider::convert_messages(&messages, false);
1202
1203        assert!(converted[0].get("tool_calls").is_some());
1204        assert_eq!(converted[0]["content"], Value::Null);
1205    }
1206
1207    #[test]
1208    fn convert_messages_keeps_text_content_for_assistant_tool_calls_with_text() {
1209        let messages = vec![Message {
1210            role: Role::Assistant,
1211            content: vec![
1212                ContentPart::Text {
1213                    text: "Using a tool.".to_string(),
1214                },
1215                ContentPart::ToolCall {
1216                    id: "call_1".to_string(),
1217                    name: "get_weather".to_string(),
1218                    arguments: r#"{"city":"Beijing"}"#.to_string(),
1219                    thought_signature: None,
1220                },
1221            ],
1222        }];
1223
1224        let converted = ZaiProvider::convert_messages(&messages, false);
1225
1226        assert!(converted[0].get("tool_calls").is_some());
1227        assert_eq!(converted[0]["content"], json!("Using a tool."));
1228    }
1229
1230    #[test]
1231    fn short_exact_zai_requests_disable_thinking() {
1232        let request = CompletionRequest {
1233            messages: vec![Message {
1234                role: Role::User,
1235                content: vec![ContentPart::Text {
1236                    text: "Reply with exactly: api-model-call-ok".to_string(),
1237                }],
1238            }],
1239            tools: vec![],
1240            model: "glm-5.1".to_string(),
1241            temperature: None,
1242            top_p: None,
1243            max_tokens: Some(20),
1244            stop: vec![],
1245        };
1246
1247        assert!(should_disable_thinking_for_request(&request));
1248        assert_eq!(
1249            thinking_config_for_request(&request),
1250            json!({"type": "disabled"})
1251        );
1252    }
1253
1254    #[test]
1255    fn long_low_budget_zai_requests_keep_thinking_enabled() {
1256        let request = CompletionRequest {
1257            messages: vec![Message {
1258                role: Role::User,
1259                content: vec![ContentPart::Text {
1260                    text: format!(
1261                        "Summarize the following material. {}",
1262                        "context paragraph. ".repeat(40)
1263                    ),
1264                }],
1265            }],
1266            tools: vec![],
1267            model: "glm-5.1".to_string(),
1268            temperature: None,
1269            top_p: None,
1270            max_tokens: Some(128),
1271            stop: vec![],
1272        };
1273
1274        assert!(!should_disable_thinking_for_request(&request));
1275        assert_eq!(thinking_config_for_request(&request)["type"], "enabled");
1276    }
1277
1278    #[test]
1279    fn normal_zai_requests_keep_clear_thinking_enabled() {
1280        let request = CompletionRequest {
1281            messages: vec![Message {
1282                role: Role::User,
1283                content: vec![ContentPart::Text {
1284                    text: "Explain the tradeoffs.".to_string(),
1285                }],
1286            }],
1287            tools: vec![],
1288            model: "glm-5.1".to_string(),
1289            temperature: None,
1290            top_p: None,
1291            max_tokens: Some(1024),
1292            stop: vec![],
1293        };
1294
1295        assert!(!should_disable_thinking_for_request(&request));
1296        assert_eq!(thinking_config_for_request(&request)["type"], "enabled");
1297        assert_eq!(
1298            thinking_config_for_request(&request)["clear_thinking"],
1299            true
1300        );
1301    }
1302
1303    #[test]
1304    fn stream_tool_chunks_keep_same_call_id_when_followup_delta_omits_id() {
1305        let mut chunks = Vec::new();
1306        let mut tool_states = HashMap::new();
1307        let mut next_fallback_tool_index = 0usize;
1308        let mut last_seen_tool_index = None;
1309
1310        ZaiProvider::append_stream_tool_call_chunks(
1311            &mut chunks,
1312            &[ZaiStreamToolCall {
1313                index: Some(0),
1314                id: Some("call_1".to_string()),
1315                function: Some(ZaiStreamFunction {
1316                    name: Some("bash".to_string()),
1317                    arguments: Some(Value::String("{\"".to_string())),
1318                }),
1319            }],
1320            &mut tool_states,
1321            &mut next_fallback_tool_index,
1322            &mut last_seen_tool_index,
1323        );
1324
1325        ZaiProvider::append_stream_tool_call_chunks(
1326            &mut chunks,
1327            &[ZaiStreamToolCall {
1328                index: Some(0),
1329                id: None,
1330                function: Some(ZaiStreamFunction {
1331                    name: None,
1332                    arguments: Some(Value::String("command\":\"pwd\"}".to_string())),
1333                }),
1334            }],
1335            &mut tool_states,
1336            &mut next_fallback_tool_index,
1337            &mut last_seen_tool_index,
1338        );
1339
1340        ZaiProvider::finish_stream_tool_call_chunks(&mut chunks, &mut tool_states);
1341
1342        assert_eq!(chunks.len(), 4);
1343        assert!(matches!(
1344            &chunks[0],
1345            StreamChunk::ToolCallStart { id, name }
1346                if id == "call_1" && name == "bash"
1347        ));
1348        assert!(matches!(
1349            &chunks[1],
1350            StreamChunk::ToolCallDelta { id, arguments_delta }
1351                if id == "call_1" && arguments_delta == "{\""
1352        ));
1353        assert!(matches!(
1354            &chunks[2],
1355            StreamChunk::ToolCallDelta { id, arguments_delta }
1356                if id == "call_1" && arguments_delta == "command\":\"pwd\"}"
1357        ));
1358        assert!(matches!(
1359            &chunks[3],
1360            StreamChunk::ToolCallEnd { id } if id == "call_1"
1361        ));
1362    }
1363
1364    #[test]
1365    fn preview_text_truncates_on_char_boundary() {
1366        let text = "a😀b";
1367        assert_eq!(ZaiProvider::preview_text(text, 2), "a😀");
1368    }
1369}