limit_llm/
zai_provider.rs1use crate::error::LlmError;
2use crate::openai_provider::OpenAiProvider;
3use crate::providers::{LlmProvider, ProviderResponseChunk};
4use crate::types::{Message, Tool};
5use async_trait::async_trait;
6use futures::Stream;
7use std::pin::Pin;
8use tracing::debug;
9
10#[derive(Clone, Debug)]
11pub struct ThinkingConfig {
12 pub thinking_enabled: bool,
13 pub clear_thinking: bool,
14}
15
16impl Default for ThinkingConfig {
17 fn default() -> Self {
18 Self {
19 thinking_enabled: false,
20 clear_thinking: true,
21 }
22 }
23}
24
25#[derive(Clone)]
26pub struct ZaiProvider {
27 openai: OpenAiProvider,
28 #[allow(dead_code)]
29 thinking_config: ThinkingConfig,
30}
31impl ZaiProvider {
32 pub fn new(
33 api_key: String,
34 base_url: Option<&str>,
35 model: &str,
36 max_tokens: u32,
37 timeout: u64,
38 thinking_config: ThinkingConfig,
39 ) -> Self {
40 let default_url = "https://api.z.ai/api/coding/paas/v4/chat/completions";
41
42 debug!(
43 "ZAI provider config: thinking_enabled={}, clear_thinking={}",
44 thinking_config.thinking_enabled, thinking_config.clear_thinking
45 );
46
47 let extra_body = if thinking_config.thinking_enabled {
49 let mut body = serde_json::Map::new();
50 let thinking = serde_json::json!({
51 "type": "enabled",
52 "clear_thinking": thinking_config.clear_thinking
53 });
54 body.insert("thinking".to_string(), thinking);
55 debug!("ZAI provider: building extra_body with thinking ENABLED");
56 Some(body)
57 } else {
58 let mut body = serde_json::Map::new();
60 let thinking = serde_json::json!({
61 "type": "disabled"
62 });
63 body.insert("thinking".to_string(), thinking);
64 debug!("ZAI provider: building extra_body with thinking DISABLED");
65 Some(body)
66 };
67
68 Self {
69 openai: OpenAiProvider::with_extra_body(
70 api_key,
71 base_url.or(Some(default_url)),
72 model,
73 max_tokens,
74 timeout,
75 extra_body,
76 ),
77 thinking_config,
78 }
79 }
80}
81
82#[async_trait]
87impl LlmProvider for ZaiProvider {
88 #[allow(clippy::type_complexity)]
89 async fn send(
90 &self,
91 messages: Vec<Message>,
92 tools: Vec<Tool>,
93 ) -> Result<
94 Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
95 LlmError,
96 > {
97 self.openai.send(messages, tools).await
101 }
102
103 fn provider_name(&self) -> &str {
104 "zai"
105 }
106
107 fn model_name(&self) -> &str {
108 self.openai.model_name()
109 }
110
111 fn clone_box(&self) -> Box<dyn LlmProvider> {
112 Box::new(self.clone())
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_zai_provider_creation() {
122 let provider = ZaiProvider::new(
123 "test-key".to_string(),
124 None,
125 "glm-4.7",
126 4096,
127 60,
128 ThinkingConfig::default(),
129 );
130 assert_eq!(provider.provider_name(), "zai");
131 assert_eq!(provider.model_name(), "glm-4.7");
132 }
133
134 #[test]
135 fn test_zai_provider_with_custom_url() {
136 let custom_url = "https://custom.api.com/chat";
137 let provider = ZaiProvider::new(
138 "test-key".to_string(),
139 Some(custom_url),
140 "glm-5",
141 8192,
142 120,
143 ThinkingConfig::default(),
144 );
145 assert_eq!(provider.provider_name(), "zai");
146 assert_eq!(provider.model_name(), "glm-5");
147 }
148
149 #[test]
150 fn test_thinking_config_default() {
151 let config = ThinkingConfig::default();
152 assert!(!config.thinking_enabled);
153 assert!(config.clear_thinking);
154 }
155
156 #[test]
157 fn test_zai_provider_clone() {
158 let provider = ZaiProvider::new(
159 "test-key".to_string(),
160 None,
161 "glm-4.7",
162 4096,
163 60,
164 ThinkingConfig {
165 thinking_enabled: true,
166 clear_thinking: false,
167 },
168 );
169 let cloned = provider.clone_box();
170 assert_eq!(cloned.provider_name(), "zai");
171 assert_eq!(cloned.model_name(), "glm-4.7");
172 }
173}