1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{LlmError, Result};
7use crate::retry::RetryConfig;
8use crate::types::Tool;
9
10fn default_chain_limit() -> usize {
15 10
16}
17
18fn default_parallel_tools() -> bool {
19 true
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AgentConfig {
25 #[serde(default)]
27 pub model: Option<String>,
28
29 #[serde(default)]
31 pub system_prompt: Option<String>,
32
33 #[serde(default)]
35 pub tools: Vec<String>,
36
37 #[serde(default = "default_chain_limit")]
39 pub chain_limit: usize,
40
41 #[serde(default)]
43 pub options: HashMap<String, serde_json::Value>,
44
45 #[serde(default)]
47 pub budget: Option<BudgetConfig>,
48
49 #[serde(default)]
51 pub retry: Option<RetryConfig>,
52
53 #[serde(default = "default_parallel_tools")]
56 pub parallel_tools: bool,
57
58 #[serde(default)]
60 pub max_parallel_tools: Option<usize>,
61}
62
63impl Default for AgentConfig {
64 fn default() -> Self {
65 Self {
66 model: None,
67 system_prompt: None,
68 tools: Vec::new(),
69 chain_limit: default_chain_limit(),
70 options: HashMap::new(),
71 budget: None,
72 retry: None,
73 parallel_tools: default_parallel_tools(),
74 max_parallel_tools: None,
75 }
76 }
77}
78
79impl AgentConfig {
80 pub fn load(path: &Path) -> Result<Self> {
83 let contents = std::fs::read_to_string(path).map_err(|e| {
84 if e.kind() == std::io::ErrorKind::NotFound {
85 LlmError::Config(format!("agent config not found: {}", path.display()))
86 } else {
87 LlmError::Io(e)
88 }
89 })?;
90 toml::from_str(&contents).map_err(|e| LlmError::Config(e.to_string()))
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct BudgetConfig {
97 #[serde(default)]
98 pub max_tokens: Option<u64>,
99}
100
101pub fn resolve_agent_model<'a>(config: &'a AgentConfig, client_default: &'a str) -> &'a str {
111 config.model.as_deref().unwrap_or(client_default)
112}
113
114pub fn resolve_agent_system<'a>(
117 arg: Option<&'a str>,
118 config: &'a AgentConfig,
119) -> Option<&'a str> {
120 arg.or(config.system_prompt.as_deref())
121}
122
123pub fn resolve_agent_retry(
125 cli_arg: Option<u32>,
126 config: &AgentConfig,
127 client_default: &RetryConfig,
128) -> RetryConfig {
129 if let Some(n) = cli_arg {
130 let mut cfg = client_default.clone();
131 cfg.max_retries = n;
132 return cfg;
133 }
134 if let Some(agent_retry) = &config.retry {
135 return agent_retry.clone();
136 }
137 client_default.clone()
138}
139
140pub fn resolve_agent_tools(
144 config: &AgentConfig,
145 registry_tools: &[Tool],
146) -> Result<Vec<Tool>> {
147 let mut out = Vec::with_capacity(config.tools.len());
148 for name in &config.tools {
149 match registry_tools.iter().find(|t| t.name == *name) {
150 Some(t) => out.push(t.clone()),
151 None => {
152 return Err(LlmError::Config(format!(
153 "unknown tool in agent config: {name}"
154 )));
155 }
156 }
157 }
158 Ok(out)
159}
160
161pub fn resolve_agent_budget(config: &AgentConfig) -> Option<u64> {
163 config.budget.as_ref().and_then(|b| b.max_tokens)
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
172pub enum AgentSource {
173 Global,
174 Local,
175}
176
177impl std::fmt::Display for AgentSource {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 match self {
180 AgentSource::Global => write!(f, "global"),
181 AgentSource::Local => write!(f, "local"),
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct AgentInfo {
189 pub name: String,
190 pub path: PathBuf,
191 pub source: AgentSource,
192}
193
194pub fn discover_agents(
199 global_dir: &Path,
200 local_dir: Option<&Path>,
201) -> Result<Vec<AgentInfo>> {
202 let mut agents: HashMap<String, AgentInfo> = HashMap::new();
203
204 scan_agents_dir(global_dir, AgentSource::Global, &mut agents)?;
206
207 if let Some(local) = local_dir {
209 scan_agents_dir(local, AgentSource::Local, &mut agents)?;
210 }
211
212 let mut result: Vec<AgentInfo> = agents.into_values().collect();
213 result.sort_by(|a, b| a.name.cmp(&b.name));
214 Ok(result)
215}
216
217pub fn resolve_agent(
219 name: &str,
220 global_dir: &Path,
221 local_dir: Option<&Path>,
222) -> Result<(AgentConfig, PathBuf)> {
223 if let Some(local) = local_dir {
225 let path = local.join(format!("{name}.toml"));
226 if path.exists() {
227 let config = AgentConfig::load(&path)?;
228 return Ok((config, path));
229 }
230 }
231
232 let path = global_dir.join(format!("{name}.toml"));
234 if path.exists() {
235 let config = AgentConfig::load(&path)?;
236 return Ok((config, path));
237 }
238
239 Err(LlmError::Config(format!("agent not found: {name}")))
240}
241
242fn scan_agents_dir(
243 dir: &Path,
244 source: AgentSource,
245 agents: &mut HashMap<String, AgentInfo>,
246) -> Result<()> {
247 let entries = match std::fs::read_dir(dir) {
248 Ok(entries) => entries,
249 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()),
250 Err(e) => return Err(LlmError::Io(e)),
251 };
252
253 for entry in entries {
254 let entry = entry?;
255 let path = entry.path();
256 if path.extension().and_then(|e| e.to_str()) == Some("toml")
257 && let Some(stem) = path.file_stem().and_then(|s| s.to_str())
258 {
259 agents.insert(
260 stem.to_string(),
261 AgentInfo {
262 name: stem.to_string(),
263 path,
264 source: source.clone(),
265 },
266 );
267 }
268 }
269 Ok(())
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
279 fn agent_config_default() {
280 let config = AgentConfig::default();
281 assert!(config.model.is_none());
282 assert!(config.system_prompt.is_none());
283 assert!(config.tools.is_empty());
284 assert_eq!(config.chain_limit, 10);
285 assert!(config.options.is_empty());
286 assert!(config.budget.is_none());
287 assert!(config.retry.is_none());
288 assert!(config.parallel_tools);
289 assert!(config.max_parallel_tools.is_none());
290 }
291
292 #[test]
293 fn agent_config_load_full_toml() {
294 let tmp = tempfile::tempdir().unwrap();
295 let path = tmp.path().join("reviewer.toml");
296 std::fs::write(
297 &path,
298 r#"
299model = "claude-sonnet-4-20250514"
300system_prompt = "You are a code reviewer."
301tools = ["ripgrep", "read_file", "llm_time"]
302chain_limit = 20
303
304[options]
305temperature = 0
306
307[budget]
308max_tokens = 50000
309"#,
310 )
311 .unwrap();
312
313 let config = AgentConfig::load(&path).unwrap();
314 assert_eq!(config.model.as_deref(), Some("claude-sonnet-4-20250514"));
315 assert_eq!(config.system_prompt.as_deref(), Some("You are a code reviewer."));
316 assert_eq!(config.tools, vec!["ripgrep", "read_file", "llm_time"]);
317 assert_eq!(config.chain_limit, 20);
318 assert_eq!(config.options["temperature"], serde_json::json!(0));
319
320 let budget = config.budget.unwrap();
321 assert_eq!(budget.max_tokens, Some(50000));
322 }
323
324 #[test]
325 fn agent_config_load_minimal_toml() {
326 let tmp = tempfile::tempdir().unwrap();
327 let path = tmp.path().join("minimal.toml");
328 std::fs::write(&path, "model = \"gpt-4o-mini\"\n").unwrap();
329
330 let config = AgentConfig::load(&path).unwrap();
331 assert_eq!(config.model.as_deref(), Some("gpt-4o-mini"));
332 assert!(config.system_prompt.is_none());
333 assert!(config.tools.is_empty());
334 assert_eq!(config.chain_limit, 10); assert!(config.options.is_empty());
336 }
337
338 #[test]
339 fn agent_config_load_with_options() {
340 let tmp = tempfile::tempdir().unwrap();
341 let path = tmp.path().join("opts.toml");
342 std::fs::write(
343 &path,
344 r#"
345[options]
346temperature = 0.7
347max_tokens = 200
348"#,
349 )
350 .unwrap();
351
352 let config = AgentConfig::load(&path).unwrap();
353 assert_eq!(config.options["temperature"], serde_json::json!(0.7));
354 assert_eq!(config.options["max_tokens"], serde_json::json!(200));
355 }
356
357 #[test]
358 fn agent_config_load_missing_file() {
359 let result = AgentConfig::load(Path::new("/nonexistent/agent.toml"));
360 assert!(result.is_err());
361 let err = result.unwrap_err();
362 assert!(matches!(err, LlmError::Config(_)));
363 assert!(err.to_string().contains("not found"));
364 }
365
366 #[test]
367 fn agent_config_load_invalid_toml() {
368 let tmp = tempfile::tempdir().unwrap();
369 let path = tmp.path().join("bad.toml");
370 std::fs::write(&path, "not valid {{{{ toml").unwrap();
371
372 let result = AgentConfig::load(&path);
373 assert!(result.is_err());
374 assert!(matches!(result.unwrap_err(), LlmError::Config(_)));
375 }
376
377 #[test]
378 fn agent_config_chain_limit_default() {
379 let tmp = tempfile::tempdir().unwrap();
380 let path = tmp.path().join("empty.toml");
381 std::fs::write(&path, "").unwrap();
382
383 let config = AgentConfig::load(&path).unwrap();
384 assert_eq!(config.chain_limit, 10);
385 }
386
387 #[test]
390 fn discover_agents_empty_dirs() {
391 let global = tempfile::tempdir().unwrap();
392 let local = tempfile::tempdir().unwrap();
393
394 let agents = discover_agents(global.path(), Some(local.path())).unwrap();
395 assert!(agents.is_empty());
396 }
397
398 #[test]
399 fn discover_agents_global_only() {
400 let global = tempfile::tempdir().unwrap();
401 std::fs::write(
402 global.path().join("reviewer.toml"),
403 "model = \"gpt-4o\"\n",
404 )
405 .unwrap();
406
407 let agents = discover_agents(global.path(), None).unwrap();
408 assert_eq!(agents.len(), 1);
409 assert_eq!(agents[0].name, "reviewer");
410 assert_eq!(agents[0].source, AgentSource::Global);
411 }
412
413 #[test]
414 fn discover_agents_local_only() {
415 let global = tempfile::tempdir().unwrap();
416 let local = tempfile::tempdir().unwrap();
417 std::fs::write(
418 local.path().join("helper.toml"),
419 "model = \"gpt-4o-mini\"\n",
420 )
421 .unwrap();
422
423 let agents = discover_agents(global.path(), Some(local.path())).unwrap();
424 assert_eq!(agents.len(), 1);
425 assert_eq!(agents[0].name, "helper");
426 assert_eq!(agents[0].source, AgentSource::Local);
427 }
428
429 #[test]
430 fn discover_agents_local_shadows_global() {
431 let global = tempfile::tempdir().unwrap();
432 let local = tempfile::tempdir().unwrap();
433 std::fs::write(
434 global.path().join("reviewer.toml"),
435 "model = \"gpt-4o\"\n",
436 )
437 .unwrap();
438 std::fs::write(
439 local.path().join("reviewer.toml"),
440 "model = \"gpt-4o-mini\"\n",
441 )
442 .unwrap();
443
444 let agents = discover_agents(global.path(), Some(local.path())).unwrap();
445 assert_eq!(agents.len(), 1);
446 assert_eq!(agents[0].name, "reviewer");
447 assert_eq!(agents[0].source, AgentSource::Local);
448 }
449
450 #[test]
451 fn discover_agents_sorted() {
452 let global = tempfile::tempdir().unwrap();
453 std::fs::write(global.path().join("zebra.toml"), "").unwrap();
454 std::fs::write(global.path().join("alpha.toml"), "").unwrap();
455 std::fs::write(global.path().join("mid.toml"), "").unwrap();
456
457 let agents = discover_agents(global.path(), None).unwrap();
458 let names: Vec<&str> = agents.iter().map(|a| a.name.as_str()).collect();
459 assert_eq!(names, vec!["alpha", "mid", "zebra"]);
460 }
461
462 #[test]
463 fn discover_agents_non_toml_ignored() {
464 let global = tempfile::tempdir().unwrap();
465 std::fs::write(global.path().join("agent.toml"), "").unwrap();
466 std::fs::write(global.path().join("readme.md"), "# agents").unwrap();
467 std::fs::write(global.path().join("notes.txt"), "some notes").unwrap();
468
469 let agents = discover_agents(global.path(), None).unwrap();
470 assert_eq!(agents.len(), 1);
471 assert_eq!(agents[0].name, "agent");
472 }
473
474 #[test]
475 fn discover_agents_nonexistent_dirs() {
476 let agents = discover_agents(
477 Path::new("/nonexistent/global"),
478 Some(Path::new("/nonexistent/local")),
479 )
480 .unwrap();
481 assert!(agents.is_empty());
482 }
483
484 #[test]
485 fn resolve_agent_found() {
486 let global = tempfile::tempdir().unwrap();
487 std::fs::write(
488 global.path().join("reviewer.toml"),
489 "model = \"gpt-4o\"\nsystem_prompt = \"Review code.\"\n",
490 )
491 .unwrap();
492
493 let (config, path) = resolve_agent("reviewer", global.path(), None).unwrap();
494 assert_eq!(config.model.as_deref(), Some("gpt-4o"));
495 assert_eq!(config.system_prompt.as_deref(), Some("Review code."));
496 assert_eq!(path, global.path().join("reviewer.toml"));
497 }
498
499 #[test]
500 fn resolve_agent_not_found() {
501 let global = tempfile::tempdir().unwrap();
502 let result = resolve_agent("nonexistent", global.path(), None);
503 assert!(result.is_err());
504 let err = result.unwrap_err();
505 assert!(matches!(err, LlmError::Config(_)));
506 assert!(err.to_string().contains("agent not found"));
507 }
508
509 #[test]
512 fn agent_config_parses_retry() {
513 let tmp = tempfile::tempdir().unwrap();
514 let path = tmp.path().join("retry.toml");
515 std::fs::write(
516 &path,
517 r#"
518[retry]
519max_retries = 5
520base_delay_ms = 500
521"#,
522 )
523 .unwrap();
524
525 let config = AgentConfig::load(&path).unwrap();
526 let retry = config.retry.unwrap();
527 assert_eq!(retry.max_retries, 5);
528 assert_eq!(retry.base_delay_ms, 500);
529 assert_eq!(retry.max_delay_ms, 30_000);
531 assert!(retry.jitter);
532 }
533
534 #[test]
535 fn agent_config_retry_defaults() {
536 let tmp = tempfile::tempdir().unwrap();
537 let path = tmp.path().join("retry_defaults.toml");
538 std::fs::write(&path, "[retry]\n").unwrap();
539
540 let config = AgentConfig::load(&path).unwrap();
541 let retry = config.retry.unwrap();
542 assert_eq!(retry.max_retries, 3);
543 assert_eq!(retry.base_delay_ms, 1000);
544 assert_eq!(retry.max_delay_ms, 30_000);
545 assert!(retry.jitter);
546 }
547
548 #[test]
551 fn agent_config_parses_parallel_tools_fields() {
552 let tmp = tempfile::tempdir().unwrap();
553 let path = tmp.path().join("parallel.toml");
554 std::fs::write(
555 &path,
556 r#"
557parallel_tools = false
558max_parallel_tools = 3
559"#,
560 )
561 .unwrap();
562
563 let config = AgentConfig::load(&path).unwrap();
564 assert!(!config.parallel_tools);
565 assert_eq!(config.max_parallel_tools, Some(3));
566 }
567
568 #[test]
569 fn agent_config_parallel_tools_defaults() {
570 let tmp = tempfile::tempdir().unwrap();
571 let path = tmp.path().join("defaults.toml");
572 std::fs::write(&path, "model = \"gpt-4o-mini\"\n").unwrap();
573
574 let config = AgentConfig::load(&path).unwrap();
575 assert!(config.parallel_tools, "parallel_tools should default to true");
576 assert_eq!(config.max_parallel_tools, None);
577 }
578
579 #[test]
580 fn agent_config_no_retry() {
581 let tmp = tempfile::tempdir().unwrap();
582 let path = tmp.path().join("no_retry.toml");
583 std::fs::write(&path, "model = \"gpt-4o-mini\"\n").unwrap();
584
585 let config = AgentConfig::load(&path).unwrap();
586 assert!(config.retry.is_none());
587 }
588
589 fn tool(name: &str) -> Tool {
592 Tool {
593 name: name.to_string(),
594 description: format!("{name} tool"),
595 input_schema: serde_json::json!({"type": "object"}),
596 }
597 }
598
599 #[test]
600 fn resolve_model_uses_config_when_set() {
601 let mut cfg = AgentConfig::default();
602 cfg.model = Some("gpt-4o".into());
603 assert_eq!(resolve_agent_model(&cfg, "gpt-4o-mini"), "gpt-4o");
604 }
605
606 #[test]
607 fn resolve_model_falls_back_to_client_default() {
608 let cfg = AgentConfig::default();
609 assert_eq!(resolve_agent_model(&cfg, "gpt-4o-mini"), "gpt-4o-mini");
610 }
611
612 #[test]
613 fn resolve_system_prefers_arg_over_config() {
614 let mut cfg = AgentConfig::default();
615 cfg.system_prompt = Some("from config".into());
616 assert_eq!(
617 resolve_agent_system(Some("from arg"), &cfg),
618 Some("from arg")
619 );
620 }
621
622 #[test]
623 fn resolve_system_uses_config_when_no_arg() {
624 let mut cfg = AgentConfig::default();
625 cfg.system_prompt = Some("from config".into());
626 assert_eq!(resolve_agent_system(None, &cfg), Some("from config"));
627 }
628
629 #[test]
630 fn resolve_system_none_when_neither() {
631 let cfg = AgentConfig::default();
632 assert_eq!(resolve_agent_system(None, &cfg), None);
633 }
634
635 #[test]
636 fn resolve_retry_cli_arg_wins() {
637 let cfg = AgentConfig::default();
638 let client = RetryConfig::default();
639 let out = resolve_agent_retry(Some(7), &cfg, &client);
640 assert_eq!(out.max_retries, 7);
641 }
642
643 #[test]
644 fn resolve_retry_agent_config_wins_over_default() {
645 let mut cfg = AgentConfig::default();
646 cfg.retry = Some(RetryConfig {
647 max_retries: 5,
648 base_delay_ms: 123,
649 max_delay_ms: 456,
650 jitter: false,
651 });
652 let client = RetryConfig::default();
653 let out = resolve_agent_retry(None, &cfg, &client);
654 assert_eq!(out.max_retries, 5);
655 assert_eq!(out.base_delay_ms, 123);
656 }
657
658 #[test]
659 fn resolve_retry_falls_back_to_client() {
660 let cfg = AgentConfig::default();
661 let client = RetryConfig {
662 max_retries: 2,
663 base_delay_ms: 1000,
664 max_delay_ms: 30_000,
665 jitter: true,
666 };
667 let out = resolve_agent_retry(None, &cfg, &client);
668 assert_eq!(out.max_retries, 2);
669 }
670
671 #[test]
672 fn resolve_tools_filters_to_agent_whitelist() {
673 let mut cfg = AgentConfig::default();
674 cfg.tools = vec!["read_file".into(), "llm_time".into()];
675 let registry = vec![tool("read_file"), tool("ripgrep"), tool("llm_time")];
676
677 let out = resolve_agent_tools(&cfg, ®istry).unwrap();
678 assert_eq!(out.len(), 2);
679 assert_eq!(out[0].name, "read_file");
680 assert_eq!(out[1].name, "llm_time");
681 }
682
683 #[test]
684 fn resolve_tools_errors_on_unknown_with_cli_format() {
685 let mut cfg = AgentConfig::default();
686 cfg.tools = vec!["missing".into()];
687 let registry = vec![tool("read_file")];
688
689 let err = resolve_agent_tools(&cfg, ®istry).unwrap_err();
690 let msg = err.to_string();
691 assert!(
694 msg.contains("unknown tool in agent config: missing"),
695 "got: {msg}"
696 );
697 }
698
699 #[test]
700 fn resolve_tools_empty_config_returns_empty() {
701 let cfg = AgentConfig::default();
702 let registry = vec![tool("read_file")];
703 let out = resolve_agent_tools(&cfg, ®istry).unwrap();
704 assert!(out.is_empty());
705 }
706
707 #[test]
708 fn resolve_budget_extracts_max_tokens() {
709 let mut cfg = AgentConfig::default();
710 cfg.budget = Some(BudgetConfig {
711 max_tokens: Some(5000),
712 });
713 assert_eq!(resolve_agent_budget(&cfg), Some(5000));
714 }
715
716 #[test]
717 fn resolve_budget_none_when_unset() {
718 let cfg = AgentConfig::default();
719 assert_eq!(resolve_agent_budget(&cfg), None);
720 }
721
722 #[test]
723 fn resolve_agent_local_wins() {
724 let global = tempfile::tempdir().unwrap();
725 let local = tempfile::tempdir().unwrap();
726 std::fs::write(
727 global.path().join("reviewer.toml"),
728 "model = \"gpt-4o\"\n",
729 )
730 .unwrap();
731 std::fs::write(
732 local.path().join("reviewer.toml"),
733 "model = \"claude-sonnet-4-20250514\"\n",
734 )
735 .unwrap();
736
737 let (config, path) = resolve_agent("reviewer", global.path(), Some(local.path())).unwrap();
738 assert_eq!(config.model.as_deref(), Some("claude-sonnet-4-20250514"));
739 assert_eq!(path, local.path().join("reviewer.toml"));
740 }
741}