limit_llm/
zai_provider.rs1use crate::error::LlmError;
2use crate::openai_provider::OpenAiProvider;
3use crate::providers::{LlmProvider, ProviderResponseChunk};
4use crate::types::{Message, Tool};
5use async_trait::async_trait;
6use futures::Stream;
7use std::pin::Pin;
8
9#[derive(Clone, Debug)]
10pub struct ThinkingConfig {
11 pub thinking_enabled: bool,
12 pub clear_thinking: bool,
13}
14
15impl Default for ThinkingConfig {
16 fn default() -> Self {
17 Self {
18 thinking_enabled: false,
19 clear_thinking: true,
20 }
21 }
22}
23
24#[derive(Clone)]
25pub struct ZaiProvider {
26 openai: OpenAiProvider,
27 #[allow(dead_code)]
28 thinking_config: ThinkingConfig,
29}
30impl ZaiProvider {
31 pub fn new(
32 api_key: String,
33 base_url: Option<&str>,
34 model: &str,
35 max_tokens: u32,
36 timeout: u64,
37 thinking_config: ThinkingConfig,
38 ) -> Self {
39 let default_url = "https://api.z.ai/api/coding/paas/v4/chat/completions";
40 Self {
41 openai: OpenAiProvider::new(
42 api_key,
43 base_url.or(Some(default_url)),
44 model,
45 max_tokens,
46 timeout,
47 ),
48 thinking_config,
49 }
50 }
51}
52
53#[async_trait]
58impl LlmProvider for ZaiProvider {
59 #[allow(clippy::type_complexity)]
60 async fn send(
61 &self,
62 messages: Vec<Message>,
63 tools: Vec<Tool>,
64 ) -> Result<
65 Pin<Box<dyn Stream<Item = Result<ProviderResponseChunk, LlmError>> + Send + '_>>,
66 LlmError,
67 > {
68 self.openai.send(messages, tools).await
72 }
73
74 fn provider_name(&self) -> &str {
75 "zai"
76 }
77
78 fn model_name(&self) -> &str {
79 self.openai.model_name()
80 }
81
82 fn clone_box(&self) -> Box<dyn LlmProvider> {
83 Box::new(self.clone())
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_zai_provider_creation() {
93 let provider = ZaiProvider::new(
94 "test-key".to_string(),
95 None,
96 "glm-4.7",
97 4096,
98 60,
99 ThinkingConfig::default(),
100 );
101 assert_eq!(provider.provider_name(), "zai");
102 assert_eq!(provider.model_name(), "glm-4.7");
103 }
104
105 #[test]
106 fn test_zai_provider_with_custom_url() {
107 let custom_url = "https://custom.api.com/chat";
108 let provider = ZaiProvider::new(
109 "test-key".to_string(),
110 Some(custom_url),
111 "glm-5",
112 8192,
113 120,
114 ThinkingConfig::default(),
115 );
116 assert_eq!(provider.provider_name(), "zai");
117 assert_eq!(provider.model_name(), "glm-5");
118 }
119
120 #[test]
121 fn test_thinking_config_default() {
122 let config = ThinkingConfig::default();
123 assert!(!config.thinking_enabled);
124 assert!(config.clear_thinking);
125 }
126
127 #[test]
128 fn test_zai_provider_clone() {
129 let provider = ZaiProvider::new(
130 "test-key".to_string(),
131 None,
132 "glm-4.7",
133 4096,
134 60,
135 ThinkingConfig {
136 thinking_enabled: true,
137 clear_thinking: false,
138 },
139 );
140 let cloned = provider.clone_box();
141 assert_eq!(cloned.provider_name(), "zai");
142 assert_eq!(cloned.model_name(), "glm-4.7");
143 }
144}