Skip to main content

batty_cli/config/
mod.rs

1#![cfg_attr(not(test), allow(dead_code))]
2
3use anyhow::{Context, Result};
4use serde::Deserialize;
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7
8const CONFIG_FILENAME: &str = "config.toml";
9const CONFIG_DIR: &str = ".batty";
10
11#[derive(Debug, Default, Clone, Copy, Deserialize, PartialEq)]
12#[serde(rename_all = "kebab-case")]
13pub enum Policy {
14    #[default]
15    Observe,
16    Suggest,
17    Act,
18}
19
20#[derive(Debug, Deserialize)]
21pub struct Defaults {
22    #[serde(default = "default_agent")]
23    pub agent: String,
24    #[serde(default)]
25    pub policy: Policy,
26    #[serde(default)]
27    pub dod: Option<String>,
28    #[serde(default = "default_max_retries")]
29    pub max_retries: u32,
30}
31
32/// Policy section with auto-answer patterns.
33///
34/// ```toml
35/// [policy.auto_answer]
36/// "Continue? [y/n]" = "y"
37/// "Allow tool" = "y"
38/// ```
39#[derive(Debug, Deserialize, Default)]
40pub struct PolicyConfig {
41    #[serde(default)]
42    pub auto_answer: HashMap<String, String>,
43}
44
45fn default_agent() -> String {
46    "claude".to_string()
47}
48
49fn default_max_retries() -> u32 {
50    3
51}
52
53fn default_supervisor_enabled() -> bool {
54    true
55}
56
57fn default_supervisor_program() -> String {
58    "claude".to_string()
59}
60
61fn default_supervisor_args() -> Vec<String> {
62    vec![
63        "-p".to_string(),
64        "--output-format".to_string(),
65        "text".to_string(),
66    ]
67}
68
69fn default_supervisor_timeout_secs() -> u64 {
70    60
71}
72
73fn default_supervisor_trace_io() -> bool {
74    true
75}
76
77fn default_detector_silence_timeout_secs() -> u64 {
78    3
79}
80
81fn default_detector_answer_cooldown_millis() -> u64 {
82    1000
83}
84
85fn default_detector_unknown_request_fallback() -> bool {
86    true
87}
88
89fn default_detector_idle_input_fallback() -> bool {
90    true
91}
92
93fn default_dangerous_mode_enabled() -> bool {
94    false
95}
96
97impl Default for Defaults {
98    fn default() -> Self {
99        Self {
100            agent: default_agent(),
101            policy: Policy::default(),
102            dod: None,
103            max_retries: default_max_retries(),
104        }
105    }
106}
107
108#[derive(Debug, Deserialize)]
109pub struct SupervisorConfig {
110    #[serde(default = "default_supervisor_enabled")]
111    pub enabled: bool,
112    #[serde(default = "default_supervisor_program")]
113    pub program: String,
114    #[serde(default = "default_supervisor_args")]
115    pub args: Vec<String>,
116    #[serde(default = "default_supervisor_timeout_secs")]
117    pub timeout_secs: u64,
118    #[serde(default = "default_supervisor_trace_io")]
119    pub trace_io: bool,
120}
121
122impl Default for SupervisorConfig {
123    fn default() -> Self {
124        Self {
125            enabled: default_supervisor_enabled(),
126            program: default_supervisor_program(),
127            args: default_supervisor_args(),
128            timeout_secs: default_supervisor_timeout_secs(),
129            trace_io: default_supervisor_trace_io(),
130        }
131    }
132}
133
134#[derive(Debug, Deserialize)]
135pub struct DetectorSettings {
136    #[serde(default = "default_detector_silence_timeout_secs")]
137    pub silence_timeout_secs: u64,
138    #[serde(default = "default_detector_answer_cooldown_millis")]
139    pub answer_cooldown_millis: u64,
140    #[serde(default = "default_detector_unknown_request_fallback")]
141    pub unknown_request_fallback: bool,
142    #[serde(default = "default_detector_idle_input_fallback")]
143    pub idle_input_fallback: bool,
144}
145
146impl Default for DetectorSettings {
147    fn default() -> Self {
148        Self {
149            silence_timeout_secs: default_detector_silence_timeout_secs(),
150            answer_cooldown_millis: default_detector_answer_cooldown_millis(),
151            unknown_request_fallback: default_detector_unknown_request_fallback(),
152            idle_input_fallback: default_detector_idle_input_fallback(),
153        }
154    }
155}
156
157#[derive(Debug, Deserialize)]
158pub struct DangerousModeConfig {
159    #[serde(default = "default_dangerous_mode_enabled")]
160    pub enabled: bool,
161}
162
163impl Default for DangerousModeConfig {
164    fn default() -> Self {
165        Self {
166            enabled: default_dangerous_mode_enabled(),
167        }
168    }
169}
170
171#[derive(Debug, Deserialize, Default)]
172pub struct ProjectConfig {
173    #[serde(default)]
174    pub defaults: Defaults,
175    #[serde(default)]
176    pub policy: PolicyConfig,
177    #[serde(default)]
178    pub supervisor: SupervisorConfig,
179    #[serde(default)]
180    pub detector: DetectorSettings,
181    #[serde(default)]
182    pub dangerous_mode: DangerousModeConfig,
183}
184
185impl ProjectConfig {
186    /// Search upward from `start` for a `.batty/config.toml` file and load it.
187    /// Returns the default config if no file is found.
188    pub fn load(start: &Path) -> Result<(Self, Option<PathBuf>)> {
189        if let Some(path) = Self::find_config_file(start) {
190            let contents = std::fs::read_to_string(&path)
191                .with_context(|| format!("failed to read {}", path.display()))?;
192            let config: ProjectConfig = toml::from_str(&contents)
193                .with_context(|| format!("failed to parse {}", path.display()))?;
194            Ok((config, Some(path)))
195        } else {
196            Ok((ProjectConfig::default(), None))
197        }
198    }
199
200    fn find_config_file(start: &Path) -> Option<PathBuf> {
201        let mut dir = start.to_path_buf();
202        loop {
203            let candidate = dir.join(CONFIG_DIR).join(CONFIG_FILENAME);
204            if candidate.is_file() {
205                return Some(candidate);
206            }
207            if !dir.pop() {
208                return None;
209            }
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use std::fs;
218
219    #[test]
220    fn default_config_values() {
221        let config = ProjectConfig::default();
222        assert_eq!(config.defaults.agent, "claude");
223        assert_eq!(config.defaults.policy, Policy::Observe);
224        assert_eq!(config.defaults.max_retries, 3);
225        assert!(config.defaults.dod.is_none());
226        assert!(config.supervisor.enabled);
227        assert_eq!(config.supervisor.program, "claude");
228        assert_eq!(
229            config.supervisor.args,
230            vec!["-p", "--output-format", "text"]
231        );
232        assert_eq!(config.supervisor.timeout_secs, 60);
233        assert!(config.supervisor.trace_io);
234        assert_eq!(config.detector.silence_timeout_secs, 3);
235        assert_eq!(config.detector.answer_cooldown_millis, 1000);
236        assert!(config.detector.unknown_request_fallback);
237        assert!(config.detector.idle_input_fallback);
238        assert!(!config.dangerous_mode.enabled);
239    }
240
241    #[test]
242    fn parse_full_config() {
243        let toml = r#"
244[defaults]
245agent = "codex"
246policy = "act"
247dod = "cargo test --workspace"
248max_retries = 5
249
250[supervisor]
251enabled = true
252program = "claude"
253args = ["-p", "--output-format", "text"]
254timeout_secs = 45
255trace_io = true
256
257[detector]
258silence_timeout_secs = 5
259answer_cooldown_millis = 1500
260unknown_request_fallback = true
261idle_input_fallback = true
262
263[dangerous_mode]
264enabled = true
265"#;
266        let config: ProjectConfig = toml::from_str(toml).unwrap();
267        assert_eq!(config.defaults.agent, "codex");
268        assert_eq!(config.defaults.policy, Policy::Act);
269        assert_eq!(
270            config.defaults.dod.as_deref(),
271            Some("cargo test --workspace")
272        );
273        assert_eq!(config.defaults.max_retries, 5);
274        assert!(config.supervisor.enabled);
275        assert_eq!(config.supervisor.program, "claude");
276        assert_eq!(config.supervisor.timeout_secs, 45);
277        assert!(config.supervisor.trace_io);
278        assert_eq!(config.detector.silence_timeout_secs, 5);
279        assert_eq!(config.detector.answer_cooldown_millis, 1500);
280        assert!(config.detector.unknown_request_fallback);
281        assert!(config.detector.idle_input_fallback);
282        assert!(config.dangerous_mode.enabled);
283    }
284
285    #[test]
286    fn parse_partial_config() {
287        let toml = r#"
288[defaults]
289agent = "aider"
290"#;
291        let config: ProjectConfig = toml::from_str(toml).unwrap();
292        assert_eq!(config.defaults.agent, "aider");
293        assert_eq!(config.defaults.policy, Policy::Observe);
294        assert_eq!(config.defaults.max_retries, 3);
295        assert!(config.detector.unknown_request_fallback);
296        assert!(config.detector.idle_input_fallback);
297        assert!(!config.dangerous_mode.enabled);
298    }
299
300    #[test]
301    fn load_from_directory() {
302        let tmp = tempfile::tempdir().unwrap();
303        let batty_dir = tmp.path().join(".batty");
304        fs::create_dir_all(&batty_dir).unwrap();
305        fs::write(
306            batty_dir.join("config.toml"),
307            r#"
308[defaults]
309agent = "claude"
310policy = "act"
311dod = "cargo test"
312max_retries = 2
313"#,
314        )
315        .unwrap();
316
317        let (config, path) = ProjectConfig::load(tmp.path()).unwrap();
318        assert!(path.is_some());
319        assert_eq!(config.defaults.agent, "claude");
320        assert_eq!(config.defaults.policy, Policy::Act);
321        assert_eq!(config.defaults.max_retries, 2);
322    }
323
324    #[test]
325    fn load_returns_default_when_no_file() {
326        let tmp = tempfile::tempdir().unwrap();
327        let (config, path) = ProjectConfig::load(tmp.path()).unwrap();
328        assert!(path.is_none());
329        assert_eq!(config.defaults.agent, "claude");
330    }
331
332    #[test]
333    fn parse_auto_answer_config() {
334        let toml = r#"
335[defaults]
336agent = "claude"
337policy = "act"
338
339[policy.auto_answer]
340"Continue? [y/n]" = "y"
341"Allow tool" = "y"
342"#;
343        let config: ProjectConfig = toml::from_str(toml).unwrap();
344        assert_eq!(config.policy.auto_answer.len(), 2);
345        assert_eq!(
346            config.policy.auto_answer.get("Continue? [y/n]").unwrap(),
347            "y"
348        );
349        assert_eq!(config.policy.auto_answer.get("Allow tool").unwrap(), "y");
350    }
351
352    #[test]
353    fn load_walks_up_directories() {
354        let tmp = tempfile::tempdir().unwrap();
355        let batty_dir = tmp.path().join(".batty");
356        fs::create_dir_all(&batty_dir).unwrap();
357        fs::write(
358            batty_dir.join("config.toml"),
359            r#"
360[defaults]
361agent = "codex"
362"#,
363        )
364        .unwrap();
365
366        let nested = tmp.path().join("src").join("deep").join("nested");
367        fs::create_dir_all(&nested).unwrap();
368
369        let (config, path) = ProjectConfig::load(&nested).unwrap();
370        assert!(path.is_some());
371        assert_eq!(config.defaults.agent, "codex");
372    }
373}