1use std::path::{Path, PathBuf};
2use std::fs;
3
4use crate::config::{self, HawkConfig};
5use crate::error::HawkError;
6
7pub type Result<T> = std::result::Result<T, HawkError>;
8
9pub enum ConfigScope {
10 Global,
11 Project,
12 Agent(String),
13}
14
15pub struct ConfigValue {
16 pub value: String,
17 pub source: ConfigScope,
18}
19
20pub struct LayeredConfig {
21 global: Option<HawkConfig>,
22 project: Option<HawkConfig>,
23 global_path: PathBuf,
24 project_path: Option<PathBuf>,
25}
26
27impl LayeredConfig {
28 pub fn load(project_dir: Option<&Path>) -> Result<Self> {
29 let global_path = global_config_path()?;
30 let global = load_optional(&global_path)?;
31
32 let (project, project_path) = if let Some(dir) = project_dir {
33 let p = dir.join("hawk.toml");
34 let cfg = load_optional(&p)?;
35 (cfg, Some(p))
36 } else {
37 (None, None)
38 };
39
40 Ok(Self { global, project, global_path, project_path })
41 }
42
43 pub fn get_effective(&self, key: &str) -> Option<ConfigValue> {
46 if let Some(proj) = &self.project {
47 if let Some(v) = extract(proj, key) {
48 return Some(ConfigValue { value: v, source: ConfigScope::Project });
49 }
50 }
51 if let Some(glob) = &self.global {
52 if let Some(v) = extract(glob, key) {
53 return Some(ConfigValue { value: v, source: ConfigScope::Global });
54 }
55 }
56 None
57 }
58
59 pub fn set(&self, key: &str, value: &str, scope: ConfigScope) -> Result<()> {
62 let path = match scope {
63 ConfigScope::Global => self.global_path.clone(),
64 ConfigScope::Project => self
65 .project_path
66 .clone()
67 .ok_or_else(|| HawkError::Config("no project directory set".to_string()))?,
68 ConfigScope::Agent(_) => {
69 return Err(HawkError::Config(
70 "agent-level config is managed via the agent manifest".to_string(),
71 ))
72 }
73 };
74
75 let mut doc = load_toml_document(&path)?;
76 apply_key(&mut doc, key, value)?;
77
78 if let Some(parent) = path.parent() {
79 fs::create_dir_all(parent)
80 .map_err(|e| HawkError::Config(format!("cannot create config dir: {e}")))?;
81 }
82 fs::write(&path, doc.to_string())
83 .map_err(|e| HawkError::Config(format!("cannot write config: {e}")))?;
84 Ok(())
85 }
86
87 pub fn merged(&self) -> HawkConfig {
89 let base = self.global.clone().unwrap_or_default();
90 if let Some(proj) = &self.project {
91 let mut merged = base;
93 if !proj.llm.providers.is_empty() {
94 merged.llm.providers = proj.llm.providers.clone();
95 }
96 merged
97 } else {
98 base
99 }
100 }
101
102 pub fn validate(&self) -> Result<Vec<String>> {
105 let mut errors = Vec::new();
106 for (label, cfg) in [("global", &self.global), ("project", &self.project)] {
107 if let Some(c) = cfg {
108 let toml_str = config::to_toml(c)
109 .map_err(|e| HawkError::Config(format!("serialization error: {e}")))?;
110 if let Err(e) = config::parse(&toml_str) {
111 errors.push(format!("[{label}] {e}"));
112 }
113 }
114 }
115 Ok(errors)
116 }
117}
118
119fn global_config_path() -> Result<PathBuf> {
122 let home = dirs_next::home_dir()
123 .ok_or_else(|| HawkError::Config("cannot determine home directory".to_string()))?;
124 Ok(home.join(".hawk").join("config.toml"))
125}
126
127fn load_optional(path: &Path) -> Result<Option<HawkConfig>> {
128 if !path.exists() {
129 return Ok(None);
130 }
131 let text = fs::read_to_string(path)
132 .map_err(|e| HawkError::Config(format!("cannot read {}: {e}", path.display())))?;
133 let cfg = config::parse(&text)?;
134 Ok(Some(cfg))
135}
136
137fn extract(cfg: &HawkConfig, key: &str) -> Option<String> {
139 match key {
140 "core.log_level" => Some(cfg.core.log_level.clone()),
141 "core.session_retention_days" => Some(cfg.core.session_retention_days.to_string()),
142 "core.pattern_retention_days" => Some(cfg.core.pattern_retention_days.to_string()),
143 "privacy.mode" => Some(cfg.privacy.mode.clone()),
144 "llm.providers" => {
145 let s = serde_json::to_string(&cfg.llm.providers).ok()?;
146 Some(s)
147 }
148 "llm.pricing.openai_gpt4_prompt" => {
149 Some(cfg.llm.pricing.openai_gpt4_prompt.to_string())
150 }
151 "llm.pricing.openai_gpt4_completion" => {
152 Some(cfg.llm.pricing.openai_gpt4_completion.to_string())
153 }
154 "savepoint.auto_snapshot" => Some(cfg.savepoint.auto_snapshot.to_string()),
155 "savepoint.max_snapshots_per_agent" => {
156 Some(cfg.savepoint.max_snapshots_per_agent.to_string())
157 }
158 "bus.message_retention_seconds" => {
159 Some(cfg.bus.message_retention_seconds.to_string())
160 }
161 "bus.max_queue_size" => Some(cfg.bus.max_queue_size.to_string()),
162 "sync.enabled" => Some(cfg.sync.enabled.to_string()),
163 "sync.conflict_strategy" => Some(cfg.sync.conflict_strategy.clone()),
164 "compress.token_threshold" => Some(cfg.compress.token_threshold.to_string()),
165 "compress.cache_max_entries" => Some(cfg.compress.cache_max_entries.to_string()),
166 "healing.max_retries" => Some(cfg.healing.max_retries.to_string()),
167 "healing.enabled" => Some(cfg.healing.enabled.to_string()),
168 _ => None,
169 }
170}
171
172fn load_toml_document(path: &Path) -> Result<toml_edit::DocumentMut> {
175 if !path.exists() {
176 return Ok(toml_edit::DocumentMut::new());
177 }
178 let text = fs::read_to_string(path)
179 .map_err(|e| HawkError::Config(format!("cannot read {}: {e}", path.display())))?;
180 text.parse::<toml_edit::DocumentMut>()
181 .map_err(|e| HawkError::Config(format!("TOML parse error in {}: {e}", path.display())))
182}
183
184fn apply_key(doc: &mut toml_edit::DocumentMut, key: &str, value: &str) -> Result<()> {
186 let parts: Vec<&str> = key.splitn(3, '.').collect();
187 match parts.as_slice() {
188 [section, field] => {
189 let table = doc[section].or_insert(toml_edit::table());
190 table[field] = toml_edit::value(value);
191 }
192 [section, subsection, field] => {
193 let outer = doc[section].or_insert(toml_edit::table());
194 let inner = outer[subsection].or_insert(toml_edit::table());
195 inner[field] = toml_edit::value(value);
196 }
197 [field] => {
198 doc[field] = toml_edit::value(value);
199 }
200 _ => {
201 return Err(HawkError::Config(format!(
202 "key \"{key}\" has too many segments (max 3)"
203 )))
204 }
205 }
206 Ok(())
207}
208
209#[cfg(test)]
212mod tests {
213 use super::*;
214 use tempfile::TempDir;
215
216 fn write_toml(dir: &Path, name: &str, content: &str) {
217 fs::write(dir.join(name), content).unwrap();
218 }
219
220 const PROJECT_TOML: &str = r#"
221[core]
222log_level = "debug"
223session_retention_days = 7
224pattern_retention_days = 14
225
226[privacy]
227mode = "local-only"
228
229[healing]
230max_retries = 5
231enabled = true
232"#;
233
234 const GLOBAL_TOML: &str = r#"
235[core]
236log_level = "warn"
237session_retention_days = 30
238pattern_retention_days = 90
239
240[privacy]
241mode = "standard"
242
243[healing]
244max_retries = 3
245enabled = true
246"#;
247
248 fn make_layered(tmp: &TempDir, global: &str, project: &str) -> LayeredConfig {
249 let global_dir = tmp.path().join("global");
250 fs::create_dir_all(&global_dir).unwrap();
251 fs::write(global_dir.join("config.toml"), global).unwrap();
252
253 let project_dir = tmp.path().join("project");
254 fs::create_dir_all(&project_dir).unwrap();
255 write_toml(&project_dir, "hawk.toml", project);
256
257 let global_cfg = config::parse(global).unwrap();
259 let project_cfg = config::parse(project).unwrap();
260 LayeredConfig {
261 global: Some(global_cfg),
262 project: Some(project_cfg),
263 global_path: global_dir.join("config.toml"),
264 project_path: Some(project_dir.join("hawk.toml")),
265 }
266 }
267
268 #[test]
269 fn project_overrides_global() {
270 let tmp = TempDir::new().unwrap();
271 let lc = make_layered(&tmp, GLOBAL_TOML, PROJECT_TOML);
272
273 let v = lc.get_effective("core.log_level").unwrap();
274 assert_eq!(v.value, "debug");
275 assert!(matches!(v.source, ConfigScope::Project));
276 }
277
278 #[test]
279 fn global_used_when_no_project_layer() {
280 let global_cfg = config::parse(GLOBAL_TOML).unwrap();
282 let lc = LayeredConfig {
283 global: Some(global_cfg),
284 project: None,
285 global_path: PathBuf::from("/tmp/g.toml"),
286 project_path: None,
287 };
288 let v = lc.get_effective("core.log_level").unwrap();
289 assert_eq!(v.value, "warn");
290 assert!(matches!(v.source, ConfigScope::Global));
291 }
292
293 #[test]
294 fn unknown_key_returns_none() {
295 let tmp = TempDir::new().unwrap();
296 let lc = make_layered(&tmp, GLOBAL_TOML, PROJECT_TOML);
297 assert!(lc.get_effective("nonexistent.key").is_none());
298 }
299
300 #[test]
301 fn set_project_scope_writes_file() {
302 let tmp = TempDir::new().unwrap();
303 let project_dir = tmp.path().join("proj");
304 fs::create_dir_all(&project_dir).unwrap();
305 write_toml(&project_dir, "hawk.toml", PROJECT_TOML);
306
307 let lc = LayeredConfig {
308 global: None,
309 project: config::parse(PROJECT_TOML).ok(),
310 global_path: tmp.path().join("g.toml"),
311 project_path: Some(project_dir.join("hawk.toml")),
312 };
313
314 lc.set("core.log_level", "trace", ConfigScope::Project).unwrap();
315
316 let written = fs::read_to_string(project_dir.join("hawk.toml")).unwrap();
317 assert!(written.contains("trace"));
318 }
319
320 #[test]
321 fn set_global_scope_writes_file() {
322 let tmp = TempDir::new().unwrap();
323 let global_path = tmp.path().join("config.toml");
324 fs::write(&global_path, GLOBAL_TOML).unwrap();
325
326 let lc = LayeredConfig {
327 global: config::parse(GLOBAL_TOML).ok(),
328 project: None,
329 global_path: global_path.clone(),
330 project_path: None,
331 };
332
333 lc.set("privacy.mode", "air-gapped", ConfigScope::Global).unwrap();
334
335 let written = fs::read_to_string(&global_path).unwrap();
336 assert!(written.contains("air-gapped"));
337 }
338
339 #[test]
340 fn set_agent_scope_returns_error() {
341 let tmp = TempDir::new().unwrap();
342 let lc = LayeredConfig {
343 global: None,
344 project: None,
345 global_path: tmp.path().join("g.toml"),
346 project_path: None,
347 };
348 let err = lc
349 .set("core.log_level", "info", ConfigScope::Agent("my-agent".to_string()))
350 .unwrap_err();
351 assert!(err.to_string().contains("manifest"));
352 }
353
354 #[test]
355 fn validate_returns_empty_for_valid_configs() {
356 let tmp = TempDir::new().unwrap();
357 let lc = make_layered(&tmp, GLOBAL_TOML, PROJECT_TOML);
358 let errors = lc.validate().unwrap();
359 assert!(errors.is_empty(), "unexpected errors: {errors:?}");
360 }
361
362 #[test]
363 fn validate_reports_invalid_layer() {
364 let bad = "[core]\nlog_level = \"verbose\"\nsession_retention_days = 30\npattern_retention_days = 90\n";
365 let tmp = TempDir::new().unwrap();
366 let _lc = LayeredConfig {
367 global: config::parse(bad).ok(), project: None,
369 global_path: tmp.path().join("g.toml"),
370 project_path: None,
371 };
372 let bad_cfg = toml::from_str::<HawkConfig>(bad).unwrap();
375 let lc2 = LayeredConfig {
376 global: Some(bad_cfg),
377 project: None,
378 global_path: tmp.path().join("g.toml"),
379 project_path: None,
380 };
381 let errors = lc2.validate().unwrap();
382 assert!(!errors.is_empty(), "expected validation errors for bad config");
383 }
384
385 #[test]
386 fn no_project_dir_loads_only_global() {
387 let global_cfg = config::parse(GLOBAL_TOML).unwrap();
388 let lc = LayeredConfig {
389 global: Some(global_cfg),
390 project: None,
391 global_path: PathBuf::from("/tmp/g.toml"),
392 project_path: None,
393 };
394 let v = lc.get_effective("core.log_level").unwrap();
395 assert_eq!(v.value, "warn");
396 assert!(matches!(v.source, ConfigScope::Global));
397 }
398
399 #[test]
400 fn set_creates_file_if_not_exists() {
401 let tmp = TempDir::new().unwrap();
402 let global_path = tmp.path().join("new_config.toml");
403 assert!(!global_path.exists());
404
405 let lc = LayeredConfig {
406 global: None,
407 project: None,
408 global_path: global_path.clone(),
409 project_path: None,
410 };
411
412 lc.set("core.log_level", "error", ConfigScope::Global).unwrap();
413 assert!(global_path.exists());
414 let content = fs::read_to_string(&global_path).unwrap();
415 assert!(content.contains("error"));
416 }
417}