1use std::collections::HashSet;
12use std::path::Path;
13use serde::{Deserialize, Serialize};
14use crate::email_guard::EmailGuardConfig;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ShieldConfig {
19 #[serde(default = "default_block_threshold")]
21 pub block_threshold: f64,
22 #[serde(default = "default_warn_threshold")]
24 pub warn_threshold: f64,
25 #[serde(default)]
27 pub sql: SqlFirewallConfig,
28 #[serde(default)]
30 pub ssrf: SsrfConfig,
31 #[serde(default)]
33 pub rate: RateConfig,
34 #[serde(default)]
36 pub quarantine: QuarantineConfig,
37 #[serde(default = "default_audit_max")]
39 pub audit_max_events: usize,
40 #[serde(default)]
42 pub email: EmailGuardConfig,
43 #[serde(default)]
45 pub api_token: Option<String>,
46 #[serde(default)]
48 pub tls_cert: Option<String>,
49 #[serde(default)]
51 pub tls_key: Option<String>,
52 #[serde(default)]
54 pub webhook_urls: Vec<WebhookConfig>,
55 #[serde(default)]
57 pub ferrum_mail: Option<FerrumMailConfig>,
58 #[serde(default)]
60 pub signature_update: Option<SignatureUpdateConfig>,
61 #[serde(default)]
63 pub nexus_pulse: Option<NexusPulseConfig>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct NexusPulseConfig {
69 pub api_url: String,
71 pub api_key: String,
73 pub alert_recipients: Vec<String>,
75 #[serde(default)]
77 pub from_number: Option<String>,
78 #[serde(default = "default_pulse_min_severity")]
80 pub min_severity: String,
81 #[serde(default = "default_true")]
83 pub use_template: bool,
84}
85
86fn default_block_threshold() -> f64 { 0.7 }
87fn default_warn_threshold() -> f64 { 0.4 }
88fn default_audit_max() -> usize { 100_000 }
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct WebhookConfig {
93 pub url: String,
95 #[serde(default = "default_webhook_min_severity")]
97 pub min_severity: String,
98 #[serde(default)]
100 pub headers: Vec<(String, String)>,
101 #[serde(default = "default_webhook_type")]
103 pub webhook_type: String,
104}
105
106fn default_webhook_min_severity() -> String { "high".to_string() }
107fn default_webhook_type() -> String { "generic".to_string() }
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct FerrumMailConfig {
112 pub api_url: String,
114 pub api_key: String,
116 pub from_address: String,
118 pub alert_recipients: Vec<String>,
120 #[serde(default = "default_webhook_min_severity")]
122 pub min_severity: String,
123 #[serde(default = "default_true")]
125 pub include_details: bool,
126}
127
128fn default_true() -> bool { true }
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct SignatureUpdateConfig {
133 pub feed_url: String,
135 #[serde(default = "default_sig_interval")]
137 pub interval_secs: u64,
138 #[serde(default)]
140 pub auth_header: Option<String>,
141}
142
143fn default_sig_interval() -> u64 { 3600 }
144fn default_pulse_min_severity() -> String { "critical".to_string() }
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct SqlFirewallConfig {
148 #[serde(default)]
149 pub allow_comments: bool,
150 #[serde(default = "default_max_query_length")]
151 pub max_query_length: usize,
152 #[serde(default = "default_max_subquery_depth")]
153 pub max_subquery_depth: u32,
154 #[serde(default)]
155 pub blocked_functions: Vec<String>,
156 #[serde(default)]
157 pub blocked_schemas: Vec<String>,
158}
159
160fn default_max_query_length() -> usize { 10_000 }
161fn default_max_subquery_depth() -> u32 { 3 }
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct SsrfConfig {
165 #[serde(default = "default_true")]
166 pub block_private_ips: bool,
167 #[serde(default = "default_true")]
168 pub block_loopback: bool,
169 #[serde(default = "default_true")]
170 pub block_link_local: bool,
171 #[serde(default = "default_true")]
172 pub block_metadata_endpoints: bool,
173 #[serde(default = "default_allowed_schemes")]
174 pub allowed_schemes: Vec<String>,
175 #[serde(default)]
176 pub allowlist: HashSet<String>,
177 #[serde(default)]
178 pub blocklist: HashSet<String>,
179 #[serde(default = "default_blocked_ports")]
180 pub blocked_ports: Vec<u16>,
181}
182
183fn default_allowed_schemes() -> Vec<String> { vec!["http".into(), "https".into()] }
184fn default_blocked_ports() -> Vec<u16> {
185 vec![22, 23, 25, 53, 111, 135, 139, 445, 514, 873, 2049, 3306, 5432, 6379, 6380, 9200, 9300, 11211, 27017, 27018, 50070]
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct RateConfig {
190 #[serde(default = "default_rps")]
191 pub requests_per_second: f64,
192 #[serde(default = "default_burst")]
193 pub burst_capacity: f64,
194 #[serde(default = "default_warn_after")]
195 pub warn_after: u32,
196 #[serde(default = "default_throttle_after")]
197 pub throttle_after: u32,
198 #[serde(default = "default_block_after")]
199 pub block_after: u32,
200 #[serde(default = "default_ban_after")]
201 pub ban_after: u32,
202 #[serde(default = "default_ban_duration")]
203 pub ban_duration_secs: u64,
204 #[serde(default = "default_decay")]
205 pub violation_decay_secs: u64,
206}
207
208fn default_rps() -> f64 { 50.0 }
209fn default_burst() -> f64 { 100.0 }
210fn default_warn_after() -> u32 { 3 }
211fn default_throttle_after() -> u32 { 8 }
212fn default_block_after() -> u32 { 15 }
213fn default_ban_after() -> u32 { 30 }
214fn default_ban_duration() -> u64 { 300 }
215fn default_decay() -> u64 { 60 }
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct QuarantineConfig {
219 #[serde(default = "default_max_rows")]
220 pub max_rows: usize,
221 #[serde(default = "default_max_size")]
222 pub max_size_bytes: usize,
223 #[serde(default = "default_max_cols")]
224 pub max_columns: usize,
225 #[serde(default = "default_true")]
226 pub check_formula_injection: bool,
227 #[serde(default = "default_true")]
228 pub check_embedded_scripts: bool,
229}
230
231fn default_max_rows() -> usize { 5_000_000 }
232fn default_max_size() -> usize { 500 * 1024 * 1024 }
233fn default_max_cols() -> usize { 500 }
234
235impl Default for ShieldConfig {
238 fn default() -> Self {
239 Self {
240 block_threshold: 0.7,
241 warn_threshold: 0.4,
242 sql: SqlFirewallConfig::default(),
243 ssrf: SsrfConfig::default(),
244 rate: RateConfig::default(),
245 quarantine: QuarantineConfig::default(),
246 audit_max_events: 100_000,
247 email: EmailGuardConfig::default(),
248 api_token: None,
249 tls_cert: None,
250 tls_key: None,
251 webhook_urls: Vec::new(),
252 ferrum_mail: None,
253 signature_update: None,
254 nexus_pulse: None,
255 }
256 }
257}
258
259impl Default for SqlFirewallConfig {
260 fn default() -> Self {
261 Self {
262 allow_comments: false,
263 max_query_length: 10_000,
264 max_subquery_depth: 3,
265 blocked_functions: Vec::new(),
266 blocked_schemas: Vec::new(),
267 }
268 }
269}
270
271impl Default for SsrfConfig {
272 fn default() -> Self {
273 Self {
274 block_private_ips: true,
275 block_loopback: true,
276 block_link_local: true,
277 block_metadata_endpoints: true,
278 allowed_schemes: default_allowed_schemes(),
279 allowlist: HashSet::new(),
280 blocklist: HashSet::new(),
281 blocked_ports: default_blocked_ports(),
282 }
283 }
284}
285
286impl Default for RateConfig {
287 fn default() -> Self {
288 Self {
289 requests_per_second: 50.0,
290 burst_capacity: 100.0,
291 warn_after: 3,
292 throttle_after: 8,
293 block_after: 15,
294 ban_after: 30,
295 ban_duration_secs: 300,
296 violation_decay_secs: 60,
297 }
298 }
299}
300
301impl Default for QuarantineConfig {
302 fn default() -> Self {
303 Self {
304 max_rows: 5_000_000,
305 max_size_bytes: 500 * 1024 * 1024,
306 max_columns: 500,
307 check_formula_injection: true,
308 check_embedded_scripts: true,
309 }
310 }
311}
312
313pub fn load_config(path: &Path) -> Result<ShieldConfig, String> {
315 let content = std::fs::read_to_string(path)
316 .map_err(|e| format!("Failed to read config file {}: {}", path.display(), e))?;
317 let config: ShieldConfig = toml::from_str(&content)
318 .map_err(|e| format!("Failed to parse config file {}: {}", path.display(), e))?;
319 Ok(config)
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn default_config() {
328 let config = ShieldConfig::default();
329 assert_eq!(config.block_threshold, 0.7);
330 assert_eq!(config.warn_threshold, 0.4);
331 assert!(config.api_token.is_none());
332 assert!(config.tls_cert.is_none());
333 assert!(config.webhook_urls.is_empty());
334 assert!(config.ferrum_mail.is_none());
335 assert!(config.signature_update.is_none());
336 }
337
338 #[test]
339 fn parse_minimal_toml() {
340 let toml = r#"
341block_threshold = 0.8
342warn_threshold = 0.5
343"#;
344 let config: ShieldConfig = toml::from_str(toml).unwrap();
345 assert_eq!(config.block_threshold, 0.8);
346 assert_eq!(config.warn_threshold, 0.5);
347 assert_eq!(config.rate.requests_per_second, 50.0); }
349
350 #[test]
351 fn parse_full_toml() {
352 let toml = r#"
353block_threshold = 0.9
354warn_threshold = 0.6
355api_token = "my-secret-token"
356tls_cert = "/etc/nexus-shield/cert.pem"
357tls_key = "/etc/nexus-shield/key.pem"
358
359[sql]
360allow_comments = true
361max_query_length = 20000
362
363[rate]
364requests_per_second = 100.0
365burst_capacity = 200.0
366ban_duration_secs = 600
367
368[[webhook_urls]]
369url = "https://hooks.slack.com/services/xxx"
370min_severity = "critical"
371webhook_type = "slack"
372
373[ferrum_mail]
374api_url = "http://localhost:3030"
375api_key = "fm-key-123"
376from_address = "shield@company.com"
377alert_recipients = ["admin@company.com", "security@company.com"]
378
379[signature_update]
380feed_url = "https://signatures.nexusshield.dev/v1/latest.ndjson"
381interval_secs = 1800
382"#;
383 let config: ShieldConfig = toml::from_str(toml).unwrap();
384 assert_eq!(config.block_threshold, 0.9);
385 assert_eq!(config.api_token, Some("my-secret-token".to_string()));
386 assert_eq!(config.sql.max_query_length, 20000);
387 assert_eq!(config.rate.requests_per_second, 100.0);
388 assert_eq!(config.webhook_urls.len(), 1);
389 assert_eq!(config.webhook_urls[0].webhook_type, "slack");
390 let fm = config.ferrum_mail.unwrap();
391 assert_eq!(fm.alert_recipients.len(), 2);
392 let su = config.signature_update.unwrap();
393 assert_eq!(su.interval_secs, 1800);
394 }
395
396 #[test]
397 fn parse_empty_toml() {
398 let config: ShieldConfig = toml::from_str("").unwrap();
399 assert_eq!(config.block_threshold, 0.7); }
401
402 #[test]
403 fn load_nonexistent_file() {
404 let result = load_config(Path::new("/nonexistent/config.toml"));
405 assert!(result.is_err());
406 }
407
408 #[test]
409 fn webhook_config_defaults() {
410 let toml = r#"
411[[webhook_urls]]
412url = "https://example.com/hook"
413"#;
414 let config: ShieldConfig = toml::from_str(toml).unwrap();
415 assert_eq!(config.webhook_urls[0].min_severity, "high");
416 assert_eq!(config.webhook_urls[0].webhook_type, "generic");
417 }
418}