Skip to main content

limit_llm/
config.rs

1use crate::error::ConfigError;
2use serde::Deserialize;
3use std::collections::HashMap;
4use std::{env, fs, io};
5
6#[derive(Debug, Deserialize, PartialEq, Clone)]
7pub struct Config {
8    pub provider: String,
9    pub providers: HashMap<String, ProviderConfig>,
10}
11
12#[derive(Debug, Deserialize, PartialEq, Clone)]
13pub struct ProviderConfig {
14    pub api_key: Option<String>,
15    pub model: String,
16    #[serde(default)]
17    pub base_url: Option<String>,
18    #[serde(default = "default_max_tokens")]
19    pub max_tokens: u32,
20    #[serde(default = "default_timeout")]
21    pub timeout: u64,
22    /// Maximum iterations for agent loop (0 = unlimited, default: 100)
23    #[serde(default = "default_max_iterations")]
24    pub max_iterations: usize,
25    /// Enable thinking/reasoning mode (Z.AI only, default: false)
26    #[serde(default)]
27    pub thinking_enabled: bool,
28    /// Preserve thinking between turns (Z.AI only, default: true)
29    /// Set to false for Preserved Thinking in multi-turn conversations
30    #[serde(default = "default_clear_thinking")]
31    pub clear_thinking: bool,
32}
33
34fn default_model() -> String {
35    "claude-3-5-sonnet-20241022".to_string()
36}
37
38fn default_max_tokens() -> u32 {
39    4096
40}
41
42fn default_timeout() -> u64 {
43    60
44}
45
46fn default_max_iterations() -> usize {
47    100
48}
49
50fn default_clear_thinking() -> bool {
51    true
52}
53
54impl ProviderConfig {
55    pub fn api_key_or_env(&self, provider: &str) -> Option<String> {
56        if let Some(key) = &self.api_key {
57            return Some(key.clone());
58        }
59
60        match provider {
61            "anthropic" => env::var("ANTHROPIC_API_KEY").ok(),
62            "openai" => env::var("OPENAI_API_KEY")
63                .ok()
64                .or_else(|| env::var("ZAI_API_KEY").ok()),
65            "zai" => env::var("ZAI_API_KEY").ok(),
66            _ => None,
67        }
68    }
69}
70
71impl Config {
72    pub fn validate(&self) -> Result<(), ConfigError> {
73        // Check provider field is valid
74        if !["anthropic", "openai", "zai"].contains(&self.provider.as_str()) {
75            return Err(ConfigError::InvalidProvider(self.provider.clone()));
76        }
77
78        // Check provider config exists
79        if !self.providers.contains_key(&self.provider) {
80            return Err(ConfigError::MissingProvider(self.provider.clone()));
81        }
82
83        // Check active provider has required fields
84        let provider_config = self.providers.get(&self.provider).unwrap();
85        if provider_config.api_key.is_none() {
86            // Check if env var exists
87            let env_var = match self.provider.as_str() {
88                "anthropic" => "ANTHROPIC_API_KEY",
89                "openai" => "OPENAI_API_KEY",
90                "zai" => "ZAI_API_KEY",
91                _ => "API_KEY",
92            };
93            if env::var(env_var).is_err() {
94                // For openai, also check ZAI_API_KEY as fallback
95                let env_var_display = if self.provider == "openai" {
96                    "OPENAI_API_KEY or ZAI_API_KEY"
97                } else {
98                    env_var
99                };
100                return Err(ConfigError::MissingApiKey {
101                    provider: self.provider.clone(),
102                    env_var: env_var_display.to_string(),
103                });
104            }
105        }
106
107        Ok(())
108    }
109}
110
111impl Config {
112    pub fn load() -> Result<Self, io::Error> {
113        let config_path = config_path();
114
115        if !config_path.exists() {
116            return Ok(Config::default());
117        }
118
119        let config_content = fs::read_to_string(&config_path)?;
120
121        // Check for old format (detect by presence of 'api_key' at top level)
122        if config_content.contains("api_key") && !config_content.contains("[providers.") {
123            return Err(io::Error::new(
124                io::ErrorKind::InvalidData,
125                "Old config format detected. Please update to multi-provider format.\n\nNew format:\nprovider = \"anthropic\"\n\n[providers.anthropic]\nmodel = \"claude-3-5-sonnet-20241022\"\napi_key = \"...\"  # optional, falls back to ANTHROPIC_API_KEY env var"
126            ));
127        }
128
129        let config: Config = toml::from_str(&config_content)
130            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
131
132        config
133            .validate()
134            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
135
136        Ok(config)
137    }
138}
139
140impl Default for Config {
141    fn default() -> Self {
142        let mut providers = HashMap::new();
143        providers.insert(
144            "anthropic".to_string(),
145            ProviderConfig {
146                api_key: None,
147                model: default_model(),
148                base_url: None,
149                max_tokens: default_max_tokens(),
150                timeout: default_timeout(),
151                max_iterations: default_max_iterations(),
152                thinking_enabled: false,
153                clear_thinking: true,
154            },
155        );
156        Config {
157            provider: "anthropic".to_string(),
158            providers,
159        }
160    }
161}
162
163fn config_path() -> std::path::PathBuf {
164    let home_dir = dirs::home_dir().expect("Failed to get home directory");
165    home_dir.join(".limit").join("config.toml")
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_load_from_actual_config() {
174        // Load config from actual path (tests loading with existing file)
175        let config = Config::load().unwrap();
176
177        // Should have loaded the actual config (openai with z.ai endpoint)
178        assert_eq!(config.provider, "openai");
179        assert!(config.providers.contains_key("openai"));
180        let openai = config.providers.get("openai").unwrap();
181        assert_eq!(openai.model, "glm-4.7");
182        assert_eq!(
183            openai.api_key,
184            Some("fc56e203c1964d498f9e1efe7e817a26.3PDyp6TP0D0QSmhM".to_string())
185        );
186        assert_eq!(
187            openai.base_url,
188            Some("https://api.z.ai/api/coding/paas/v4/chat/completions".to_string())
189        );
190    }
191
192    #[test]
193    fn test_load_valid_config() {
194        let config_content = r#"
195provider = "anthropic"
196
197[providers.anthropic]
198api_key = "sk-ant-test123"
199model = "claude-3-5-sonnet-20241022"
200"#;
201
202        let config: Config = toml::from_str(config_content).unwrap();
203
204        assert_eq!(config.provider, "anthropic");
205        assert!(config.providers.contains_key("anthropic"));
206        let anthropic = config.providers.get("anthropic").unwrap();
207        assert_eq!(anthropic.api_key, Some("sk-ant-test123".to_string()));
208        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
209    }
210
211    #[test]
212    fn test_load_partial_config_uses_defaults() {
213        let config_content = r#"
214provider = "anthropic"
215
216[providers.anthropic]
217api_key = "sk-ant-partial"
218model = "custom-model"
219"#;
220
221        let config: Config = toml::from_str(config_content).unwrap();
222
223        assert_eq!(config.provider, "anthropic");
224        let anthropic = config.providers.get("anthropic").unwrap();
225        assert_eq!(anthropic.api_key, Some("sk-ant-partial".to_string()));
226        assert_eq!(anthropic.model, "custom-model");
227        assert!(anthropic.base_url.is_none()); // default
228    }
229
230    #[test]
231    fn test_load_config_with_base_url() {
232        let config_content = r#"
233provider = "openai"
234
235[providers.openai]
236api_key = "sk-test123"
237model = "gpt-4"
238base_url = "https://api.z.ai/api/paas/v4/chat/completions"
239"#;
240
241        let config: Config = toml::from_str(config_content).unwrap();
242
243        assert_eq!(config.provider, "openai");
244        let openai = config.providers.get("openai").unwrap();
245        assert_eq!(openai.api_key, Some("sk-test123".to_string()));
246        assert_eq!(openai.model, "gpt-4");
247        assert_eq!(
248            openai.base_url,
249            Some("https://api.z.ai/api/paas/v4/chat/completions".to_string())
250        );
251    }
252
253    #[test]
254    fn test_load_config_without_base_url() {
255        let config_content = r#"
256provider = "anthropic"
257
258[providers.anthropic]
259api_key = "sk-ant-test456"
260model = "claude-3-5-sonnet-20241022"
261"#;
262
263        let config: Config = toml::from_str(config_content).unwrap();
264
265        assert_eq!(config.provider, "anthropic");
266        let anthropic = config.providers.get("anthropic").unwrap();
267        assert_eq!(anthropic.api_key, Some("sk-ant-test456".to_string()));
268        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
269        assert!(anthropic.base_url.is_none()); // default
270    }
271
272    #[test]
273    fn test_default_config() {
274        let config = Config::default();
275
276        assert_eq!(config.provider, "anthropic");
277        assert!(config.providers.contains_key("anthropic"));
278        let anthropic = config.providers.get("anthropic").unwrap();
279        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
280        assert!(anthropic.api_key.is_none());
281        assert!(anthropic.base_url.is_none());
282    }
283
284    #[test]
285    fn test_old_format_detection() {
286        let config_content = r#"
287api_key = "sk-ant-test123"
288model = "claude-3-5-sonnet-20241022"
289"#;
290
291        let result: Result<Config, _> = toml::from_str(config_content);
292        assert!(result.is_err(), "Old format should fail to parse");
293    }
294
295    #[test]
296    fn test_api_key_or_env_from_config() {
297        let provider_config = ProviderConfig {
298            api_key: Some("sk-from-config".to_string()),
299            model: "claude-3-5-sonnet-20241022".to_string(),
300            base_url: None,
301            max_tokens: 4096,
302            timeout: 60,
303            max_iterations: 100,
304            thinking_enabled: false,
305            clear_thinking: true,
306        };
307
308        let key = provider_config.api_key_or_env("anthropic");
309        assert_eq!(key, Some("sk-from-config".to_string()));
310    }
311
312    #[test]
313    fn test_api_key_or_env_from_env() {
314        env::set_var("ANTHROPIC_API_KEY", "sk-from-env");
315        let provider_config = ProviderConfig {
316            api_key: None,
317            model: "claude-3-5-sonnet-20241022".to_string(),
318            base_url: None,
319            max_tokens: 4096,
320            timeout: 60,
321            max_iterations: 100,
322            thinking_enabled: false,
323            clear_thinking: true,
324        };
325
326        let key = provider_config.api_key_or_env("anthropic");
327        assert_eq!(key, Some("sk-from-env".to_string()));
328        env::remove_var("ANTHROPIC_API_KEY");
329    }
330
331    #[test]
332    fn test_openai_fallback_to_zai_api_key() {
333        env::set_var("ZAI_API_KEY", "sk-zai-key");
334        let provider_config = ProviderConfig {
335            api_key: None,
336            model: "gpt-4".to_string(),
337            base_url: None,
338            max_tokens: 4096,
339            timeout: 60,
340            max_iterations: 100,
341            thinking_enabled: false,
342            clear_thinking: true,
343        };
344
345        let key = provider_config.api_key_or_env("openai");
346        assert_eq!(key, Some("sk-zai-key".to_string()));
347        env::remove_var("ZAI_API_KEY");
348    }
349
350    #[test]
351    fn test_unknown_provider_no_env_var() {
352        let provider_config = ProviderConfig {
353            api_key: None,
354            model: "test-model".to_string(),
355            base_url: None,
356            max_tokens: 4096,
357            timeout: 60,
358            max_iterations: 100,
359            thinking_enabled: false,
360            clear_thinking: true,
361        };
362
363        let key = provider_config.api_key_or_env("unknown");
364        assert_eq!(key, None);
365    }
366
367    #[test]
368    fn test_zai_config_validation() {
369        env::set_var("ZAI_API_KEY", "test-zai-key");
370        let config_content = r#"
371provider = "zai"
372
373[providers.zai]
374model = "glm-4.7"
375"#;
376        let config: Config = toml::from_str(config_content).unwrap();
377        config.validate().unwrap();
378        env::remove_var("ZAI_API_KEY");
379    }
380    #[test]
381    fn test_zai_api_key_env_var() {
382        env::set_var("ZAI_API_KEY", "test-zai-key");
383        let provider_config = ProviderConfig {
384            api_key: None,
385            model: "glm-4.7".to_string(),
386            base_url: None,
387            max_tokens: 4096,
388            timeout: 60,
389            max_iterations: 100,
390            thinking_enabled: false,
391            clear_thinking: true,
392        };
393        let key = provider_config.api_key_or_env("zai");
394        assert_eq!(key, Some("test-zai-key".to_string()));
395        env::remove_var("ZAI_API_KEY");
396    }
397}