Skip to main content

codetether_agent/provider/
zai.rs

1//! Z.AI provider implementation (direct API)
2//!
3//! GLM-5, GLM-4.7, and other Z.AI foundation models via api.z.ai.
4//! Z.AI (formerly ZhipuAI) offers OpenAI-compatible chat completions with
5//! reasoning/thinking support via the `reasoning_content` field.
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use reqwest::Client;
15use serde::Deserialize;
16use serde_json::{Value, json};
17
18pub struct ZaiProvider {
19    client: Client,
20    api_key: String,
21    base_url: String,
22}
23
24impl std::fmt::Debug for ZaiProvider {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("ZaiProvider")
27            .field("base_url", &self.base_url)
28            .field("api_key", &"<REDACTED>")
29            .finish()
30    }
31}
32
33impl ZaiProvider {
34    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
35        tracing::debug!(
36            provider = "zai",
37            base_url = %base_url,
38            api_key_len = api_key.len(),
39            "Creating Z.AI provider with custom base URL"
40        );
41        Ok(Self {
42            client: Client::new(),
43            api_key,
44            base_url,
45        })
46    }
47
48    fn convert_messages(messages: &[Message]) -> Vec<Value> {
49        messages
50            .iter()
51            .map(|msg| {
52                let role = match msg.role {
53                    Role::System => "system",
54                    Role::User => "user",
55                    Role::Assistant => "assistant",
56                    Role::Tool => "tool",
57                };
58
59                match msg.role {
60                    Role::Tool => {
61                        if let Some(ContentPart::ToolResult {
62                            tool_call_id,
63                            content,
64                        }) = msg.content.first()
65                        {
66                            json!({
67                                "role": "tool",
68                                "tool_call_id": tool_call_id,
69                                "content": content
70                            })
71                        } else {
72                            json!({"role": role, "content": ""})
73                        }
74                    }
75                    Role::Assistant => {
76                        let text: String = msg
77                            .content
78                            .iter()
79                            .filter_map(|p| match p {
80                                ContentPart::Text { text } => Some(text.clone()),
81                                _ => None,
82                            })
83                            .collect::<Vec<_>>()
84                            .join("");
85
86                        // Preserve reasoning_content for interleaved/preserved thinking
87                        let reasoning: String = msg
88                            .content
89                            .iter()
90                            .filter_map(|p| match p {
91                                ContentPart::Thinking { text } => Some(text.clone()),
92                                _ => None,
93                            })
94                            .collect::<Vec<_>>()
95                            .join("");
96
97                        let tool_calls: Vec<Value> = msg
98                            .content
99                            .iter()
100                            .filter_map(|p| match p {
101                                ContentPart::ToolCall {
102                                    id,
103                                    name,
104                                    arguments,
105                                } => {
106                                    // Z.AI expects arguments as an object, but we store as string.
107                                    // Try to parse as JSON; fall back to wrapping in a string field.
108                                    let args_value = serde_json::from_str::<Value>(arguments)
109                                        .unwrap_or_else(|_| json!({"input": arguments}));
110                                    Some(json!({
111                                        "id": id,
112                                        "type": "function",
113                                        "function": {
114                                            "name": name,
115                                            "arguments": args_value
116                                        }
117                                    }))
118                                }
119                                _ => None,
120                            })
121                            .collect();
122
123                        let mut msg_json = json!({
124                            "role": "assistant",
125                            "content": if text.is_empty() { "".to_string() } else { text },
126                        });
127                        // Return reasoning_content for preserved thinking (clear_thinking: false)
128                        if !reasoning.is_empty() {
129                            msg_json["reasoning_content"] = json!(reasoning);
130                        }
131                        if !tool_calls.is_empty() {
132                            msg_json["tool_calls"] = json!(tool_calls);
133                        }
134                        msg_json
135                    }
136                    _ => {
137                        let text: String = msg
138                            .content
139                            .iter()
140                            .filter_map(|p| match p {
141                                ContentPart::Text { text } => Some(text.clone()),
142                                _ => None,
143                            })
144                            .collect::<Vec<_>>()
145                            .join("\n");
146
147                        json!({"role": role, "content": text})
148                    }
149                }
150            })
151            .collect()
152    }
153
154    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
155        tools
156            .iter()
157            .map(|t| {
158                json!({
159                    "type": "function",
160                    "function": {
161                        "name": t.name,
162                        "description": t.description,
163                        "parameters": t.parameters
164                    }
165                })
166            })
167            .collect()
168    }
169}
170
171#[derive(Debug, Deserialize)]
172struct ZaiResponse {
173    choices: Vec<ZaiChoice>,
174    #[serde(default)]
175    usage: Option<ZaiUsage>,
176}
177
178#[derive(Debug, Deserialize)]
179struct ZaiChoice {
180    message: ZaiMessage,
181    #[serde(default)]
182    finish_reason: Option<String>,
183}
184
185#[derive(Debug, Deserialize)]
186struct ZaiMessage {
187    #[serde(default)]
188    content: Option<String>,
189    #[serde(default)]
190    tool_calls: Option<Vec<ZaiToolCall>>,
191    #[serde(default)]
192    reasoning_content: Option<String>,
193}
194
195#[derive(Debug, Deserialize)]
196struct ZaiToolCall {
197    id: String,
198    function: ZaiFunction,
199}
200
201#[derive(Debug, Deserialize)]
202struct ZaiFunction {
203    name: String,
204    arguments: Value,
205}
206
207#[derive(Debug, Deserialize)]
208struct ZaiUsage {
209    #[serde(default)]
210    prompt_tokens: usize,
211    #[serde(default)]
212    completion_tokens: usize,
213    #[serde(default)]
214    total_tokens: usize,
215    #[serde(default)]
216    prompt_tokens_details: Option<ZaiPromptTokensDetails>,
217}
218
219#[derive(Debug, Deserialize)]
220struct ZaiPromptTokensDetails {
221    #[serde(default)]
222    cached_tokens: usize,
223}
224
225#[derive(Debug, Deserialize)]
226struct ZaiError {
227    error: ZaiErrorDetail,
228}
229
230#[derive(Debug, Deserialize)]
231struct ZaiErrorDetail {
232    message: String,
233    #[serde(default, rename = "type")]
234    error_type: Option<String>,
235}
236
237// SSE stream types
238#[derive(Debug, Deserialize)]
239struct ZaiStreamResponse {
240    choices: Vec<ZaiStreamChoice>,
241}
242
243#[derive(Debug, Deserialize)]
244struct ZaiStreamChoice {
245    delta: ZaiStreamDelta,
246    #[serde(default)]
247    finish_reason: Option<String>,
248}
249
250#[derive(Debug, Deserialize)]
251struct ZaiStreamDelta {
252    #[serde(default)]
253    content: Option<String>,
254    #[serde(default)]
255    reasoning_content: Option<String>,
256    #[serde(default)]
257    tool_calls: Option<Vec<ZaiStreamToolCall>>,
258}
259
260#[derive(Debug, Deserialize)]
261struct ZaiStreamToolCall {
262    #[serde(default)]
263    id: Option<String>,
264    function: Option<ZaiStreamFunction>,
265}
266
267#[derive(Debug, Deserialize)]
268struct ZaiStreamFunction {
269    #[serde(default)]
270    name: Option<String>,
271    #[serde(default)]
272    arguments: Option<Value>,
273}
274
275#[async_trait]
276impl Provider for ZaiProvider {
277    fn name(&self) -> &str {
278        "zai"
279    }
280
281    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282        Ok(vec![
283            ModelInfo {
284                id: "glm-5".to_string(),
285                name: "GLM-5".to_string(),
286                provider: "zai".to_string(),
287                context_window: 200_000,
288                max_output_tokens: Some(128_000),
289                supports_vision: false,
290                supports_tools: true,
291                supports_streaming: true,
292                input_cost_per_million: None,
293                output_cost_per_million: None,
294            },
295            ModelInfo {
296                id: "glm-4.7".to_string(),
297                name: "GLM-4.7".to_string(),
298                provider: "zai".to_string(),
299                context_window: 128_000,
300                max_output_tokens: Some(128_000),
301                supports_vision: false,
302                supports_tools: true,
303                supports_streaming: true,
304                input_cost_per_million: None,
305                output_cost_per_million: None,
306            },
307            ModelInfo {
308                id: "glm-4.7-flash".to_string(),
309                name: "GLM-4.7 Flash".to_string(),
310                provider: "zai".to_string(),
311                context_window: 128_000,
312                max_output_tokens: Some(128_000),
313                supports_vision: false,
314                supports_tools: true,
315                supports_streaming: true,
316                input_cost_per_million: None,
317                output_cost_per_million: None,
318            },
319            ModelInfo {
320                id: "glm-4.6".to_string(),
321                name: "GLM-4.6".to_string(),
322                provider: "zai".to_string(),
323                context_window: 128_000,
324                max_output_tokens: Some(128_000),
325                supports_vision: false,
326                supports_tools: true,
327                supports_streaming: true,
328                input_cost_per_million: None,
329                output_cost_per_million: None,
330            },
331            ModelInfo {
332                id: "glm-4.5".to_string(),
333                name: "GLM-4.5".to_string(),
334                provider: "zai".to_string(),
335                context_window: 128_000,
336                max_output_tokens: Some(96_000),
337                supports_vision: false,
338                supports_tools: true,
339                supports_streaming: true,
340                input_cost_per_million: None,
341                output_cost_per_million: None,
342            },
343        ])
344    }
345
346    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
347        let messages = Self::convert_messages(&request.messages);
348        let tools = Self::convert_tools(&request.tools);
349
350        // GLM-5 and GLM-4.7 default to temperature 1.0
351        let temperature = request.temperature.unwrap_or(1.0);
352
353        let mut body = json!({
354            "model": request.model,
355            "messages": messages,
356            "temperature": temperature,
357        });
358
359        // Enable thinking with preserved reasoning for coding/agent workflows.
360        // clear_thinking: false retains reasoning_content across turns for better
361        // coherence and improved cache hit rates.
362        body["thinking"] = json!({
363            "type": "enabled",
364            "clear_thinking": false
365        });
366
367        if !tools.is_empty() {
368            body["tools"] = json!(tools);
369        }
370        if let Some(max) = request.max_tokens {
371            body["max_tokens"] = json!(max);
372        }
373
374        tracing::debug!(model = %request.model, "Z.AI request");
375
376        let response = self
377            .client
378            .post(format!("{}/chat/completions", self.base_url))
379            .header("Authorization", format!("Bearer {}", self.api_key))
380            .header("Content-Type", "application/json")
381            .json(&body)
382            .send()
383            .await
384            .context("Failed to send request to Z.AI")?;
385
386        let status = response.status();
387        let text = response
388            .text()
389            .await
390            .context("Failed to read Z.AI response")?;
391
392        if !status.is_success() {
393            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
394                anyhow::bail!(
395                    "Z.AI API error: {} ({:?})",
396                    err.error.message,
397                    err.error.error_type
398                );
399            }
400            anyhow::bail!("Z.AI API error: {} {}", status, text);
401        }
402
403        let response: ZaiResponse = serde_json::from_str(&text).context(format!(
404            "Failed to parse Z.AI response: {}",
405            &text[..text.len().min(200)]
406        ))?;
407
408        let choice = response
409            .choices
410            .first()
411            .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
412
413        // Log thinking/reasoning content if present
414        if let Some(ref reasoning) = choice.message.reasoning_content {
415            if !reasoning.is_empty() {
416                tracing::info!(
417                    reasoning_len = reasoning.len(),
418                    "Z.AI reasoning content received"
419                );
420            }
421        }
422
423        let mut content = Vec::new();
424        let mut has_tool_calls = false;
425
426        // Emit thinking content as a Thinking part
427        if let Some(ref reasoning) = choice.message.reasoning_content {
428            if !reasoning.is_empty() {
429                content.push(ContentPart::Thinking {
430                    text: reasoning.clone(),
431                });
432            }
433        }
434
435        if let Some(text) = &choice.message.content {
436            if !text.is_empty() {
437                content.push(ContentPart::Text { text: text.clone() });
438            }
439        }
440
441        if let Some(tool_calls) = &choice.message.tool_calls {
442            has_tool_calls = !tool_calls.is_empty();
443            for tc in tool_calls {
444                // Z.AI returns arguments as an object; serialize to string for our ContentPart
445                let arguments = match &tc.function.arguments {
446                    Value::String(s) => s.clone(),
447                    other => serde_json::to_string(other).unwrap_or_default(),
448                };
449                content.push(ContentPart::ToolCall {
450                    id: tc.id.clone(),
451                    name: tc.function.name.clone(),
452                    arguments,
453                });
454            }
455        }
456
457        let finish_reason = if has_tool_calls {
458            FinishReason::ToolCalls
459        } else {
460            match choice.finish_reason.as_deref() {
461                Some("stop") => FinishReason::Stop,
462                Some("length") => FinishReason::Length,
463                Some("tool_calls") => FinishReason::ToolCalls,
464                Some("sensitive") => FinishReason::ContentFilter,
465                _ => FinishReason::Stop,
466            }
467        };
468
469        Ok(CompletionResponse {
470            message: Message {
471                role: Role::Assistant,
472                content,
473            },
474            usage: Usage {
475                prompt_tokens: response
476                    .usage
477                    .as_ref()
478                    .map(|u| u.prompt_tokens)
479                    .unwrap_or(0),
480                completion_tokens: response
481                    .usage
482                    .as_ref()
483                    .map(|u| u.completion_tokens)
484                    .unwrap_or(0),
485                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
486                cache_read_tokens: response
487                    .usage
488                    .as_ref()
489                    .and_then(|u| u.prompt_tokens_details.as_ref())
490                    .map(|d| d.cached_tokens)
491                    .filter(|&t| t > 0),
492                cache_write_tokens: None,
493            },
494            finish_reason,
495        })
496    }
497
498    async fn complete_stream(
499        &self,
500        request: CompletionRequest,
501    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
502        let messages = Self::convert_messages(&request.messages);
503        let tools = Self::convert_tools(&request.tools);
504
505        let temperature = request.temperature.unwrap_or(1.0);
506
507        let mut body = json!({
508            "model": request.model,
509            "messages": messages,
510            "temperature": temperature,
511            "stream": true,
512        });
513
514        body["thinking"] = json!({
515            "type": "enabled",
516            "clear_thinking": false
517        });
518
519        if !tools.is_empty() {
520            body["tools"] = json!(tools);
521            // Enable streaming tool calls (supported by glm-4.6+)
522            body["tool_stream"] = json!(true);
523        }
524        if let Some(max) = request.max_tokens {
525            body["max_tokens"] = json!(max);
526        }
527
528        tracing::debug!(model = %request.model, "Z.AI streaming request");
529
530        let response = self
531            .client
532            .post(format!("{}/chat/completions", self.base_url))
533            .header("Authorization", format!("Bearer {}", self.api_key))
534            .header("Content-Type", "application/json")
535            .json(&body)
536            .send()
537            .await
538            .context("Failed to send streaming request to Z.AI")?;
539
540        if !response.status().is_success() {
541            let status = response.status();
542            let text = response.text().await.unwrap_or_default();
543            if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
544                anyhow::bail!(
545                    "Z.AI API error: {} ({:?})",
546                    err.error.message,
547                    err.error.error_type
548                );
549            }
550            anyhow::bail!("Z.AI streaming error: {} {}", status, text);
551        }
552
553        let stream = response.bytes_stream();
554        let mut buffer = String::new();
555
556        Ok(stream
557            .flat_map(move |chunk_result| {
558                let mut chunks: Vec<StreamChunk> = Vec::new();
559                match chunk_result {
560                    Ok(bytes) => {
561                        let text = String::from_utf8_lossy(&bytes);
562                        buffer.push_str(&text);
563
564                        let mut text_buf = String::new();
565                        while let Some(line_end) = buffer.find('\n') {
566                            let line = buffer[..line_end].trim().to_string();
567                            buffer = buffer[line_end + 1..].to_string();
568
569                            if line == "data: [DONE]" {
570                                if !text_buf.is_empty() {
571                                    chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
572                                }
573                                chunks.push(StreamChunk::Done { usage: None });
574                                continue;
575                            }
576                            if let Some(data) = line.strip_prefix("data: ") {
577                                if let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
578                                {
579                                    if let Some(choice) = parsed.choices.first() {
580                                        // Reasoning content streamed as text (prefixed for TUI rendering)
581                                        if let Some(ref reasoning) = choice.delta.reasoning_content
582                                        {
583                                            if !reasoning.is_empty() {
584                                                text_buf.push_str(reasoning);
585                                            }
586                                        }
587                                        if let Some(ref content) = choice.delta.content {
588                                            text_buf.push_str(content);
589                                        }
590                                        // Streaming tool calls
591                                        if let Some(ref tool_calls) = choice.delta.tool_calls {
592                                            if !text_buf.is_empty() {
593                                                chunks.push(StreamChunk::Text(std::mem::take(
594                                                    &mut text_buf,
595                                                )));
596                                            }
597                                            for tc in tool_calls {
598                                                if let Some(ref func) = tc.function {
599                                                    if let Some(ref name) = func.name {
600                                                        // New tool call starting
601                                                        chunks.push(StreamChunk::ToolCallStart {
602                                                            id: tc.id.clone().unwrap_or_default(),
603                                                            name: name.clone(),
604                                                        });
605                                                    }
606                                                    if let Some(ref args) = func.arguments {
607                                                        let delta = match args {
608                                                            Value::String(s) => s.clone(),
609                                                            other => serde_json::to_string(other)
610                                                                .unwrap_or_default(),
611                                                        };
612                                                        if !delta.is_empty() {
613                                                            chunks.push(
614                                                                StreamChunk::ToolCallDelta {
615                                                                    id: tc
616                                                                        .id
617                                                                        .clone()
618                                                                        .unwrap_or_default(),
619                                                                    arguments_delta: delta,
620                                                                },
621                                                            );
622                                                        }
623                                                    }
624                                                }
625                                            }
626                                        }
627                                        // finish_reason signals end of a tool call or completion
628                                        if let Some(ref reason) = choice.finish_reason {
629                                            if !text_buf.is_empty() {
630                                                chunks.push(StreamChunk::Text(std::mem::take(
631                                                    &mut text_buf,
632                                                )));
633                                            }
634                                            if reason == "tool_calls" {
635                                                // Emit ToolCallEnd for the last tool call
636                                                if let Some(ref tcs) = choice.delta.tool_calls {
637                                                    if let Some(tc) = tcs.last() {
638                                                        chunks.push(StreamChunk::ToolCallEnd {
639                                                            id: tc.id.clone().unwrap_or_default(),
640                                                        });
641                                                    }
642                                                }
643                                            }
644                                        }
645                                    }
646                                }
647                            }
648                        }
649                        if !text_buf.is_empty() {
650                            chunks.push(StreamChunk::Text(text_buf));
651                        }
652                    }
653                    Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
654                }
655                futures::stream::iter(chunks)
656            })
657            .boxed())
658    }
659}