Skip to main content

limit_llm/
config.rs

1use crate::error::ConfigError;
2use serde::Deserialize;
3use std::collections::HashMap;
4use std::{env, fs, io};
5
6#[derive(Debug, Deserialize, PartialEq, Clone)]
7pub struct Config {
8    pub provider: String,
9    pub providers: HashMap<String, ProviderConfig>,
10}
11
12#[derive(Debug, Deserialize, PartialEq, Clone)]
13pub struct ProviderConfig {
14    pub api_key: Option<String>,
15    pub model: String,
16    #[serde(default)]
17    pub base_url: Option<String>,
18    #[serde(default = "default_max_tokens")]
19    pub max_tokens: u32,
20    #[serde(default = "default_timeout")]
21    pub timeout: u64,
22}
23
24fn default_model() -> String {
25    "claude-3-5-sonnet-20241022".to_string()
26}
27
28fn default_max_tokens() -> u32 {
29    4096
30}
31
32fn default_timeout() -> u64 {
33    60
34}
35
36impl ProviderConfig {
37    pub fn api_key_or_env(&self, provider: &str) -> Option<String> {
38        if let Some(key) = &self.api_key {
39            return Some(key.clone());
40        }
41
42        match provider {
43            "anthropic" => env::var("ANTHROPIC_API_KEY").ok(),
44            "openai" => env::var("OPENAI_API_KEY")
45                .ok()
46                .or_else(|| env::var("ZAI_API_KEY").ok()),
47            "zai" => env::var("ZAI_API_KEY").ok(),
48            _ => None,
49        }
50    }
51}
52
53impl Config {
54    pub fn validate(&self) -> Result<(), ConfigError> {
55        // Check provider field is valid
56        if !["anthropic", "openai", "zai"].contains(&self.provider.as_str()) {
57            return Err(ConfigError::InvalidProvider(self.provider.clone()));
58        }
59
60        // Check provider config exists
61        if !self.providers.contains_key(&self.provider) {
62            return Err(ConfigError::MissingProvider(self.provider.clone()));
63        }
64
65        // Check active provider has required fields
66        let provider_config = self.providers.get(&self.provider).unwrap();
67        if provider_config.api_key.is_none() {
68            // Check if env var exists
69            let env_var = match self.provider.as_str() {
70                "anthropic" => "ANTHROPIC_API_KEY",
71                "openai" => "OPENAI_API_KEY",
72                "zai" => "ZAI_API_KEY",
73                _ => "API_KEY",
74            };
75            if env::var(env_var).is_err() {
76                // For openai, also check ZAI_API_KEY as fallback
77                let env_var_display = if self.provider == "openai" {
78                    "OPENAI_API_KEY or ZAI_API_KEY"
79                } else {
80                    env_var
81                };
82                return Err(ConfigError::MissingApiKey {
83                    provider: self.provider.clone(),
84                    env_var: env_var_display.to_string(),
85                });
86            }
87        }
88
89        Ok(())
90    }
91}
92
93impl Config {
94    pub fn load() -> Result<Self, io::Error> {
95        let config_path = config_path();
96
97        if !config_path.exists() {
98            return Ok(Config::default());
99        }
100
101        let config_content = fs::read_to_string(&config_path)?;
102
103        // Check for old format (detect by presence of 'api_key' at top level)
104        if config_content.contains("api_key") && !config_content.contains("[providers.") {
105            return Err(io::Error::new(
106                io::ErrorKind::InvalidData,
107                "Old config format detected. Please update to multi-provider format.\n\nNew format:\nprovider = \"anthropic\"\n\n[providers.anthropic]\nmodel = \"claude-3-5-sonnet-20241022\"\napi_key = \"...\"  # optional, falls back to ANTHROPIC_API_KEY env var"
108            ));
109        }
110
111        let config: Config = toml::from_str(&config_content)
112            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
113
114        config
115            .validate()
116            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
117
118        Ok(config)
119    }
120}
121
122impl Default for Config {
123    fn default() -> Self {
124        let mut providers = HashMap::new();
125        providers.insert(
126            "anthropic".to_string(),
127            ProviderConfig {
128                api_key: None,
129                model: default_model(),
130                base_url: None,
131                max_tokens: default_max_tokens(),
132                timeout: default_timeout(),
133            },
134        );
135        Config {
136            provider: "anthropic".to_string(),
137            providers,
138        }
139    }
140}
141
142fn config_path() -> std::path::PathBuf {
143    let home_dir = dirs::home_dir().expect("Failed to get home directory");
144    home_dir.join(".limit").join("config.toml")
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_load_from_actual_config() {
153        // Load config from actual path (tests loading with existing file)
154        let config = Config::load().unwrap();
155
156        // Should have loaded the actual config (openai with z.ai endpoint)
157        assert_eq!(config.provider, "openai");
158        assert!(config.providers.contains_key("openai"));
159        let openai = config.providers.get("openai").unwrap();
160        assert_eq!(openai.model, "glm-4.7");
161        assert_eq!(
162            openai.api_key,
163            Some("fc56e203c1964d498f9e1efe7e817a26.3PDyp6TP0D0QSmhM".to_string())
164        );
165        assert_eq!(
166            openai.base_url,
167            Some("https://api.z.ai/api/coding/paas/v4/chat/completions".to_string())
168        );
169    }
170
171    #[test]
172    fn test_load_valid_config() {
173        let config_content = r#"
174provider = "anthropic"
175
176[providers.anthropic]
177api_key = "sk-ant-test123"
178model = "claude-3-5-sonnet-20241022"
179"#;
180
181        let config: Config = toml::from_str(config_content).unwrap();
182
183        assert_eq!(config.provider, "anthropic");
184        assert!(config.providers.contains_key("anthropic"));
185        let anthropic = config.providers.get("anthropic").unwrap();
186        assert_eq!(anthropic.api_key, Some("sk-ant-test123".to_string()));
187        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
188    }
189
190    #[test]
191    fn test_load_partial_config_uses_defaults() {
192        let config_content = r#"
193provider = "anthropic"
194
195[providers.anthropic]
196api_key = "sk-ant-partial"
197model = "custom-model"
198"#;
199
200        let config: Config = toml::from_str(config_content).unwrap();
201
202        assert_eq!(config.provider, "anthropic");
203        let anthropic = config.providers.get("anthropic").unwrap();
204        assert_eq!(anthropic.api_key, Some("sk-ant-partial".to_string()));
205        assert_eq!(anthropic.model, "custom-model");
206        assert!(anthropic.base_url.is_none()); // default
207    }
208
209    #[test]
210    fn test_load_config_with_base_url() {
211        let config_content = r#"
212provider = "openai"
213
214[providers.openai]
215api_key = "sk-test123"
216model = "gpt-4"
217base_url = "https://api.z.ai/api/paas/v4/chat/completions"
218"#;
219
220        let config: Config = toml::from_str(config_content).unwrap();
221
222        assert_eq!(config.provider, "openai");
223        let openai = config.providers.get("openai").unwrap();
224        assert_eq!(openai.api_key, Some("sk-test123".to_string()));
225        assert_eq!(openai.model, "gpt-4");
226        assert_eq!(
227            openai.base_url,
228            Some("https://api.z.ai/api/paas/v4/chat/completions".to_string())
229        );
230    }
231
232    #[test]
233    fn test_load_config_without_base_url() {
234        let config_content = r#"
235provider = "anthropic"
236
237[providers.anthropic]
238api_key = "sk-ant-test456"
239model = "claude-3-5-sonnet-20241022"
240"#;
241
242        let config: Config = toml::from_str(config_content).unwrap();
243
244        assert_eq!(config.provider, "anthropic");
245        let anthropic = config.providers.get("anthropic").unwrap();
246        assert_eq!(anthropic.api_key, Some("sk-ant-test456".to_string()));
247        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
248        assert!(anthropic.base_url.is_none()); // default
249    }
250
251    #[test]
252    fn test_default_config() {
253        let config = Config::default();
254
255        assert_eq!(config.provider, "anthropic");
256        assert!(config.providers.contains_key("anthropic"));
257        let anthropic = config.providers.get("anthropic").unwrap();
258        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
259        assert!(anthropic.api_key.is_none());
260        assert!(anthropic.base_url.is_none());
261    }
262
263    #[test]
264    fn test_old_format_detection() {
265        let config_content = r#"
266api_key = "sk-ant-test123"
267model = "claude-3-5-sonnet-20241022"
268"#;
269
270        let result: Result<Config, _> = toml::from_str(config_content);
271        assert!(result.is_err(), "Old format should fail to parse");
272    }
273
274    #[test]
275    fn test_api_key_or_env_from_config() {
276        let provider_config = ProviderConfig {
277            api_key: Some("sk-from-config".to_string()),
278            model: "claude-3-5-sonnet-20241022".to_string(),
279            base_url: None,
280            max_tokens: 4096,
281            timeout: 60,
282        };
283
284        let key = provider_config.api_key_or_env("anthropic");
285        assert_eq!(key, Some("sk-from-config".to_string()));
286    }
287
288    #[test]
289    fn test_api_key_or_env_from_env() {
290        env::set_var("ANTHROPIC_API_KEY", "sk-from-env");
291        let provider_config = ProviderConfig {
292            api_key: None,
293            model: "claude-3-5-sonnet-20241022".to_string(),
294            base_url: None,
295            max_tokens: 4096,
296            timeout: 60,
297        };
298
299        let key = provider_config.api_key_or_env("anthropic");
300        assert_eq!(key, Some("sk-from-env".to_string()));
301        env::remove_var("ANTHROPIC_API_KEY");
302    }
303
304    #[test]
305    fn test_openai_fallback_to_zai_api_key() {
306        env::set_var("ZAI_API_KEY", "sk-zai-key");
307        let provider_config = ProviderConfig {
308            api_key: None,
309            model: "gpt-4".to_string(),
310            base_url: None,
311            max_tokens: 4096,
312            timeout: 60,
313        };
314
315        let key = provider_config.api_key_or_env("openai");
316        assert_eq!(key, Some("sk-zai-key".to_string()));
317        env::remove_var("ZAI_API_KEY");
318    }
319
320    #[test]
321    fn test_unknown_provider_no_env_var() {
322        let provider_config = ProviderConfig {
323            api_key: None,
324            model: "test-model".to_string(),
325            base_url: None,
326            max_tokens: 4096,
327            timeout: 60,
328        };
329
330        let key = provider_config.api_key_or_env("unknown");
331        assert_eq!(key, None);
332    }
333
334    #[test]
335    fn test_zai_config_validation() {
336        env::set_var("ZAI_API_KEY", "test-zai-key");
337        let config_content = r#"
338provider = "zai"
339
340[providers.zai]
341model = "glm-4.7"
342"#;
343        let config: Config = toml::from_str(config_content).unwrap();
344        config.validate().unwrap();
345        env::remove_var("ZAI_API_KEY");
346    }
347    #[test]
348    fn test_zai_api_key_env_var() {
349        env::set_var("ZAI_API_KEY", "test-zai-key");
350        let provider_config = ProviderConfig {
351            api_key: None,
352            model: "glm-4.7".to_string(),
353            base_url: None,
354            max_tokens: 4096,
355            timeout: 60,
356        };
357        let key = provider_config.api_key_or_env("zai");
358        assert_eq!(key, Some("test-zai-key".to_string()));
359        env::remove_var("ZAI_API_KEY");
360    }
361}