1use super::schema::Config;
4use super::validate::validate_config;
5use serde_json::{Map, Value};
6use std::path::{Path, PathBuf};
7
8#[derive(Clone)]
10pub struct ConfigLoader {
11 config_dir: PathBuf,
12 config_path: PathBuf,
13}
14
15impl ConfigLoader {
16 pub fn new() -> Self {
18 let config_dir = dirs::home_dir()
19 .map(|h| h.join(".agent-diva"))
20 .unwrap_or_else(|| PathBuf::from(".agent-diva"));
21
22 let config_path = config_dir.join("config.json");
23
24 Self {
25 config_dir,
26 config_path,
27 }
28 }
29
30 pub fn with_dir<P: AsRef<Path>>(dir: P) -> Self {
32 let config_dir = dir.as_ref().to_path_buf();
33 Self {
34 config_path: config_dir.join("config.json"),
35 config_dir,
36 }
37 }
38
39 pub fn with_file<P: AsRef<Path>>(path: P) -> Self {
41 let config_path = path.as_ref().to_path_buf();
42 let config_dir = config_path
43 .parent()
44 .map(Path::to_path_buf)
45 .unwrap_or_else(|| PathBuf::from("."));
46
47 Self {
48 config_dir,
49 config_path,
50 }
51 }
52
53 pub fn load(&self) -> crate::Result<Config> {
55 let mut merged = serde_json::to_value(Config::default())?;
56
57 if self.config_path.exists() {
58 let content = std::fs::read_to_string(&self.config_path)?;
59 let file_value: Value = serde_json::from_str(&content)?;
60 merge_values(&mut merged, file_value);
61 }
62
63 apply_alias_overrides(&mut merged);
64 apply_path_overrides(&mut merged);
65 normalize_alias_keys(&mut merged);
66
67 let config: Config = serde_json::from_value(merged)?;
68 validate_config(&config)?;
69 Ok(config)
70 }
71
72 pub fn save(&self, config: &Config) -> crate::Result<()> {
74 std::fs::create_dir_all(&self.config_dir)?;
75 let content = serde_json::to_string_pretty(config)?;
76 std::fs::write(&self.config_path, content)?;
77 Ok(())
78 }
79
80 pub fn config_dir(&self) -> &Path {
82 &self.config_dir
83 }
84
85 pub fn config_path(&self) -> &Path {
87 &self.config_path
88 }
89}
90
91impl Default for ConfigLoader {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97fn merge_values(base: &mut Value, overlay: Value) {
98 match (base, overlay) {
99 (Value::Object(base_map), Value::Object(overlay_map)) => {
100 for (key, value) in overlay_map {
101 if let Some(existing) = base_map.get_mut(&key) {
102 merge_values(existing, value);
103 } else {
104 base_map.insert(key, value);
105 }
106 }
107 }
108 (base_value, overlay_value) => {
109 *base_value = overlay_value;
110 }
111 }
112}
113
114fn parse_env_value(raw: &str) -> Value {
115 if let Ok(v) = serde_json::from_str::<Value>(raw) {
116 return v;
117 }
118 if raw.eq_ignore_ascii_case("true") {
119 return Value::Bool(true);
120 }
121 if raw.eq_ignore_ascii_case("false") {
122 return Value::Bool(false);
123 }
124 if let Ok(v) = raw.parse::<i64>() {
125 return Value::Number(v.into());
126 }
127 if let Ok(v) = raw.parse::<f64>() {
128 if let Some(n) = serde_json::Number::from_f64(v) {
129 return Value::Number(n);
130 }
131 }
132 Value::String(raw.to_string())
133}
134
135fn set_path_value(root: &mut Value, path: &[String], value: Value) {
136 if path.is_empty() {
137 *root = value;
138 return;
139 }
140
141 let mut current = root;
142 for segment in &path[..path.len() - 1] {
143 if !current.is_object() {
144 *current = Value::Object(Map::new());
145 }
146 let map = current.as_object_mut().expect("object ensured");
147 current = map
148 .entry(segment.clone())
149 .or_insert_with(|| Value::Object(Map::new()));
150 }
151
152 if !current.is_object() {
153 *current = Value::Object(Map::new());
154 }
155 if let Some(map) = current.as_object_mut() {
156 map.insert(path[path.len() - 1].clone(), value);
157 }
158}
159
160fn apply_alias_overrides(config: &mut Value) {
161 let aliases = [
162 ("ANTHROPIC_API_KEY", "providers.anthropic.api_key"),
163 ("OPENAI_API_KEY", "providers.openai.api_key"),
164 ("OPENROUTER_API_KEY", "providers.openrouter.api_key"),
165 ("DEEPSEEK_API_KEY", "providers.deepseek.api_key"),
166 ("GROQ_API_KEY", "providers.groq.api_key"),
167 ("GEMINI_API_KEY", "providers.gemini.api_key"),
168 ("DASHSCOPE_API_KEY", "providers.dashscope.api_key"),
169 ("MOONSHOT_API_KEY", "providers.moonshot.api_key"),
170 ("MINIMAX_API_KEY", "providers.minimax.api_key"),
171 ("HOSTED_VLLM_API_KEY", "providers.vllm.api_key"),
172 ("AIHUBMIX_API_KEY", "providers.aihubmix.api_key"),
173 ("ZAI_API_KEY", "providers.zhipu.api_key"),
174 ("ZHIPUAI_API_KEY", "providers.zhipu.api_key"),
175 ];
176
177 for (env_key, target_path) in aliases {
178 if let Ok(value) = std::env::var(env_key) {
179 let path: Vec<String> = target_path.split('.').map(ToString::to_string).collect();
180 set_path_value(config, &path, Value::String(value));
181 }
182 }
183}
184
185fn apply_path_overrides(config: &mut Value) {
186 const PREFIX: &str = "AGENT_DIVA__";
187 for (key, value) in std::env::vars() {
188 if !key.starts_with(PREFIX) {
189 continue;
190 }
191 let suffix = &key[PREFIX.len()..];
192 if suffix.is_empty() {
193 continue;
194 }
195 let segments: Vec<String> = suffix
196 .split("__")
197 .filter(|s| !s.is_empty())
198 .map(|s| s.to_ascii_lowercase())
199 .collect();
200 if segments.is_empty() {
201 continue;
202 }
203 set_path_value(config, &segments, parse_env_value(&value));
204 }
205}
206
207fn object_at_path_mut<'a>(
208 root: &'a mut Value,
209 path: &[&str],
210) -> Option<&'a mut Map<String, Value>> {
211 let mut current = root;
212 for segment in path {
213 current = current.get_mut(*segment)?;
214 }
215 current.as_object_mut()
216}
217
218fn coalesce_alias_keys(
219 root: &mut Value,
220 object_path: &[&str],
221 canonical_key: &str,
222 alias_keys: &[&str],
223) {
224 let Some(map) = object_at_path_mut(root, object_path) else {
225 return;
226 };
227
228 let mut merged_value = map.remove(canonical_key);
229 for alias_key in alias_keys {
230 if let Some(alias_value) = map.remove(*alias_key) {
231 merged_value = Some(alias_value);
233 }
234 }
235
236 if let Some(value) = merged_value {
237 map.insert(canonical_key.to_string(), value);
238 }
239}
240
241fn normalize_alias_keys(config: &mut Value) {
242 coalesce_alias_keys(
243 config,
244 &["channels"],
245 "neuro-link",
246 &["neuro_link", "generic_pipe"],
247 );
248 coalesce_alias_keys(config, &["tools"], "mcpServers", &["mcp_servers"]);
249 coalesce_alias_keys(config, &["tools"], "mcpManager", &["mcp_manager"]);
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use once_cell::sync::Lazy;
256 use std::sync::{Mutex, MutexGuard};
257 use tempfile::TempDir;
258
259 static ENV_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
260
261 struct EnvVarGuard {
262 key: String,
263 original: Option<String>,
264 }
265
266 impl EnvVarGuard {
267 fn set(key: &str, value: &str) -> Self {
268 let original = std::env::var(key).ok();
269 unsafe { std::env::set_var(key, value) };
271 Self {
272 key: key.to_string(),
273 original,
274 }
275 }
276 }
277
278 impl Drop for EnvVarGuard {
279 fn drop(&mut self) {
280 if let Some(value) = &self.original {
281 unsafe { std::env::set_var(&self.key, value) };
283 } else {
284 unsafe { std::env::remove_var(&self.key) };
286 }
287 }
288 }
289
290 fn lock_env() -> MutexGuard<'static, ()> {
291 ENV_LOCK
292 .lock()
293 .unwrap_or_else(|poisoned| poisoned.into_inner())
294 }
295
296 #[test]
297 fn test_load_default_config() {
298 let _lock = lock_env();
299 let temp_dir = TempDir::new().unwrap();
300 let loader = ConfigLoader::with_dir(temp_dir.path());
301 let config = loader.load().unwrap();
302
303 assert_eq!(config.agents.defaults.provider.as_deref(), Some("deepseek"));
304 assert_eq!(config.agents.defaults.model, "deepseek-chat");
305 assert_eq!(config.agents.defaults.max_tokens, 8192);
306 }
307
308 #[test]
309 fn test_save_and_load_config() {
310 let _lock = lock_env();
311 let temp_dir = TempDir::new().unwrap();
312 let loader = ConfigLoader::with_dir(temp_dir.path());
313
314 let mut config = Config::default();
315 config.agents.defaults.model = "test-model".to_string();
316
317 loader.save(&config).unwrap();
318 let loaded = loader.load().unwrap();
319
320 assert_eq!(loaded.agents.defaults.model, "test-model");
321 }
322
323 #[test]
324 fn test_load_applies_alias_env_overrides() {
325 let _lock = lock_env();
326 let _api_key_guard = EnvVarGuard::set("OPENAI_API_KEY", "sk-openai-from-env");
327 let _minimax_guard = EnvVarGuard::set("MINIMAX_API_KEY", "mini-key");
328
329 let temp_dir = TempDir::new().unwrap();
330 let loader = ConfigLoader::with_dir(temp_dir.path());
331 let config = loader.load().unwrap();
332
333 assert_eq!(config.providers.openai.api_key, "sk-openai-from-env");
334 assert_eq!(config.providers.minimax.api_key, "mini-key");
335 }
336
337 #[test]
338 fn test_load_applies_path_env_overrides() {
339 let _lock = lock_env();
340 let _model_guard = EnvVarGuard::set("AGENT_DIVA__AGENTS__DEFAULTS__MODEL", "openai/gpt-4o");
341 let _temp_guard = EnvVarGuard::set("AGENT_DIVA__AGENTS__DEFAULTS__TEMPERATURE", "0.9");
342 let _iter_guard =
343 EnvVarGuard::set("AGENT_DIVA__AGENTS__DEFAULTS__MAX_TOOL_ITERATIONS", "42");
344 let _enabled_guard = EnvVarGuard::set("AGENT_DIVA__CHANNELS__TELEGRAM__ENABLED", "true");
345 let _token_guard = EnvVarGuard::set("AGENT_DIVA__CHANNELS__TELEGRAM__TOKEN", "tg-token");
346
347 let temp_dir = TempDir::new().unwrap();
348 let loader = ConfigLoader::with_dir(temp_dir.path());
349 let config = loader.load().unwrap();
350
351 assert_eq!(config.agents.defaults.model, "openai/gpt-4o");
352 assert!((config.agents.defaults.temperature - 0.9).abs() < f32::EPSILON);
353 assert_eq!(config.agents.defaults.max_tool_iterations, 42);
354 assert!(config.channels.telegram.enabled);
355 assert_eq!(config.channels.telegram.token, "tg-token");
356 }
357
358 #[test]
359 fn test_path_env_overrides_alias_and_file() {
360 let _lock = lock_env();
361 let _alias_guard = EnvVarGuard::set("OPENAI_API_KEY", "sk-openai-alias");
362 let _path_guard = EnvVarGuard::set(
363 "AGENT_DIVA__PROVIDERS__OPENAI__API_KEY",
364 "sk-openai-path-override",
365 );
366
367 let temp_dir = TempDir::new().unwrap();
368 let loader = ConfigLoader::with_dir(temp_dir.path());
369
370 let config_path = temp_dir.path().join("config.json");
371 std::fs::write(
372 &config_path,
373 r#"{"providers":{"openai":{"api_key":"sk-openai-file"}}}"#,
374 )
375 .unwrap();
376
377 let config = loader.load().unwrap();
378 assert_eq!(config.providers.openai.api_key, "sk-openai-path-override");
379 }
380
381 #[test]
382 fn test_validation_rejects_invalid_temperature() {
383 let _lock = lock_env();
384 let _temp_guard = EnvVarGuard::set("AGENT_DIVA__AGENTS__DEFAULTS__TEMPERATURE", "2.5");
385
386 let temp_dir = TempDir::new().unwrap();
387 let loader = ConfigLoader::with_dir(temp_dir.path());
388 let err = loader.load().unwrap_err();
389 assert!(err.to_string().contains("temperature"));
390 }
391
392 #[test]
393 fn test_load_supports_mcp_servers_camel_case() {
394 let _lock = lock_env();
395 let temp_dir = TempDir::new().unwrap();
396 let loader = ConfigLoader::with_dir(temp_dir.path());
397
398 let config_path = temp_dir.path().join("config.json");
399 std::fs::write(
400 &config_path,
401 r#"{
402 "tools": {
403 "mcpServers": {
404 "filesystem": {
405 "command": "npx",
406 "args": ["-y", "@modelcontextprotocol/server-filesystem", "."]
407 }
408 }
409 }
410}"#,
411 )
412 .unwrap();
413
414 let config = loader.load().unwrap();
415 let server = config.tools.mcp_servers.get("filesystem").unwrap();
416 assert_eq!(server.command, "npx");
417 assert_eq!(server.args.len(), 3);
418 }
419
420 #[test]
421 fn test_load_supports_generic_pipe_alias_without_duplicate_field_error() {
422 let _lock = lock_env();
423 let temp_dir = TempDir::new().unwrap();
424 let loader = ConfigLoader::with_dir(temp_dir.path());
425
426 let config_path = temp_dir.path().join("config.json");
427 std::fs::write(
428 &config_path,
429 r#"{
430 "channels": {
431 "generic_pipe": {
432 "enabled": true,
433 "host": "127.0.0.1",
434 "port": 9200
435 }
436 }
437}"#,
438 )
439 .unwrap();
440
441 let config = loader.load().unwrap();
442 assert!(config.channels.neuro_link.enabled);
443 assert_eq!(config.channels.neuro_link.host, "127.0.0.1");
444 assert_eq!(config.channels.neuro_link.port, 9200);
445 }
446
447 #[test]
448 fn test_load_supports_mcp_servers_snake_case_alias() {
449 let _lock = lock_env();
450 let temp_dir = TempDir::new().unwrap();
451 let loader = ConfigLoader::with_dir(temp_dir.path());
452
453 let config_path = temp_dir.path().join("config.json");
454 std::fs::write(
455 &config_path,
456 r#"{
457 "tools": {
458 "mcp_servers": {
459 "filesystem": {
460 "command": "uvx",
461 "args": ["mcp-server-filesystem", "."]
462 }
463 }
464 }
465}"#,
466 )
467 .unwrap();
468
469 let config = loader.load().unwrap();
470 let server = config.tools.mcp_servers.get("filesystem").unwrap();
471 assert_eq!(server.command, "uvx");
472 assert_eq!(server.args.len(), 2);
473 }
474
475 #[test]
476 fn test_with_file_uses_parent_as_config_dir() {
477 let _lock = lock_env();
478 let temp_dir = TempDir::new().unwrap();
479 let config_path = temp_dir.path().join("instances").join("alpha.json");
480 let loader = ConfigLoader::with_file(&config_path);
481
482 assert_eq!(loader.config_path(), config_path.as_path());
483 assert_eq!(loader.config_dir(), config_path.parent().unwrap());
484 }
485}