1use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use futures::StreamExt;
14use reqwest::Client;
15use serde::Deserialize;
16use serde_json::{Value, json};
17
18pub struct ZaiProvider {
19 client: Client,
20 api_key: String,
21 base_url: String,
22}
23
24impl std::fmt::Debug for ZaiProvider {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("ZaiProvider")
27 .field("base_url", &self.base_url)
28 .field("api_key", &"<REDACTED>")
29 .finish()
30 }
31}
32
33impl ZaiProvider {
34 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self> {
35 tracing::debug!(
36 provider = "zai",
37 base_url = %base_url,
38 api_key_len = api_key.len(),
39 "Creating Z.AI provider with custom base URL"
40 );
41 Ok(Self {
42 client: Client::new(),
43 api_key,
44 base_url,
45 })
46 }
47
48 fn convert_messages(messages: &[Message]) -> Vec<Value> {
49 messages
50 .iter()
51 .map(|msg| {
52 let role = match msg.role {
53 Role::System => "system",
54 Role::User => "user",
55 Role::Assistant => "assistant",
56 Role::Tool => "tool",
57 };
58
59 match msg.role {
60 Role::Tool => {
61 if let Some(ContentPart::ToolResult {
62 tool_call_id,
63 content,
64 }) = msg.content.first()
65 {
66 json!({
67 "role": "tool",
68 "tool_call_id": tool_call_id,
69 "content": content
70 })
71 } else {
72 json!({"role": role, "content": ""})
73 }
74 }
75 Role::Assistant => {
76 let text: String = msg
77 .content
78 .iter()
79 .filter_map(|p| match p {
80 ContentPart::Text { text } => Some(text.clone()),
81 _ => None,
82 })
83 .collect::<Vec<_>>()
84 .join("");
85
86 let reasoning: String = msg
88 .content
89 .iter()
90 .filter_map(|p| match p {
91 ContentPart::Thinking { text } => Some(text.clone()),
92 _ => None,
93 })
94 .collect::<Vec<_>>()
95 .join("");
96
97 let tool_calls: Vec<Value> = msg
98 .content
99 .iter()
100 .filter_map(|p| match p {
101 ContentPart::ToolCall {
102 id,
103 name,
104 arguments,
105 } => {
106 let args_value = serde_json::from_str::<Value>(arguments)
109 .unwrap_or_else(|_| json!({"input": arguments}));
110 Some(json!({
111 "id": id,
112 "type": "function",
113 "function": {
114 "name": name,
115 "arguments": args_value
116 }
117 }))
118 }
119 _ => None,
120 })
121 .collect();
122
123 let mut msg_json = json!({
124 "role": "assistant",
125 "content": if text.is_empty() { "".to_string() } else { text },
126 });
127 if !reasoning.is_empty() {
129 msg_json["reasoning_content"] = json!(reasoning);
130 }
131 if !tool_calls.is_empty() {
132 msg_json["tool_calls"] = json!(tool_calls);
133 }
134 msg_json
135 }
136 _ => {
137 let text: String = msg
138 .content
139 .iter()
140 .filter_map(|p| match p {
141 ContentPart::Text { text } => Some(text.clone()),
142 _ => None,
143 })
144 .collect::<Vec<_>>()
145 .join("\n");
146
147 json!({"role": role, "content": text})
148 }
149 }
150 })
151 .collect()
152 }
153
154 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
155 tools
156 .iter()
157 .map(|t| {
158 json!({
159 "type": "function",
160 "function": {
161 "name": t.name,
162 "description": t.description,
163 "parameters": t.parameters
164 }
165 })
166 })
167 .collect()
168 }
169}
170
171#[derive(Debug, Deserialize)]
172struct ZaiResponse {
173 choices: Vec<ZaiChoice>,
174 #[serde(default)]
175 usage: Option<ZaiUsage>,
176}
177
178#[derive(Debug, Deserialize)]
179struct ZaiChoice {
180 message: ZaiMessage,
181 #[serde(default)]
182 finish_reason: Option<String>,
183}
184
185#[derive(Debug, Deserialize)]
186struct ZaiMessage {
187 #[serde(default)]
188 content: Option<String>,
189 #[serde(default)]
190 tool_calls: Option<Vec<ZaiToolCall>>,
191 #[serde(default)]
192 reasoning_content: Option<String>,
193}
194
195#[derive(Debug, Deserialize)]
196struct ZaiToolCall {
197 id: String,
198 function: ZaiFunction,
199}
200
201#[derive(Debug, Deserialize)]
202struct ZaiFunction {
203 name: String,
204 arguments: Value,
205}
206
207#[derive(Debug, Deserialize)]
208struct ZaiUsage {
209 #[serde(default)]
210 prompt_tokens: usize,
211 #[serde(default)]
212 completion_tokens: usize,
213 #[serde(default)]
214 total_tokens: usize,
215 #[serde(default)]
216 prompt_tokens_details: Option<ZaiPromptTokensDetails>,
217}
218
219#[derive(Debug, Deserialize)]
220struct ZaiPromptTokensDetails {
221 #[serde(default)]
222 cached_tokens: usize,
223}
224
225#[derive(Debug, Deserialize)]
226struct ZaiError {
227 error: ZaiErrorDetail,
228}
229
230#[derive(Debug, Deserialize)]
231struct ZaiErrorDetail {
232 message: String,
233 #[serde(default, rename = "type")]
234 error_type: Option<String>,
235}
236
237#[derive(Debug, Deserialize)]
239struct ZaiStreamResponse {
240 choices: Vec<ZaiStreamChoice>,
241}
242
243#[derive(Debug, Deserialize)]
244struct ZaiStreamChoice {
245 delta: ZaiStreamDelta,
246 #[serde(default)]
247 finish_reason: Option<String>,
248}
249
250#[derive(Debug, Deserialize)]
251struct ZaiStreamDelta {
252 #[serde(default)]
253 content: Option<String>,
254 #[serde(default)]
255 reasoning_content: Option<String>,
256 #[serde(default)]
257 tool_calls: Option<Vec<ZaiStreamToolCall>>,
258}
259
260#[derive(Debug, Deserialize)]
261struct ZaiStreamToolCall {
262 #[serde(default)]
263 id: Option<String>,
264 function: Option<ZaiStreamFunction>,
265}
266
267#[derive(Debug, Deserialize)]
268struct ZaiStreamFunction {
269 #[serde(default)]
270 name: Option<String>,
271 #[serde(default)]
272 arguments: Option<Value>,
273}
274
275#[async_trait]
276impl Provider for ZaiProvider {
277 fn name(&self) -> &str {
278 "zai"
279 }
280
281 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282 Ok(vec![
283 ModelInfo {
284 id: "glm-5".to_string(),
285 name: "GLM-5".to_string(),
286 provider: "zai".to_string(),
287 context_window: 200_000,
288 max_output_tokens: Some(128_000),
289 supports_vision: false,
290 supports_tools: true,
291 supports_streaming: true,
292 input_cost_per_million: None,
293 output_cost_per_million: None,
294 },
295 ModelInfo {
296 id: "glm-4.7".to_string(),
297 name: "GLM-4.7".to_string(),
298 provider: "zai".to_string(),
299 context_window: 128_000,
300 max_output_tokens: Some(128_000),
301 supports_vision: false,
302 supports_tools: true,
303 supports_streaming: true,
304 input_cost_per_million: None,
305 output_cost_per_million: None,
306 },
307 ModelInfo {
308 id: "glm-4.7-flash".to_string(),
309 name: "GLM-4.7 Flash".to_string(),
310 provider: "zai".to_string(),
311 context_window: 128_000,
312 max_output_tokens: Some(128_000),
313 supports_vision: false,
314 supports_tools: true,
315 supports_streaming: true,
316 input_cost_per_million: None,
317 output_cost_per_million: None,
318 },
319 ModelInfo {
320 id: "glm-4.6".to_string(),
321 name: "GLM-4.6".to_string(),
322 provider: "zai".to_string(),
323 context_window: 128_000,
324 max_output_tokens: Some(128_000),
325 supports_vision: false,
326 supports_tools: true,
327 supports_streaming: true,
328 input_cost_per_million: None,
329 output_cost_per_million: None,
330 },
331 ModelInfo {
332 id: "glm-4.5".to_string(),
333 name: "GLM-4.5".to_string(),
334 provider: "zai".to_string(),
335 context_window: 128_000,
336 max_output_tokens: Some(96_000),
337 supports_vision: false,
338 supports_tools: true,
339 supports_streaming: true,
340 input_cost_per_million: None,
341 output_cost_per_million: None,
342 },
343 ])
344 }
345
346 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
347 let messages = Self::convert_messages(&request.messages);
348 let tools = Self::convert_tools(&request.tools);
349
350 let temperature = request.temperature.unwrap_or(1.0);
352
353 let mut body = json!({
354 "model": request.model,
355 "messages": messages,
356 "temperature": temperature,
357 });
358
359 body["thinking"] = json!({
363 "type": "enabled",
364 "clear_thinking": false
365 });
366
367 if !tools.is_empty() {
368 body["tools"] = json!(tools);
369 }
370 if let Some(max) = request.max_tokens {
371 body["max_tokens"] = json!(max);
372 }
373
374 tracing::debug!(model = %request.model, "Z.AI request");
375
376 let response = self
377 .client
378 .post(format!("{}/chat/completions", self.base_url))
379 .header("Authorization", format!("Bearer {}", self.api_key))
380 .header("Content-Type", "application/json")
381 .json(&body)
382 .send()
383 .await
384 .context("Failed to send request to Z.AI")?;
385
386 let status = response.status();
387 let text = response
388 .text()
389 .await
390 .context("Failed to read Z.AI response")?;
391
392 if !status.is_success() {
393 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
394 anyhow::bail!(
395 "Z.AI API error: {} ({:?})",
396 err.error.message,
397 err.error.error_type
398 );
399 }
400 anyhow::bail!("Z.AI API error: {} {}", status, text);
401 }
402
403 let response: ZaiResponse = serde_json::from_str(&text).context(format!(
404 "Failed to parse Z.AI response: {}",
405 &text[..text.len().min(200)]
406 ))?;
407
408 let choice = response
409 .choices
410 .first()
411 .ok_or_else(|| anyhow::anyhow!("No choices in Z.AI response"))?;
412
413 if let Some(ref reasoning) = choice.message.reasoning_content {
415 if !reasoning.is_empty() {
416 tracing::info!(
417 reasoning_len = reasoning.len(),
418 "Z.AI reasoning content received"
419 );
420 }
421 }
422
423 let mut content = Vec::new();
424 let mut has_tool_calls = false;
425
426 if let Some(ref reasoning) = choice.message.reasoning_content {
428 if !reasoning.is_empty() {
429 content.push(ContentPart::Thinking {
430 text: reasoning.clone(),
431 });
432 }
433 }
434
435 if let Some(text) = &choice.message.content {
436 if !text.is_empty() {
437 content.push(ContentPart::Text { text: text.clone() });
438 }
439 }
440
441 if let Some(tool_calls) = &choice.message.tool_calls {
442 has_tool_calls = !tool_calls.is_empty();
443 for tc in tool_calls {
444 let arguments = match &tc.function.arguments {
446 Value::String(s) => s.clone(),
447 other => serde_json::to_string(other).unwrap_or_default(),
448 };
449 content.push(ContentPart::ToolCall {
450 id: tc.id.clone(),
451 name: tc.function.name.clone(),
452 arguments,
453 });
454 }
455 }
456
457 let finish_reason = if has_tool_calls {
458 FinishReason::ToolCalls
459 } else {
460 match choice.finish_reason.as_deref() {
461 Some("stop") => FinishReason::Stop,
462 Some("length") => FinishReason::Length,
463 Some("tool_calls") => FinishReason::ToolCalls,
464 Some("sensitive") => FinishReason::ContentFilter,
465 _ => FinishReason::Stop,
466 }
467 };
468
469 Ok(CompletionResponse {
470 message: Message {
471 role: Role::Assistant,
472 content,
473 },
474 usage: Usage {
475 prompt_tokens: response
476 .usage
477 .as_ref()
478 .map(|u| u.prompt_tokens)
479 .unwrap_or(0),
480 completion_tokens: response
481 .usage
482 .as_ref()
483 .map(|u| u.completion_tokens)
484 .unwrap_or(0),
485 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
486 cache_read_tokens: response
487 .usage
488 .as_ref()
489 .and_then(|u| u.prompt_tokens_details.as_ref())
490 .map(|d| d.cached_tokens)
491 .filter(|&t| t > 0),
492 cache_write_tokens: None,
493 },
494 finish_reason,
495 })
496 }
497
498 async fn complete_stream(
499 &self,
500 request: CompletionRequest,
501 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
502 let messages = Self::convert_messages(&request.messages);
503 let tools = Self::convert_tools(&request.tools);
504
505 let temperature = request.temperature.unwrap_or(1.0);
506
507 let mut body = json!({
508 "model": request.model,
509 "messages": messages,
510 "temperature": temperature,
511 "stream": true,
512 });
513
514 body["thinking"] = json!({
515 "type": "enabled",
516 "clear_thinking": false
517 });
518
519 if !tools.is_empty() {
520 body["tools"] = json!(tools);
521 body["tool_stream"] = json!(true);
523 }
524 if let Some(max) = request.max_tokens {
525 body["max_tokens"] = json!(max);
526 }
527
528 tracing::debug!(model = %request.model, "Z.AI streaming request");
529
530 let response = self
531 .client
532 .post(format!("{}/chat/completions", self.base_url))
533 .header("Authorization", format!("Bearer {}", self.api_key))
534 .header("Content-Type", "application/json")
535 .json(&body)
536 .send()
537 .await
538 .context("Failed to send streaming request to Z.AI")?;
539
540 if !response.status().is_success() {
541 let status = response.status();
542 let text = response.text().await.unwrap_or_default();
543 if let Ok(err) = serde_json::from_str::<ZaiError>(&text) {
544 anyhow::bail!(
545 "Z.AI API error: {} ({:?})",
546 err.error.message,
547 err.error.error_type
548 );
549 }
550 anyhow::bail!("Z.AI streaming error: {} {}", status, text);
551 }
552
553 let stream = response.bytes_stream();
554 let mut buffer = String::new();
555
556 Ok(stream
557 .flat_map(move |chunk_result| {
558 let mut chunks: Vec<StreamChunk> = Vec::new();
559 match chunk_result {
560 Ok(bytes) => {
561 let text = String::from_utf8_lossy(&bytes);
562 buffer.push_str(&text);
563
564 let mut text_buf = String::new();
565 while let Some(line_end) = buffer.find('\n') {
566 let line = buffer[..line_end].trim().to_string();
567 buffer = buffer[line_end + 1..].to_string();
568
569 if line == "data: [DONE]" {
570 if !text_buf.is_empty() {
571 chunks.push(StreamChunk::Text(std::mem::take(&mut text_buf)));
572 }
573 chunks.push(StreamChunk::Done { usage: None });
574 continue;
575 }
576 if let Some(data) = line.strip_prefix("data: ") {
577 if let Ok(parsed) = serde_json::from_str::<ZaiStreamResponse>(data)
578 {
579 if let Some(choice) = parsed.choices.first() {
580 if let Some(ref reasoning) = choice.delta.reasoning_content
582 {
583 if !reasoning.is_empty() {
584 text_buf.push_str(reasoning);
585 }
586 }
587 if let Some(ref content) = choice.delta.content {
588 text_buf.push_str(content);
589 }
590 if let Some(ref tool_calls) = choice.delta.tool_calls {
592 if !text_buf.is_empty() {
593 chunks.push(StreamChunk::Text(std::mem::take(
594 &mut text_buf,
595 )));
596 }
597 for tc in tool_calls {
598 if let Some(ref func) = tc.function {
599 if let Some(ref name) = func.name {
600 chunks.push(StreamChunk::ToolCallStart {
602 id: tc.id.clone().unwrap_or_default(),
603 name: name.clone(),
604 });
605 }
606 if let Some(ref args) = func.arguments {
607 let delta = match args {
608 Value::String(s) => s.clone(),
609 other => serde_json::to_string(other)
610 .unwrap_or_default(),
611 };
612 if !delta.is_empty() {
613 chunks.push(
614 StreamChunk::ToolCallDelta {
615 id: tc
616 .id
617 .clone()
618 .unwrap_or_default(),
619 arguments_delta: delta,
620 },
621 );
622 }
623 }
624 }
625 }
626 }
627 if let Some(ref reason) = choice.finish_reason {
629 if !text_buf.is_empty() {
630 chunks.push(StreamChunk::Text(std::mem::take(
631 &mut text_buf,
632 )));
633 }
634 if reason == "tool_calls" {
635 if let Some(ref tcs) = choice.delta.tool_calls {
637 if let Some(tc) = tcs.last() {
638 chunks.push(StreamChunk::ToolCallEnd {
639 id: tc.id.clone().unwrap_or_default(),
640 });
641 }
642 }
643 }
644 }
645 }
646 }
647 }
648 }
649 if !text_buf.is_empty() {
650 chunks.push(StreamChunk::Text(text_buf));
651 }
652 }
653 Err(e) => chunks.push(StreamChunk::Error(e.to_string())),
654 }
655 futures::stream::iter(chunks)
656 })
657 .boxed())
658 }
659}