Skip to main content

openhawk_core/
config.rs

1use serde::{Deserialize, Serialize};
2use crate::error::HawkError;
3
4pub type Result<T> = std::result::Result<T, HawkError>;
5
6// ── Top-level config ──────────────────────────────────────────────────────────
7
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct HawkConfig {
10    #[serde(default)]
11    pub core: CoreConfig,
12    #[serde(default)]
13    pub privacy: PrivacyConfig,
14    #[serde(default)]
15    pub llm: LlmConfig,
16    #[serde(default)]
17    pub savepoint: SavepointConfig,
18    #[serde(default)]
19    pub bus: BusConfig,
20    #[serde(default)]
21    pub sync: SyncConfig,
22    #[serde(default)]
23    pub compress: CompressConfig,
24    #[serde(default)]
25    pub healing: HealingConfig,
26}
27
28// ── Section structs ───────────────────────────────────────────────────────────
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct CoreConfig {
32    pub log_level: String,
33    pub session_retention_days: u32,
34    pub pattern_retention_days: u32,
35}
36
37impl Default for CoreConfig {
38    fn default() -> Self {
39        Self {
40            log_level: "info".to_string(),
41            session_retention_days: 30,
42            pattern_retention_days: 90,
43        }
44    }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct PrivacyConfig {
49    pub mode: String,
50}
51
52impl Default for PrivacyConfig {
53    fn default() -> Self {
54        Self { mode: "standard".to_string() }
55    }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, Default)]
59pub struct LlmConfig {
60    #[serde(default)]
61    pub providers: Vec<LlmProvider>,
62    #[serde(default)]
63    pub pricing: LlmPricing,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct LlmProvider {
68    pub name: String,
69    pub endpoint: String,
70    pub priority: u32,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, Default)]
74pub struct LlmPricing {
75    #[serde(default)]
76    pub openai_gpt4_prompt: f64,
77    #[serde(default)]
78    pub openai_gpt4_completion: f64,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SavepointConfig {
83    pub auto_snapshot: bool,
84    pub max_snapshots_per_agent: u32,
85}
86
87impl Default for SavepointConfig {
88    fn default() -> Self {
89        Self { auto_snapshot: true, max_snapshots_per_agent: 50 }
90    }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct BusConfig {
95    pub message_retention_seconds: u64,
96    pub max_queue_size: u64,
97}
98
99impl Default for BusConfig {
100    fn default() -> Self {
101        Self { message_retention_seconds: 3600, max_queue_size: 10000 }
102    }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SyncConfig {
107    pub enabled: bool,
108    pub conflict_strategy: String,
109}
110
111impl Default for SyncConfig {
112    fn default() -> Self {
113        Self { enabled: false, conflict_strategy: "last-writer-wins".to_string() }
114    }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct CompressConfig {
119    pub token_threshold: u32,
120    pub cache_max_entries: u32,
121}
122
123impl Default for CompressConfig {
124    fn default() -> Self {
125        Self { token_threshold: 4000, cache_max_entries: 1000 }
126    }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct HealingConfig {
131    pub max_retries: u32,
132    pub enabled: bool,
133}
134
135impl Default for HealingConfig {
136    fn default() -> Self {
137        Self { max_retries: 3, enabled: true }
138    }
139}
140
141// ── Parse / serialize ─────────────────────────────────────────────────────────
142
143pub fn parse(toml_str: &str) -> Result<HawkConfig> {
144    let config: HawkConfig = toml::from_str(toml_str)
145        .map_err(|e| HawkError::Config(format!("parse error: {e}")))?;
146    validate(&config)?;
147    Ok(config)
148}
149
150pub fn to_toml(config: &HawkConfig) -> Result<String> {
151    toml::to_string_pretty(config)
152        .map_err(|e| HawkError::Config(format!("serialization error: {e}")))
153}
154
155// ── Validation ────────────────────────────────────────────────────────────────
156
157const VALID_LOG_LEVELS: &[&str] = &["error", "warn", "info", "debug", "trace"];
158const VALID_PRIVACY_MODES: &[&str] = &["standard", "local-only", "air-gapped"];
159
160fn validate(config: &HawkConfig) -> Result<()> {
161    if !VALID_LOG_LEVELS.contains(&config.core.log_level.as_str()) {
162        return Err(HawkError::Config(format!(
163            "[core] log_level \"{}\" is invalid; expected one of: {}",
164            config.core.log_level,
165            VALID_LOG_LEVELS.join(", ")
166        )));
167    }
168
169    if !VALID_PRIVACY_MODES.contains(&config.privacy.mode.as_str()) {
170        return Err(HawkError::Config(format!(
171            "[privacy] mode \"{}\" is invalid; expected one of: {}",
172            config.privacy.mode,
173            VALID_PRIVACY_MODES.join(", ")
174        )));
175    }
176
177    if config.core.session_retention_days == 0 {
178        return Err(HawkError::Config(
179            "[core] session_retention_days must be greater than 0".to_string(),
180        ));
181    }
182
183    if config.healing.max_retries < 1 {
184        return Err(HawkError::Config(
185            "[healing] max_retries must be at least 1".to_string(),
186        ));
187    }
188
189    Ok(())
190}
191
192// ── Tests ─────────────────────────────────────────────────────────────────────
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    const FULL_CONFIG: &str = r#"
199[core]
200log_level = "info"
201session_retention_days = 30
202pattern_retention_days = 90
203
204[privacy]
205mode = "standard"
206
207[llm]
208providers = [
209    { name = "openai", endpoint = "https://api.openai.com/v1", priority = 1 },
210    { name = "ollama", endpoint = "http://localhost:11434", priority = 2 },
211]
212
213[llm.pricing]
214openai_gpt4_prompt = 0.00003
215openai_gpt4_completion = 0.00006
216
217[savepoint]
218auto_snapshot = true
219max_snapshots_per_agent = 50
220
221[bus]
222message_retention_seconds = 3600
223max_queue_size = 10000
224
225[sync]
226enabled = false
227conflict_strategy = "last-writer-wins"
228
229[compress]
230token_threshold = 4000
231cache_max_entries = 1000
232
233[healing]
234max_retries = 3
235enabled = true
236"#;
237
238    #[test]
239    fn test_parse_full_config() {
240        let config = parse(FULL_CONFIG).expect("should parse valid config");
241        assert_eq!(config.core.log_level, "info");
242        assert_eq!(config.core.session_retention_days, 30);
243        assert_eq!(config.core.pattern_retention_days, 90);
244        assert_eq!(config.privacy.mode, "standard");
245        assert_eq!(config.llm.providers.len(), 2);
246        assert_eq!(config.llm.providers[0].name, "openai");
247        assert_eq!(config.llm.providers[1].priority, 2);
248        assert!((config.llm.pricing.openai_gpt4_prompt - 0.00003).abs() < f64::EPSILON);
249        assert!(config.savepoint.auto_snapshot);
250        assert_eq!(config.savepoint.max_snapshots_per_agent, 50);
251        assert_eq!(config.bus.message_retention_seconds, 3600);
252        assert_eq!(config.bus.max_queue_size, 10000);
253        assert!(!config.sync.enabled);
254        assert_eq!(config.sync.conflict_strategy, "last-writer-wins");
255        assert_eq!(config.compress.token_threshold, 4000);
256        assert_eq!(config.compress.cache_max_entries, 1000);
257        assert_eq!(config.healing.max_retries, 3);
258        assert!(config.healing.enabled);
259    }
260
261    #[test]
262    fn test_round_trip() {
263        let config = parse(FULL_CONFIG).unwrap();
264        let serialized = to_toml(&config).unwrap();
265        let reparsed = parse(&serialized).unwrap();
266        assert_eq!(config.core.log_level, reparsed.core.log_level);
267        assert_eq!(config.privacy.mode, reparsed.privacy.mode);
268        assert_eq!(config.llm.providers.len(), reparsed.llm.providers.len());
269        assert_eq!(config.healing.max_retries, reparsed.healing.max_retries);
270    }
271
272    #[test]
273    fn test_defaults() {
274        let config = parse("").unwrap();
275        assert_eq!(config.core.log_level, "info");
276        assert_eq!(config.core.session_retention_days, 30);
277        assert_eq!(config.privacy.mode, "standard");
278        assert_eq!(config.healing.max_retries, 3);
279        assert!(config.healing.enabled);
280    }
281
282    #[test]
283    fn test_invalid_log_level() {
284        let toml = r#"
285[core]
286log_level = "verbose"
287session_retention_days = 30
288pattern_retention_days = 90
289"#;
290        let err = parse(toml).unwrap_err();
291        let msg = err.to_string();
292        assert!(msg.contains("log_level"), "error should mention log_level: {msg}");
293        assert!(msg.contains("verbose"), "error should include the bad value: {msg}");
294    }
295
296    #[test]
297    fn test_invalid_privacy_mode() {
298        let toml = r#"
299[privacy]
300mode = "cloud-only"
301"#;
302        let err = parse(toml).unwrap_err();
303        let msg = err.to_string();
304        assert!(msg.contains("mode"), "error should mention mode: {msg}");
305        assert!(msg.contains("cloud-only"), "error should include the bad value: {msg}");
306    }
307
308    #[test]
309    fn test_session_retention_zero() {
310        let toml = r#"
311[core]
312log_level = "info"
313session_retention_days = 0
314pattern_retention_days = 90
315"#;
316        let err = parse(toml).unwrap_err();
317        assert!(err.to_string().contains("session_retention_days"));
318    }
319
320    #[test]
321    fn test_max_retries_zero() {
322        let toml = r#"
323[healing]
324max_retries = 0
325enabled = true
326"#;
327        let err = parse(toml).unwrap_err();
328        assert!(err.to_string().contains("max_retries"));
329    }
330
331    #[test]
332    fn test_invalid_toml_syntax() {
333        let toml = "[core\nlog_level = \"info\"";
334        let err = parse(toml).unwrap_err();
335        assert!(err.to_string().contains("parse error"));
336    }
337
338    #[test]
339    fn test_privacy_modes_all_valid() {
340        for mode in &["standard", "local-only", "air-gapped"] {
341            let toml = format!("[privacy]\nmode = \"{mode}\"");
342            assert!(parse(&toml).is_ok(), "mode {mode} should be valid");
343        }
344    }
345
346    #[test]
347    fn test_log_levels_all_valid() {
348        for level in &["error", "warn", "info", "debug", "trace"] {
349            let toml = format!(
350                "[core]\nlog_level = \"{level}\"\nsession_retention_days = 1\npattern_retention_days = 1"
351            );
352            assert!(parse(&toml).is_ok(), "log_level {level} should be valid");
353        }
354    }
355}