1use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11
12pub struct ZaiProvider {
13 client: Client,
14 api_key: String,
15 base_url: String,
16}
17
18impl ZaiProvider {
19 pub fn new(api_key: String) -> Self {
20 Self {
21 client: Client::new(),
22 api_key,
23 base_url: "https://open.bigmodel.cn/api/paas/v4".to_string(),
24 }
25 }
26
27 pub fn with_base_url(api_key: String, base_url: String) -> Self {
28 Self {
29 client: Client::new(),
30 api_key,
31 base_url,
32 }
33 }
34}
35
36impl Default for ZaiProvider {
37 fn default() -> Self {
38 let api_key = std::env::var("ZAI_API_KEY")
39 .or_else(|_| std::env::var("ZHIPU_API_KEY"))
40 .unwrap_or_default();
41 Self::new(api_key)
42 }
43}
44
45#[async_trait]
46impl LlmProvider for ZaiProvider {
47 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
48 let url = format!("{}/chat/completions", self.base_url);
49
50 let model = if request.model.is_empty() || request.model == "default" {
51 "glm-4-plus".to_string()
52 } else {
53 request.model.clone()
54 };
55
56 let zai_request = ZaiChatRequest {
57 model,
58 messages: request.messages.into_iter().map(Into::into).collect(),
59 temperature: request.temperature,
60 max_tokens: request.max_tokens,
61 stream: request.stream,
62 };
63
64 let response = self
65 .client
66 .post(&url)
67 .header("Authorization", format!("Bearer {}", self.api_key))
68 .header("Content-Type", "application/json")
69 .json(&zai_request)
70 .send()
71 .await
72 .context("Failed to send request to Z.AI")?;
73
74 if !response.status().is_success() {
75 let error_text = response.text().await?;
76 return Err(anyhow::anyhow!("Z.AI API error: {}", error_text));
77 }
78
79 let zai_response: ZaiChatResponse = response.json().await?;
80
81 let content = zai_response
82 .choices
83 .first()
84 .map(|c| c.message.content.clone())
85 .unwrap_or_default();
86
87 Ok(LlmResponse {
88 content,
89 model: zai_response.model.unwrap_or_else(|| "glm".to_string()),
90 usage: zai_response.usage.map(Into::into),
91 })
92 }
93
94 fn name(&self) -> &'static str {
95 "ZAI"
96 }
97}
98
99#[derive(Debug, Serialize)]
100struct ZaiChatRequest {
101 model: String,
102 messages: Vec<ZaiMessage>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 temperature: Option<f32>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 max_tokens: Option<usize>,
107 stream: bool,
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111struct ZaiMessage {
112 role: String,
113 content: String,
114}
115
116impl From<LlmMessage> for ZaiMessage {
117 fn from(msg: LlmMessage) -> Self {
118 Self {
119 role: match msg.role {
120 LlmRole::System => "system".to_string(),
121 LlmRole::User => "user".to_string(),
122 LlmRole::Assistant => "assistant".to_string(),
123 },
124 content: msg.content,
125 }
126 }
127}
128
129#[derive(Debug, Deserialize)]
130struct ZaiChatResponse {
131 model: Option<String>,
132 choices: Vec<ZaiChoice>,
133 usage: Option<ZaiUsage>,
134}
135
136#[derive(Debug, Deserialize)]
137struct ZaiChoice {
138 message: ZaiMessage,
139}
140
141#[derive(Debug, Deserialize)]
142struct ZaiUsage {
143 prompt_tokens: usize,
144 completion_tokens: usize,
145 total_tokens: usize,
146}
147
148impl From<ZaiUsage> for LlmUsage {
149 fn from(u: ZaiUsage) -> Self {
150 Self {
151 prompt_tokens: u.prompt_tokens,
152 completion_tokens: u.completion_tokens,
153 total_tokens: u.total_tokens,
154 }
155 }
156}
157
158pub mod models {
159 pub const GLM_4_PLUS: &str = "glm-4-plus";
160 pub const GLM_4_7: &str = "glm-4.7";
161 pub const GLM_4_6: &str = "glm-4.6";
162 pub const GLM_4_AIR: &str = "glm-4-air";
163 pub const GLM_4_FLASH: &str = "glm-4-flash";
164}