Skip to main content

batty_cli/config/
mod.rs

1#![cfg_attr(not(test), allow(dead_code))]
2
3use anyhow::{Context, Result};
4use serde::Deserialize;
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7
8const CONFIG_FILENAME: &str = "config.toml";
9const CONFIG_DIR: &str = ".batty";
10
11#[derive(Debug, Default, Clone, Copy, Deserialize, PartialEq)]
12#[serde(rename_all = "kebab-case")]
13pub enum Policy {
14    #[default]
15    Observe,
16    Suggest,
17    Act,
18}
19
20#[derive(Debug, Deserialize)]
21pub struct Defaults {
22    #[serde(default = "default_agent")]
23    pub agent: String,
24    #[serde(default)]
25    pub policy: Policy,
26    #[serde(default)]
27    pub dod: Option<String>,
28    #[serde(default = "default_max_retries")]
29    pub max_retries: u32,
30}
31
32/// Policy section with auto-answer patterns.
33///
34/// ```toml
35/// [policy.auto_answer]
36/// "Continue? [y/n]" = "y"
37/// "Allow tool" = "y"
38/// ```
39#[derive(Debug, Deserialize, Default)]
40#[allow(dead_code)] // Used by policy engine (task #8), wired in task #12
41pub struct PolicyConfig {
42    #[serde(default)]
43    pub auto_answer: HashMap<String, String>,
44}
45
46fn default_agent() -> String {
47    "claude".to_string()
48}
49
50fn default_max_retries() -> u32 {
51    3
52}
53
54fn default_supervisor_enabled() -> bool {
55    true
56}
57
58fn default_supervisor_program() -> String {
59    "claude".to_string()
60}
61
62fn default_supervisor_args() -> Vec<String> {
63    vec![
64        "-p".to_string(),
65        "--output-format".to_string(),
66        "text".to_string(),
67    ]
68}
69
70fn default_supervisor_timeout_secs() -> u64 {
71    60
72}
73
74fn default_supervisor_trace_io() -> bool {
75    true
76}
77
78fn default_detector_silence_timeout_secs() -> u64 {
79    3
80}
81
82fn default_detector_answer_cooldown_millis() -> u64 {
83    1000
84}
85
86fn default_detector_unknown_request_fallback() -> bool {
87    true
88}
89
90fn default_detector_idle_input_fallback() -> bool {
91    true
92}
93
94fn default_dangerous_mode_enabled() -> bool {
95    false
96}
97
98impl Default for Defaults {
99    fn default() -> Self {
100        Self {
101            agent: default_agent(),
102            policy: Policy::default(),
103            dod: None,
104            max_retries: default_max_retries(),
105        }
106    }
107}
108
109#[derive(Debug, Deserialize)]
110pub struct SupervisorConfig {
111    #[serde(default = "default_supervisor_enabled")]
112    pub enabled: bool,
113    #[serde(default = "default_supervisor_program")]
114    pub program: String,
115    #[serde(default = "default_supervisor_args")]
116    pub args: Vec<String>,
117    #[serde(default = "default_supervisor_timeout_secs")]
118    pub timeout_secs: u64,
119    #[serde(default = "default_supervisor_trace_io")]
120    pub trace_io: bool,
121}
122
123impl Default for SupervisorConfig {
124    fn default() -> Self {
125        Self {
126            enabled: default_supervisor_enabled(),
127            program: default_supervisor_program(),
128            args: default_supervisor_args(),
129            timeout_secs: default_supervisor_timeout_secs(),
130            trace_io: default_supervisor_trace_io(),
131        }
132    }
133}
134
135#[derive(Debug, Deserialize)]
136pub struct DetectorSettings {
137    #[serde(default = "default_detector_silence_timeout_secs")]
138    pub silence_timeout_secs: u64,
139    #[serde(default = "default_detector_answer_cooldown_millis")]
140    pub answer_cooldown_millis: u64,
141    #[serde(default = "default_detector_unknown_request_fallback")]
142    pub unknown_request_fallback: bool,
143    #[serde(default = "default_detector_idle_input_fallback")]
144    pub idle_input_fallback: bool,
145}
146
147impl Default for DetectorSettings {
148    fn default() -> Self {
149        Self {
150            silence_timeout_secs: default_detector_silence_timeout_secs(),
151            answer_cooldown_millis: default_detector_answer_cooldown_millis(),
152            unknown_request_fallback: default_detector_unknown_request_fallback(),
153            idle_input_fallback: default_detector_idle_input_fallback(),
154        }
155    }
156}
157
158#[derive(Debug, Deserialize)]
159pub struct DangerousModeConfig {
160    #[serde(default = "default_dangerous_mode_enabled")]
161    pub enabled: bool,
162}
163
164impl Default for DangerousModeConfig {
165    fn default() -> Self {
166        Self {
167            enabled: default_dangerous_mode_enabled(),
168        }
169    }
170}
171
172#[derive(Debug, Deserialize, Default)]
173pub struct ProjectConfig {
174    #[serde(default)]
175    pub defaults: Defaults,
176    #[serde(default)]
177    #[allow(dead_code)] // Used by policy engine, wired in task #12
178    pub policy: PolicyConfig,
179    #[serde(default)]
180    pub supervisor: SupervisorConfig,
181    #[serde(default)]
182    pub detector: DetectorSettings,
183    #[serde(default)]
184    pub dangerous_mode: DangerousModeConfig,
185}
186
187impl ProjectConfig {
188    /// Search upward from `start` for a `.batty/config.toml` file and load it.
189    /// Returns the default config if no file is found.
190    pub fn load(start: &Path) -> Result<(Self, Option<PathBuf>)> {
191        if let Some(path) = Self::find_config_file(start) {
192            let contents = std::fs::read_to_string(&path)
193                .with_context(|| format!("failed to read {}", path.display()))?;
194            let config: ProjectConfig = toml::from_str(&contents)
195                .with_context(|| format!("failed to parse {}", path.display()))?;
196            Ok((config, Some(path)))
197        } else {
198            Ok((ProjectConfig::default(), None))
199        }
200    }
201
202    fn find_config_file(start: &Path) -> Option<PathBuf> {
203        let mut dir = start.to_path_buf();
204        loop {
205            let candidate = dir.join(CONFIG_DIR).join(CONFIG_FILENAME);
206            if candidate.is_file() {
207                return Some(candidate);
208            }
209            if !dir.pop() {
210                return None;
211            }
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use std::fs;
220
221    #[test]
222    fn default_config_values() {
223        let config = ProjectConfig::default();
224        assert_eq!(config.defaults.agent, "claude");
225        assert_eq!(config.defaults.policy, Policy::Observe);
226        assert_eq!(config.defaults.max_retries, 3);
227        assert!(config.defaults.dod.is_none());
228        assert!(config.supervisor.enabled);
229        assert_eq!(config.supervisor.program, "claude");
230        assert_eq!(
231            config.supervisor.args,
232            vec!["-p", "--output-format", "text"]
233        );
234        assert_eq!(config.supervisor.timeout_secs, 60);
235        assert!(config.supervisor.trace_io);
236        assert_eq!(config.detector.silence_timeout_secs, 3);
237        assert_eq!(config.detector.answer_cooldown_millis, 1000);
238        assert!(config.detector.unknown_request_fallback);
239        assert!(config.detector.idle_input_fallback);
240        assert!(!config.dangerous_mode.enabled);
241    }
242
243    #[test]
244    fn parse_full_config() {
245        let toml = r#"
246[defaults]
247agent = "codex"
248policy = "act"
249dod = "cargo test --workspace"
250max_retries = 5
251
252[supervisor]
253enabled = true
254program = "claude"
255args = ["-p", "--output-format", "text"]
256timeout_secs = 45
257trace_io = true
258
259[detector]
260silence_timeout_secs = 5
261answer_cooldown_millis = 1500
262unknown_request_fallback = true
263idle_input_fallback = true
264
265[dangerous_mode]
266enabled = true
267"#;
268        let config: ProjectConfig = toml::from_str(toml).unwrap();
269        assert_eq!(config.defaults.agent, "codex");
270        assert_eq!(config.defaults.policy, Policy::Act);
271        assert_eq!(
272            config.defaults.dod.as_deref(),
273            Some("cargo test --workspace")
274        );
275        assert_eq!(config.defaults.max_retries, 5);
276        assert!(config.supervisor.enabled);
277        assert_eq!(config.supervisor.program, "claude");
278        assert_eq!(config.supervisor.timeout_secs, 45);
279        assert!(config.supervisor.trace_io);
280        assert_eq!(config.detector.silence_timeout_secs, 5);
281        assert_eq!(config.detector.answer_cooldown_millis, 1500);
282        assert!(config.detector.unknown_request_fallback);
283        assert!(config.detector.idle_input_fallback);
284        assert!(config.dangerous_mode.enabled);
285    }
286
287    #[test]
288    fn parse_partial_config() {
289        let toml = r#"
290[defaults]
291agent = "aider"
292"#;
293        let config: ProjectConfig = toml::from_str(toml).unwrap();
294        assert_eq!(config.defaults.agent, "aider");
295        assert_eq!(config.defaults.policy, Policy::Observe);
296        assert_eq!(config.defaults.max_retries, 3);
297        assert!(config.detector.unknown_request_fallback);
298        assert!(config.detector.idle_input_fallback);
299        assert!(!config.dangerous_mode.enabled);
300    }
301
302    #[test]
303    fn load_from_directory() {
304        let tmp = tempfile::tempdir().unwrap();
305        let batty_dir = tmp.path().join(".batty");
306        fs::create_dir_all(&batty_dir).unwrap();
307        fs::write(
308            batty_dir.join("config.toml"),
309            r#"
310[defaults]
311agent = "claude"
312policy = "act"
313dod = "cargo test"
314max_retries = 2
315"#,
316        )
317        .unwrap();
318
319        let (config, path) = ProjectConfig::load(tmp.path()).unwrap();
320        assert!(path.is_some());
321        assert_eq!(config.defaults.agent, "claude");
322        assert_eq!(config.defaults.policy, Policy::Act);
323        assert_eq!(config.defaults.max_retries, 2);
324    }
325
326    #[test]
327    fn load_returns_default_when_no_file() {
328        let tmp = tempfile::tempdir().unwrap();
329        let (config, path) = ProjectConfig::load(tmp.path()).unwrap();
330        assert!(path.is_none());
331        assert_eq!(config.defaults.agent, "claude");
332    }
333
334    #[test]
335    fn parse_auto_answer_config() {
336        let toml = r#"
337[defaults]
338agent = "claude"
339policy = "act"
340
341[policy.auto_answer]
342"Continue? [y/n]" = "y"
343"Allow tool" = "y"
344"#;
345        let config: ProjectConfig = toml::from_str(toml).unwrap();
346        assert_eq!(config.policy.auto_answer.len(), 2);
347        assert_eq!(
348            config.policy.auto_answer.get("Continue? [y/n]").unwrap(),
349            "y"
350        );
351        assert_eq!(config.policy.auto_answer.get("Allow tool").unwrap(), "y");
352    }
353
354    #[test]
355    fn load_walks_up_directories() {
356        let tmp = tempfile::tempdir().unwrap();
357        let batty_dir = tmp.path().join(".batty");
358        fs::create_dir_all(&batty_dir).unwrap();
359        fs::write(
360            batty_dir.join("config.toml"),
361            r#"
362[defaults]
363agent = "codex"
364"#,
365        )
366        .unwrap();
367
368        let nested = tmp.path().join("src").join("deep").join("nested");
369        fs::create_dir_all(&nested).unwrap();
370
371        let (config, path) = ProjectConfig::load(&nested).unwrap();
372        assert!(path.is_some());
373        assert_eq!(config.defaults.agent, "codex");
374    }
375}