Skip to main content

codetether_agent/provider/
zai.rs

1//! Z.AI provider implementation (direct API)
2//!
3//! GLM-5, GLM-4.7, and other Z.AI foundation models via api.z.ai.
4//! Z.AI (formerly ZhipuAI) offers OpenAI-compatible chat completions with
5//! reasoning/thinking support via the `reasoning_content` field.
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use reqwest::Client;
15use serde::Deserialize;
16use serde_json::{Value, json};
17
18pub struct ZaiProvider {
19    client: Client,
20    api_key: String,
21    base_url: String,
22}
23
24impl std::fmt::Debug for ZaiProvider {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("ZaiProvider")
27            .field("base_url", &self.base_url)
28            .field("api_key", &"<REDACTED>")
29            .finish()
30    }
31}
32
33impl ZaiProvider {
34    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
35        tracing::debug!(
36            provider = "zai",
37            base_url = %base_url,
38            api_key_len = api_key.len(),
39            "Creating Z.AI provider with custom base URL"
40        );
41        Ok(Self {
42            client: Client::new(),
43            api_key,
44            base_url,
45        })
46    }
47
48    fn convert_messages(messages: &[Message], include_reasoning_content: bool) -> Vec<Value> {
49        messages
50            .iter()
51            .map(|msg| {
52                let role = match msg.role {
53                    Role::System => "system",
54                    Role::User => "user",
55                    Role::Assistant => "assistant",
56                    Role::Tool => "tool",
57                };
58
59                match msg.role {
60                    Role::Tool => {
61                        if let Some(ContentPart::ToolResult {
62                            tool_call_id,
63                            content,
64                        }) = msg.content.first()
65                        {
66                            json!({
67                                "role": "tool",
68                                "tool_call_id": tool_call_id,
69                                "content": content
70                            })
71                        } else {
72                            json!({"role": role, "content": ""})
73                        }
74                    }
75                    Role::Assistant => {
76                        let text: String = msg
77                            .content
78                            .iter()
79                            .filter_map(|p| match p {
80                                ContentPart::Text { text } => Some(text.clone()),
81                                _ => None,
82                            })
83                            .collect::<Vec<_>>()
84                            .join("");
85
86                        let tool_calls: Vec<Value> = msg
87                            .content
88                            .iter()
89                            .filter_map(|p| match p {
90                                ContentPart::ToolCall {
91                                    id,
92                                    name,
93                                    arguments,
94                                } => {
95                                    // Z.AI request schema expects assistant.tool_calls[*].function.arguments
96                                    // to be a JSON-format string. Normalize to a valid JSON string.
97                                    let args_string = serde_json::from_str::<Value>(arguments)
98                                        .map(|parsed| {
99                                            serde_json::to_string(&parsed)
100                                                .unwrap_or_else(|_| "{}".to_string())
101                                        })
102                                        .unwrap_or_else(|_| {
103                                            json!({"input": arguments}).to_string()
104                                        });
105                                    Some(json!({
106                                        "id": id,
107                                        "type": "function",
108                                        "function": {
109                                            "name": name,
110                                            "arguments": args_string
111                                        }
112                                    }))
113                                }
114                                _ => None,
115                            })
116                            .collect();
117
118                        let mut msg_json = json!({
119                            "role": "assistant",
120                            "content": if text.is_empty() { "".to_string() } else { text },
121                        });
122                        if include_reasoning_content {
123                            let reasoning: String = msg
124                                .content
125                                .iter()
126                                .filter_map(|p| match p {
127                                    ContentPart::Thinking { text } => Some(text.clone()),
128                                    _ => None,
129                                })
130                                .collect::<Vec<_>>()
131                                .join("");
132                            if !reasoning.is_empty() {
133                                msg_json["reasoning_content"] = json!(reasoning);
134                            }
135                        }
136                        if !tool_calls.is_empty() {
137                            msg_json["tool_calls"] = json!(tool_calls);
138                        }
139                        msg_json
140                    }
141                    _ => {
142                        let text: String = msg
143                            .content
144                            .iter()
145                            .filter_map(|p| match p {
146                                ContentPart::Text { text } => Some(text.clone()),
147                                _ => None,
148                            })
149                            .collect::<Vec<_>>()
150                            .join("\n");
151
152                        json!({"role": role, "content": text})
153                    }
154                }
155            })
156            .collect()
157    }
158
159    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
160        tools
161            .iter()
162            .map(|t| {
163                json!({
164                    "type": "function",
165                    "function": {
166                        "name": t.name,
167                        "description": t.description,
168                        "parameters": t.parameters
169                    }
170                })
171            })
172            .collect()
173    }
174
175    fn model_supports_tool_stream(model: &str) -> bool {
176        model.contains("glm-5") || model.contains("glm-4.7") || model.contains("glm-4.6")
177    }
178}
179
180#[derive(Debug, Deserialize)]
181struct ZaiResponse {
182    choices: Vec<ZaiChoice>,
183    #[serde(default)]
184    usage: Option<ZaiUsage>,
185}
186
187#[derive(Debug, Deserialize)]
188struct ZaiChoice {
189    message: ZaiMessage,
190    #[serde(default)]
191    finish_reason: Option<String>,
192}
193
194#[derive(Debug, Deserialize)]
195struct ZaiMessage {
196    #[serde(default)]
197    content: Option<String>,
198    #[serde(default)]
199    tool_calls: Option<Vec<ZaiToolCall>>,
200    #[serde(default)]
201    reasoning_content: Option<String>,
202}
203
204#[derive(Debug, Deserialize)]
205struct ZaiToolCall {
206    id: String,
207    function: ZaiFunction,
208}
209
210#[derive(Debug, Deserialize)]
211struct ZaiFunction {
212    name: String,
213    arguments: Value,
214}
215
216#[derive(Debug, Deserialize)]
217struct ZaiUsage {
218    #[serde(default)]
219    prompt_tokens: usize,
220    #[serde(default)]
221    completion_tokens: usize,
222    #[serde(default)]
223    total_tokens: usize,
224    #[serde(default)]
225    prompt_tokens_details: Option<ZaiPromptTokensDetails>,
226}
227
228#[derive(Debug, Deserialize)]
229struct ZaiPromptTokensDetails {
230    #[serde(default)]
231    cached_tokens: usize,
232}
233
234#[derive(Debug, Deserialize)]
235struct ZaiError {
236    error: ZaiErrorDetail,
237}
238
239#[derive(Debug, Deserialize)]
240struct ZaiErrorDetail {
241    message: String,
242    #[serde(default, rename = "type")]
243    error_type: Option<String>,
244}
245
246// SSE stream types
247#[derive(Debug, Deserialize)]
248struct ZaiStreamResponse {
249    choices: Vec<ZaiStreamChoice>,
250}
251
252#[derive(Debug, Deserialize)]
253struct ZaiStreamChoice {
254    delta: ZaiStreamDelta,
255    #[serde(default)]
256    finish_reason: Option<String>,
257}
258
259#[derive(Debug, Deserialize)]
260struct ZaiStreamDelta {
261    #[serde(default)]
262    content: Option<String>,
263    #[serde(default)]
264    reasoning_content: Option<String>,
265    #[serde(default)]
266    tool_calls: Option<Vec<ZaiStreamToolCall>>,
267}
268
269#[derive(Debug, Deserialize)]
270struct ZaiStreamToolCall {
271    #[serde(default)]
272    id: Option<String>,
273    function: Option<ZaiStreamFunction>,
274}
275
276#[derive(Debug, Deserialize)]
277struct ZaiStreamFunction {
278    #[serde(default)]
279    name: Option<String>,
280    #[serde(default)]
281    arguments: Option<Value>,
282}
283
284#[async_trait]
285impl Provider for ZaiProvider {
286    fn name(&self) -> &str {
287        "zai"
288    }
289
290    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
291        Ok(vec![
292            ModelInfo {
293                id: "glm-5".to_string(),
294                name: "GLM-5".to_string(),
295                provider: "zai".to_string(),
296                context_window: 200_000,
297                max_output_tokens: Some(128_000),
298                supports_vision: false,
299                supports_tools: true,
300                supports_streaming: true,
301                input_cost_per_million: None,
302                output_cost_per_million: None,
303            },
304            ModelInfo {
305                id: "glm-4.7".to_string(),
306                name: "GLM-4.7".to_string(),
307                provider: "zai".to_string(),
308                context_window: 128_000,
309                max_output_tokens: Some(128_000),
310                supports_vision: false,
311                supports_tools: true,
312                supports_streaming: true,
313                input_cost_per_million: None,
314                output_cost_per_million: None,
315            },
316            ModelInfo {
317                id: "glm-4.7-flash".to_string(),
318                name: "GLM-4.7 Flash".to_string(),
319                provider: "zai".to_string(),
320                context_window: 128_000,
321                max_output_tokens: Some(128_000),
322                supports_vision: false,
323                supports_tools: true,
324                supports_streaming: true,
325                input_cost_per_million: None,
326                output_cost_per_million: None,
327            },
328            ModelInfo {
329                id: "glm-4.6".to_string(),
330                name: "GLM-4.6".to_string(),
331                provider: "zai".to_string(),
332                context_window: 128_000,
333                max_output_tokens: Some(128_000),
334                supports_vision: false,
335                supports_tools: true,
336                supports_streaming: true,
337                input_cost_per_million: None,
338                output_cost_per_million: None,
339            },
340            ModelInfo {
341                id: "glm-4.5".to_string(),
342                name: "GLM-4.5".to_string(),
343                provider: "zai".to_string(),
344                context_window: 128_000,
345                max_output_tokens: Some(96_000),
346                supports_vision: false,
347                supports_tools: true,
348                supports_streaming: true,
349                input_cost_per_million: None,
350                output_cost_per_million: None,
351            },
352        ])
353    }
354
355    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
356        // Compatibility-first mode: omit historical reasoning_content from
357        // request messages to avoid strict parameter validation errors on some
358        // endpoint variants.
359        let messages = Self::convert_messages(&request.messages, false);
360        let tools = Self::convert_tools(&request.tools);
361
362        // GLM-5 and GLM-4.7 default to temperature 1.0
363        let temperature = request.temperature.unwrap_or(1.0);
364
365        let mut body = json!({
366            "model": request.model,
367            "messages": messages,
368            "temperature": temperature,
369        });
370
371        // Keep thinking enabled, but avoid provider-specific sub-fields that
372        // may be rejected by stricter API variants.
373        body["thinking"] = json!({
374            "type": "enabled"
375        });
376
377        if !tools.is_empty() {
378            body["tools"] = json!(tools);
379        }
380        if let Some(max) = request.max_tokens {
381            body["max_tokens"] = json!(max);
382        }
383
384        tracing::debug!(model = %request.model, "Z.AI request");
385
386        let response = self
387            .client
388            .post(format!("{}/chat/completions", self.base_url))
389            .header("Authorization", format!("Bearer {}", self.api_key))
390            .header("Content-Type", "application/json")
391            .json(&body)
392            .send()
393            .await
394            .context("Failed to send request to Z.AI")?;
395
396        let status = response.status();
397        let text = response
398            .text()
399            .await
400            .context("Failed to read Z.AI response")?;
401
402        if !status.is_success() {
403            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
404                anyhow::bail!(
405                    "Z.AI API error: {} ({:?})",
406                    err.error.message,
407                    err.error.error_type
408                );
409            }
410            anyhow::bail!("Z.AI API error: {} {}", status, text);
411        }
412
413        let response: ZaiResponse = serde_json::from_str(&text).context(format!(
414            "Failed to parse Z.AI response: {}",
415            &text[..text.len().min(200)]
416        ))?;
417
418        let choice = response
419            .choices
420            .first()
421            .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
422
423        // Log thinking/reasoning content if present
424        if let Some(ref reasoning) = choice.message.reasoning_content {
425            if !reasoning.is_empty() {
426                tracing::info!(
427                    reasoning_len = reasoning.len(),
428                    "Z.AI reasoning content received"
429                );
430            }
431        }
432
433        let mut content = Vec::new();
434        let mut has_tool_calls = false;
435
436        // Emit thinking content as a Thinking part
437        if let Some(ref reasoning) = choice.message.reasoning_content {
438            if !reasoning.is_empty() {
439                content.push(ContentPart::Thinking {
440                    text: reasoning.clone(),
441                });
442            }
443        }
444
445        if let Some(text) = &choice.message.content {
446            if !text.is_empty() {
447                content.push(ContentPart::Text { text: text.clone() });
448            }
449        }
450
451        if let Some(tool_calls) = &choice.message.tool_calls {
452            has_tool_calls = !tool_calls.is_empty();
453            for tc in tool_calls {
454                // Z.AI returns arguments as an object; serialize to string for our ContentPart
455                let arguments = match &tc.function.arguments {
456                    Value::String(s) => s.clone(),
457                    other => serde_json::to_string(other).unwrap_or_default(),
458                };
459                content.push(ContentPart::ToolCall {
460                    id: tc.id.clone(),
461                    name: tc.function.name.clone(),
462                    arguments,
463                });
464            }
465        }
466
467        let finish_reason = if has_tool_calls {
468            FinishReason::ToolCalls
469        } else {
470            match choice.finish_reason.as_deref() {
471                Some("stop") => FinishReason::Stop,
472                Some("length") => FinishReason::Length,
473                Some("tool_calls") => FinishReason::ToolCalls,
474                Some("sensitive") => FinishReason::ContentFilter,
475                _ => FinishReason::Stop,
476            }
477        };
478
479        Ok(CompletionResponse {
480            message: Message {
481                role: Role::Assistant,
482                content,
483            },
484            usage: Usage {
485                prompt_tokens: response
486                    .usage
487                    .as_ref()
488                    .map(|u| u.prompt_tokens)
489                    .unwrap_or(0),
490                completion_tokens: response
491                    .usage
492                    .as_ref()
493                    .map(|u| u.completion_tokens)
494                    .unwrap_or(0),
495                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
496                cache_read_tokens: response
497                    .usage
498                    .as_ref()
499                    .and_then(|u| u.prompt_tokens_details.as_ref())
500                    .map(|d| d.cached_tokens)
501                    .filter(|&t| t > 0),
502                cache_write_tokens: None,
503            },
504            finish_reason,
505        })
506    }
507
508    async fn complete_stream(
509        &self,
510        request: CompletionRequest,
511    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
512        // Compatibility-first mode: omit historical reasoning_content from
513        // request messages to avoid strict parameter validation errors on some
514        // endpoint variants.
515        let messages = Self::convert_messages(&request.messages, false);
516        let tools = Self::convert_tools(&request.tools);
517
518        let temperature = request.temperature.unwrap_or(1.0);
519
520        let mut body = json!({
521            "model": request.model,
522            "messages": messages,
523            "temperature": temperature,
524            "stream": true,
525        });
526
527        body["thinking"] = json!({
528            "type": "enabled"
529        });
530
531        if !tools.is_empty() {
532            body["tools"] = json!(tools);
533            if Self::model_supports_tool_stream(&request.model) {
534                // Enable streaming tool calls only on known-compatible models.
535                body["tool_stream"] = json!(true);
536            }
537        }
538        if let Some(max) = request.max_tokens {
539            body["max_tokens"] = json!(max);
540        }
541
542        tracing::debug!(model = %request.model, "Z.AI streaming request");
543
544        let response = self
545            .client
546            .post(format!("{}/chat/completions", self.base_url))
547            .header("Authorization", format!("Bearer {}", self.api_key))
548            .header("Content-Type", "application/json")
549            .json(&body)
550            .send()
551            .await
552            .context("Failed to send streaming request to Z.AI")?;
553
554        if !response.status().is_success() {
555            let status = response.status();
556            let text = response.text().await.unwrap_or_default();
557            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
558                anyhow::bail!(
559                    "Z.AI API error: {} ({:?})",
560                    err.error.message,
561                    err.error.error_type
562                );
563            }
564            anyhow::bail!("Z.AI streaming error: {} {}", status, text);
565        }
566
567        let stream = response.bytes_stream();
568        let mut buffer = String::new();
569
570        Ok(stream
571            .flat_map(move |chunk_result| {
572                let mut chunks: Vec<StreamChunk> = Vec::new();
573                match chunk_result {
574                    Ok(bytes) => {
575                        let text = String::from_utf8_lossy(&bytes);
576                        buffer.push_str(&text);
577
578                        let mut text_buf = String::new();
579                        while let Some(line_end) = buffer.find('\n') {
580                            let line = buffer[..line_end].trim().to_string();
581                            buffer = buffer[line_end + 1..].to_string();
582
583                            if line == "data: [DONE]" {
584                                if !text_buf.is_empty() {
585                                    chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
586                                }
587                                chunks.push(StreamChunk::Done { usage: None });
588                                continue;
589                            }
590                            if let Some(data) = line.strip_prefix("data: ") {
591                                if let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
592                                {
593                                    if let Some(choice) = parsed.choices.first() {
594                                        // Reasoning content streamed as text (prefixed for TUI rendering)
595                                        if let Some(ref reasoning) = choice.delta.reasoning_content
596                                        {
597                                            if !reasoning.is_empty() {
598                                                text_buf.push_str(reasoning);
599                                            }
600                                        }
601                                        if let Some(ref content) = choice.delta.content {
602                                            text_buf.push_str(content);
603                                        }
604                                        // Streaming tool calls
605                                        if let Some(ref tool_calls) = choice.delta.tool_calls {
606                                            if !text_buf.is_empty() {
607                                                chunks.push(StreamChunk::Text(std::mem::take(
608                                                    &mut text_buf,
609                                                )));
610                                            }
611                                            for tc in tool_calls {
612                                                if let Some(ref func) = tc.function {
613                                                    if let Some(ref name) = func.name {
614                                                        // New tool call starting
615                                                        chunks.push(StreamChunk::ToolCallStart {
616                                                            id: tc.id.clone().unwrap_or_default(),
617                                                            name: name.clone(),
618                                                        });
619                                                    }
620                                                    if let Some(ref args) = func.arguments {
621                                                        let delta = match args {
622                                                            Value::String(s) => s.clone(),
623                                                            other => serde_json::to_string(other)
624                                                                .unwrap_or_default(),
625                                                        };
626                                                        if !delta.is_empty() {
627                                                            chunks.push(
628                                                                StreamChunk::ToolCallDelta {
629                                                                    id: tc
630                                                                        .id
631                                                                        .clone()
632                                                                        .unwrap_or_default(),
633                                                                    arguments_delta: delta,
634                                                                },
635                                                            );
636                                                        }
637                                                    }
638                                                }
639                                            }
640                                        }
641                                        // finish_reason signals end of a tool call or completion
642                                        if let Some(ref reason) = choice.finish_reason {
643                                            if !text_buf.is_empty() {
644                                                chunks.push(StreamChunk::Text(std::mem::take(
645                                                    &mut text_buf,
646                                                )));
647                                            }
648                                            if reason == "tool_calls" {
649                                                // Emit ToolCallEnd for the last tool call
650                                                if let Some(ref tcs) = choice.delta.tool_calls {
651                                                    if let Some(tc) = tcs.last() {
652                                                        chunks.push(StreamChunk::ToolCallEnd {
653                                                            id: tc.id.clone().unwrap_or_default(),
654                                                        });
655                                                    }
656                                                }
657                                            }
658                                        }
659                                    }
660                                }
661                            }
662                        }
663                        if !text_buf.is_empty() {
664                            chunks.push(StreamChunk::Text(text_buf));
665                        }
666                    }
667                    Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
668                }
669                futures::stream::iter(chunks)
670            })
671            .boxed())
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678
679    #[test]
680    fn convert_messages_serializes_tool_arguments_as_json_string() {
681        let messages = vec![Message {
682            role: Role::Assistant,
683            content: vec![ContentPart::ToolCall {
684                id: "call_1".to_string(),
685                name: "get_weather".to_string(),
686                arguments: "{\"city\":\"Beijing\"}".to_string(),
687            }],
688        }];
689
690        let converted = ZaiProvider::convert_messages(&messages, true);
691        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
692            .as_str()
693            .expect("arguments must be a string");
694
695        assert_eq!(args, "{\"city\":\"Beijing\"}");
696    }
697
698    #[test]
699    fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
700        let messages = vec![Message {
701            role: Role::Assistant,
702            content: vec![ContentPart::ToolCall {
703                id: "call_1".to_string(),
704                name: "get_weather".to_string(),
705                arguments: "city=Beijing".to_string(),
706            }],
707        }];
708
709        let converted = ZaiProvider::convert_messages(&messages, true);
710        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
711            .as_str()
712            .expect("arguments must be a string");
713        let parsed: Value = serde_json::from_str(args).expect("arguments must contain valid JSON");
714
715        assert_eq!(parsed, json!({"input": "city=Beijing"}));
716    }
717}