Skip to main content

limit_llm/
zai_provider.rs

1use crate::error::LlmError;
2use crate::openai_provider::OpenAiProvider;
3use crate::providers::{LlmProvider, ProviderResponseChunk};
4use crate::types::{Message, Tool};
5use async_trait::async_trait;
6use futures::Stream;
7use std::pin::Pin;
8use tracing::debug;
9
10#[derive(Clone, Debug)]
11pub struct ThinkingConfig {
12    pub thinking_enabled: bool,
13    pub clear_thinking: bool,
14}
15
16impl Default for ThinkingConfig {
17    fn default() -> Self {
18        Self {
19            thinking_enabled: false,
20            clear_thinking: true,
21        }
22    }
23}
24
25#[derive(Clone)]
26pub struct ZaiProvider {
27    openai: OpenAiProvider,
28    #[allow(dead_code)]
29    thinking_config: ThinkingConfig,
30}
31impl ZaiProvider {
32    pub fn new(
33        api_key: String,
34        base_url: Option<&str>,
35        model: &str,
36        max_tokens: u32,
37        timeout: u64,
38        thinking_config: ThinkingConfig,
39    ) -> Self {
40        let default_url = "https://api.z.ai/api/coding/paas/v4/chat/completions";
41
42        debug!(
43            "ZAI provider config: thinking_enabled={}, clear_thinking={}",
44            thinking_config.thinking_enabled, thinking_config.clear_thinking
45        );
46
47        // Build extra_body with thinking config
48        let extra_body = if thinking_config.thinking_enabled {
49            let mut body = serde_json::Map::new();
50            let thinking = serde_json::json!({
51                "type": "enabled",
52                "clear_thinking": thinking_config.clear_thinking
53            });
54            body.insert("thinking".to_string(), thinking);
55            debug!("ZAI provider: building extra_body with thinking ENABLED");
56            Some(body)
57        } else {
58            // Disabled thinking - explicitly disable to avoid default interleaved thinking
59            let mut body = serde_json::Map::new();
60            let thinking = serde_json::json!({
61                "type": "disabled"
62            });
63            body.insert("thinking".to_string(), thinking);
64            debug!("ZAI provider: building extra_body with thinking DISABLED");
65            Some(body)
66        };
67
68        Self {
69            openai: OpenAiProvider::with_extra_body(
70                api_key,
71                base_url.or(Some(default_url)),
72                model,
73                max_tokens,
74                timeout,
75                extra_body,
76            ),
77            thinking_config,
78        }
79    }
80}
81
82// ZAI uses numeric error codes like {"error":{"code":"1214","message":"..."}}
83// Error parsing is handled by OpenAiProvider's do_request() method
84// Full ZAI-specific error parsing would require modifying OpenAiProvider's error handling
85
86#[async_trait]
87impl LlmProvider for ZaiProvider {
88    #[allow(clippy::type_complexity)]
89    async fn send(
90        &self,
91        messages: Vec<Message>,
92        tools: Vec<Tool>,
93    ) -> Result<
94        Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
95        LlmError,
96    > {
97        // Delegate to OpenAiProvider which handles streaming
98        // Note: reasoning_content parsing would require modifying OpenAiProvider's
99        // SSE stream parsing - deferred to follow-up task
100        self.openai.send(messages, tools).await
101    }
102
103    fn provider_name(&self) -> &str {
104        "zai"
105    }
106
107    fn model_name(&self) -> &str {
108        self.openai.model_name()
109    }
110
111    fn clone_box(&self) -> Box<dyn LlmProvider> {
112        Box::new(self.clone())
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_zai_provider_creation() {
122        let provider = ZaiProvider::new(
123            "test-key".to_string(),
124            None,
125            "glm-4.7",
126            4096,
127            60,
128            ThinkingConfig::default(),
129        );
130        assert_eq!(provider.provider_name(), "zai");
131        assert_eq!(provider.model_name(), "glm-4.7");
132    }
133
134    #[test]
135    fn test_zai_provider_with_custom_url() {
136        let custom_url = "https://custom.api.com/chat";
137        let provider = ZaiProvider::new(
138            "test-key".to_string(),
139            Some(custom_url),
140            "glm-5",
141            8192,
142            120,
143            ThinkingConfig::default(),
144        );
145        assert_eq!(provider.provider_name(), "zai");
146        assert_eq!(provider.model_name(), "glm-5");
147    }
148
149    #[test]
150    fn test_thinking_config_default() {
151        let config = ThinkingConfig::default();
152        assert!(!config.thinking_enabled);
153        assert!(config.clear_thinking);
154    }
155
156    #[test]
157    fn test_zai_provider_clone() {
158        let provider = ZaiProvider::new(
159            "test-key".to_string(),
160            None,
161            "glm-4.7",
162            4096,
163            60,
164            ThinkingConfig {
165                thinking_enabled: true,
166                clear_thinking: false,
167            },
168        );
169        let cloned = provider.clone_box();
170        assert_eq!(cloned.provider_name(), "zai");
171        assert_eq!(cloned.model_name(), "glm-4.7");
172    }
173}