1use serde::{Deserialize, Serialize};
2use crate::error::HawkError;
3
4pub type Result<T> = std::result::Result<T, HawkError>;
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct HawkConfig {
10 #[serde(default)]
11 pub core: CoreConfig,
12 #[serde(default)]
13 pub privacy: PrivacyConfig,
14 #[serde(default)]
15 pub llm: LlmConfig,
16 #[serde(default)]
17 pub savepoint: SavepointConfig,
18 #[serde(default)]
19 pub bus: BusConfig,
20 #[serde(default)]
21 pub sync: SyncConfig,
22 #[serde(default)]
23 pub compress: CompressConfig,
24 #[serde(default)]
25 pub healing: HealingConfig,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct CoreConfig {
32 pub log_level: String,
33 pub session_retention_days: u32,
34 pub pattern_retention_days: u32,
35}
36
37impl Default for CoreConfig {
38 fn default() -> Self {
39 Self {
40 log_level: "info".to_string(),
41 session_retention_days: 30,
42 pattern_retention_days: 90,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct PrivacyConfig {
49 pub mode: String,
50}
51
52impl Default for PrivacyConfig {
53 fn default() -> Self {
54 Self { mode: "standard".to_string() }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, Default)]
59pub struct LlmConfig {
60 #[serde(default)]
61 pub providers: Vec<LlmProvider>,
62 #[serde(default)]
63 pub pricing: LlmPricing,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct LlmProvider {
68 pub name: String,
69 pub endpoint: String,
70 pub priority: u32,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, Default)]
74pub struct LlmPricing {
75 #[serde(default)]
76 pub openai_gpt4_prompt: f64,
77 #[serde(default)]
78 pub openai_gpt4_completion: f64,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SavepointConfig {
83 pub auto_snapshot: bool,
84 pub max_snapshots_per_agent: u32,
85}
86
87impl Default for SavepointConfig {
88 fn default() -> Self {
89 Self { auto_snapshot: true, max_snapshots_per_agent: 50 }
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct BusConfig {
95 pub message_retention_seconds: u64,
96 pub max_queue_size: u64,
97}
98
99impl Default for BusConfig {
100 fn default() -> Self {
101 Self { message_retention_seconds: 3600, max_queue_size: 10000 }
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct SyncConfig {
107 pub enabled: bool,
108 pub conflict_strategy: String,
109}
110
111impl Default for SyncConfig {
112 fn default() -> Self {
113 Self { enabled: false, conflict_strategy: "last-writer-wins".to_string() }
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct CompressConfig {
119 pub token_threshold: u32,
120 pub cache_max_entries: u32,
121}
122
123impl Default for CompressConfig {
124 fn default() -> Self {
125 Self { token_threshold: 4000, cache_max_entries: 1000 }
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct HealingConfig {
131 pub max_retries: u32,
132 pub enabled: bool,
133}
134
135impl Default for HealingConfig {
136 fn default() -> Self {
137 Self { max_retries: 3, enabled: true }
138 }
139}
140
141pub fn parse(toml_str: &str) -> Result<HawkConfig> {
144 let config: HawkConfig = toml::from_str(toml_str)
145 .map_err(|e| HawkError::Config(format!("parse error: {e}")))?;
146 validate(&config)?;
147 Ok(config)
148}
149
150pub fn to_toml(config: &HawkConfig) -> Result<String> {
151 toml::to_string_pretty(config)
152 .map_err(|e| HawkError::Config(format!("serialization error: {e}")))
153}
154
155const VALID_LOG_LEVELS: &[&str] = &["error", "warn", "info", "debug", "trace"];
158const VALID_PRIVACY_MODES: &[&str] = &["standard", "local-only", "air-gapped"];
159
160fn validate(config: &HawkConfig) -> Result<()> {
161 if !VALID_LOG_LEVELS.contains(&config.core.log_level.as_str()) {
162 return Err(HawkError::Config(format!(
163 "[core] log_level \"{}\" is invalid; expected one of: {}",
164 config.core.log_level,
165 VALID_LOG_LEVELS.join(", ")
166 )));
167 }
168
169 if !VALID_PRIVACY_MODES.contains(&config.privacy.mode.as_str()) {
170 return Err(HawkError::Config(format!(
171 "[privacy] mode \"{}\" is invalid; expected one of: {}",
172 config.privacy.mode,
173 VALID_PRIVACY_MODES.join(", ")
174 )));
175 }
176
177 if config.core.session_retention_days == 0 {
178 return Err(HawkError::Config(
179 "[core] session_retention_days must be greater than 0".to_string(),
180 ));
181 }
182
183 if config.healing.max_retries < 1 {
184 return Err(HawkError::Config(
185 "[healing] max_retries must be at least 1".to_string(),
186 ));
187 }
188
189 Ok(())
190}
191
192#[cfg(test)]
195mod tests {
196 use super::*;
197
198 const FULL_CONFIG: &str = r#"
199[core]
200log_level = "info"
201session_retention_days = 30
202pattern_retention_days = 90
203
204[privacy]
205mode = "standard"
206
207[llm]
208providers = [
209 { name = "openai", endpoint = "https://api.openai.com/v1", priority = 1 },
210 { name = "ollama", endpoint = "http://localhost:11434", priority = 2 },
211]
212
213[llm.pricing]
214openai_gpt4_prompt = 0.00003
215openai_gpt4_completion = 0.00006
216
217[savepoint]
218auto_snapshot = true
219max_snapshots_per_agent = 50
220
221[bus]
222message_retention_seconds = 3600
223max_queue_size = 10000
224
225[sync]
226enabled = false
227conflict_strategy = "last-writer-wins"
228
229[compress]
230token_threshold = 4000
231cache_max_entries = 1000
232
233[healing]
234max_retries = 3
235enabled = true
236"#;
237
238 #[test]
239 fn test_parse_full_config() {
240 let config = parse(FULL_CONFIG).expect("should parse valid config");
241 assert_eq!(config.core.log_level, "info");
242 assert_eq!(config.core.session_retention_days, 30);
243 assert_eq!(config.core.pattern_retention_days, 90);
244 assert_eq!(config.privacy.mode, "standard");
245 assert_eq!(config.llm.providers.len(), 2);
246 assert_eq!(config.llm.providers[0].name, "openai");
247 assert_eq!(config.llm.providers[1].priority, 2);
248 assert!((config.llm.pricing.openai_gpt4_prompt - 0.00003).abs() < f64::EPSILON);
249 assert!(config.savepoint.auto_snapshot);
250 assert_eq!(config.savepoint.max_snapshots_per_agent, 50);
251 assert_eq!(config.bus.message_retention_seconds, 3600);
252 assert_eq!(config.bus.max_queue_size, 10000);
253 assert!(!config.sync.enabled);
254 assert_eq!(config.sync.conflict_strategy, "last-writer-wins");
255 assert_eq!(config.compress.token_threshold, 4000);
256 assert_eq!(config.compress.cache_max_entries, 1000);
257 assert_eq!(config.healing.max_retries, 3);
258 assert!(config.healing.enabled);
259 }
260
261 #[test]
262 fn test_round_trip() {
263 let config = parse(FULL_CONFIG).unwrap();
264 let serialized = to_toml(&config).unwrap();
265 let reparsed = parse(&serialized).unwrap();
266 assert_eq!(config.core.log_level, reparsed.core.log_level);
267 assert_eq!(config.privacy.mode, reparsed.privacy.mode);
268 assert_eq!(config.llm.providers.len(), reparsed.llm.providers.len());
269 assert_eq!(config.healing.max_retries, reparsed.healing.max_retries);
270 }
271
272 #[test]
273 fn test_defaults() {
274 let config = parse("").unwrap();
275 assert_eq!(config.core.log_level, "info");
276 assert_eq!(config.core.session_retention_days, 30);
277 assert_eq!(config.privacy.mode, "standard");
278 assert_eq!(config.healing.max_retries, 3);
279 assert!(config.healing.enabled);
280 }
281
282 #[test]
283 fn test_invalid_log_level() {
284 let toml = r#"
285[core]
286log_level = "verbose"
287session_retention_days = 30
288pattern_retention_days = 90
289"#;
290 let err = parse(toml).unwrap_err();
291 let msg = err.to_string();
292 assert!(msg.contains("log_level"), "error should mention log_level: {msg}");
293 assert!(msg.contains("verbose"), "error should include the bad value: {msg}");
294 }
295
296 #[test]
297 fn test_invalid_privacy_mode() {
298 let toml = r#"
299[privacy]
300mode = "cloud-only"
301"#;
302 let err = parse(toml).unwrap_err();
303 let msg = err.to_string();
304 assert!(msg.contains("mode"), "error should mention mode: {msg}");
305 assert!(msg.contains("cloud-only"), "error should include the bad value: {msg}");
306 }
307
308 #[test]
309 fn test_session_retention_zero() {
310 let toml = r#"
311[core]
312log_level = "info"
313session_retention_days = 0
314pattern_retention_days = 90
315"#;
316 let err = parse(toml).unwrap_err();
317 assert!(err.to_string().contains("session_retention_days"));
318 }
319
320 #[test]
321 fn test_max_retries_zero() {
322 let toml = r#"
323[healing]
324max_retries = 0
325enabled = true
326"#;
327 let err = parse(toml).unwrap_err();
328 assert!(err.to_string().contains("max_retries"));
329 }
330
331 #[test]
332 fn test_invalid_toml_syntax() {
333 let toml = "[core\nlog_level = \"info\"";
334 let err = parse(toml).unwrap_err();
335 assert!(err.to_string().contains("parse error"));
336 }
337
338 #[test]
339 fn test_privacy_modes_all_valid() {
340 for mode in &["standard", "local-only", "air-gapped"] {
341 let toml = format!("[privacy]\nmode = \"{mode}\"");
342 assert!(parse(&toml).is_ok(), "mode {mode} should be valid");
343 }
344 }
345
346 #[test]
347 fn test_log_levels_all_valid() {
348 for level in &["error", "warn", "info", "debug", "trace"] {
349 let toml = format!(
350 "[core]\nlog_level = \"{level}\"\nsession_retention_days = 1\npattern_retention_days = 1"
351 );
352 assert!(parse(&toml).is_ok(), "log_level {level} should be valid");
353 }
354 }
355}