Skip to main content

limit_llm/
zai_provider.rs

1use crate::error::LlmError;
2use crate::openai_provider::OpenAiProvider;
3use crate::providers::{LlmProvider, ProviderResponseChunk};
4use crate::types::{Message, Tool};
5use async_trait::async_trait;
6use futures::Stream;
7use std::pin::Pin;
8
9#[derive(Clone, Debug)]
10pub struct ThinkingConfig {
11    pub thinking_enabled: bool,
12    pub clear_thinking: bool,
13}
14
15impl Default for ThinkingConfig {
16    fn default() -> Self {
17        Self {
18            thinking_enabled: false,
19            clear_thinking: true,
20        }
21    }
22}
23
24#[derive(Clone)]
25pub struct ZaiProvider {
26    openai: OpenAiProvider,
27    #[allow(dead_code)]
28    thinking_config: ThinkingConfig,
29}
30impl ZaiProvider {
31    pub fn new(
32        api_key: String,
33        base_url: Option<&str>,
34        model: &str,
35        max_tokens: u32,
36        timeout: u64,
37        thinking_config: ThinkingConfig,
38    ) -> Self {
39        let default_url = "https://api.z.ai/api/coding/paas/v4/chat/completions";
40        Self {
41            openai: OpenAiProvider::new(
42                api_key,
43                base_url.or(Some(default_url)),
44                model,
45                max_tokens,
46                timeout,
47            ),
48            thinking_config,
49        }
50    }
51}
52
53// ZAI uses numeric error codes like {"error":{"code":"1214","message":"..."}}
54// Error parsing is handled by OpenAiProvider's do_request() method
55// Full ZAI-specific error parsing would require modifying OpenAiProvider's error handling
56
57#[async_trait]
58impl LlmProvider for ZaiProvider {
59    #[allow(clippy::type_complexity)]
60    async fn send(
61        &self,
62        messages: Vec<Message>,
63        tools: Vec<Tool>,
64    ) -> Result<
65        Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
66        LlmError,
67    > {
68        // Delegate to OpenAiProvider which handles streaming
69        // Note: reasoning_content parsing would require modifying OpenAiProvider's
70        // SSE stream parsing - deferred to follow-up task
71        self.openai.send(messages, tools).await
72    }
73
74    fn provider_name(&self) -> &str {
75        "zai"
76    }
77
78    fn model_name(&self) -> &str {
79        self.openai.model_name()
80    }
81
82    fn clone_box(&self) -> Box<dyn LlmProvider> {
83        Box::new(self.clone())
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_zai_provider_creation() {
93        let provider = ZaiProvider::new(
94            "test-key".to_string(),
95            None,
96            "glm-4.7",
97            4096,
98            60,
99            ThinkingConfig::default(),
100        );
101        assert_eq!(provider.provider_name(), "zai");
102        assert_eq!(provider.model_name(), "glm-4.7");
103    }
104
105    #[test]
106    fn test_zai_provider_with_custom_url() {
107        let custom_url = "https://custom.api.com/chat";
108        let provider = ZaiProvider::new(
109            "test-key".to_string(),
110            Some(custom_url),
111            "glm-5",
112            8192,
113            120,
114            ThinkingConfig::default(),
115        );
116        assert_eq!(provider.provider_name(), "zai");
117        assert_eq!(provider.model_name(), "glm-5");
118    }
119
120    #[test]
121    fn test_thinking_config_default() {
122        let config = ThinkingConfig::default();
123        assert!(!config.thinking_enabled);
124        assert!(config.clear_thinking);
125    }
126
127    #[test]
128    fn test_zai_provider_clone() {
129        let provider = ZaiProvider::new(
130            "test-key".to_string(),
131            None,
132            "glm-4.7",
133            4096,
134            60,
135            ThinkingConfig {
136                thinking_enabled: true,
137                clear_thinking: false,
138            },
139        );
140        let cloned = provider.clone_box();
141        assert_eq!(cloned.provider_name(), "zai");
142        assert_eq!(cloned.model_name(), "glm-4.7");
143    }
144}