Skip to main content

codetether_agent/provider/
zai.rs

1//! Z.AI provider implementation (direct API)
2//!
3//! GLM-5, GLM-4.7, and other Z.AI foundation models via api.z.ai.
4//! Z.AI (formerly ZhipuAI) offers OpenAI-compatible chat completions with
5//! reasoning/thinking support via the `reasoning_content` field.
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use reqwest::Client;
15use serde::Deserialize;
16use serde_json::{Value, json};
17
18pub struct ZaiProvider {
19    client: Client,
20    api_key: String,
21    base_url: String,
22}
23
24impl std::fmt::Debug for ZaiProvider {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("ZaiProvider")
27            .field("base_url", &self.base_url)
28            .field("api_key", &"<REDACTED>")
29            .finish()
30    }
31}
32
33impl ZaiProvider {
34    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
35        tracing::debug!(
36            provider = "zai",
37            base_url = %base_url,
38            api_key_len = api_key.len(),
39            "Creating Z.AI provider with custom base URL"
40        );
41        Ok(Self {
42            client: Client::new(),
43            api_key,
44            base_url,
45        })
46    }
47
48    fn convert_messages(messages: &[Message]) -> Vec<Value> {
49        messages
50            .iter()
51            .map(|msg| {
52                let role = match msg.role {
53                    Role::System => "system",
54                    Role::User => "user",
55                    Role::Assistant => "assistant",
56                    Role::Tool => "tool",
57                };
58
59                match msg.role {
60                    Role::Tool => {
61                        if let Some(ContentPart::ToolResult {
62                            tool_call_id,
63                            content,
64                        }) = msg.content.first()
65                        {
66                            json!({
67                                "role": "tool",
68                                "tool_call_id": tool_call_id,
69                                "content": content
70                            })
71                        } else {
72                            json!({"role": role, "content": ""})
73                        }
74                    }
75                    Role::Assistant => {
76                        let text: String = msg
77                            .content
78                            .iter()
79                            .filter_map(|p| match p {
80                                ContentPart::Text { text } => Some(text.clone()),
81                                _ => None,
82                            })
83                            .collect::<Vec<_>>()
84                            .join("");
85
86                        // Preserve reasoning_content for interleaved/preserved thinking
87                        let reasoning: String = msg
88                            .content
89                            .iter()
90                            .filter_map(|p| match p {
91                                ContentPart::Thinking { text } => Some(text.clone()),
92                                _ => None,
93                            })
94                            .collect::<Vec<_>>()
95                            .join("");
96
97                        let tool_calls: Vec<Value> = msg
98                            .content
99                            .iter()
100                            .filter_map(|p| match p {
101                                ContentPart::ToolCall {
102                                    id,
103                                    name,
104                                    arguments,
105                                } => {
106                                    // Z.AI request schema expects assistant.tool_calls[*].function.arguments
107                                    // to be a JSON-format string. Normalize to a valid JSON string.
108                                    let args_string = serde_json::from_str::<Value>(arguments)
109                                        .map(|parsed| {
110                                            serde_json::to_string(&parsed)
111                                                .unwrap_or_else(|_| "{}".to_string())
112                                        })
113                                        .unwrap_or_else(|_| {
114                                            json!({"input": arguments}).to_string()
115                                        });
116                                    Some(json!({
117                                        "id": id,
118                                        "type": "function",
119                                        "function": {
120                                            "name": name,
121                                            "arguments": args_string
122                                        }
123                                    }))
124                                }
125                                _ => None,
126                            })
127                            .collect();
128
129                        let mut msg_json = json!({
130                            "role": "assistant",
131                            "content": if text.is_empty() { "".to_string() } else { text },
132                        });
133                        // Return reasoning_content for preserved thinking (clear_thinking: false)
134                        if !reasoning.is_empty() {
135                            msg_json["reasoning_content"] = json!(reasoning);
136                        }
137                        if !tool_calls.is_empty() {
138                            msg_json["tool_calls"] = json!(tool_calls);
139                        }
140                        msg_json
141                    }
142                    _ => {
143                        let text: String = msg
144                            .content
145                            .iter()
146                            .filter_map(|p| match p {
147                                ContentPart::Text { text } => Some(text.clone()),
148                                _ => None,
149                            })
150                            .collect::<Vec<_>>()
151                            .join("\n");
152
153                        json!({"role": role, "content": text})
154                    }
155                }
156            })
157            .collect()
158    }
159
160    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
161        tools
162            .iter()
163            .map(|t| {
164                json!({
165                    "type": "function",
166                    "function": {
167                        "name": t.name,
168                        "description": t.description,
169                        "parameters": t.parameters
170                    }
171                })
172            })
173            .collect()
174    }
175}
176
177#[derive(Debug, Deserialize)]
178struct ZaiResponse {
179    choices: Vec<ZaiChoice>,
180    #[serde(default)]
181    usage: Option<ZaiUsage>,
182}
183
184#[derive(Debug, Deserialize)]
185struct ZaiChoice {
186    message: ZaiMessage,
187    #[serde(default)]
188    finish_reason: Option<String>,
189}
190
191#[derive(Debug, Deserialize)]
192struct ZaiMessage {
193    #[serde(default)]
194    content: Option<String>,
195    #[serde(default)]
196    tool_calls: Option<Vec<ZaiToolCall>>,
197    #[serde(default)]
198    reasoning_content: Option<String>,
199}
200
201#[derive(Debug, Deserialize)]
202struct ZaiToolCall {
203    id: String,
204    function: ZaiFunction,
205}
206
207#[derive(Debug, Deserialize)]
208struct ZaiFunction {
209    name: String,
210    arguments: Value,
211}
212
213#[derive(Debug, Deserialize)]
214struct ZaiUsage {
215    #[serde(default)]
216    prompt_tokens: usize,
217    #[serde(default)]
218    completion_tokens: usize,
219    #[serde(default)]
220    total_tokens: usize,
221    #[serde(default)]
222    prompt_tokens_details: Option<ZaiPromptTokensDetails>,
223}
224
225#[derive(Debug, Deserialize)]
226struct ZaiPromptTokensDetails {
227    #[serde(default)]
228    cached_tokens: usize,
229}
230
231#[derive(Debug, Deserialize)]
232struct ZaiError {
233    error: ZaiErrorDetail,
234}
235
236#[derive(Debug, Deserialize)]
237struct ZaiErrorDetail {
238    message: String,
239    #[serde(default, rename = "type")]
240    error_type: Option<String>,
241}
242
243// SSE stream types
244#[derive(Debug, Deserialize)]
245struct ZaiStreamResponse {
246    choices: Vec<ZaiStreamChoice>,
247}
248
249#[derive(Debug, Deserialize)]
250struct ZaiStreamChoice {
251    delta: ZaiStreamDelta,
252    #[serde(default)]
253    finish_reason: Option<String>,
254}
255
256#[derive(Debug, Deserialize)]
257struct ZaiStreamDelta {
258    #[serde(default)]
259    content: Option<String>,
260    #[serde(default)]
261    reasoning_content: Option<String>,
262    #[serde(default)]
263    tool_calls: Option<Vec<ZaiStreamToolCall>>,
264}
265
266#[derive(Debug, Deserialize)]
267struct ZaiStreamToolCall {
268    #[serde(default)]
269    id: Option<String>,
270    function: Option<ZaiStreamFunction>,
271}
272
273#[derive(Debug, Deserialize)]
274struct ZaiStreamFunction {
275    #[serde(default)]
276    name: Option<String>,
277    #[serde(default)]
278    arguments: Option<Value>,
279}
280
281#[async_trait]
282impl Provider for ZaiProvider {
283    fn name(&self) -> &str {
284        "zai"
285    }
286
287    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
288        Ok(vec![
289            ModelInfo {
290                id: "glm-5".to_string(),
291                name: "GLM-5".to_string(),
292                provider: "zai".to_string(),
293                context_window: 200_000,
294                max_output_tokens: Some(128_000),
295                supports_vision: false,
296                supports_tools: true,
297                supports_streaming: true,
298                input_cost_per_million: None,
299                output_cost_per_million: None,
300            },
301            ModelInfo {
302                id: "glm-4.7".to_string(),
303                name: "GLM-4.7".to_string(),
304                provider: "zai".to_string(),
305                context_window: 128_000,
306                max_output_tokens: Some(128_000),
307                supports_vision: false,
308                supports_tools: true,
309                supports_streaming: true,
310                input_cost_per_million: None,
311                output_cost_per_million: None,
312            },
313            ModelInfo {
314                id: "glm-4.7-flash".to_string(),
315                name: "GLM-4.7 Flash".to_string(),
316                provider: "zai".to_string(),
317                context_window: 128_000,
318                max_output_tokens: Some(128_000),
319                supports_vision: false,
320                supports_tools: true,
321                supports_streaming: true,
322                input_cost_per_million: None,
323                output_cost_per_million: None,
324            },
325            ModelInfo {
326                id: "glm-4.6".to_string(),
327                name: "GLM-4.6".to_string(),
328                provider: "zai".to_string(),
329                context_window: 128_000,
330                max_output_tokens: Some(128_000),
331                supports_vision: false,
332                supports_tools: true,
333                supports_streaming: true,
334                input_cost_per_million: None,
335                output_cost_per_million: None,
336            },
337            ModelInfo {
338                id: "glm-4.5".to_string(),
339                name: "GLM-4.5".to_string(),
340                provider: "zai".to_string(),
341                context_window: 128_000,
342                max_output_tokens: Some(96_000),
343                supports_vision: false,
344                supports_tools: true,
345                supports_streaming: true,
346                input_cost_per_million: None,
347                output_cost_per_million: None,
348            },
349        ])
350    }
351
352    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
353        let messages = Self::convert_messages(&request.messages);
354        let tools = Self::convert_tools(&request.tools);
355
356        // GLM-5 and GLM-4.7 default to temperature 1.0
357        let temperature = request.temperature.unwrap_or(1.0);
358
359        let mut body = json!({
360            "model": request.model,
361            "messages": messages,
362            "temperature": temperature,
363        });
364
365        // Enable thinking with preserved reasoning for coding/agent workflows.
366        // clear_thinking: false retains reasoning_content across turns for better
367        // coherence and improved cache hit rates.
368        body["thinking"] = json!({
369            "type": "enabled",
370            "clear_thinking": false
371        });
372
373        if !tools.is_empty() {
374            body["tools"] = json!(tools);
375        }
376        if let Some(max) = request.max_tokens {
377            body["max_tokens"] = json!(max);
378        }
379
380        tracing::debug!(model = %request.model, "Z.AI request");
381
382        let response = self
383            .client
384            .post(format!("{}/chat/completions", self.base_url))
385            .header("Authorization", format!("Bearer {}", self.api_key))
386            .header("Content-Type", "application/json")
387            .json(&body)
388            .send()
389            .await
390            .context("Failed to send request to Z.AI")?;
391
392        let status = response.status();
393        let text = response
394            .text()
395            .await
396            .context("Failed to read Z.AI response")?;
397
398        if !status.is_success() {
399            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
400                anyhow::bail!(
401                    "Z.AI API error: {} ({:?})",
402                    err.error.message,
403                    err.error.error_type
404                );
405            }
406            anyhow::bail!("Z.AI API error: {} {}", status, text);
407        }
408
409        let response: ZaiResponse = serde_json::from_str(&text).context(format!(
410            "Failed to parse Z.AI response: {}",
411            &text[..text.len().min(200)]
412        ))?;
413
414        let choice = response
415            .choices
416            .first()
417            .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
418
419        // Log thinking/reasoning content if present
420        if let Some(ref reasoning) = choice.message.reasoning_content {
421            if !reasoning.is_empty() {
422                tracing::info!(
423                    reasoning_len = reasoning.len(),
424                    "Z.AI reasoning content received"
425                );
426            }
427        }
428
429        let mut content = Vec::new();
430        let mut has_tool_calls = false;
431
432        // Emit thinking content as a Thinking part
433        if let Some(ref reasoning) = choice.message.reasoning_content {
434            if !reasoning.is_empty() {
435                content.push(ContentPart::Thinking {
436                    text: reasoning.clone(),
437                });
438            }
439        }
440
441        if let Some(text) = &choice.message.content {
442            if !text.is_empty() {
443                content.push(ContentPart::Text { text: text.clone() });
444            }
445        }
446
447        if let Some(tool_calls) = &choice.message.tool_calls {
448            has_tool_calls = !tool_calls.is_empty();
449            for tc in tool_calls {
450                // Z.AI returns arguments as an object; serialize to string for our ContentPart
451                let arguments = match &tc.function.arguments {
452                    Value::String(s) => s.clone(),
453                    other => serde_json::to_string(other).unwrap_or_default(),
454                };
455                content.push(ContentPart::ToolCall {
456                    id: tc.id.clone(),
457                    name: tc.function.name.clone(),
458                    arguments,
459                });
460            }
461        }
462
463        let finish_reason = if has_tool_calls {
464            FinishReason::ToolCalls
465        } else {
466            match choice.finish_reason.as_deref() {
467                Some("stop") => FinishReason::Stop,
468                Some("length") => FinishReason::Length,
469                Some("tool_calls") => FinishReason::ToolCalls,
470                Some("sensitive") => FinishReason::ContentFilter,
471                _ => FinishReason::Stop,
472            }
473        };
474
475        Ok(CompletionResponse {
476            message: Message {
477                role: Role::Assistant,
478                content,
479            },
480            usage: Usage {
481                prompt_tokens: response
482                    .usage
483                    .as_ref()
484                    .map(|u| u.prompt_tokens)
485                    .unwrap_or(0),
486                completion_tokens: response
487                    .usage
488                    .as_ref()
489                    .map(|u| u.completion_tokens)
490                    .unwrap_or(0),
491                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
492                cache_read_tokens: response
493                    .usage
494                    .as_ref()
495                    .and_then(|u| u.prompt_tokens_details.as_ref())
496                    .map(|d| d.cached_tokens)
497                    .filter(|&t| t > 0),
498                cache_write_tokens: None,
499            },
500            finish_reason,
501        })
502    }
503
504    async fn complete_stream(
505        &self,
506        request: CompletionRequest,
507    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
508        let messages = Self::convert_messages(&request.messages);
509        let tools = Self::convert_tools(&request.tools);
510
511        let temperature = request.temperature.unwrap_or(1.0);
512
513        let mut body = json!({
514            "model": request.model,
515            "messages": messages,
516            "temperature": temperature,
517            "stream": true,
518        });
519
520        body["thinking"] = json!({
521            "type": "enabled",
522            "clear_thinking": false
523        });
524
525        if !tools.is_empty() {
526            body["tools"] = json!(tools);
527            // Enable streaming tool calls (supported by glm-4.6+)
528            body["tool_stream"] = json!(true);
529        }
530        if let Some(max) = request.max_tokens {
531            body["max_tokens"] = json!(max);
532        }
533
534        tracing::debug!(model = %request.model, "Z.AI streaming request");
535
536        let response = self
537            .client
538            .post(format!("{}/chat/completions", self.base_url))
539            .header("Authorization", format!("Bearer {}", self.api_key))
540            .header("Content-Type", "application/json")
541            .json(&body)
542            .send()
543            .await
544            .context("Failed to send streaming request to Z.AI")?;
545
546        if !response.status().is_success() {
547            let status = response.status();
548            let text = response.text().await.unwrap_or_default();
549            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
550                anyhow::bail!(
551                    "Z.AI API error: {} ({:?})",
552                    err.error.message,
553                    err.error.error_type
554                );
555            }
556            anyhow::bail!("Z.AI streaming error: {} {}", status, text);
557        }
558
559        let stream = response.bytes_stream();
560        let mut buffer = String::new();
561
562        Ok(stream
563            .flat_map(move |chunk_result| {
564                let mut chunks: Vec<StreamChunk> = Vec::new();
565                match chunk_result {
566                    Ok(bytes) => {
567                        let text = String::from_utf8_lossy(&bytes);
568                        buffer.push_str(&text);
569
570                        let mut text_buf = String::new();
571                        while let Some(line_end) = buffer.find('\n') {
572                            let line = buffer[..line_end].trim().to_string();
573                            buffer = buffer[line_end + 1..].to_string();
574
575                            if line == "data: [DONE]" {
576                                if !text_buf.is_empty() {
577                                    chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
578                                }
579                                chunks.push(StreamChunk::Done { usage: None });
580                                continue;
581                            }
582                            if let Some(data) = line.strip_prefix("data: ") {
583                                if let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
584                                {
585                                    if let Some(choice) = parsed.choices.first() {
586                                        // Reasoning content streamed as text (prefixed for TUI rendering)
587                                        if let Some(ref reasoning) = choice.delta.reasoning_content
588                                        {
589                                            if !reasoning.is_empty() {
590                                                text_buf.push_str(reasoning);
591                                            }
592                                        }
593                                        if let Some(ref content) = choice.delta.content {
594                                            text_buf.push_str(content);
595                                        }
596                                        // Streaming tool calls
597                                        if let Some(ref tool_calls) = choice.delta.tool_calls {
598                                            if !text_buf.is_empty() {
599                                                chunks.push(StreamChunk::Text(std::mem::take(
600                                                    &mut text_buf,
601                                                )));
602                                            }
603                                            for tc in tool_calls {
604                                                if let Some(ref func) = tc.function {
605                                                    if let Some(ref name) = func.name {
606                                                        // New tool call starting
607                                                        chunks.push(StreamChunk::ToolCallStart {
608                                                            id: tc.id.clone().unwrap_or_default(),
609                                                            name: name.clone(),
610                                                        });
611                                                    }
612                                                    if let Some(ref args) = func.arguments {
613                                                        let delta = match args {
614                                                            Value::String(s) => s.clone(),
615                                                            other => serde_json::to_string(other)
616                                                                .unwrap_or_default(),
617                                                        };
618                                                        if !delta.is_empty() {
619                                                            chunks.push(
620                                                                StreamChunk::ToolCallDelta {
621                                                                    id: tc
622                                                                        .id
623                                                                        .clone()
624                                                                        .unwrap_or_default(),
625                                                                    arguments_delta: delta,
626                                                                },
627                                                            );
628                                                        }
629                                                    }
630                                                }
631                                            }
632                                        }
633                                        // finish_reason signals end of a tool call or completion
634                                        if let Some(ref reason) = choice.finish_reason {
635                                            if !text_buf.is_empty() {
636                                                chunks.push(StreamChunk::Text(std::mem::take(
637                                                    &mut text_buf,
638                                                )));
639                                            }
640                                            if reason == "tool_calls" {
641                                                // Emit ToolCallEnd for the last tool call
642                                                if let Some(ref tcs) = choice.delta.tool_calls {
643                                                    if let Some(tc) = tcs.last() {
644                                                        chunks.push(StreamChunk::ToolCallEnd {
645                                                            id: tc.id.clone().unwrap_or_default(),
646                                                        });
647                                                    }
648                                                }
649                                            }
650                                        }
651                                    }
652                                }
653                            }
654                        }
655                        if !text_buf.is_empty() {
656                            chunks.push(StreamChunk::Text(text_buf));
657                        }
658                    }
659                    Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
660                }
661                futures::stream::iter(chunks)
662            })
663            .boxed())
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670
671    #[test]
672    fn convert_messages_serializes_tool_arguments_as_json_string() {
673        let messages = vec![Message {
674            role: Role::Assistant,
675            content: vec![ContentPart::ToolCall {
676                id: "call_1".to_string(),
677                name: "get_weather".to_string(),
678                arguments: "{\"city\":\"Beijing\"}".to_string(),
679            }],
680        }];
681
682        let converted = ZaiProvider::convert_messages(&messages);
683        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
684            .as_str()
685            .expect("arguments must be a string");
686
687        assert_eq!(args, "{\"city\":\"Beijing\"}");
688    }
689
690    #[test]
691    fn convert_messages_wraps_invalid_tool_arguments_as_json_string() {
692        let messages = vec![Message {
693            role: Role::Assistant,
694            content: vec![ContentPart::ToolCall {
695                id: "call_1".to_string(),
696                name: "get_weather".to_string(),
697                arguments: "city=Beijing".to_string(),
698            }],
699        }];
700
701        let converted = ZaiProvider::convert_messages(&messages);
702        let args = converted[0]["tool_calls"][0]["function"]["arguments"]
703            .as_str()
704            .expect("arguments must be a string");
705        let parsed: Value = serde_json::from_str(args).expect("arguments must contain valid JSON");
706
707        assert_eq!(parsed, json!({"input": "city=Beijing"}));
708    }
709}