Skip to main content

codetether_agent/provider/
zai.rs

1//! Z.AI provider implementation (direct API)
2//!
3//! GLM-5, GLM-4.7, and other Z.AI foundation models via api.z.ai.
4//! Z.AI (formerly ZhipuAI) offers OpenAI-compatible chat completions with
5//! reasoning/thinking support via the `reasoning_content` field.
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use reqwest::Client;
15use serde::Deserialize;
16use serde_json::{Value, json};
17
18pub struct ZaiProvider {
19    client: Client,
20    api_key: String,
21    base_url: String,
22}
23
24impl std::fmt::Debug for ZaiProvider {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("ZaiProvider")
27            .field("base_url", &self.base_url)
28            .field("api_key", &"<REDACTED>")
29            .finish()
30    }
31}
32
33impl ZaiProvider {
34    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
35        tracing::debug!(
36            provider = "zai",
37            base_url = %base_url,
38            api_key_len = api_key.len(),
39            "Creating Z.AI provider with custom base URL"
40        );
41        Ok(Self {
42            client: Client::new(),
43            api_key,
44            base_url,
45        })
46    }
47
48    fn convert_messages(messages: &[Message], include_reasoning_content: bool) -> Vec<Value> {
49        messages
50            .iter()
51            .map(|msg| {
52                let role = match msg.role {
53                    Role::System => "system",
54                    Role::User => "user",
55                    Role::Assistant => "assistant",
56                    Role::Tool => "tool",
57                };
58
59                match msg.role {
60                    Role::Tool => {
61                        if let Some(ContentPart::ToolResult {
62                            tool_call_id,
63                            content,
64                        }) = msg.content.first()
65                        {
66                            json!({
67                                "role": "tool",
68                                "tool_call_id": tool_call_id,
69                                "content": content
70                            })
71                        } else {
72                            json!({"role": role, "content": ""})
73                        }
74                    }
75                    Role::Assistant => {
76                        let text: String = msg
77                            .content
78                            .iter()
79                            .filter_map(|p| match p {
80                                ContentPart::Text { text } => Some(text.clone()),
81                                _ => None,
82                            })
83                            .collect::<Vec<_>>()
84                            .join("");
85
86                        let tool_calls: Vec<Value> = msg
87                            .content
88                            .iter()
89                            .filter_map(|p| match p {
90                                ContentPart::ToolCall {
91                                    id,
92                                    name,
93                                    arguments,
94                                    ..
95                                } => {
96                                    // Z.AI request schema expects assistant.tool_calls[*].function.arguments
97                                    // to be a JSON-format string. Normalize to a valid JSON string.
98                                    let args_string = serde_json::from_str::<Value>(arguments)
99                                        .map(|parsed| {
100                                            serde_json::to_string(&parsed)
101                                                .unwrap_or_else(|_| "{}".to_string())
102                                        })
103                                        .unwrap_or_else(|_| {
104                                            json!({"input": arguments}).to_string()
105                                        });
106                                    Some(json!({
107                                        "id": id,
108                                        "type": "function",
109                                        "function": {
110                                            "name": name,
111                                            "arguments": args_string
112                                        }
113                                    }))
114                                }
115                                _ => None,
116                            })
117                            .collect();
118
119                        let mut msg_json = json!({
120                            "role": "assistant",
121                            "content": if text.is_empty() { "".to_string() } else { text },
122                        });
123                        if include_reasoning_content {
124                            let reasoning: String = msg
125                                .content
126                                .iter()
127                                .filter_map(|p| match p {
128                                    ContentPart::Thinking { text } => Some(text.clone()),
129                                    _ => None,
130                                })
131                                .collect::<Vec<_>>()
132                                .join("");
133                            if !reasoning.is_empty() {
134                                msg_json["reasoning_content"] = json!(reasoning);
135                            }
136                        }
137                        if !tool_calls.is_empty() {
138                            msg_json["tool_calls"] = json!(tool_calls);
139                        }
140                        msg_json
141                    }
142                    _ => {
143                        let text: String = msg
144                            .content
145                            .iter()
146                            .filter_map(|p| match p {
147                                ContentPart::Text { text } => Some(text.clone()),
148                                _ => None,
149                            })
150                            .collect::<Vec<_>>()
151                            .join("\n");
152
153                        json!({"role": role, "content": text})
154                    }
155                }
156            })
157            .collect()
158    }
159
160    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
161        tools
162            .iter()
163            .map(|t| {
164                json!({
165                    "type": "function",
166                    "function": {
167                        "name": t.name,
168                        "description": t.description,
169                        "parameters": t.parameters
170                    }
171                })
172            })
173            .collect()
174    }
175
176    fn model_supports_tool_stream(model: &str) -> bool {
177        model.contains("glm-5") || model.contains("glm-4.7") || model.contains("glm-4.6")
178    }
179
180    fn preview_text(text: &str, max_chars: usize) -> &str {
181        if max_chars == 0 {
182            return "";
183        }
184        if let Some((idx, _)) = text.char_indices().nth(max_chars) {
185            &text[..idx]
186        } else {
187            text
188        }
189    }
190}
191
192#[derive(Debug, Deserialize)]
193struct ZaiResponse {
194    choices: Vec<ZaiChoice>,
195    #[serde(default)]
196    usage: Option<ZaiUsage>,
197}
198
199#[derive(Debug, Deserialize)]
200struct ZaiChoice {
201    message: ZaiMessage,
202    #[serde(default)]
203    finish_reason: Option<String>,
204}
205
206#[derive(Debug, Deserialize)]
207struct ZaiMessage {
208    #[serde(default)]
209    content: Option<String>,
210    #[serde(default)]
211    tool_calls: Option<Vec<ZaiToolCall>>,
212    #[serde(default)]
213    reasoning_content: Option<String>,
214}
215
216#[derive(Debug, Deserialize)]
217struct ZaiToolCall {
218    id: String,
219    function: ZaiFunction,
220}
221
222#[derive(Debug, Deserialize)]
223struct ZaiFunction {
224    name: String,
225    arguments: Value,
226}
227
228#[derive(Debug, Deserialize)]
229struct ZaiUsage {
230    #[serde(default)]
231    prompt_tokens: usize,
232    #[serde(default)]
233    completion_tokens: usize,
234    #[serde(default)]
235    total_tokens: usize,
236    #[serde(default)]
237    prompt_tokens_details: Option<ZaiPromptTokensDetails>,
238}
239
240#[derive(Debug, Deserialize)]
241struct ZaiPromptTokensDetails {
242    #[serde(default)]
243    cached_tokens: usize,
244}
245
246#[derive(Debug, Deserialize)]
247struct ZaiError {
248    error: ZaiErrorDetail,
249}
250
251#[derive(Debug, Deserialize)]
252struct ZaiErrorDetail {
253    message: String,
254    #[serde(default, rename = "type")]
255    error_type: Option<String>,
256}
257
258// SSE stream types
259#[derive(Debug, Deserialize)]
260struct ZaiStreamResponse {
261    choices: Vec<ZaiStreamChoice>,
262}
263
264#[derive(Debug, Deserialize)]
265struct ZaiStreamChoice {
266    delta: ZaiStreamDelta,
267    #[serde(default)]
268    finish_reason: Option<String>,
269}
270
271#[derive(Debug, Deserialize)]
272struct ZaiStreamDelta {
273    #[serde(default)]
274    content: Option<String>,
275    #[serde(default)]
276    reasoning_content: Option<String>,
277    #[serde(default)]
278    tool_calls: Option<Vec<ZaiStreamToolCall>>,
279}
280
281#[derive(Debug, Deserialize)]
282struct ZaiStreamToolCall {
283    #[serde(default)]
284    id: Option<String>,
285    function: Option<ZaiStreamFunction>,
286}
287
288#[derive(Debug, Deserialize)]
289struct ZaiStreamFunction {
290    #[serde(default)]
291    name: Option<String>,
292    #[serde(default)]
293    arguments: Option<Value>,
294}
295
296#[async_trait]
297impl Provider for ZaiProvider {
298    fn name(&self) -> &str {
299        "zai"
300    }
301
302    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
303        Ok(vec![
304            ModelInfo {
305                id: "glm-5".to_string(),
306                name: "GLM-5".to_string(),
307                provider: "zai".to_string(),
308                context_window: 200_000,
309                max_output_tokens: Some(128_000),
310                supports_vision: false,
311                supports_tools: true,
312                supports_streaming: true,
313                input_cost_per_million: None,
314                output_cost_per_million: None,
315            },
316            ModelInfo {
317                id: "glm-4.7".to_string(),
318                name: "GLM-4.7".to_string(),
319                provider: "zai".to_string(),
320                context_window: 128_000,
321                max_output_tokens: Some(128_000),
322                supports_vision: false,
323                supports_tools: true,
324                supports_streaming: true,
325                input_cost_per_million: None,
326                output_cost_per_million: None,
327            },
328            ModelInfo {
329                id: "glm-4.7-flash".to_string(),
330                name: "GLM-4.7 Flash".to_string(),
331                provider: "zai".to_string(),
332                context_window: 128_000,
333                max_output_tokens: Some(128_000),
334                supports_vision: false,
335                supports_tools: true,
336                supports_streaming: true,
337                input_cost_per_million: None,
338                output_cost_per_million: None,
339            },
340            ModelInfo {
341                id: "glm-4.6".to_string(),
342                name: "GLM-4.6".to_string(),
343                provider: "zai".to_string(),
344                context_window: 128_000,
345                max_output_tokens: Some(128_000),
346                supports_vision: false,
347                supports_tools: true,
348                supports_streaming: true,
349                input_cost_per_million: None,
350                output_cost_per_million: None,
351            },
352            ModelInfo {
353                id: "glm-4.5".to_string(),
354                name: "GLM-4.5".to_string(),
355                provider: "zai".to_string(),
356                context_window: 128_000,
357                max_output_tokens: Some(96_000),
358                supports_vision: false,
359                supports_tools: true,
360                supports_streaming: true,
361                input_cost_per_million: None,
362                output_cost_per_million: None,
363            },
364        ])
365    }
366
367    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
368        // Compatibility-first mode: omit historical reasoning_content from
369        // request messages to avoid strict parameter validation errors on some
370        // endpoint variants.
371        let messages = Self::convert_messages(&request.messages, false);
372        let tools = Self::convert_tools(&request.tools);
373
374        // GLM-5 and GLM-4.7 default to temperature 1.0
375        let temperature = request.temperature.unwrap_or(1.0);
376
377        let mut body = json!({
378            "model": request.model,
379            "messages": messages,
380            "temperature": temperature,
381        });
382
383        // Keep thinking enabled, but avoid provider-specific sub-fields that
384        // may be rejected by stricter API variants.
385        body["thinking"] = json!({
386            "type": "enabled"
387        });
388
389        if !tools.is_empty() {
390            body["tools"] = json!(tools);
391        }
392        if let Some(max) = request.max_tokens {
393            body["max_tokens"] = json!(max);
394        }
395
396        tracing::debug!(model = %request.model, "Z.AI request");
397
398        let response = self
399            .client
400            .post(format!("{}/chat/completions", self.base_url))
401            .header("Authorization", format!("Bearer {}", self.api_key))
402            .header("Content-Type", "application/json")
403            .json(&body)
404            .send()
405            .await
406            .context("Failed to send request to Z.AI")?;
407
408        let status = response.status();
409        let text = response
410            .text()
411            .await
412            .context("Failed to read Z.AI response")?;
413
414        if !status.is_success() {
415            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
416                anyhow::bail!(
417                    "Z.AI API error: {} ({:?})",
418                    err.error.message,
419                    err.error.error_type
420                );
421            }
422            anyhow::bail!("Z.AI API error: {} {}", status, text);
423        }
424
425        let response: ZaiResponse = serde_json::from_str(&text).context(format!(
426            "Failed to parse Z.AI response: {}",
427            Self::preview_text(&text, 200)
428        ))?;
429
430        let choice = response
431            .choices
432            .first()
433            .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
434
435        // Log thinking/reasoning content if present
436        if let Some(ref reasoning) = choice.message.reasoning_content {
437            if !reasoning.is_empty() {
438                tracing::info!(
439                    reasoning_len = reasoning.len(),
440                    "Z.AI reasoning content received"
441                );
442            }
443        }
444
445        let mut content = Vec::new();
446        let mut has_tool_calls = false;
447
448        // Emit thinking content as a Thinking part
449        if let Some(ref reasoning) = choice.message.reasoning_content {
450            if !reasoning.is_empty() {
451                content.push(ContentPart::Thinking {
452                    text: reasoning.clone(),
453                });
454            }
455        }
456
457        if let Some(text) = &choice.message.content {
458            if !text.is_empty() {
459                content.push(ContentPart::Text { text: text.clone() });
460            }
461        }
462
463        if let Some(tool_calls) = &choice.message.tool_calls {
464            has_tool_calls = !tool_calls.is_empty();
465            for tc in tool_calls {
466                // Z.AI returns arguments as an object; serialize to string for our ContentPart
467                let arguments = match &tc.function.arguments {
468                    Value::String(s) => s.clone(),
469                    other => serde_json::to_string(other).unwrap_or_default(),
470                };
471                content.push(ContentPart::ToolCall {
472                    id: tc.id.clone(),
473                    name: tc.function.name.clone(),
474                    arguments,
475                    thought_signature: None,
476                });
477            }
478        }
479
480        let finish_reason = if has_tool_calls {
481            FinishReason::ToolCalls
482        } else {
483            match choice.finish_reason.as_deref() {
484                Some("stop") => FinishReason::Stop,
485                Some("length") => FinishReason::Length,
486                Some("tool_calls") => FinishReason::ToolCalls,
487                Some("sensitive") => FinishReason::ContentFilter,
488                _ => FinishReason::Stop,
489            }
490        };
491
492        Ok(CompletionResponse {
493            message: Message {
494                role: Role::Assistant,
495                content,
496            },
497            usage: Usage {
498                prompt_tokens: response
499                    .usage
500                    .as_ref()
501                    .map(|u| u.prompt_tokens)
502                    .unwrap_or(0),
503                completion_tokens: response
504                    .usage
505                    .as_ref()
506                    .map(|u| u.completion_tokens)
507                    .unwrap_or(0),
508                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
509                cache_read_tokens: response
510                    .usage
511                    .as_ref()
512                    .and_then(|u| u.prompt_tokens_details.as_ref())
513                    .map(|d| d.cached_tokens)
514                    .filter(|&t| t > 0),
515                cache_write_tokens: None,
516            },
517            finish_reason,
518        })
519    }
520
521    async fn complete_stream(
522        &self,
523        request: CompletionRequest,
524    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
525        // Compatibility-first mode: omit historical reasoning_content from
526        // request messages to avoid strict parameter validation errors on some
527        // endpoint variants.
528        let messages = Self::convert_messages(&request.messages, false);
529        let tools = Self::convert_tools(&request.tools);
530
531        let temperature = request.temperature.unwrap_or(1.0);
532
533        let mut body = json!({
534            "model": request.model,
535            "messages": messages,
536            "temperature": temperature,
537            "stream": true,
538        });
539
540        body["thinking"] = json!({
541            "type": "enabled"
542        });
543
544        if !tools.is_empty() {
545            body["tools"] = json!(tools);
546            if Self::model_supports_tool_stream(&request.model) {
547                // Enable streaming tool calls only on known-compatible models.
548                body["tool_stream"] = json!(true);
549            }
550        }
551        if let Some(max) = request.max_tokens {
552            body["max_tokens"] = json!(max);
553        }
554
555        tracing::debug!(model = %request.model, "Z.AI streaming request");
556
557        let response = self
558            .client
559            .post(format!("{}/chat/completions", self.base_url))
560            .header("Authorization", format!("Bearer {}", self.api_key))
561            .header("Content-Type", "application/json")
562            .json(&body)
563            .send()
564            .await
565            .context("Failed to send streaming request to Z.AI")?;
566
567        if !response.status().is_success() {
568            let status = response.status();
569            let text = response.text().await.unwrap_or_default();
570            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
571                anyhow::bail!(
572                    "Z.AI API error: {} ({:?})",
573                    err.error.message,
574                    err.error.error_type
575                );
576            }
577            anyhow::bail!("Z.AI streaming error: {} {}", status, text);
578        }
579
580        let stream = response.bytes_stream();
581        let mut buffer = String::new();
582
583        Ok(stream
584            .flat_map(move |chunk_result| {
585                let mut chunks: Vec<StreamChunk> = Vec::new();
586                match chunk_result {
587                    Ok(bytes) => {
588                        let text = String::from_utf8_lossy(&bytes);
589                        buffer.push_str(&text);
590
591                        let mut text_buf = String::new();
592                        while let Some(line_end) = buffer.find('\n') {
593                            let line = buffer[..line_end].trim().to_string();
594                            buffer = buffer[line_end + 1..].to_string();
595
596                            if line == "data: [DONE]" {
597                                if !text_buf.is_empty() {
598                                    chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
599                                }
600                                chunks.push(StreamChunk::Done { usage: None });
601                                continue;
602                            }
603                            if let Some(data) = line.strip_prefix("data: ") {
604                                if let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
605                                {
606                                    if let Some(choice) = parsed.choices.first() {
607                                        // Reasoning content streamed as text (prefixed for TUI rendering)
608                                        if let Some(ref reasoning) = choice.delta.reasoning_content
609                                        {
610                                            if !reasoning.is_empty() {
611                                                text_buf.push_str(reasoning);
612                                            }
613                                        }
614                                        if let Some(ref content) = choice.delta.content {
615                                            text_buf.push_str(content);
616                                        }
617                                        // Streaming tool calls
618                                        if let Some(ref tool_calls) = choice.delta.tool_calls {
619                                            if !text_buf.is_empty() {
620                                                chunks.push(StreamChunk::Text(std::mem::take(
621                                                    &mut text_buf,
622                                                )));
623                                            }
624                                            for tc in tool_calls {
625                                                if let Some(ref func) = tc.function {
626                                                    if let Some(ref name) = func.name {
627                                                        // New tool call starting
628                                                        chunks.push(StreamChunk::ToolCallStart {
629                                                            id: tc.id.clone().unwrap_or_default(),
630                                                            name: name.clone(),
631                                                        });
632                                                    }
633                                                    if let Some(ref args) = func.arguments {
634                                                        let delta = match args {
635                                                            Value::String(s) => s.clone(),
636                                                            other => serde_json::to_string(other)
637                                                                .unwrap_or_default(),
638                                                        };
639                                                        if !delta.is_empty() {
640                                                            chunks.push(
641                                                                StreamChunk::ToolCallDelta {
642                                                                    id: tc
643                                                                        .id
644                                                                        .clone()
645                                                                        .unwrap_or_default(),
646                                                                    arguments_delta: delta,
647                                                                },
648                                                            );
649                                                        }
650                                                    }
651                                                }
652                                            }
653                                        }
654                                        // finish_reason signals end of a tool call or completion
655                                        if let Some(ref reason) = choice.finish_reason {
656                                            if !text_buf.is_empty() {
657                                                chunks.push(StreamChunk::Text(std::mem::take(
658                                                    &mut text_buf,
659                                                )));
660                                            }
661                                            if reason == "tool_calls" {
662                                                // Emit ToolCallEnd for the last tool call
663                                                if let Some(ref tcs) = choice.delta.tool_calls {
664                                                    if let Some(tc) = tcs.last() {
665                                                        chunks.push(StreamChunk::ToolCallEnd {
666                                                            id: tc.id.clone().unwrap_or_default(),
667                                                        });
668                                                    }
669                                                }
670                                            }
671                                        }
672                                    }
673                                }
674                            }
675                        }
676                        if !text_buf.is_empty() {
677                            chunks.push(StreamChunk::Text(text_buf));
678                        }
679                    }
680                    Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
681                }
682                futures::stream::iter(chunks)
683            })
684            .boxed())
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691
692    #[test]
693    fn convert_messages_serializes_tool_arguments_as_json_string() {
694        let messages = vec![Message {
695            role: Role::Assistant,
696            content: vec![ContentPart::ToolCall {
697                id: "call_1".to_string(),
698                name: "get_weather".to_string(),
699                arguments: "{\"city\":\"Beijing\".. }".to_string(),
700                thought_signature: None,
701            }],
702        }];
703
704        let converted = ZaiProvider::convert_messages(&messages, true);
705        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
706            .as_str()
707            .expect("arguments must be a string");
708
709        assert_eq!(args, "{\"city\":\"Beijing\"}");
710    }
711
712    #[test]
713    fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
714        let messages = vec![Message {
715            role: Role::Assistant,
716            content: vec![ContentPart::ToolCall {
717                id: "call_1".to_string(),
718                name: "get_weather".to_string(),
719                arguments: "city=Beijing".to_string(),
720                thought_signature: None,
721            }],
722        }];
723
724        let converted = ZaiProvider::convert_messages(&messages, true);
725        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
726            .as_str()
727            .expect("arguments must be a string");
728        let parsed: Value = serde_json::from_str(args).expect("arguments must contain valid JSON");
729
730        assert_eq!(parsed, json!({"input": "city=Beijing"}));
731    }
732
733    #[test]
734    fn preview_text_truncates_on_char_boundary() {
735        let text = "a😀b";
736        assert_eq!(ZaiProvider::preview_text(text, 2), "a😀");
737    }
738}