1#![cfg_attr(not(test), allow(dead_code))]
2
3use anyhow::{Context, Result};
4use serde::Deserialize;
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7
8const CONFIG_FILENAME: &str = "config.toml";
9const CONFIG_DIR: &str = ".batty";
10
11#[derive(Debug, Default, Clone, Copy, Deserialize, PartialEq)]
12#[serde(rename_all = "kebab-case")]
13pub enum Policy {
14 #[default]
15 Observe,
16 Suggest,
17 Act,
18}
19
20#[derive(Debug, Deserialize)]
21pub struct Defaults {
22 #[serde(default = "default_agent")]
23 pub agent: String,
24 #[serde(default)]
25 pub policy: Policy,
26 #[serde(default)]
27 pub dod: Option<String>,
28 #[serde(default = "default_max_retries")]
29 pub max_retries: u32,
30}
31
32#[derive(Debug, Deserialize, Default)]
40#[allow(dead_code)] pub struct PolicyConfig {
42 #[serde(default)]
43 pub auto_answer: HashMap<String, String>,
44}
45
46fn default_agent() -> String {
47 "claude".to_string()
48}
49
50fn default_max_retries() -> u32 {
51 3
52}
53
54fn default_supervisor_enabled() -> bool {
55 true
56}
57
58fn default_supervisor_program() -> String {
59 "claude".to_string()
60}
61
62fn default_supervisor_args() -> Vec<String> {
63 vec![
64 "-p".to_string(),
65 "--output-format".to_string(),
66 "text".to_string(),
67 ]
68}
69
70fn default_supervisor_timeout_secs() -> u64 {
71 60
72}
73
74fn default_supervisor_trace_io() -> bool {
75 true
76}
77
78fn default_detector_silence_timeout_secs() -> u64 {
79 3
80}
81
82fn default_detector_answer_cooldown_millis() -> u64 {
83 1000
84}
85
86fn default_detector_unknown_request_fallback() -> bool {
87 true
88}
89
90fn default_detector_idle_input_fallback() -> bool {
91 true
92}
93
94fn default_dangerous_mode_enabled() -> bool {
95 false
96}
97
98impl Default for Defaults {
99 fn default() -> Self {
100 Self {
101 agent: default_agent(),
102 policy: Policy::default(),
103 dod: None,
104 max_retries: default_max_retries(),
105 }
106 }
107}
108
109#[derive(Debug, Deserialize)]
110pub struct SupervisorConfig {
111 #[serde(default = "default_supervisor_enabled")]
112 pub enabled: bool,
113 #[serde(default = "default_supervisor_program")]
114 pub program: String,
115 #[serde(default = "default_supervisor_args")]
116 pub args: Vec<String>,
117 #[serde(default = "default_supervisor_timeout_secs")]
118 pub timeout_secs: u64,
119 #[serde(default = "default_supervisor_trace_io")]
120 pub trace_io: bool,
121}
122
123impl Default for SupervisorConfig {
124 fn default() -> Self {
125 Self {
126 enabled: default_supervisor_enabled(),
127 program: default_supervisor_program(),
128 args: default_supervisor_args(),
129 timeout_secs: default_supervisor_timeout_secs(),
130 trace_io: default_supervisor_trace_io(),
131 }
132 }
133}
134
135#[derive(Debug, Deserialize)]
136pub struct DetectorSettings {
137 #[serde(default = "default_detector_silence_timeout_secs")]
138 pub silence_timeout_secs: u64,
139 #[serde(default = "default_detector_answer_cooldown_millis")]
140 pub answer_cooldown_millis: u64,
141 #[serde(default = "default_detector_unknown_request_fallback")]
142 pub unknown_request_fallback: bool,
143 #[serde(default = "default_detector_idle_input_fallback")]
144 pub idle_input_fallback: bool,
145}
146
147impl Default for DetectorSettings {
148 fn default() -> Self {
149 Self {
150 silence_timeout_secs: default_detector_silence_timeout_secs(),
151 answer_cooldown_millis: default_detector_answer_cooldown_millis(),
152 unknown_request_fallback: default_detector_unknown_request_fallback(),
153 idle_input_fallback: default_detector_idle_input_fallback(),
154 }
155 }
156}
157
158#[derive(Debug, Deserialize)]
159pub struct DangerousModeConfig {
160 #[serde(default = "default_dangerous_mode_enabled")]
161 pub enabled: bool,
162}
163
164impl Default for DangerousModeConfig {
165 fn default() -> Self {
166 Self {
167 enabled: default_dangerous_mode_enabled(),
168 }
169 }
170}
171
172#[derive(Debug, Deserialize, Default)]
173pub struct ProjectConfig {
174 #[serde(default)]
175 pub defaults: Defaults,
176 #[serde(default)]
177 #[allow(dead_code)] pub policy: PolicyConfig,
179 #[serde(default)]
180 pub supervisor: SupervisorConfig,
181 #[serde(default)]
182 pub detector: DetectorSettings,
183 #[serde(default)]
184 pub dangerous_mode: DangerousModeConfig,
185}
186
187impl ProjectConfig {
188 pub fn load(start: &Path) -> Result<(Self, Option<PathBuf>)> {
191 if let Some(path) = Self::find_config_file(start) {
192 let contents = std::fs::read_to_string(&path)
193 .with_context(|| format!("failed to read {}", path.display()))?;
194 let config: ProjectConfig = toml::from_str(&contents)
195 .with_context(|| format!("failed to parse {}", path.display()))?;
196 Ok((config, Some(path)))
197 } else {
198 Ok((ProjectConfig::default(), None))
199 }
200 }
201
202 fn find_config_file(start: &Path) -> Option<PathBuf> {
203 let mut dir = start.to_path_buf();
204 loop {
205 let candidate = dir.join(CONFIG_DIR).join(CONFIG_FILENAME);
206 if candidate.is_file() {
207 return Some(candidate);
208 }
209 if !dir.pop() {
210 return None;
211 }
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use std::fs;
220
221 #[test]
222 fn default_config_values() {
223 let config = ProjectConfig::default();
224 assert_eq!(config.defaults.agent, "claude");
225 assert_eq!(config.defaults.policy, Policy::Observe);
226 assert_eq!(config.defaults.max_retries, 3);
227 assert!(config.defaults.dod.is_none());
228 assert!(config.supervisor.enabled);
229 assert_eq!(config.supervisor.program, "claude");
230 assert_eq!(
231 config.supervisor.args,
232 vec!["-p", "--output-format", "text"]
233 );
234 assert_eq!(config.supervisor.timeout_secs, 60);
235 assert!(config.supervisor.trace_io);
236 assert_eq!(config.detector.silence_timeout_secs, 3);
237 assert_eq!(config.detector.answer_cooldown_millis, 1000);
238 assert!(config.detector.unknown_request_fallback);
239 assert!(config.detector.idle_input_fallback);
240 assert!(!config.dangerous_mode.enabled);
241 }
242
243 #[test]
244 fn parse_full_config() {
245 let toml = r#"
246[defaults]
247agent = "codex"
248policy = "act"
249dod = "cargo test --workspace"
250max_retries = 5
251
252[supervisor]
253enabled = true
254program = "claude"
255args = ["-p", "--output-format", "text"]
256timeout_secs = 45
257trace_io = true
258
259[detector]
260silence_timeout_secs = 5
261answer_cooldown_millis = 1500
262unknown_request_fallback = true
263idle_input_fallback = true
264
265[dangerous_mode]
266enabled = true
267"#;
268 let config: ProjectConfig = toml::from_str(toml).unwrap();
269 assert_eq!(config.defaults.agent, "codex");
270 assert_eq!(config.defaults.policy, Policy::Act);
271 assert_eq!(
272 config.defaults.dod.as_deref(),
273 Some("cargo test --workspace")
274 );
275 assert_eq!(config.defaults.max_retries, 5);
276 assert!(config.supervisor.enabled);
277 assert_eq!(config.supervisor.program, "claude");
278 assert_eq!(config.supervisor.timeout_secs, 45);
279 assert!(config.supervisor.trace_io);
280 assert_eq!(config.detector.silence_timeout_secs, 5);
281 assert_eq!(config.detector.answer_cooldown_millis, 1500);
282 assert!(config.detector.unknown_request_fallback);
283 assert!(config.detector.idle_input_fallback);
284 assert!(config.dangerous_mode.enabled);
285 }
286
287 #[test]
288 fn parse_partial_config() {
289 let toml = r#"
290[defaults]
291agent = "aider"
292"#;
293 let config: ProjectConfig = toml::from_str(toml).unwrap();
294 assert_eq!(config.defaults.agent, "aider");
295 assert_eq!(config.defaults.policy, Policy::Observe);
296 assert_eq!(config.defaults.max_retries, 3);
297 assert!(config.detector.unknown_request_fallback);
298 assert!(config.detector.idle_input_fallback);
299 assert!(!config.dangerous_mode.enabled);
300 }
301
302 #[test]
303 fn load_from_directory() {
304 let tmp = tempfile::tempdir().unwrap();
305 let batty_dir = tmp.path().join(".batty");
306 fs::create_dir_all(&batty_dir).unwrap();
307 fs::write(
308 batty_dir.join("config.toml"),
309 r#"
310[defaults]
311agent = "claude"
312policy = "act"
313dod = "cargo test"
314max_retries = 2
315"#,
316 )
317 .unwrap();
318
319 let (config, path) = ProjectConfig::load(tmp.path()).unwrap();
320 assert!(path.is_some());
321 assert_eq!(config.defaults.agent, "claude");
322 assert_eq!(config.defaults.policy, Policy::Act);
323 assert_eq!(config.defaults.max_retries, 2);
324 }
325
326 #[test]
327 fn load_returns_default_when_no_file() {
328 let tmp = tempfile::tempdir().unwrap();
329 let (config, path) = ProjectConfig::load(tmp.path()).unwrap();
330 assert!(path.is_none());
331 assert_eq!(config.defaults.agent, "claude");
332 }
333
334 #[test]
335 fn parse_auto_answer_config() {
336 let toml = r#"
337[defaults]
338agent = "claude"
339policy = "act"
340
341[policy.auto_answer]
342"Continue? [y/n]" = "y"
343"Allow tool" = "y"
344"#;
345 let config: ProjectConfig = toml::from_str(toml).unwrap();
346 assert_eq!(config.policy.auto_answer.len(), 2);
347 assert_eq!(
348 config.policy.auto_answer.get("Continue? [y/n]").unwrap(),
349 "y"
350 );
351 assert_eq!(config.policy.auto_answer.get("Allow tool").unwrap(), "y");
352 }
353
354 #[test]
355 fn load_walks_up_directories() {
356 let tmp = tempfile::tempdir().unwrap();
357 let batty_dir = tmp.path().join(".batty");
358 fs::create_dir_all(&batty_dir).unwrap();
359 fs::write(
360 batty_dir.join("config.toml"),
361 r#"
362[defaults]
363agent = "codex"
364"#,
365 )
366 .unwrap();
367
368 let nested = tmp.path().join("src").join("deep").join("nested");
369 fs::create_dir_all(&nested).unwrap();
370
371 let (config, path) = ProjectConfig::load(&nested).unwrap();
372 assert!(path.is_some());
373 assert_eq!(config.defaults.agent, "codex");
374 }
375}