1use crate::error::ConfigError;
2use serde::Deserialize;
3use std::collections::HashMap;
4use std::{env, fs, io};
5
6#[derive(Debug, Deserialize, PartialEq, Clone)]
7pub struct Config {
8 pub provider: String,
9 pub providers: HashMap<String, ProviderConfig>,
10}
11
12#[derive(Debug, Deserialize, PartialEq, Clone)]
13pub struct ProviderConfig {
14 pub api_key: Option<String>,
15 pub model: String,
16 #[serde(default)]
17 pub base_url: Option<String>,
18 #[serde(default = "default_max_tokens")]
19 pub max_tokens: u32,
20 #[serde(default = "default_timeout")]
21 pub timeout: u64,
22}
23
24fn default_model() -> String {
25 "claude-3-5-sonnet-20241022".to_string()
26}
27
28fn default_max_tokens() -> u32 {
29 4096
30}
31
32fn default_timeout() -> u64 {
33 60
34}
35
36impl ProviderConfig {
37 pub fn api_key_or_env(&self, provider: &str) -> Option<String> {
38 if let Some(key) = &self.api_key {
39 return Some(key.clone());
40 }
41
42 match provider {
43 "anthropic" => env::var("ANTHROPIC_API_KEY").ok(),
44 "openai" => env::var("OPENAI_API_KEY")
45 .ok()
46 .or_else(|| env::var("ZAI_API_KEY").ok()),
47 "zai" => env::var("ZAI_API_KEY").ok(),
48 _ => None,
49 }
50 }
51}
52
53impl Config {
54 pub fn validate(&self) -> Result<(), ConfigError> {
55 if !["anthropic", "openai", "zai"].contains(&self.provider.as_str()) {
57 return Err(ConfigError::InvalidProvider(self.provider.clone()));
58 }
59
60 if !self.providers.contains_key(&self.provider) {
62 return Err(ConfigError::MissingProvider(self.provider.clone()));
63 }
64
65 let provider_config = self.providers.get(&self.provider).unwrap();
67 if provider_config.api_key.is_none() {
68 let env_var = match self.provider.as_str() {
70 "anthropic" => "ANTHROPIC_API_KEY",
71 "openai" => "OPENAI_API_KEY",
72 "zai" => "ZAI_API_KEY",
73 _ => "API_KEY",
74 };
75 if env::var(env_var).is_err() {
76 let env_var_display = if self.provider == "openai" {
78 "OPENAI_API_KEY or ZAI_API_KEY"
79 } else {
80 env_var
81 };
82 return Err(ConfigError::MissingApiKey {
83 provider: self.provider.clone(),
84 env_var: env_var_display.to_string(),
85 });
86 }
87 }
88
89 Ok(())
90 }
91}
92
93impl Config {
94 pub fn load() -> Result<Self, io::Error> {
95 let config_path = config_path();
96
97 if !config_path.exists() {
98 return Ok(Config::default());
99 }
100
101 let config_content = fs::read_to_string(&config_path)?;
102
103 if config_content.contains("api_key") && !config_content.contains("[providers.") {
105 return Err(io::Error::new(
106 io::ErrorKind::InvalidData,
107 "Old config format detected. Please update to multi-provider format.\n\nNew format:\nprovider = \"anthropic\"\n\n[providers.anthropic]\nmodel = \"claude-3-5-sonnet-20241022\"\napi_key = \"...\" # optional, falls back to ANTHROPIC_API_KEY env var"
108 ));
109 }
110
111 let config: Config = toml::from_str(&config_content)
112 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
113
114 config
115 .validate()
116 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
117
118 Ok(config)
119 }
120}
121
122impl Default for Config {
123 fn default() -> Self {
124 let mut providers = HashMap::new();
125 providers.insert(
126 "anthropic".to_string(),
127 ProviderConfig {
128 api_key: None,
129 model: default_model(),
130 base_url: None,
131 max_tokens: default_max_tokens(),
132 timeout: default_timeout(),
133 },
134 );
135 Config {
136 provider: "anthropic".to_string(),
137 providers,
138 }
139 }
140}
141
142fn config_path() -> std::path::PathBuf {
143 let home_dir = dirs::home_dir().expect("Failed to get home directory");
144 home_dir.join(".limit").join("config.toml")
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_load_from_actual_config() {
153 let config = Config::load().unwrap();
155
156 assert_eq!(config.provider, "openai");
158 assert!(config.providers.contains_key("openai"));
159 let openai = config.providers.get("openai").unwrap();
160 assert_eq!(openai.model, "glm-4.7");
161 assert_eq!(
162 openai.api_key,
163 Some("fc56e203c1964d498f9e1efe7e817a26.3PDyp6TP0D0QSmhM".to_string())
164 );
165 assert_eq!(
166 openai.base_url,
167 Some("https://api.z.ai/api/coding/paas/v4/chat/completions".to_string())
168 );
169 }
170
171 #[test]
172 fn test_load_valid_config() {
173 let config_content = r#"
174provider = "anthropic"
175
176[providers.anthropic]
177api_key = "sk-ant-test123"
178model = "claude-3-5-sonnet-20241022"
179"#;
180
181 let config: Config = toml::from_str(config_content).unwrap();
182
183 assert_eq!(config.provider, "anthropic");
184 assert!(config.providers.contains_key("anthropic"));
185 let anthropic = config.providers.get("anthropic").unwrap();
186 assert_eq!(anthropic.api_key, Some("sk-ant-test123".to_string()));
187 assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
188 }
189
190 #[test]
191 fn test_load_partial_config_uses_defaults() {
192 let config_content = r#"
193provider = "anthropic"
194
195[providers.anthropic]
196api_key = "sk-ant-partial"
197model = "custom-model"
198"#;
199
200 let config: Config = toml::from_str(config_content).unwrap();
201
202 assert_eq!(config.provider, "anthropic");
203 let anthropic = config.providers.get("anthropic").unwrap();
204 assert_eq!(anthropic.api_key, Some("sk-ant-partial".to_string()));
205 assert_eq!(anthropic.model, "custom-model");
206 assert!(anthropic.base_url.is_none()); }
208
209 #[test]
210 fn test_load_config_with_base_url() {
211 let config_content = r#"
212provider = "openai"
213
214[providers.openai]
215api_key = "sk-test123"
216model = "gpt-4"
217base_url = "https://api.z.ai/api/paas/v4/chat/completions"
218"#;
219
220 let config: Config = toml::from_str(config_content).unwrap();
221
222 assert_eq!(config.provider, "openai");
223 let openai = config.providers.get("openai").unwrap();
224 assert_eq!(openai.api_key, Some("sk-test123".to_string()));
225 assert_eq!(openai.model, "gpt-4");
226 assert_eq!(
227 openai.base_url,
228 Some("https://api.z.ai/api/paas/v4/chat/completions".to_string())
229 );
230 }
231
232 #[test]
233 fn test_load_config_without_base_url() {
234 let config_content = r#"
235provider = "anthropic"
236
237[providers.anthropic]
238api_key = "sk-ant-test456"
239model = "claude-3-5-sonnet-20241022"
240"#;
241
242 let config: Config = toml::from_str(config_content).unwrap();
243
244 assert_eq!(config.provider, "anthropic");
245 let anthropic = config.providers.get("anthropic").unwrap();
246 assert_eq!(anthropic.api_key, Some("sk-ant-test456".to_string()));
247 assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
248 assert!(anthropic.base_url.is_none()); }
250
251 #[test]
252 fn test_default_config() {
253 let config = Config::default();
254
255 assert_eq!(config.provider, "anthropic");
256 assert!(config.providers.contains_key("anthropic"));
257 let anthropic = config.providers.get("anthropic").unwrap();
258 assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
259 assert!(anthropic.api_key.is_none());
260 assert!(anthropic.base_url.is_none());
261 }
262
263 #[test]
264 fn test_old_format_detection() {
265 let config_content = r#"
266api_key = "sk-ant-test123"
267model = "claude-3-5-sonnet-20241022"
268"#;
269
270 let result: Result<Config, _> = toml::from_str(config_content);
271 assert!(result.is_err(), "Old format should fail to parse");
272 }
273
274 #[test]
275 fn test_api_key_or_env_from_config() {
276 let provider_config = ProviderConfig {
277 api_key: Some("sk-from-config".to_string()),
278 model: "claude-3-5-sonnet-20241022".to_string(),
279 base_url: None,
280 max_tokens: 4096,
281 timeout: 60,
282 };
283
284 let key = provider_config.api_key_or_env("anthropic");
285 assert_eq!(key, Some("sk-from-config".to_string()));
286 }
287
288 #[test]
289 fn test_api_key_or_env_from_env() {
290 env::set_var("ANTHROPIC_API_KEY", "sk-from-env");
291 let provider_config = ProviderConfig {
292 api_key: None,
293 model: "claude-3-5-sonnet-20241022".to_string(),
294 base_url: None,
295 max_tokens: 4096,
296 timeout: 60,
297 };
298
299 let key = provider_config.api_key_or_env("anthropic");
300 assert_eq!(key, Some("sk-from-env".to_string()));
301 env::remove_var("ANTHROPIC_API_KEY");
302 }
303
304 #[test]
305 fn test_openai_fallback_to_zai_api_key() {
306 env::set_var("ZAI_API_KEY", "sk-zai-key");
307 let provider_config = ProviderConfig {
308 api_key: None,
309 model: "gpt-4".to_string(),
310 base_url: None,
311 max_tokens: 4096,
312 timeout: 60,
313 };
314
315 let key = provider_config.api_key_or_env("openai");
316 assert_eq!(key, Some("sk-zai-key".to_string()));
317 env::remove_var("ZAI_API_KEY");
318 }
319
320 #[test]
321 fn test_unknown_provider_no_env_var() {
322 let provider_config = ProviderConfig {
323 api_key: None,
324 model: "test-model".to_string(),
325 base_url: None,
326 max_tokens: 4096,
327 timeout: 60,
328 };
329
330 let key = provider_config.api_key_or_env("unknown");
331 assert_eq!(key, None);
332 }
333
334 #[test]
335 fn test_zai_config_validation() {
336 env::set_var("ZAI_API_KEY", "test-zai-key");
337 let config_content = r#"
338provider = "zai"
339
340[providers.zai]
341model = "glm-4.7"
342"#;
343 let config: Config = toml::from_str(config_content).unwrap();
344 config.validate().unwrap();
345 env::remove_var("ZAI_API_KEY");
346 }
347 #[test]
348 fn test_zai_api_key_env_var() {
349 env::set_var("ZAI_API_KEY", "test-zai-key");
350 let provider_config = ProviderConfig {
351 api_key: None,
352 model: "glm-4.7".to_string(),
353 base_url: None,
354 max_tokens: 4096,
355 timeout: 60,
356 };
357 let key = provider_config.api_key_or_env("zai");
358 assert_eq!(key, Some("test-zai-key".to_string()));
359 env::remove_var("ZAI_API_KEY");
360 }
361}