Skip to main content

agent_diva_core/config/
loader.rs

1//! Configuration loading and management
2
3use super::schema::Config;
4use super::validate::validate_config;
5use serde_json::{Map, Value};
6use std::path::{Path, PathBuf};
7
8/// Configuration loader
9#[derive(Clone)]
10pub struct ConfigLoader {
11    config_dir: PathBuf,
12    config_path: PathBuf,
13}
14
15impl ConfigLoader {
16    /// Create a new config loader with the default config directory
17    pub fn new() -> Self {
18        let config_dir = dirs::home_dir()
19            .map(|h| h.join(".agent-diva"))
20            .unwrap_or_else(|| PathBuf::from(".agent-diva"));
21
22        let config_path = config_dir.join("config.json");
23
24        Self {
25            config_dir,
26            config_path,
27        }
28    }
29
30    /// Create a new config loader with a custom config directory
31    pub fn with_dir<P: AsRef<Path>>(dir: P) -> Self {
32        let config_dir = dir.as_ref().to_path_buf();
33        Self {
34            config_path: config_dir.join("config.json"),
35            config_dir,
36        }
37    }
38
39    /// Create a new config loader with an explicit config file path
40    pub fn with_file<P: AsRef<Path>>(path: P) -> Self {
41        let config_path = path.as_ref().to_path_buf();
42        let config_dir = config_path
43            .parent()
44            .map(Path::to_path_buf)
45            .unwrap_or_else(|| PathBuf::from("."));
46
47        Self {
48            config_dir,
49            config_path,
50        }
51    }
52
53    /// Load configuration from file and environment
54    pub fn load(&self) -> crate::Result<Config> {
55        let mut merged = serde_json::to_value(Config::default())?;
56
57        if self.config_path.exists() {
58            let content = std::fs::read_to_string(&self.config_path)?;
59            let file_value: Value = serde_json::from_str(&content)?;
60            merge_values(&mut merged, file_value);
61        }
62
63        apply_alias_overrides(&mut merged);
64        apply_path_overrides(&mut merged);
65        normalize_alias_keys(&mut merged);
66
67        let config: Config = serde_json::from_value(merged)?;
68        validate_config(&config)?;
69        Ok(config)
70    }
71
72    /// Save configuration to file
73    pub fn save(&self, config: &Config) -> crate::Result<()> {
74        std::fs::create_dir_all(&self.config_dir)?;
75        let content = serde_json::to_string_pretty(config)?;
76        std::fs::write(&self.config_path, content)?;
77        Ok(())
78    }
79
80    /// Get the config directory path
81    pub fn config_dir(&self) -> &Path {
82        &self.config_dir
83    }
84
85    /// Get the config file path
86    pub fn config_path(&self) -> &Path {
87        &self.config_path
88    }
89}
90
91impl Default for ConfigLoader {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97fn merge_values(base: &mut Value, overlay: Value) {
98    match (base, overlay) {
99        (Value::Object(base_map), Value::Object(overlay_map)) => {
100            for (key, value) in overlay_map {
101                if let Some(existing) = base_map.get_mut(&key) {
102                    merge_values(existing, value);
103                } else {
104                    base_map.insert(key, value);
105                }
106            }
107        }
108        (base_value, overlay_value) => {
109            *base_value = overlay_value;
110        }
111    }
112}
113
114fn parse_env_value(raw: &str) -> Value {
115    if let Ok(v) = serde_json::from_str::<Value>(raw) {
116        return v;
117    }
118    if raw.eq_ignore_ascii_case("true") {
119        return Value::Bool(true);
120    }
121    if raw.eq_ignore_ascii_case("false") {
122        return Value::Bool(false);
123    }
124    if let Ok(v) = raw.parse::<i64>() {
125        return Value::Number(v.into());
126    }
127    if let Ok(v) = raw.parse::<f64>() {
128        if let Some(n) = serde_json::Number::from_f64(v) {
129            return Value::Number(n);
130        }
131    }
132    Value::String(raw.to_string())
133}
134
135fn set_path_value(root: &mut Value, path: &[String], value: Value) {
136    if path.is_empty() {
137        *root = value;
138        return;
139    }
140
141    let mut current = root;
142    for segment in &path[..path.len() - 1] {
143        if !current.is_object() {
144            *current = Value::Object(Map::new());
145        }
146        let map = current.as_object_mut().expect("object ensured");
147        current = map
148            .entry(segment.clone())
149            .or_insert_with(|| Value::Object(Map::new()));
150    }
151
152    if !current.is_object() {
153        *current = Value::Object(Map::new());
154    }
155    if let Some(map) = current.as_object_mut() {
156        map.insert(path[path.len() - 1].clone(), value);
157    }
158}
159
160fn apply_alias_overrides(config: &mut Value) {
161    let aliases = [
162        ("ANTHROPIC_API_KEY", "providers.anthropic.api_key"),
163        ("OPENAI_API_KEY", "providers.openai.api_key"),
164        ("OPENROUTER_API_KEY", "providers.openrouter.api_key"),
165        ("DEEPSEEK_API_KEY", "providers.deepseek.api_key"),
166        ("GROQ_API_KEY", "providers.groq.api_key"),
167        ("GEMINI_API_KEY", "providers.gemini.api_key"),
168        ("DASHSCOPE_API_KEY", "providers.dashscope.api_key"),
169        ("MOONSHOT_API_KEY", "providers.moonshot.api_key"),
170        ("MINIMAX_API_KEY", "providers.minimax.api_key"),
171        ("HOSTED_VLLM_API_KEY", "providers.vllm.api_key"),
172        ("AIHUBMIX_API_KEY", "providers.aihubmix.api_key"),
173        ("ZAI_API_KEY", "providers.zhipu.api_key"),
174        ("ZHIPUAI_API_KEY", "providers.zhipu.api_key"),
175    ];
176
177    for (env_key, target_path) in aliases {
178        if let Ok(value) = std::env::var(env_key) {
179            let path: Vec<String> = target_path.split('.').map(ToString::to_string).collect();
180            set_path_value(config, &path, Value::String(value));
181        }
182    }
183}
184
185fn apply_path_overrides(config: &mut Value) {
186    const PREFIX: &str = "AGENT_DIVA__";
187    for (key, value) in std::env::vars() {
188        if !key.starts_with(PREFIX) {
189            continue;
190        }
191        let suffix = &key[PREFIX.len()..];
192        if suffix.is_empty() {
193            continue;
194        }
195        let segments: Vec<String> = suffix
196            .split("__")
197            .filter(|s| !s.is_empty())
198            .map(|s| s.to_ascii_lowercase())
199            .collect();
200        if segments.is_empty() {
201            continue;
202        }
203        set_path_value(config, &segments, parse_env_value(&value));
204    }
205}
206
207fn object_at_path_mut<'a>(
208    root: &'a mut Value,
209    path: &[&str],
210) -> Option<&'a mut Map<String, Value>> {
211    let mut current = root;
212    for segment in path {
213        current = current.get_mut(*segment)?;
214    }
215    current.as_object_mut()
216}
217
218fn coalesce_alias_keys(
219    root: &mut Value,
220    object_path: &[&str],
221    canonical_key: &str,
222    alias_keys: &[&str],
223) {
224    let Some(map) = object_at_path_mut(root, object_path) else {
225        return;
226    };
227
228    let mut merged_value = map.remove(canonical_key);
229    for alias_key in alias_keys {
230        if let Some(alias_value) = map.remove(*alias_key) {
231            // Alias keys represent explicit user input and should override defaults.
232            merged_value = Some(alias_value);
233        }
234    }
235
236    if let Some(value) = merged_value {
237        map.insert(canonical_key.to_string(), value);
238    }
239}
240
241fn normalize_alias_keys(config: &mut Value) {
242    coalesce_alias_keys(
243        config,
244        &["channels"],
245        "neuro-link",
246        &["neuro_link", "generic_pipe"],
247    );
248    coalesce_alias_keys(config, &["tools"], "mcpServers", &["mcp_servers"]);
249    coalesce_alias_keys(config, &["tools"], "mcpManager", &["mcp_manager"]);
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use once_cell::sync::Lazy;
256    use std::sync::{Mutex, MutexGuard};
257    use tempfile::TempDir;
258
259    static ENV_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
260
261    struct EnvVarGuard {
262        key: String,
263        original: Option<String>,
264    }
265
266    impl EnvVarGuard {
267        fn set(key: &str, value: &str) -> Self {
268            let original = std::env::var(key).ok();
269            // SAFETY: tests serialize env mutations with ENV_LOCK.
270            unsafe { std::env::set_var(key, value) };
271            Self {
272                key: key.to_string(),
273                original,
274            }
275        }
276    }
277
278    impl Drop for EnvVarGuard {
279        fn drop(&mut self) {
280            if let Some(value) = &self.original {
281                // SAFETY: tests serialize env mutations with ENV_LOCK.
282                unsafe { std::env::set_var(&self.key, value) };
283            } else {
284                // SAFETY: tests serialize env mutations with ENV_LOCK.
285                unsafe { std::env::remove_var(&self.key) };
286            }
287        }
288    }
289
290    fn lock_env() -> MutexGuard<'static, ()> {
291        ENV_LOCK
292            .lock()
293            .unwrap_or_else(|poisoned| poisoned.into_inner())
294    }
295
296    #[test]
297    fn test_load_default_config() {
298        let _lock = lock_env();
299        let temp_dir = TempDir::new().unwrap();
300        let loader = ConfigLoader::with_dir(temp_dir.path());
301        let config = loader.load().unwrap();
302
303        assert_eq!(config.agents.defaults.provider.as_deref(), Some("deepseek"));
304        assert_eq!(config.agents.defaults.model, "deepseek-chat");
305        assert_eq!(config.agents.defaults.max_tokens, 8192);
306    }
307
308    #[test]
309    fn test_save_and_load_config() {
310        let _lock = lock_env();
311        let temp_dir = TempDir::new().unwrap();
312        let loader = ConfigLoader::with_dir(temp_dir.path());
313
314        let mut config = Config::default();
315        config.agents.defaults.model = "test-model".to_string();
316
317        loader.save(&config).unwrap();
318        let loaded = loader.load().unwrap();
319
320        assert_eq!(loaded.agents.defaults.model, "test-model");
321    }
322
323    #[test]
324    fn test_load_applies_alias_env_overrides() {
325        let _lock = lock_env();
326        let _api_key_guard = EnvVarGuard::set("OPENAI_API_KEY", "sk-openai-from-env");
327        let _minimax_guard = EnvVarGuard::set("MINIMAX_API_KEY", "mini-key");
328
329        let temp_dir = TempDir::new().unwrap();
330        let loader = ConfigLoader::with_dir(temp_dir.path());
331        let config = loader.load().unwrap();
332
333        assert_eq!(config.providers.openai.api_key, "sk-openai-from-env");
334        assert_eq!(config.providers.minimax.api_key, "mini-key");
335    }
336
337    #[test]
338    fn test_load_applies_path_env_overrides() {
339        let _lock = lock_env();
340        let _model_guard = EnvVarGuard::set("AGENT_DIVA__AGENTS__DEFAULTS__MODEL", "openai/gpt-4o");
341        let _temp_guard = EnvVarGuard::set("AGENT_DIVA__AGENTS__DEFAULTS__TEMPERATURE", "0.9");
342        let _iter_guard =
343            EnvVarGuard::set("AGENT_DIVA__AGENTS__DEFAULTS__MAX_TOOL_ITERATIONS", "42");
344        let _enabled_guard = EnvVarGuard::set("AGENT_DIVA__CHANNELS__TELEGRAM__ENABLED", "true");
345        let _token_guard = EnvVarGuard::set("AGENT_DIVA__CHANNELS__TELEGRAM__TOKEN", "tg-token");
346
347        let temp_dir = TempDir::new().unwrap();
348        let loader = ConfigLoader::with_dir(temp_dir.path());
349        let config = loader.load().unwrap();
350
351        assert_eq!(config.agents.defaults.model, "openai/gpt-4o");
352        assert!((config.agents.defaults.temperature - 0.9).abs() < f32::EPSILON);
353        assert_eq!(config.agents.defaults.max_tool_iterations, 42);
354        assert!(config.channels.telegram.enabled);
355        assert_eq!(config.channels.telegram.token, "tg-token");
356    }
357
358    #[test]
359    fn test_path_env_overrides_alias_and_file() {
360        let _lock = lock_env();
361        let _alias_guard = EnvVarGuard::set("OPENAI_API_KEY", "sk-openai-alias");
362        let _path_guard = EnvVarGuard::set(
363            "AGENT_DIVA__PROVIDERS__OPENAI__API_KEY",
364            "sk-openai-path-override",
365        );
366
367        let temp_dir = TempDir::new().unwrap();
368        let loader = ConfigLoader::with_dir(temp_dir.path());
369
370        let config_path = temp_dir.path().join("config.json");
371        std::fs::write(
372            &config_path,
373            r#"{"providers":{"openai":{"api_key":"sk-openai-file"}}}"#,
374        )
375        .unwrap();
376
377        let config = loader.load().unwrap();
378        assert_eq!(config.providers.openai.api_key, "sk-openai-path-override");
379    }
380
381    #[test]
382    fn test_validation_rejects_invalid_temperature() {
383        let _lock = lock_env();
384        let _temp_guard = EnvVarGuard::set("AGENT_DIVA__AGENTS__DEFAULTS__TEMPERATURE", "2.5");
385
386        let temp_dir = TempDir::new().unwrap();
387        let loader = ConfigLoader::with_dir(temp_dir.path());
388        let err = loader.load().unwrap_err();
389        assert!(err.to_string().contains("temperature"));
390    }
391
392    #[test]
393    fn test_load_supports_mcp_servers_camel_case() {
394        let _lock = lock_env();
395        let temp_dir = TempDir::new().unwrap();
396        let loader = ConfigLoader::with_dir(temp_dir.path());
397
398        let config_path = temp_dir.path().join("config.json");
399        std::fs::write(
400            &config_path,
401            r#"{
402  "tools": {
403    "mcpServers": {
404      "filesystem": {
405        "command": "npx",
406        "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]
407      }
408    }
409  }
410}"#,
411        )
412        .unwrap();
413
414        let config = loader.load().unwrap();
415        let server = config.tools.mcp_servers.get("filesystem").unwrap();
416        assert_eq!(server.command, "npx");
417        assert_eq!(server.args.len(), 3);
418    }
419
420    #[test]
421    fn test_load_supports_generic_pipe_alias_without_duplicate_field_error() {
422        let _lock = lock_env();
423        let temp_dir = TempDir::new().unwrap();
424        let loader = ConfigLoader::with_dir(temp_dir.path());
425
426        let config_path = temp_dir.path().join("config.json");
427        std::fs::write(
428            &config_path,
429            r#"{
430  "channels": {
431    "generic_pipe": {
432      "enabled": true,
433      "host": "127.0.0.1",
434      "port": 9200
435    }
436  }
437}"#,
438        )
439        .unwrap();
440
441        let config = loader.load().unwrap();
442        assert!(config.channels.neuro_link.enabled);
443        assert_eq!(config.channels.neuro_link.host, "127.0.0.1");
444        assert_eq!(config.channels.neuro_link.port, 9200);
445    }
446
447    #[test]
448    fn test_load_supports_mcp_servers_snake_case_alias() {
449        let _lock = lock_env();
450        let temp_dir = TempDir::new().unwrap();
451        let loader = ConfigLoader::with_dir(temp_dir.path());
452
453        let config_path = temp_dir.path().join("config.json");
454        std::fs::write(
455            &config_path,
456            r#"{
457  "tools": {
458    "mcp_servers": {
459      "filesystem": {
460        "command": "uvx",
461        "args": ["mcp-server-filesystem", "."]
462      }
463    }
464  }
465}"#,
466        )
467        .unwrap();
468
469        let config = loader.load().unwrap();
470        let server = config.tools.mcp_servers.get("filesystem").unwrap();
471        assert_eq!(server.command, "uvx");
472        assert_eq!(server.args.len(), 2);
473    }
474
475    #[test]
476    fn test_with_file_uses_parent_as_config_dir() {
477        let _lock = lock_env();
478        let temp_dir = TempDir::new().unwrap();
479        let config_path = temp_dir.path().join("instances").join("alpha.json");
480        let loader = ConfigLoader::with_file(&config_path);
481
482        assert_eq!(loader.config_path(), config_path.as_path());
483        assert_eq!(loader.config_dir(), config_path.parent().unwrap());
484    }
485}