Skip to main content

openhawk_core/
config_engine.rs

1use std::path::{Path, PathBuf};
2use std::fs;
3
4use crate::config::{self, HawkConfig};
5use crate::error::HawkError;
6
7pub type Result<T> = std::result::Result<T, HawkError>;
8
9pub enum ConfigScope {
10    Global,
11    Project,
12    Agent(String),
13}
14
15pub struct ConfigValue {
16    pub value: String,
17    pub source: ConfigScope,
18}
19
20pub struct LayeredConfig {
21    global: Option<HawkConfig>,
22    project: Option<HawkConfig>,
23    global_path: PathBuf,
24    project_path: Option<PathBuf>,
25}
26
27impl LayeredConfig {
28    pub fn load(project_dir: Option<&Path>) -> Result<Self> {
29        let global_path = global_config_path()?;
30        let global = load_optional(&global_path)?;
31
32        let (project, project_path) = if let Some(dir) = project_dir {
33            let p = dir.join("hawk.toml");
34            let cfg = load_optional(&p)?;
35            (cfg, Some(p))
36        } else {
37            (None, None)
38        };
39
40        Ok(Self { global, project, global_path, project_path })
41    }
42
43    /// Returns the effective value for a dot-notation key, with source annotation.
44    /// Priority: project > global (agent-level is handled by callers via manifest).
45    pub fn get_effective(&self, key: &str) -> Option<ConfigValue> {
46        if let Some(proj) = &self.project {
47            if let Some(v) = extract(proj, key) {
48                return Some(ConfigValue { value: v, source: ConfigScope::Project });
49            }
50        }
51        if let Some(glob) = &self.global {
52            if let Some(v) = extract(glob, key) {
53                return Some(ConfigValue { value: v, source: ConfigScope::Global });
54            }
55        }
56        None
57    }
58
59    /// Persists `value` for `key` to the file indicated by `scope`.
60    /// Parses the existing file (if any), updates the key, and writes back.
61    pub fn set(&self, key: &str, value: &str, scope: ConfigScope) -> Result<()> {
62        let path = match scope {
63            ConfigScope::Global => self.global_path.clone(),
64            ConfigScope::Project => self
65                .project_path
66                .clone()
67                .ok_or_else(|| HawkError::Config("no project directory set".to_string()))?,
68            ConfigScope::Agent(_) => {
69                return Err(HawkError::Config(
70                    "agent-level config is managed via the agent manifest".to_string(),
71                ))
72            }
73        };
74
75        let mut doc = load_toml_document(&path)?;
76        apply_key(&mut doc, key, value)?;
77
78        if let Some(parent) = path.parent() {
79            fs::create_dir_all(parent)
80                .map_err(|e| HawkError::Config(format!("cannot create config dir: {e}")))?;
81        }
82        fs::write(&path, doc.to_string())
83            .map_err(|e| HawkError::Config(format!("cannot write config: {e}")))?;
84        Ok(())
85    }
86
87    /// Returns the merged `HawkConfig` (project layer overrides global defaults).
88    pub fn merged(&self) -> HawkConfig {
89        let base = self.global.clone().unwrap_or_default();
90        if let Some(proj) = &self.project {
91            // Merge: project values win where they differ from defaults.
92            let mut merged = base;
93            if !proj.llm.providers.is_empty() {
94                merged.llm.providers = proj.llm.providers.clone();
95            }
96            merged
97        } else {
98            base
99        }
100    }
101
102    /// Validates all loaded layers against the schema.
103    /// Returns a list of error messages (empty = valid).
104    pub fn validate(&self) -> Result<Vec<String>> {
105        let mut errors = Vec::new();
106        for (label, cfg) in [("global", &self.global), ("project", &self.project)] {
107            if let Some(c) = cfg {
108                let toml_str = config::to_toml(c)
109                    .map_err(|e| HawkError::Config(format!("serialization error: {e}")))?;
110                if let Err(e) = config::parse(&toml_str) {
111                    errors.push(format!("[{label}] {e}"));
112                }
113            }
114        }
115        Ok(errors)
116    }
117}
118
119// ── Helpers ───────────────────────────────────────────────────────────────────
120
121fn global_config_path() -> Result<PathBuf> {
122    let home = dirs_next::home_dir()
123        .ok_or_else(|| HawkError::Config("cannot determine home directory".to_string()))?;
124    Ok(home.join(".hawk").join("config.toml"))
125}
126
127fn load_optional(path: &Path) -> Result<Option<HawkConfig>> {
128    if !path.exists() {
129        return Ok(None);
130    }
131    let text = fs::read_to_string(path)
132        .map_err(|e| HawkError::Config(format!("cannot read {}: {e}", path.display())))?;
133    let cfg = config::parse(&text)?;
134    Ok(Some(cfg))
135}
136
137/// Extract a dot-notation key from a `HawkConfig` as a string.
138fn extract(cfg: &HawkConfig, key: &str) -> Option<String> {
139    match key {
140        "core.log_level" => Some(cfg.core.log_level.clone()),
141        "core.session_retention_days" => Some(cfg.core.session_retention_days.to_string()),
142        "core.pattern_retention_days" => Some(cfg.core.pattern_retention_days.to_string()),
143        "privacy.mode" => Some(cfg.privacy.mode.clone()),
144        "llm.providers" => {
145            let s = serde_json::to_string(&cfg.llm.providers).ok()?;
146            Some(s)
147        }
148        "llm.pricing.openai_gpt4_prompt" => {
149            Some(cfg.llm.pricing.openai_gpt4_prompt.to_string())
150        }
151        "llm.pricing.openai_gpt4_completion" => {
152            Some(cfg.llm.pricing.openai_gpt4_completion.to_string())
153        }
154        "savepoint.auto_snapshot" => Some(cfg.savepoint.auto_snapshot.to_string()),
155        "savepoint.max_snapshots_per_agent" => {
156            Some(cfg.savepoint.max_snapshots_per_agent.to_string())
157        }
158        "bus.message_retention_seconds" => {
159            Some(cfg.bus.message_retention_seconds.to_string())
160        }
161        "bus.max_queue_size" => Some(cfg.bus.max_queue_size.to_string()),
162        "sync.enabled" => Some(cfg.sync.enabled.to_string()),
163        "sync.conflict_strategy" => Some(cfg.sync.conflict_strategy.clone()),
164        "compress.token_threshold" => Some(cfg.compress.token_threshold.to_string()),
165        "compress.cache_max_entries" => Some(cfg.compress.cache_max_entries.to_string()),
166        "healing.max_retries" => Some(cfg.healing.max_retries.to_string()),
167        "healing.enabled" => Some(cfg.healing.enabled.to_string()),
168        _ => None,
169    }
170}
171
172/// Load the file as a raw `toml_edit::DocumentMut` (preserves formatting).
173/// Returns an empty document if the file does not exist.
174fn load_toml_document(path: &Path) -> Result<toml_edit::DocumentMut> {
175    if !path.exists() {
176        return Ok(toml_edit::DocumentMut::new());
177    }
178    let text = fs::read_to_string(path)
179        .map_err(|e| HawkError::Config(format!("cannot read {}: {e}", path.display())))?;
180    text.parse::<toml_edit::DocumentMut>()
181        .map_err(|e| HawkError::Config(format!("TOML parse error in {}: {e}", path.display())))
182}
183
184/// Write a scalar string value at a dot-notation path into a `toml_edit::DocumentMut`.
185fn apply_key(doc: &mut toml_edit::DocumentMut, key: &str, value: &str) -> Result<()> {
186    let parts: Vec<&str> = key.splitn(3, '.').collect();
187    match parts.as_slice() {
188        [section, field] => {
189            let table = doc[section].or_insert(toml_edit::table());
190            table[field] = toml_edit::value(value);
191        }
192        [section, subsection, field] => {
193            let outer = doc[section].or_insert(toml_edit::table());
194            let inner = outer[subsection].or_insert(toml_edit::table());
195            inner[field] = toml_edit::value(value);
196        }
197        [field] => {
198            doc[field] = toml_edit::value(value);
199        }
200        _ => {
201            return Err(HawkError::Config(format!(
202                "key \"{key}\" has too many segments (max 3)"
203            )))
204        }
205    }
206    Ok(())
207}
208
209// ── Tests ─────────────────────────────────────────────────────────────────────
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use tempfile::TempDir;
215
216    fn write_toml(dir: &Path, name: &str, content: &str) {
217        fs::write(dir.join(name), content).unwrap();
218    }
219
220    const PROJECT_TOML: &str = r#"
221[core]
222log_level = "debug"
223session_retention_days = 7
224pattern_retention_days = 14
225
226[privacy]
227mode = "local-only"
228
229[healing]
230max_retries = 5
231enabled = true
232"#;
233
234    const GLOBAL_TOML: &str = r#"
235[core]
236log_level = "warn"
237session_retention_days = 30
238pattern_retention_days = 90
239
240[privacy]
241mode = "standard"
242
243[healing]
244max_retries = 3
245enabled = true
246"#;
247
248    fn make_layered(tmp: &TempDir, global: &str, project: &str) -> LayeredConfig {
249        let global_dir = tmp.path().join("global");
250        fs::create_dir_all(&global_dir).unwrap();
251        fs::write(global_dir.join("config.toml"), global).unwrap();
252
253        let project_dir = tmp.path().join("project");
254        fs::create_dir_all(&project_dir).unwrap();
255        write_toml(&project_dir, "hawk.toml", project);
256
257        // Manually construct so we don't depend on $HOME
258        let global_cfg = config::parse(global).unwrap();
259        let project_cfg = config::parse(project).unwrap();
260        LayeredConfig {
261            global: Some(global_cfg),
262            project: Some(project_cfg),
263            global_path: global_dir.join("config.toml"),
264            project_path: Some(project_dir.join("hawk.toml")),
265        }
266    }
267
268    #[test]
269    fn project_overrides_global() {
270        let tmp = TempDir::new().unwrap();
271        let lc = make_layered(&tmp, GLOBAL_TOML, PROJECT_TOML);
272
273        let v = lc.get_effective("core.log_level").unwrap();
274        assert_eq!(v.value, "debug");
275        assert!(matches!(v.source, ConfigScope::Project));
276    }
277
278    #[test]
279    fn global_used_when_no_project_layer() {
280        // Only global layer present — all keys should come from global
281        let global_cfg = config::parse(GLOBAL_TOML).unwrap();
282        let lc = LayeredConfig {
283            global: Some(global_cfg),
284            project: None,
285            global_path: PathBuf::from("/tmp/g.toml"),
286            project_path: None,
287        };
288        let v = lc.get_effective("core.log_level").unwrap();
289        assert_eq!(v.value, "warn");
290        assert!(matches!(v.source, ConfigScope::Global));
291    }
292
293    #[test]
294    fn unknown_key_returns_none() {
295        let tmp = TempDir::new().unwrap();
296        let lc = make_layered(&tmp, GLOBAL_TOML, PROJECT_TOML);
297        assert!(lc.get_effective("nonexistent.key").is_none());
298    }
299
300    #[test]
301    fn set_project_scope_writes_file() {
302        let tmp = TempDir::new().unwrap();
303        let project_dir = tmp.path().join("proj");
304        fs::create_dir_all(&project_dir).unwrap();
305        write_toml(&project_dir, "hawk.toml", PROJECT_TOML);
306
307        let lc = LayeredConfig {
308            global: None,
309            project: config::parse(PROJECT_TOML).ok(),
310            global_path: tmp.path().join("g.toml"),
311            project_path: Some(project_dir.join("hawk.toml")),
312        };
313
314        lc.set("core.log_level", "trace", ConfigScope::Project).unwrap();
315
316        let written = fs::read_to_string(project_dir.join("hawk.toml")).unwrap();
317        assert!(written.contains("trace"));
318    }
319
320    #[test]
321    fn set_global_scope_writes_file() {
322        let tmp = TempDir::new().unwrap();
323        let global_path = tmp.path().join("config.toml");
324        fs::write(&global_path, GLOBAL_TOML).unwrap();
325
326        let lc = LayeredConfig {
327            global: config::parse(GLOBAL_TOML).ok(),
328            project: None,
329            global_path: global_path.clone(),
330            project_path: None,
331        };
332
333        lc.set("privacy.mode", "air-gapped", ConfigScope::Global).unwrap();
334
335        let written = fs::read_to_string(&global_path).unwrap();
336        assert!(written.contains("air-gapped"));
337    }
338
339    #[test]
340    fn set_agent_scope_returns_error() {
341        let tmp = TempDir::new().unwrap();
342        let lc = LayeredConfig {
343            global: None,
344            project: None,
345            global_path: tmp.path().join("g.toml"),
346            project_path: None,
347        };
348        let err = lc
349            .set("core.log_level", "info", ConfigScope::Agent("my-agent".to_string()))
350            .unwrap_err();
351        assert!(err.to_string().contains("manifest"));
352    }
353
354    #[test]
355    fn validate_returns_empty_for_valid_configs() {
356        let tmp = TempDir::new().unwrap();
357        let lc = make_layered(&tmp, GLOBAL_TOML, PROJECT_TOML);
358        let errors = lc.validate().unwrap();
359        assert!(errors.is_empty(), "unexpected errors: {errors:?}");
360    }
361
362    #[test]
363    fn validate_reports_invalid_layer() {
364        let bad = "[core]\nlog_level = \"verbose\"\nsession_retention_days = 30\npattern_retention_days = 90\n";
365        let tmp = TempDir::new().unwrap();
366        let _lc = LayeredConfig {
367            global: config::parse(bad).ok(), // parse succeeds (validation is separate)
368            project: None,
369            global_path: tmp.path().join("g.toml"),
370            project_path: None,
371        };
372        // parse() calls validate() internally, so global will be None if invalid.
373        // Test that validate() on a manually-constructed bad config surfaces errors.
374        let bad_cfg = toml::from_str::<HawkConfig>(bad).unwrap();
375        let lc2 = LayeredConfig {
376            global: Some(bad_cfg),
377            project: None,
378            global_path: tmp.path().join("g.toml"),
379            project_path: None,
380        };
381        let errors = lc2.validate().unwrap();
382        assert!(!errors.is_empty(), "expected validation errors for bad config");
383    }
384
385    #[test]
386    fn no_project_dir_loads_only_global() {
387        let global_cfg = config::parse(GLOBAL_TOML).unwrap();
388        let lc = LayeredConfig {
389            global: Some(global_cfg),
390            project: None,
391            global_path: PathBuf::from("/tmp/g.toml"),
392            project_path: None,
393        };
394        let v = lc.get_effective("core.log_level").unwrap();
395        assert_eq!(v.value, "warn");
396        assert!(matches!(v.source, ConfigScope::Global));
397    }
398
399    #[test]
400    fn set_creates_file_if_not_exists() {
401        let tmp = TempDir::new().unwrap();
402        let global_path = tmp.path().join("new_config.toml");
403        assert!(!global_path.exists());
404
405        let lc = LayeredConfig {
406            global: None,
407            project: None,
408            global_path: global_path.clone(),
409            project_path: None,
410        };
411
412        lc.set("core.log_level", "error", ConfigScope::Global).unwrap();
413        assert!(global_path.exists());
414        let content = fs::read_to_string(&global_path).unwrap();
415        assert!(content.contains("error"));
416    }
417}