1#![cfg_attr(not(test), allow(dead_code))]
2
3use anyhow::{Context, Result};
4use serde::Deserialize;
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7
8const CONFIG_FILENAME: &str = "config.toml";
9const CONFIG_DIR: &str = ".batty";
10
11#[derive(Debug, Default, Clone, Copy, Deserialize, PartialEq)]
12#[serde(rename_all = "kebab-case")]
13pub enum Policy {
14 #[default]
15 Observe,
16 Suggest,
17 Act,
18}
19
20#[derive(Debug, Deserialize)]
21pub struct Defaults {
22 #[serde(default = "default_agent")]
23 pub agent: String,
24 #[serde(default)]
25 pub policy: Policy,
26 #[serde(default)]
27 pub dod: Option<String>,
28 #[serde(default = "default_max_retries")]
29 pub max_retries: u32,
30}
31
32#[derive(Debug, Deserialize, Default)]
40pub struct PolicyConfig {
41 #[serde(default)]
42 pub auto_answer: HashMap<String, String>,
43}
44
45fn default_agent() -> String {
46 "claude".to_string()
47}
48
49fn default_max_retries() -> u32 {
50 3
51}
52
53fn default_supervisor_enabled() -> bool {
54 true
55}
56
57fn default_supervisor_program() -> String {
58 "claude".to_string()
59}
60
61fn default_supervisor_args() -> Vec<String> {
62 vec![
63 "-p".to_string(),
64 "--output-format".to_string(),
65 "text".to_string(),
66 ]
67}
68
69fn default_supervisor_timeout_secs() -> u64 {
70 60
71}
72
73fn default_supervisor_trace_io() -> bool {
74 true
75}
76
77fn default_detector_silence_timeout_secs() -> u64 {
78 3
79}
80
81fn default_detector_answer_cooldown_millis() -> u64 {
82 1000
83}
84
85fn default_detector_unknown_request_fallback() -> bool {
86 true
87}
88
89fn default_detector_idle_input_fallback() -> bool {
90 true
91}
92
93fn default_dangerous_mode_enabled() -> bool {
94 false
95}
96
97impl Default for Defaults {
98 fn default() -> Self {
99 Self {
100 agent: default_agent(),
101 policy: Policy::default(),
102 dod: None,
103 max_retries: default_max_retries(),
104 }
105 }
106}
107
108#[derive(Debug, Deserialize)]
109pub struct SupervisorConfig {
110 #[serde(default = "default_supervisor_enabled")]
111 pub enabled: bool,
112 #[serde(default = "default_supervisor_program")]
113 pub program: String,
114 #[serde(default = "default_supervisor_args")]
115 pub args: Vec<String>,
116 #[serde(default = "default_supervisor_timeout_secs")]
117 pub timeout_secs: u64,
118 #[serde(default = "default_supervisor_trace_io")]
119 pub trace_io: bool,
120}
121
122impl Default for SupervisorConfig {
123 fn default() -> Self {
124 Self {
125 enabled: default_supervisor_enabled(),
126 program: default_supervisor_program(),
127 args: default_supervisor_args(),
128 timeout_secs: default_supervisor_timeout_secs(),
129 trace_io: default_supervisor_trace_io(),
130 }
131 }
132}
133
134#[derive(Debug, Deserialize)]
135pub struct DetectorSettings {
136 #[serde(default = "default_detector_silence_timeout_secs")]
137 pub silence_timeout_secs: u64,
138 #[serde(default = "default_detector_answer_cooldown_millis")]
139 pub answer_cooldown_millis: u64,
140 #[serde(default = "default_detector_unknown_request_fallback")]
141 pub unknown_request_fallback: bool,
142 #[serde(default = "default_detector_idle_input_fallback")]
143 pub idle_input_fallback: bool,
144}
145
146impl Default for DetectorSettings {
147 fn default() -> Self {
148 Self {
149 silence_timeout_secs: default_detector_silence_timeout_secs(),
150 answer_cooldown_millis: default_detector_answer_cooldown_millis(),
151 unknown_request_fallback: default_detector_unknown_request_fallback(),
152 idle_input_fallback: default_detector_idle_input_fallback(),
153 }
154 }
155}
156
157#[derive(Debug, Deserialize)]
158pub struct DangerousModeConfig {
159 #[serde(default = "default_dangerous_mode_enabled")]
160 pub enabled: bool,
161}
162
163impl Default for DangerousModeConfig {
164 fn default() -> Self {
165 Self {
166 enabled: default_dangerous_mode_enabled(),
167 }
168 }
169}
170
171#[derive(Debug, Deserialize, Default)]
172pub struct ProjectConfig {
173 #[serde(default)]
174 pub defaults: Defaults,
175 #[serde(default)]
176 pub policy: PolicyConfig,
177 #[serde(default)]
178 pub supervisor: SupervisorConfig,
179 #[serde(default)]
180 pub detector: DetectorSettings,
181 #[serde(default)]
182 pub dangerous_mode: DangerousModeConfig,
183}
184
185impl ProjectConfig {
186 pub fn load(start: &Path) -> Result<(Self, Option<PathBuf>)> {
189 if let Some(path) = Self::find_config_file(start) {
190 let contents = std::fs::read_to_string(&path)
191 .with_context(|| format!("failed to read {}", path.display()))?;
192 let config: ProjectConfig = toml::from_str(&contents)
193 .with_context(|| format!("failed to parse {}", path.display()))?;
194 Ok((config, Some(path)))
195 } else {
196 Ok((ProjectConfig::default(), None))
197 }
198 }
199
200 fn find_config_file(start: &Path) -> Option<PathBuf> {
201 let mut dir = start.to_path_buf();
202 loop {
203 let candidate = dir.join(CONFIG_DIR).join(CONFIG_FILENAME);
204 if candidate.is_file() {
205 return Some(candidate);
206 }
207 if !dir.pop() {
208 return None;
209 }
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use std::fs;
218
219 #[test]
220 fn default_config_values() {
221 let config = ProjectConfig::default();
222 assert_eq!(config.defaults.agent, "claude");
223 assert_eq!(config.defaults.policy, Policy::Observe);
224 assert_eq!(config.defaults.max_retries, 3);
225 assert!(config.defaults.dod.is_none());
226 assert!(config.supervisor.enabled);
227 assert_eq!(config.supervisor.program, "claude");
228 assert_eq!(
229 config.supervisor.args,
230 vec!["-p", "--output-format", "text"]
231 );
232 assert_eq!(config.supervisor.timeout_secs, 60);
233 assert!(config.supervisor.trace_io);
234 assert_eq!(config.detector.silence_timeout_secs, 3);
235 assert_eq!(config.detector.answer_cooldown_millis, 1000);
236 assert!(config.detector.unknown_request_fallback);
237 assert!(config.detector.idle_input_fallback);
238 assert!(!config.dangerous_mode.enabled);
239 }
240
241 #[test]
242 fn parse_full_config() {
243 let toml = r#"
244[defaults]
245agent = "codex"
246policy = "act"
247dod = "cargo test --workspace"
248max_retries = 5
249
250[supervisor]
251enabled = true
252program = "claude"
253args = ["-p", "--output-format", "text"]
254timeout_secs = 45
255trace_io = true
256
257[detector]
258silence_timeout_secs = 5
259answer_cooldown_millis = 1500
260unknown_request_fallback = true
261idle_input_fallback = true
262
263[dangerous_mode]
264enabled = true
265"#;
266 let config: ProjectConfig = toml::from_str(toml).unwrap();
267 assert_eq!(config.defaults.agent, "codex");
268 assert_eq!(config.defaults.policy, Policy::Act);
269 assert_eq!(
270 config.defaults.dod.as_deref(),
271 Some("cargo test --workspace")
272 );
273 assert_eq!(config.defaults.max_retries, 5);
274 assert!(config.supervisor.enabled);
275 assert_eq!(config.supervisor.program, "claude");
276 assert_eq!(config.supervisor.timeout_secs, 45);
277 assert!(config.supervisor.trace_io);
278 assert_eq!(config.detector.silence_timeout_secs, 5);
279 assert_eq!(config.detector.answer_cooldown_millis, 1500);
280 assert!(config.detector.unknown_request_fallback);
281 assert!(config.detector.idle_input_fallback);
282 assert!(config.dangerous_mode.enabled);
283 }
284
285 #[test]
286 fn parse_partial_config() {
287 let toml = r#"
288[defaults]
289agent = "aider"
290"#;
291 let config: ProjectConfig = toml::from_str(toml).unwrap();
292 assert_eq!(config.defaults.agent, "aider");
293 assert_eq!(config.defaults.policy, Policy::Observe);
294 assert_eq!(config.defaults.max_retries, 3);
295 assert!(config.detector.unknown_request_fallback);
296 assert!(config.detector.idle_input_fallback);
297 assert!(!config.dangerous_mode.enabled);
298 }
299
300 #[test]
301 fn load_from_directory() {
302 let tmp = tempfile::tempdir().unwrap();
303 let batty_dir = tmp.path().join(".batty");
304 fs::create_dir_all(&batty_dir).unwrap();
305 fs::write(
306 batty_dir.join("config.toml"),
307 r#"
308[defaults]
309agent = "claude"
310policy = "act"
311dod = "cargo test"
312max_retries = 2
313"#,
314 )
315 .unwrap();
316
317 let (config, path) = ProjectConfig::load(tmp.path()).unwrap();
318 assert!(path.is_some());
319 assert_eq!(config.defaults.agent, "claude");
320 assert_eq!(config.defaults.policy, Policy::Act);
321 assert_eq!(config.defaults.max_retries, 2);
322 }
323
324 #[test]
325 fn load_returns_default_when_no_file() {
326 let tmp = tempfile::tempdir().unwrap();
327 let (config, path) = ProjectConfig::load(tmp.path()).unwrap();
328 assert!(path.is_none());
329 assert_eq!(config.defaults.agent, "claude");
330 }
331
332 #[test]
333 fn parse_auto_answer_config() {
334 let toml = r#"
335[defaults]
336agent = "claude"
337policy = "act"
338
339[policy.auto_answer]
340"Continue? [y/n]" = "y"
341"Allow tool" = "y"
342"#;
343 let config: ProjectConfig = toml::from_str(toml).unwrap();
344 assert_eq!(config.policy.auto_answer.len(), 2);
345 assert_eq!(
346 config.policy.auto_answer.get("Continue? [y/n]").unwrap(),
347 "y"
348 );
349 assert_eq!(config.policy.auto_answer.get("Allow tool").unwrap(), "y");
350 }
351
352 #[test]
353 fn load_walks_up_directories() {
354 let tmp = tempfile::tempdir().unwrap();
355 let batty_dir = tmp.path().join(".batty");
356 fs::create_dir_all(&batty_dir).unwrap();
357 fs::write(
358 batty_dir.join("config.toml"),
359 r#"
360[defaults]
361agent = "codex"
362"#,
363 )
364 .unwrap();
365
366 let nested = tmp.path().join("src").join("deep").join("nested");
367 fs::create_dir_all(&nested).unwrap();
368
369 let (config, path) = ProjectConfig::load(&nested).unwrap();
370 assert!(path.is_some());
371 assert_eq!(config.defaults.agent, "codex");
372 }
373}