Skip to main content

limit_llm/
config.rs

1use crate::error::ConfigError;
2use serde::Deserialize;
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::{env, fs, io};
6
7#[derive(Debug, Deserialize, PartialEq, Clone)]
8pub struct Config {
9    pub provider: String,
10    pub providers: HashMap<String, ProviderConfig>,
11    #[serde(default)]
12    pub browser: BrowserConfigSection,
13}
14
15/// Browser configuration section in config.toml
16#[derive(Debug, Deserialize, PartialEq, Clone)]
17pub struct BrowserConfigSection {
18    /// Enable browser tool (default: false)
19    #[serde(default)]
20    pub enabled: bool,
21    /// Path to agent-browser binary (default: "agent-browser" in PATH)
22    #[serde(default)]
23    pub binary_path: Option<PathBuf>,
24    /// Browser engine to use (default: "chrome")
25    #[serde(default = "default_browser_engine")]
26    pub engine: String,
27    /// Run in headless mode (default: true)
28    #[serde(default = "default_true")]
29    pub headless: bool,
30    /// Timeout for operations in milliseconds (default: 30000)
31    #[serde(default = "default_browser_timeout")]
32    pub timeout_ms: u64,
33}
34
35impl Default for BrowserConfigSection {
36    fn default() -> Self {
37        Self {
38            enabled: false,
39            binary_path: None,
40            engine: default_browser_engine(),
41            headless: default_true(),
42            timeout_ms: default_browser_timeout(),
43        }
44    }
45}
46
47fn default_browser_engine() -> String {
48    "chrome".to_string()
49}
50
51fn default_true() -> bool {
52    true
53}
54
55fn default_browser_timeout() -> u64 {
56    30_000
57}
58
59#[derive(Debug, Deserialize, PartialEq, Clone)]
60pub struct ProviderConfig {
61    pub api_key: Option<String>,
62    pub model: String,
63    #[serde(default)]
64    pub base_url: Option<String>,
65    #[serde(default = "default_max_tokens")]
66    pub max_tokens: u32,
67    #[serde(default = "default_timeout")]
68    pub timeout: u64,
69    /// Maximum iterations for agent loop (0 = unlimited, default: 100)
70    #[serde(default = "default_max_iterations")]
71    pub max_iterations: usize,
72    /// Enable thinking/reasoning mode (Z.AI only, default: false)
73    #[serde(default)]
74    pub thinking_enabled: bool,
75    /// Preserve thinking between turns (Z.AI only, default: true)
76    /// Set to false for Preserved Thinking in multi-turn conversations
77    #[serde(default = "default_clear_thinking")]
78    pub clear_thinking: bool,
79}
80
81fn default_model() -> String {
82    "claude-3-5-sonnet-20241022".to_string()
83}
84
85fn default_max_tokens() -> u32 {
86    4096
87}
88
89fn default_timeout() -> u64 {
90    60
91}
92
93fn default_max_iterations() -> usize {
94    100
95}
96
97fn default_clear_thinking() -> bool {
98    true
99}
100
101impl ProviderConfig {
102    pub fn api_key_or_env(&self, provider: &str) -> Option<String> {
103        if let Some(key) = &self.api_key {
104            return Some(key.clone());
105        }
106
107        match provider {
108            "anthropic" => env::var("ANTHROPIC_API_KEY").ok(),
109            "openai" => env::var("OPENAI_API_KEY")
110                .ok()
111                .or_else(|| env::var("ZAI_API_KEY").ok()),
112            "zai" => env::var("ZAI_API_KEY").ok(),
113            // Local providers don't require API key, return placeholder
114            "local" | "ollama" | "lmstudio" | "vllm" => Some("local".to_string()),
115            _ => None,
116        }
117    }
118}
119
120impl Config {
121    pub fn validate(&self) -> Result<(), ConfigError> {
122        // Valid provider names (includes local LLM aliases)
123        let valid_providers = [
124            "anthropic",
125            "openai",
126            "zai",
127            "local",
128            "ollama",
129            "lmstudio",
130            "vllm",
131        ];
132
133        // Check provider field is valid
134        if !valid_providers.contains(&self.provider.as_str()) {
135            return Err(ConfigError::InvalidProvider(self.provider.clone()));
136        }
137
138        // Check provider config exists
139        if !self.providers.contains_key(&self.provider) {
140            return Err(ConfigError::MissingProvider(self.provider.clone()));
141        }
142
143        // Local providers don't require API key
144        let local_providers = ["local", "ollama", "lmstudio", "vllm"];
145        if local_providers.contains(&self.provider.as_str()) {
146            return Ok(());
147        }
148
149        // Check active provider has required fields
150        let provider_config = self.providers.get(&self.provider).unwrap();
151        if provider_config.api_key.is_none() {
152            // Check if env var exists
153            let env_var = match self.provider.as_str() {
154                "anthropic" => "ANTHROPIC_API_KEY",
155                "openai" => "OPENAI_API_KEY",
156                "zai" => "ZAI_API_KEY",
157                _ => "API_KEY",
158            };
159            if env::var(env_var).is_err() {
160                // For openai, also check ZAI_API_KEY as fallback
161                let env_var_display = if self.provider == "openai" {
162                    "OPENAI_API_KEY or ZAI_API_KEY"
163                } else {
164                    env_var
165                };
166                return Err(ConfigError::MissingApiKey {
167                    provider: self.provider.clone(),
168                    env_var: env_var_display.to_string(),
169                });
170            }
171        }
172
173        Ok(())
174    }
175}
176
177impl Config {
178    pub fn load() -> Result<Self, io::Error> {
179        let config_path = config_path();
180
181        if !config_path.exists() {
182            return Ok(Config::default());
183        }
184
185        let config_content = fs::read_to_string(&config_path)?;
186
187        // Check for old format (detect by presence of 'api_key' at top level)
188        if config_content.contains("api_key") && !config_content.contains("[providers.") {
189            return Err(io::Error::new(
190                io::ErrorKind::InvalidData,
191                "Old config format detected. Please update to multi-provider format.\n\nNew format:\nprovider = \"anthropic\"\n\n[providers.anthropic]\nmodel = \"claude-3-5-sonnet-20241022\"\napi_key = \"...\"  # optional, falls back to ANTHROPIC_API_KEY env var"
192            ));
193        }
194
195        let config: Config = toml::from_str(&config_content)
196            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
197
198        config
199            .validate()
200            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
201
202        Ok(config)
203    }
204}
205
206impl Default for Config {
207    fn default() -> Self {
208        let mut providers = HashMap::new();
209        providers.insert(
210            "anthropic".to_string(),
211            ProviderConfig {
212                api_key: None,
213                model: default_model(),
214                base_url: None,
215                max_tokens: default_max_tokens(),
216                timeout: default_timeout(),
217                max_iterations: default_max_iterations(),
218                thinking_enabled: false,
219                clear_thinking: true,
220            },
221        );
222        Config {
223            provider: "anthropic".to_string(),
224            providers,
225            browser: BrowserConfigSection::default(),
226        }
227    }
228}
229
230fn config_path() -> std::path::PathBuf {
231    let home_dir = dirs::home_dir().expect("Failed to get home directory");
232    home_dir.join(".limit").join("config.toml")
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_load_from_actual_config() {
241        // Load config from actual path (tests loading with existing file)
242        let config = Config::load().unwrap();
243
244        // Should have loaded a valid config (provider and config entry exist)
245        assert!(!config.provider.is_empty());
246        assert!(config.providers.contains_key(&config.provider));
247    }
248
249    #[test]
250    fn test_load_valid_config() {
251        let config_content = r#"
252provider = "anthropic"
253
254[providers.anthropic]
255api_key = "sk-ant-test123"
256model = "claude-3-5-sonnet-20241022"
257"#;
258
259        let config: Config = toml::from_str(config_content).unwrap();
260
261        assert_eq!(config.provider, "anthropic");
262        assert!(config.providers.contains_key("anthropic"));
263        let anthropic = config.providers.get("anthropic").unwrap();
264        assert_eq!(anthropic.api_key, Some("sk-ant-test123".to_string()));
265        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
266    }
267
268    #[test]
269    fn test_load_partial_config_uses_defaults() {
270        let config_content = r#"
271provider = "anthropic"
272
273[providers.anthropic]
274api_key = "sk-ant-partial"
275model = "custom-model"
276"#;
277
278        let config: Config = toml::from_str(config_content).unwrap();
279
280        assert_eq!(config.provider, "anthropic");
281        let anthropic = config.providers.get("anthropic").unwrap();
282        assert_eq!(anthropic.api_key, Some("sk-ant-partial".to_string()));
283        assert_eq!(anthropic.model, "custom-model");
284        assert!(anthropic.base_url.is_none()); // default
285    }
286
287    #[test]
288    fn test_load_config_with_base_url() {
289        let config_content = r#"
290provider = "openai"
291
292[providers.openai]
293api_key = "sk-test123"
294model = "gpt-4"
295base_url = "https://api.z.ai/api/paas/v4/chat/completions"
296"#;
297
298        let config: Config = toml::from_str(config_content).unwrap();
299
300        assert_eq!(config.provider, "openai");
301        let openai = config.providers.get("openai").unwrap();
302        assert_eq!(openai.api_key, Some("sk-test123".to_string()));
303        assert_eq!(openai.model, "gpt-4");
304        assert_eq!(
305            openai.base_url,
306            Some("https://api.z.ai/api/paas/v4/chat/completions".to_string())
307        );
308    }
309
310    #[test]
311    fn test_load_config_without_base_url() {
312        let config_content = r#"
313provider = "anthropic"
314
315[providers.anthropic]
316api_key = "sk-ant-test456"
317model = "claude-3-5-sonnet-20241022"
318"#;
319
320        let config: Config = toml::from_str(config_content).unwrap();
321
322        assert_eq!(config.provider, "anthropic");
323        let anthropic = config.providers.get("anthropic").unwrap();
324        assert_eq!(anthropic.api_key, Some("sk-ant-test456".to_string()));
325        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
326        assert!(anthropic.base_url.is_none()); // default
327    }
328
329    #[test]
330    fn test_default_config() {
331        let config = Config::default();
332
333        assert_eq!(config.provider, "anthropic");
334        assert!(config.providers.contains_key("anthropic"));
335        let anthropic = config.providers.get("anthropic").unwrap();
336        assert_eq!(anthropic.model, "claude-3-5-sonnet-20241022");
337        assert!(anthropic.api_key.is_none());
338        assert!(anthropic.base_url.is_none());
339    }
340
341    #[test]
342    fn test_old_format_detection() {
343        let config_content = r#"
344api_key = "sk-ant-test123"
345model = "claude-3-5-sonnet-20241022"
346"#;
347
348        let result: Result<Config, _> = toml::from_str(config_content);
349        assert!(result.is_err(), "Old format should fail to parse");
350    }
351
352    #[test]
353    fn test_api_key_or_env_from_config() {
354        let provider_config = ProviderConfig {
355            api_key: Some("sk-from-config".to_string()),
356            model: "claude-3-5-sonnet-20241022".to_string(),
357            base_url: None,
358            max_tokens: 4096,
359            timeout: 60,
360            max_iterations: 100,
361            thinking_enabled: false,
362            clear_thinking: true,
363        };
364
365        let key = provider_config.api_key_or_env("anthropic");
366        assert_eq!(key, Some("sk-from-config".to_string()));
367    }
368
369    #[test]
370    fn test_api_key_or_env_from_env() {
371        env::set_var("ANTHROPIC_API_KEY", "sk-from-env");
372        let provider_config = ProviderConfig {
373            api_key: None,
374            model: "claude-3-5-sonnet-20241022".to_string(),
375            base_url: None,
376            max_tokens: 4096,
377            timeout: 60,
378            max_iterations: 100,
379            thinking_enabled: false,
380            clear_thinking: true,
381        };
382
383        let key = provider_config.api_key_or_env("anthropic");
384        assert_eq!(key, Some("sk-from-env".to_string()));
385        env::remove_var("ANTHROPIC_API_KEY");
386    }
387
388    #[test]
389    fn test_openai_fallback_to_zai_api_key() {
390        env::set_var("ZAI_API_KEY", "sk-zai-key");
391        let provider_config = ProviderConfig {
392            api_key: None,
393            model: "gpt-4".to_string(),
394            base_url: None,
395            max_tokens: 4096,
396            timeout: 60,
397            max_iterations: 100,
398            thinking_enabled: false,
399            clear_thinking: true,
400        };
401
402        let key = provider_config.api_key_or_env("openai");
403        assert_eq!(key, Some("sk-zai-key".to_string()));
404        env::remove_var("ZAI_API_KEY");
405    }
406
407    #[test]
408    fn test_unknown_provider_no_env_var() {
409        let provider_config = ProviderConfig {
410            api_key: None,
411            model: "test-model".to_string(),
412            base_url: None,
413            max_tokens: 4096,
414            timeout: 60,
415            max_iterations: 100,
416            thinking_enabled: false,
417            clear_thinking: true,
418        };
419
420        let key = provider_config.api_key_or_env("unknown");
421        assert_eq!(key, None);
422    }
423
424    #[test]
425    fn test_zai_config_validation() {
426        env::set_var("ZAI_API_KEY", "test-zai-key");
427        let config_content = r#"
428provider = "zai"
429
430[providers.zai]
431model = "glm-4.7"
432"#;
433        let config: Config = toml::from_str(config_content).unwrap();
434        config.validate().unwrap();
435        env::remove_var("ZAI_API_KEY");
436    }
437    #[test]
438    fn test_zai_api_key_env_var() {
439        env::set_var("ZAI_API_KEY", "test-zai-key");
440        let provider_config = ProviderConfig {
441            api_key: None,
442            model: "glm-4.7".to_string(),
443            base_url: None,
444            max_tokens: 4096,
445            timeout: 60,
446            max_iterations: 100,
447            thinking_enabled: false,
448            clear_thinking: true,
449        };
450        let key = provider_config.api_key_or_env("zai");
451        assert_eq!(key, Some("test-zai-key".to_string()));
452        env::remove_var("ZAI_API_KEY");
453    }
454}