Skip to main content

codetether_agent/provider/
zai.rs

1//! Z.AI provider implementation (direct API)
2//!
3//! GLM-5, GLM-4.7, and other Z.AI foundation models via api.z.ai.
4//! Z.AI (formerly ZhipuAI) offers OpenAI-compatible chat completions with
5//! reasoning/thinking support via the `reasoning_content` field.
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use once_cell::sync::Lazy;
15use regex::Regex;
16use reqwest::Client;
17use serde::Deserialize;
18use serde_json::{Value, json};
19use std::collections::HashMap;
20
21pub const DEFAULT_BASE_URL: &str = "https://api.z.ai/api/paas/v4";
22const CODING_BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4";
23const PONY_ALPHA_2_MODEL: &str = "pony-alpha-2";
24
25pub struct ZaiProvider {
26    client: Client,
27    api_key: String,
28    base_url: String,
29}
30
31#[derive(Debug, Default)]
32struct ZaiStreamToolState {
33    stream_id: String,
34    name: Option<String>,
35    started: bool,
36    finished: bool,
37}
38
39impl std::fmt::Debug for ZaiProvider {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("ZaiProvider")
42            .field("base_url", &self.base_url)
43            .field("api_key", &"<REDACTED>")
44            .finish()
45    }
46}
47
48impl ZaiProvider {
49    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
50        tracing::debug!(
51            provider = "zai",
52            base_url = %base_url,
53            api_key_len = api_key.len(),
54            "Creating Z.AI provider with custom base URL"
55        );
56        Ok(Self {
57            client: crate::provider::shared_http::shared_client().clone(),
58            api_key,
59            base_url,
60        })
61    }
62
63    fn request_base_url(&self, model: &str) -> &str {
64        if model.eq_ignore_ascii_case(PONY_ALPHA_2_MODEL) {
65            CODING_BASE_URL
66        } else {
67            &self.base_url
68        }
69    }
70
71    /// Fetch available models from the Z.AI /models endpoint.
72    /// Returns an empty vec on any failure (network, parse, auth) so callers
73    /// can fall back to the hardcoded catalog.
74    async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
75        // Always hit the standard API endpoint for model discovery,
76        // even if base_url points at the coding endpoint.
77        let discovery_url = if self.base_url.contains("/coding/") {
78            self.base_url.replace("/coding/", "/")
79        } else {
80            self.base_url.clone()
81        };
82        let url = format!("{discovery_url}/models");
83        let response = match self
84            .client
85            .get(&url)
86            .header("Authorization", format!("Bearer {}", self.api_key))
87            .send()
88            .await
89        {
90            Ok(r) => r,
91            Err(e) => {
92                tracing::debug!(
93                    url = %url,
94                    error = %e,
95                    "Z.AI /models discovery request failed"
96                );
97                return Vec::new();
98            }
99        };
100
101        if !response.status().is_success() {
102            tracing::debug!(
103                url = %url,
104                status = %response.status(),
105                "Z.AI /models endpoint returned non-success"
106            );
107            return Vec::new();
108        }
109
110        let payload: Value = match response.json().await {
111            Ok(p) => p,
112            Err(e) => {
113                tracing::debug!(
114                    url = %url,
115                    error = %e,
116                    "Failed to parse Z.AI /models response"
117                );
118                return Vec::new();
119            }
120        };
121
122        let models = payload
123            .get("data")
124            .and_then(Value::as_array)
125            .into_iter()
126            .flatten()
127            .filter_map(|entry| {
128                let id = match entry {
129                    Value::String(s) => s.trim().to_string(),
130                    Value::Object(_) => entry.get("id").and_then(Value::as_str)?.trim().to_string(),
131                    _ => return None,
132                };
133                if id.is_empty() {
134                    return None;
135                }
136                let name = entry
137                    .get("name")
138                    .and_then(Value::as_str)
139                    .map(str::trim)
140                    .filter(|n| !n.is_empty())
141                    .unwrap_or(&id)
142                    .to_string();
143                Some(ModelInfo {
144                    id,
145                    name,
146                    provider: "zai".to_string(),
147                    context_window: 200_000,
148                    max_output_tokens: Some(128_000),
149                    supports_vision: false,
150                    supports_tools: true,
151                    supports_streaming: true,
152                    input_cost_per_million: None,
153                    output_cost_per_million: None,
154                })
155            })
156            .collect::<Vec<_>>();
157
158        if models.is_empty() {
159            tracing::debug!(url = %url, "Z.AI /models returned no model ids");
160        } else {
161            tracing::info!(count = models.len(), "Z.AI /models discovery succeeded");
162        }
163        models
164    }
165
166    fn normalize_tool_arguments(arguments: &str) -> String {
167        // The live Z.AI endpoint rejects object-typed historical arguments in
168        // assistant.tool_calls and accepts OpenAI-style JSON strings instead.
169        if let Ok(parsed) = serde_json::from_str::<Value>(arguments) {
170            if parsed.is_object() {
171                return serde_json::to_string(&parsed).unwrap_or_else(|_| "{}".to_string());
172            }
173            return json!({"input": parsed}).to_string();
174        }
175
176        if let Some(salvaged) = Self::salvage_json_object(arguments) {
177            return serde_json::to_string(&salvaged).unwrap_or_else(|_| "{}".to_string());
178        }
179
180        json!({"input": arguments}).to_string()
181    }
182
183    fn salvage_json_object(arguments: &str) -> Option<Value> {
184        let trimmed = arguments.trim();
185        if !trimmed.starts_with('{') {
186            return None;
187        }
188
189        static RE_SIMPLE_PAIR: Lazy<Regex> = Lazy::new(|| {
190            // Matches simple JSON key/value pairs where the value is a primitive
191            // or a quoted string. This is intentionally conservative.
192            Regex::new(
193                r#"(?s)\"(?P<k>[^\"\\]*(?:\\.[^\"\\]*)*)\"\s*:\s*(?P<v>\"(?:\\.|[^\"])*\"|true|false|null|-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)"#,
194            )
195            .expect("invalid regex")
196        });
197
198        let mut map = serde_json::Map::new();
199        for caps in RE_SIMPLE_PAIR.captures_iter(trimmed) {
200            let key = caps.name("k")?.as_str();
201            let val_str = caps.name("v")?.as_str();
202            if let Ok(val) = serde_json::from_str::<Value>(val_str) {
203                map.insert(key.to_string(), val);
204            }
205        }
206
207        if map.is_empty() {
208            None
209        } else {
210            Some(Value::Object(map))
211        }
212    }
213
214    fn convert_messages(messages: &[Message], include_reasoning_content: bool) -> Vec<Value> {
215        messages
216            .iter()
217            .map(|msg| {
218                let role = match msg.role {
219                    Role::System => "system",
220                    Role::User => "user",
221                    Role::Assistant => "assistant",
222                    Role::Tool => "tool",
223                };
224
225                match msg.role {
226                    Role::Tool => {
227                        if let Some(ContentPart::ToolResult {
228                            tool_call_id,
229                            content,
230                        }) = msg.content.first()
231                        {
232                            json!({
233                                "role": "tool",
234                                "tool_call_id": tool_call_id,
235                                "content": content
236                            })
237                        } else {
238                            json!({"role": role, "content": ""})
239                        }
240                    }
241                    Role::Assistant => {
242                        let text: String = msg
243                            .content
244                            .iter()
245                            .filter_map(|p| match p {
246                                ContentPart::Text { text } => Some(text.clone()),
247                                _ => None,
248                            })
249                            .collect::<Vec<_>>()
250                            .join("");
251
252                        let tool_calls: Vec<Value> = msg
253                            .content
254                            .iter()
255                            .filter_map(|p| match p {
256                                ContentPart::ToolCall {
257                                    id,
258                                    name,
259                                    arguments,
260                                    ..
261                                } => {
262                                    let args_string = Self::normalize_tool_arguments(arguments);
263                                    Some(json!({
264                                        "id": id,
265                                        "type": "function",
266                                        "function": {
267                                            "name": name,
268                                            "arguments": args_string
269                                        }
270                                    }))
271                                }
272                                _ => None,
273                            })
274                            .collect();
275
276                        let mut msg_json = json!({
277                            "role": "assistant",
278                            "content": text,
279                        });
280                        if include_reasoning_content {
281                            let reasoning: String = msg
282                                .content
283                                .iter()
284                                .filter_map(|p| match p {
285                                    ContentPart::Thinking { text } => Some(text.clone()),
286                                    _ => None,
287                                })
288                                .collect::<Vec<_>>()
289                                .join("");
290                            if !reasoning.is_empty() {
291                                msg_json["reasoning_content"] = json!(reasoning);
292                            }
293                        }
294                        if !tool_calls.is_empty() {
295                            msg_json["tool_calls"] = json!(tool_calls);
296                        }
297                        msg_json
298                    }
299                    _ => {
300                        let text: String = msg
301                            .content
302                            .iter()
303                            .filter_map(|p| match p {
304                                ContentPart::Text { text } => Some(text.clone()),
305                                _ => None,
306                            })
307                            .collect::<Vec<_>>()
308                            .join("\n");
309
310                        json!({"role": role, "content": text})
311                    }
312                }
313            })
314            .collect()
315    }
316
317    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
318        tools
319            .iter()
320            .map(|t| {
321                json!({
322                    "type": "function",
323                    "function": {
324                        "name": t.name,
325                        "description": t.description,
326                        "parameters": t.parameters
327                    }
328                })
329            })
330            .collect()
331    }
332
333    fn model_supports_tool_stream(model: &str) -> bool {
334        model.contains("glm-5")
335            || model.contains("glm-4.7")
336            || model.contains("glm-4.6")
337            || model.eq_ignore_ascii_case(PONY_ALPHA_2_MODEL)
338    }
339
340    fn preview_text(text: &str, max_chars: usize) -> &str {
341        if max_chars == 0 {
342            return "";
343        }
344        if let Some((idx, _)) = text.char_indices().nth(max_chars) {
345            &text[..idx]
346        } else {
347            text
348        }
349    }
350
351    fn stream_tool_arguments_fragment(arguments: &Value) -> String {
352        match arguments {
353            Value::Null => String::new(),
354            Value::String(s) => s.clone(),
355            other => serde_json::to_string(other).unwrap_or_default(),
356        }
357    }
358
359    fn append_stream_tool_call_chunks(
360        chunks: &mut Vec<StreamChunk>,
361        tool_calls: &[ZaiStreamToolCall],
362        tool_states: &mut HashMap<usize, ZaiStreamToolState>,
363        next_fallback_index: &mut usize,
364        last_seen_index: &mut Option<usize>,
365    ) {
366        for tc in tool_calls {
367            let index = tc
368                .index
369                .or_else(|| {
370                    tc.id.as_ref().and_then(|id| {
371                        tool_states
372                            .iter()
373                            .find_map(|(idx, state)| (state.stream_id == *id).then_some(*idx))
374                    })
375                })
376                .or(*last_seen_index)
377                .unwrap_or_else(|| {
378                    let idx = *next_fallback_index;
379                    *next_fallback_index += 1;
380                    idx
381                });
382            *last_seen_index = Some(index);
383
384            let state = tool_states
385                .entry(index)
386                .or_insert_with(|| ZaiStreamToolState {
387                    stream_id: tc.id.clone().unwrap_or_else(|| format!("zai-tool-{index}")),
388                    ..Default::default()
389                });
390
391            if let Some(id) = &tc.id
392                && !state.started
393                && state.stream_id.starts_with("zai-tool-")
394            {
395                state.stream_id = id.clone();
396            }
397
398            if let Some(func) = &tc.function {
399                if let Some(name) = &func.name
400                    && !name.is_empty()
401                {
402                    state.name = Some(name.clone());
403                }
404
405                if !state.started
406                    && let Some(name) = &state.name
407                {
408                    chunks.push(StreamChunk::ToolCallStart {
409                        id: state.stream_id.clone(),
410                        name: name.clone(),
411                    });
412                    state.started = true;
413                }
414
415                if let Some(arguments) = &func.arguments {
416                    let delta = Self::stream_tool_arguments_fragment(arguments);
417                    if !delta.is_empty() {
418                        if !state.started {
419                            chunks.push(StreamChunk::ToolCallStart {
420                                id: state.stream_id.clone(),
421                                name: state.name.clone().unwrap_or_else(|| "tool".to_string()),
422                            });
423                            state.started = true;
424                        }
425                        chunks.push(StreamChunk::ToolCallDelta {
426                            id: state.stream_id.clone(),
427                            arguments_delta: delta,
428                        });
429                    }
430                }
431            }
432        }
433    }
434
435    fn finish_stream_tool_call_chunks(
436        chunks: &mut Vec<StreamChunk>,
437        tool_states: &mut HashMap<usize, ZaiStreamToolState>,
438    ) {
439        let mut ordered_indexes: Vec<_> = tool_states.keys().copied().collect();
440        ordered_indexes.sort_unstable();
441
442        for index in ordered_indexes {
443            if let Some(state) = tool_states.get_mut(&index)
444                && state.started
445                && !state.finished
446            {
447                chunks.push(StreamChunk::ToolCallEnd {
448                    id: state.stream_id.clone(),
449                });
450                state.finished = true;
451            }
452        }
453    }
454}
455
456#[derive(Debug, Deserialize)]
457struct ZaiResponse {
458    choices: Vec<ZaiChoice>,
459    #[serde(default)]
460    usage: Option<ZaiUsage>,
461}
462
463#[derive(Debug, Deserialize)]
464struct ZaiChoice {
465    message: ZaiMessage,
466    #[serde(default)]
467    finish_reason: Option<String>,
468}
469
470#[derive(Debug, Deserialize)]
471struct ZaiMessage {
472    #[serde(default)]
473    content: Option<String>,
474    #[serde(default)]
475    tool_calls: Option<Vec<ZaiToolCall>>,
476    #[serde(default)]
477    reasoning_content: Option<String>,
478}
479
480#[derive(Debug, Deserialize)]
481struct ZaiToolCall {
482    id: String,
483    function: ZaiFunction,
484}
485
486#[derive(Debug, Deserialize)]
487struct ZaiFunction {
488    name: String,
489    arguments: Value,
490}
491
492#[derive(Debug, Deserialize)]
493struct ZaiUsage {
494    #[serde(default)]
495    prompt_tokens: usize,
496    #[serde(default)]
497    completion_tokens: usize,
498    #[serde(default)]
499    total_tokens: usize,
500    #[serde(default)]
501    prompt_tokens_details: Option<ZaiPromptTokensDetails>,
502}
503
504#[derive(Debug, Deserialize)]
505struct ZaiPromptTokensDetails {
506    #[serde(default)]
507    cached_tokens: usize,
508}
509
510#[derive(Debug, Deserialize)]
511struct ZaiError {
512    error: ZaiErrorDetail,
513}
514
515#[derive(Debug, Deserialize)]
516struct ZaiErrorDetail {
517    message: String,
518    #[serde(default, rename = "type")]
519    error_type: Option<String>,
520}
521
522// SSE stream types
523#[derive(Debug, Deserialize)]
524struct ZaiStreamResponse {
525    choices: Vec<ZaiStreamChoice>,
526}
527
528#[derive(Debug, Deserialize)]
529struct ZaiStreamChoice {
530    delta: ZaiStreamDelta,
531    #[serde(default)]
532    finish_reason: Option<String>,
533}
534
535#[derive(Debug, Deserialize)]
536struct ZaiStreamDelta {
537    #[serde(default)]
538    content: Option<String>,
539    #[serde(default)]
540    reasoning_content: Option<String>,
541    #[serde(default)]
542    tool_calls: Option<Vec<ZaiStreamToolCall>>,
543}
544
545#[derive(Debug, Deserialize)]
546struct ZaiStreamToolCall {
547    #[serde(default)]
548    index: Option<usize>,
549    #[serde(default)]
550    id: Option<String>,
551    function: Option<ZaiStreamFunction>,
552}
553
554#[derive(Debug, Deserialize)]
555struct ZaiStreamFunction {
556    #[serde(default)]
557    name: Option<String>,
558    #[serde(default)]
559    arguments: Option<Value>,
560}
561
562#[async_trait]
563impl Provider for ZaiProvider {
564    fn name(&self) -> &str {
565        "zai"
566    }
567
568    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
569        // Attempt dynamic model discovery from the Z.AI /models endpoint.
570        // When the API is reachable, this returns the authoritative model
571        // catalog including newly-released models without a code change.
572        let discovered = self.discover_models_from_api().await;
573        if !discovered.is_empty() {
574            // Merge in special models that exist outside the /models API
575            // (coding endpoint models, etc.)
576            let mut models = discovered;
577            if !models.iter().any(|m| m.id == PONY_ALPHA_2_MODEL) {
578                models.push(ModelInfo {
579                    id: PONY_ALPHA_2_MODEL.to_string(),
580                    name: "Pony Alpha 2".to_string(),
581                    provider: "zai".to_string(),
582                    context_window: 128_000,
583                    max_output_tokens: Some(16_384),
584                    supports_vision: false,
585                    supports_tools: true,
586                    supports_streaming: true,
587                    input_cost_per_million: None,
588                    output_cost_per_million: None,
589                });
590            }
591            if !models.iter().any(|m| m.id == "glm-4.7-flash") {
592                models.push(ModelInfo {
593                    id: "glm-4.7-flash".to_string(),
594                    name: "GLM-4.7 Flash".to_string(),
595                    provider: "zai".to_string(),
596                    context_window: 128_000,
597                    max_output_tokens: Some(128_000),
598                    supports_vision: false,
599                    supports_tools: true,
600                    supports_streaming: true,
601                    input_cost_per_million: None,
602                    output_cost_per_million: None,
603                });
604            }
605            return Ok(models);
606        }
607
608        // Static catalog used when the /models endpoint is unavailable
609        // (e.g. network partition, auth failure, or non-standard deployments).
610        Ok(vec![
611            ModelInfo {
612                id: "glm-5.1".to_string(),
613                name: "GLM-5.1".to_string(),
614                provider: "zai".to_string(),
615                context_window: 200_000,
616                max_output_tokens: Some(128_000),
617                supports_vision: false,
618                supports_tools: true,
619                supports_streaming: true,
620                input_cost_per_million: None,
621                output_cost_per_million: None,
622            },
623            ModelInfo {
624                id: "glm-5".to_string(),
625                name: "GLM-5".to_string(),
626                provider: "zai".to_string(),
627                context_window: 200_000,
628                max_output_tokens: Some(128_000),
629                supports_vision: false,
630                supports_tools: true,
631                supports_streaming: true,
632                input_cost_per_million: None,
633                output_cost_per_million: None,
634            },
635            ModelInfo {
636                id: "glm-4.7".to_string(),
637                name: "GLM-4.7".to_string(),
638                provider: "zai".to_string(),
639                context_window: 128_000,
640                max_output_tokens: Some(128_000),
641                supports_vision: false,
642                supports_tools: true,
643                supports_streaming: true,
644                input_cost_per_million: None,
645                output_cost_per_million: None,
646            },
647            ModelInfo {
648                id: "glm-4.7-flash".to_string(),
649                name: "GLM-4.7 Flash".to_string(),
650                provider: "zai".to_string(),
651                context_window: 128_000,
652                max_output_tokens: Some(128_000),
653                supports_vision: false,
654                supports_tools: true,
655                supports_streaming: true,
656                input_cost_per_million: None,
657                output_cost_per_million: None,
658            },
659            ModelInfo {
660                id: "glm-4.6".to_string(),
661                name: "GLM-4.6".to_string(),
662                provider: "zai".to_string(),
663                context_window: 128_000,
664                max_output_tokens: Some(128_000),
665                supports_vision: false,
666                supports_tools: true,
667                supports_streaming: true,
668                input_cost_per_million: None,
669                output_cost_per_million: None,
670            },
671            ModelInfo {
672                id: "glm-4.5".to_string(),
673                name: "GLM-4.5".to_string(),
674                provider: "zai".to_string(),
675                context_window: 128_000,
676                max_output_tokens: Some(96_000),
677                supports_vision: false,
678                supports_tools: true,
679                supports_streaming: true,
680                input_cost_per_million: None,
681                output_cost_per_million: None,
682            },
683            ModelInfo {
684                id: "glm-5-turbo".to_string(),
685                name: "GLM-5 Turbo".to_string(),
686                provider: "zai".to_string(),
687                context_window: 200_000,
688                max_output_tokens: Some(128_000),
689                supports_vision: false,
690                supports_tools: true,
691                supports_streaming: true,
692                input_cost_per_million: Some(0.96),
693                output_cost_per_million: Some(3.20),
694            },
695            ModelInfo {
696                id: PONY_ALPHA_2_MODEL.to_string(),
697                name: "Pony Alpha 2".to_string(),
698                provider: "zai".to_string(),
699                context_window: 128_000,
700                max_output_tokens: Some(16_384),
701                supports_vision: false,
702                supports_tools: true,
703                supports_streaming: true,
704                input_cost_per_million: None,
705                output_cost_per_million: None,
706            },
707        ])
708    }
709
710    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
711        // Compatibility-first mode: omit historical reasoning_content from
712        // request messages to avoid strict parameter validation errors on some
713        // endpoint variants.
714        let messages = Self::convert_messages(&request.messages, false);
715        let tools = Self::convert_tools(&request.tools);
716
717        // GLM-5 and GLM-4.7 default to temperature 1.0
718        let temperature = request.temperature.unwrap_or(1.0);
719
720        let mut body = json!({
721            "model": request.model,
722            "messages": messages,
723            "temperature": temperature,
724        });
725
726        // Always enable thinking with clear_thinking: true.
727        // The Coding endpoint (api.z.ai/api/coding/paas/v4) defaults
728        // clear_thinking to false (Preserved Thinking), which requires
729        // reasoning_content in historical assistant messages. Since we strip
730        // reasoning_content, we must explicitly set clear_thinking: true.
731        body["thinking"] = json!({
732            "type": "enabled",
733            "clear_thinking": true
734        });
735
736        if !tools.is_empty() {
737            body["tools"] = json!(tools);
738        }
739        if let Some(max) = request.max_tokens {
740            body["max_tokens"] = json!(max);
741        }
742
743        tracing::debug!(model = %request.model, "Z.AI request");
744        tracing::trace!(body = %serde_json::to_string(&body).unwrap_or_default(), "Z.AI request body");
745        let request_base_url = self.request_base_url(&request.model);
746
747        let (text, status) = super::retry::send_with_retry(|| async {
748            let resp = self
749                .client
750                .post(format!("{}/chat/completions", request_base_url))
751                .header("Authorization", format!("Bearer {}", self.api_key))
752                .header("Content-Type", "application/json")
753                .json(&body)
754                .send()
755                .await
756                .context("Failed to send request to Z.AI")?;
757            let status = resp.status();
758            let text = resp.text().await.context("Failed to read Z.AI response")?;
759            Ok((text, status))
760        })
761        .await?;
762
763        if !status.is_success() {
764            tracing::debug!(status = %status, body = %text, "Z.AI error response");
765            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
766                anyhow::bail!(
767                    "Z.AI API error: {} ({:?})",
768                    err.error.message,
769                    err.error.error_type
770                );
771            }
772            anyhow::bail!("Z.AI API error: {status} {text}");
773        }
774
775        let response: ZaiResponse = serde_json::from_str(&text).context(format!(
776            "Failed to parse Z.AI response: {}",
777            Self::preview_text(&text, 200)
778        ))?;
779
780        let choice = response
781            .choices
782            .first()
783            .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
784
785        // Log thinking/reasoning content if present
786        if let Some(ref reasoning) = choice.message.reasoning_content
787            && !reasoning.is_empty()
788        {
789            tracing::info!(
790                reasoning_len = reasoning.len(),
791                "Z.AI reasoning content received"
792            );
793        }
794
795        let mut content = Vec::new();
796        let mut has_tool_calls = false;
797
798        // Emit thinking content as a Thinking part
799        if let Some(ref reasoning) = choice.message.reasoning_content
800            && !reasoning.is_empty()
801        {
802            content.push(ContentPart::Thinking {
803                text: reasoning.clone(),
804            });
805        }
806
807        if let Some(text) = &choice.message.content
808            && !text.is_empty()
809        {
810            content.push(ContentPart::Text { text: text.clone() });
811        }
812
813        if let Some(tool_calls) = &choice.message.tool_calls {
814            has_tool_calls = !tool_calls.is_empty();
815            for tc in tool_calls {
816                // Z.AI returns arguments as an object; serialize to string for our ContentPart
817                let arguments = match &tc.function.arguments {
818                    Value::String(s) => s.clone(),
819                    other => serde_json::to_string(other).unwrap_or_default(),
820                };
821                content.push(ContentPart::ToolCall {
822                    id: tc.id.clone(),
823                    name: tc.function.name.clone(),
824                    arguments,
825                    thought_signature: None,
826                });
827            }
828        }
829
830        let finish_reason = if has_tool_calls {
831            FinishReason::ToolCalls
832        } else {
833            match choice.finish_reason.as_deref() {
834                Some("stop") => FinishReason::Stop,
835                Some("length") => FinishReason::Length,
836                Some("tool_calls") => FinishReason::ToolCalls,
837                Some("sensitive") => FinishReason::ContentFilter,
838                _ => FinishReason::Stop,
839            }
840        };
841
842        Ok(CompletionResponse {
843            message: Message {
844                role: Role::Assistant,
845                content,
846            },
847            usage: Usage {
848                prompt_tokens: {
849                    // Subtract cached input so the cost estimator does
850                    // not double-count: cache_read_tokens is billed
851                    // separately at the discounted rate.
852                    let u = response.usage.as_ref();
853                    let total = u.map(|u| u.prompt_tokens).unwrap_or(0);
854                    let cached = u
855                        .and_then(|u| u.prompt_tokens_details.as_ref())
856                        .map(|d| d.cached_tokens)
857                        .unwrap_or(0);
858                    total.saturating_sub(cached)
859                },
860                completion_tokens: response
861                    .usage
862                    .as_ref()
863                    .map(|u| u.completion_tokens)
864                    .unwrap_or(0),
865                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
866                cache_read_tokens: response
867                    .usage
868                    .as_ref()
869                    .and_then(|u| u.prompt_tokens_details.as_ref())
870                    .map(|d| d.cached_tokens)
871                    .filter(|&t| t > 0),
872                cache_write_tokens: None,
873            },
874            finish_reason,
875        })
876    }
877
878    async fn complete_stream(
879        &self,
880        request: CompletionRequest,
881    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
882        // Compatibility-first mode: omit historical reasoning_content from
883        // request messages to avoid strict parameter validation errors on some
884        // endpoint variants.
885        let messages = Self::convert_messages(&request.messages, false);
886        let tools = Self::convert_tools(&request.tools);
887
888        let temperature = request.temperature.unwrap_or(1.0);
889
890        let mut body = json!({
891            "model": request.model,
892            "messages": messages,
893            "temperature": temperature,
894            "stream": true,
895        });
896
897        body["thinking"] = json!({
898            "type": "enabled",
899            "clear_thinking": true
900        });
901
902        if !tools.is_empty() {
903            body["tools"] = json!(tools);
904            if Self::model_supports_tool_stream(&request.model) {
905                // Enable streaming tool calls only on known-compatible models.
906                body["tool_stream"] = json!(true);
907            }
908        }
909        if let Some(max) = request.max_tokens {
910            body["max_tokens"] = json!(max);
911        }
912
913        tracing::debug!(model = %request.model, "Z.AI streaming request");
914        let request_base_url = self.request_base_url(&request.model);
915
916        let response = super::retry::send_response_with_retry(|| async {
917            self.client
918                .post(format!("{}/chat/completions", request_base_url))
919                .header("Authorization", format!("Bearer {}", self.api_key))
920                .header("Content-Type", "application/json")
921                .json(&body)
922                .send()
923                .await
924                .context("Failed to send streaming request to Z.AI")
925        })
926        .await?;
927
928        let stream = response.bytes_stream();
929        let mut buffer = String::new();
930        let mut tool_states = HashMap::<usize, ZaiStreamToolState>::new();
931        let mut next_fallback_tool_index = 0usize;
932        let mut last_seen_tool_index = None;
933
934        Ok(stream
935            .flat_map(move |chunk_result| {
936                let mut chunks: Vec<StreamChunk> = Vec::new();
937                match chunk_result {
938                    Ok(bytes) => {
939                        let text = String::from_utf8_lossy(&bytes);
940                        buffer.push_str(&text);
941
942                        let mut text_buf = String::new();
943                        while let Some(line_end) = buffer.find('\n') {
944                            let line = buffer[..line_end].trim().to_string();
945                            buffer = buffer[line_end + 1..].to_string();
946
947                            if line == "data: [DONE]" {
948                                if !text_buf.is_empty() {
949                                    chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
950                                }
951                                chunks.push(StreamChunk::Done { usage: None });
952                                continue;
953                            }
954                            if let Some(data) = line.strip_prefix("data: ")
955                                && let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
956                                && let Some(choice) = parsed.choices.first()
957                            {
958                                // Reasoning content streamed as text (prefixed for TUI rendering)
959                                if let Some(ref reasoning) = choice.delta.reasoning_content
960                                    && !reasoning.is_empty()
961                                {
962                                    text_buf.push_str(reasoning);
963                                }
964                                if let Some(ref content) = choice.delta.content {
965                                    text_buf.push_str(content);
966                                }
967                                // Streaming tool calls
968                                if let Some(ref tool_calls) = choice.delta.tool_calls {
969                                    if !text_buf.is_empty() {
970                                        chunks
971                                            .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
972                                    }
973                                    Self::append_stream_tool_call_chunks(
974                                        &mut chunks,
975                                        tool_calls,
976                                        &mut tool_states,
977                                        &mut next_fallback_tool_index,
978                                        &mut last_seen_tool_index,
979                                    );
980                                }
981                                // finish_reason signals end of a tool call or completion
982                                if let Some(ref reason) = choice.finish_reason {
983                                    if !text_buf.is_empty() {
984                                        chunks
985                                            .push(StreamChunk::Text(std::mem::take(&mut text_buf)));
986                                    }
987                                    if reason == "tool_calls" {
988                                        Self::finish_stream_tool_call_chunks(
989                                            &mut chunks,
990                                            &mut tool_states,
991                                        );
992                                    }
993                                }
994                            }
995                        }
996                        if !text_buf.is_empty() {
997                            chunks.push(StreamChunk::Text(text_buf));
998                        }
999                    }
1000                    Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
1001                }
1002                futures::stream::iter(chunks)
1003            })
1004            .boxed())
1005    }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011    use crate::provider::Provider;
1012
1013    #[tokio::test]
1014    async fn list_models_includes_pony_alpha_2() {
1015        let provider =
1016            ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1017                .expect("provider should construct");
1018        let models = provider.list_models().await.expect("models should list");
1019
1020        assert!(models.iter().any(|model| model.id == PONY_ALPHA_2_MODEL));
1021    }
1022
1023    #[tokio::test]
1024    async fn list_models_includes_glm_5_turbo() {
1025        let provider =
1026            ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1027                .expect("provider should construct");
1028        let models = provider.list_models().await.expect("models should list");
1029
1030        let turbo = models
1031            .iter()
1032            .find(|m| m.id == "glm-5-turbo")
1033            .expect("glm-5-turbo should be in model list");
1034        assert_eq!(turbo.context_window, 200_000);
1035        assert_eq!(turbo.max_output_tokens, Some(128_000));
1036        assert!(turbo.supports_tools);
1037        assert!(turbo.supports_streaming);
1038        assert_eq!(turbo.input_cost_per_million, Some(0.96));
1039        assert_eq!(turbo.output_cost_per_million, Some(3.20));
1040    }
1041
1042    #[tokio::test]
1043    async fn list_models_includes_glm_5_1() {
1044        let provider =
1045            ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1046                .expect("provider should construct");
1047        let models = provider.list_models().await.expect("models should list");
1048
1049        let glm51 = models
1050            .iter()
1051            .find(|m| m.id == "glm-5.1")
1052            .expect("glm-5.1 should be in model list");
1053        assert_eq!(glm51.context_window, 200_000);
1054        assert_eq!(glm51.max_output_tokens, Some(128_000));
1055        assert!(glm51.supports_tools);
1056        assert!(glm51.supports_streaming);
1057    }
1058
1059    #[test]
1060    fn model_supports_tool_stream_matches_glm_5_1() {
1061        assert!(ZaiProvider::model_supports_tool_stream("glm-5.1"));
1062        assert!(ZaiProvider::model_supports_tool_stream("glm-5"));
1063        assert!(ZaiProvider::model_supports_tool_stream("glm-5-turbo"));
1064        assert!(!ZaiProvider::model_supports_tool_stream("glm-4.5"));
1065    }
1066
1067    #[test]
1068    fn pony_alpha_2_routes_to_coding_endpoint() {
1069        let provider =
1070            ZaiProvider::with_base_url("test-key".to_string(), DEFAULT_BASE_URL.to_string())
1071                .expect("provider should construct");
1072
1073        assert_eq!(
1074            provider.request_base_url(PONY_ALPHA_2_MODEL),
1075            CODING_BASE_URL
1076        );
1077        assert_eq!(provider.request_base_url("glm-5"), DEFAULT_BASE_URL);
1078    }
1079
1080    #[test]
1081    fn convert_messages_serializes_tool_arguments_as_json_string() {
1082        let messages = vec![Message {
1083            role: Role::Assistant,
1084            content: vec![ContentPart::ToolCall {
1085                id: "call_1".to_string(),
1086                name: "get_weather".to_string(),
1087                arguments: "{\"city\":\"Beijing\".. }".to_string(),
1088                thought_signature: None,
1089            }],
1090        }];
1091
1092        let converted = ZaiProvider::convert_messages(&messages, true);
1093        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1094            .as_str()
1095            .expect("arguments must be a string");
1096        let parsed: Value =
1097            serde_json::from_str(args).expect("arguments string must contain valid JSON");
1098
1099        assert_eq!(parsed, json!({"city":"Beijing"}));
1100    }
1101
1102    #[test]
1103    fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
1104        let messages = vec![Message {
1105            role: Role::Assistant,
1106            content: vec![ContentPart::ToolCall {
1107                id: "call_1".to_string(),
1108                name: "get_weather".to_string(),
1109                arguments: "city=Beijing".to_string(),
1110                thought_signature: None,
1111            }],
1112        }];
1113
1114        let converted = ZaiProvider::convert_messages(&messages, true);
1115        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1116            .as_str()
1117            .expect("arguments must be a string");
1118        let parsed: Value =
1119            serde_json::from_str(args).expect("arguments string must contain valid JSON");
1120
1121        assert_eq!(parsed, json!({"input":"city=Beijing"}));
1122    }
1123
1124    #[test]
1125    fn convert_messages_wraps_scalar_tool_arguments_as_json_string() {
1126        let messages = vec![Message {
1127            role: Role::Assistant,
1128            content: vec![ContentPart::ToolCall {
1129                id: "call_1".to_string(),
1130                name: "get_weather".to_string(),
1131                arguments: "\"Beijing\"".to_string(),
1132                thought_signature: None,
1133            }],
1134        }];
1135
1136        let converted = ZaiProvider::convert_messages(&messages, true);
1137        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
1138            .as_str()
1139            .expect("arguments must be a string");
1140        let parsed: Value =
1141            serde_json::from_str(args).expect("arguments string must contain valid JSON");
1142
1143        assert_eq!(parsed, json!({"input":"Beijing"}));
1144    }
1145
1146    #[test]
1147    fn stream_tool_chunks_keep_same_call_id_when_followup_delta_omits_id() {
1148        let mut chunks = Vec::new();
1149        let mut tool_states = HashMap::new();
1150        let mut next_fallback_tool_index = 0usize;
1151        let mut last_seen_tool_index = None;
1152
1153        ZaiProvider::append_stream_tool_call_chunks(
1154            &mut chunks,
1155            &[ZaiStreamToolCall {
1156                index: Some(0),
1157                id: Some("call_1".to_string()),
1158                function: Some(ZaiStreamFunction {
1159                    name: Some("bash".to_string()),
1160                    arguments: Some(Value::String("{\"".to_string())),
1161                }),
1162            }],
1163            &mut tool_states,
1164            &mut next_fallback_tool_index,
1165            &mut last_seen_tool_index,
1166        );
1167
1168        ZaiProvider::append_stream_tool_call_chunks(
1169            &mut chunks,
1170            &[ZaiStreamToolCall {
1171                index: Some(0),
1172                id: None,
1173                function: Some(ZaiStreamFunction {
1174                    name: None,
1175                    arguments: Some(Value::String("command\":\"pwd\"}".to_string())),
1176                }),
1177            }],
1178            &mut tool_states,
1179            &mut next_fallback_tool_index,
1180            &mut last_seen_tool_index,
1181        );
1182
1183        ZaiProvider::finish_stream_tool_call_chunks(&mut chunks, &mut tool_states);
1184
1185        assert_eq!(chunks.len(), 4);
1186        assert!(matches!(
1187            &chunks[0],
1188            StreamChunk::ToolCallStart { id, name }
1189                if id == "call_1" && name == "bash"
1190        ));
1191        assert!(matches!(
1192            &chunks[1],
1193            StreamChunk::ToolCallDelta { id, arguments_delta }
1194                if id == "call_1" && arguments_delta == "{\""
1195        ));
1196        assert!(matches!(
1197            &chunks[2],
1198            StreamChunk::ToolCallDelta { id, arguments_delta }
1199                if id == "call_1" && arguments_delta == "command\":\"pwd\"}"
1200        ));
1201        assert!(matches!(
1202            &chunks[3],
1203            StreamChunk::ToolCallEnd { id } if id == "call_1"
1204        ));
1205    }
1206
1207    #[test]
1208    fn preview_text_truncates_on_char_boundary() {
1209        let text = "a😀b";
1210        assert_eq!(ZaiProvider::preview_text(text, 2), "a😀");
1211    }
1212}