1use std::collections::HashSet;
12use std::path::Path;
13use serde::{Deserialize, Serialize};
14use crate::email_guard::EmailGuardConfig;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ShieldConfig {
19 #[serde(default = "default_block_threshold")]
21 pub block_threshold: f64,
22 #[serde(default = "default_warn_threshold")]
24 pub warn_threshold: f64,
25 #[serde(default)]
27 pub sql: SqlFirewallConfig,
28 #[serde(default)]
30 pub ssrf: SsrfConfig,
31 #[serde(default)]
33 pub rate: RateConfig,
34 #[serde(default)]
36 pub quarantine: QuarantineConfig,
37 #[serde(default = "default_audit_max")]
39 pub audit_max_events: usize,
40 #[serde(default)]
42 pub email: EmailGuardConfig,
43 #[serde(default)]
45 pub api_token: Option<String>,
46 #[serde(default)]
48 pub tls_cert: Option<String>,
49 #[serde(default)]
51 pub tls_key: Option<String>,
52 #[serde(default)]
54 pub webhook_urls: Vec<WebhookConfig>,
55 #[serde(default)]
57 pub ferrum_mail: Option<FerrumMailConfig>,
58 #[serde(default)]
60 pub signature_update: Option<SignatureUpdateConfig>,
61}
62
63fn default_block_threshold() -> f64 { 0.7 }
64fn default_warn_threshold() -> f64 { 0.4 }
65fn default_audit_max() -> usize { 100_000 }
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct WebhookConfig {
70 pub url: String,
72 #[serde(default = "default_webhook_min_severity")]
74 pub min_severity: String,
75 #[serde(default)]
77 pub headers: Vec<(String, String)>,
78 #[serde(default = "default_webhook_type")]
80 pub webhook_type: String,
81}
82
83fn default_webhook_min_severity() -> String { "high".to_string() }
84fn default_webhook_type() -> String { "generic".to_string() }
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct FerrumMailConfig {
89 pub api_url: String,
91 pub api_key: String,
93 pub from_address: String,
95 pub alert_recipients: Vec<String>,
97 #[serde(default = "default_webhook_min_severity")]
99 pub min_severity: String,
100 #[serde(default = "default_true")]
102 pub include_details: bool,
103}
104
105fn default_true() -> bool { true }
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct SignatureUpdateConfig {
110 pub feed_url: String,
112 #[serde(default = "default_sig_interval")]
114 pub interval_secs: u64,
115 #[serde(default)]
117 pub auth_header: Option<String>,
118}
119
120fn default_sig_interval() -> u64 { 3600 }
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct SqlFirewallConfig {
124 #[serde(default)]
125 pub allow_comments: bool,
126 #[serde(default = "default_max_query_length")]
127 pub max_query_length: usize,
128 #[serde(default = "default_max_subquery_depth")]
129 pub max_subquery_depth: u32,
130 #[serde(default)]
131 pub blocked_functions: Vec<String>,
132 #[serde(default)]
133 pub blocked_schemas: Vec<String>,
134}
135
136fn default_max_query_length() -> usize { 10_000 }
137fn default_max_subquery_depth() -> u32 { 3 }
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct SsrfConfig {
141 #[serde(default = "default_true")]
142 pub block_private_ips: bool,
143 #[serde(default = "default_true")]
144 pub block_loopback: bool,
145 #[serde(default = "default_true")]
146 pub block_link_local: bool,
147 #[serde(default = "default_true")]
148 pub block_metadata_endpoints: bool,
149 #[serde(default = "default_allowed_schemes")]
150 pub allowed_schemes: Vec<String>,
151 #[serde(default)]
152 pub allowlist: HashSet<String>,
153 #[serde(default)]
154 pub blocklist: HashSet<String>,
155 #[serde(default = "default_blocked_ports")]
156 pub blocked_ports: Vec<u16>,
157}
158
159fn default_allowed_schemes() -> Vec<String> { vec!["http".into(), "https".into()] }
160fn default_blocked_ports() -> Vec<u16> {
161 vec![22, 23, 25, 53, 111, 135, 139, 445, 514, 873, 2049, 3306, 5432, 6379, 6380, 9200, 9300, 11211, 27017, 27018, 50070]
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct RateConfig {
166 #[serde(default = "default_rps")]
167 pub requests_per_second: f64,
168 #[serde(default = "default_burst")]
169 pub burst_capacity: f64,
170 #[serde(default = "default_warn_after")]
171 pub warn_after: u32,
172 #[serde(default = "default_throttle_after")]
173 pub throttle_after: u32,
174 #[serde(default = "default_block_after")]
175 pub block_after: u32,
176 #[serde(default = "default_ban_after")]
177 pub ban_after: u32,
178 #[serde(default = "default_ban_duration")]
179 pub ban_duration_secs: u64,
180 #[serde(default = "default_decay")]
181 pub violation_decay_secs: u64,
182}
183
184fn default_rps() -> f64 { 50.0 }
185fn default_burst() -> f64 { 100.0 }
186fn default_warn_after() -> u32 { 3 }
187fn default_throttle_after() -> u32 { 8 }
188fn default_block_after() -> u32 { 15 }
189fn default_ban_after() -> u32 { 30 }
190fn default_ban_duration() -> u64 { 300 }
191fn default_decay() -> u64 { 60 }
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct QuarantineConfig {
195 #[serde(default = "default_max_rows")]
196 pub max_rows: usize,
197 #[serde(default = "default_max_size")]
198 pub max_size_bytes: usize,
199 #[serde(default = "default_max_cols")]
200 pub max_columns: usize,
201 #[serde(default = "default_true")]
202 pub check_formula_injection: bool,
203 #[serde(default = "default_true")]
204 pub check_embedded_scripts: bool,
205}
206
207fn default_max_rows() -> usize { 5_000_000 }
208fn default_max_size() -> usize { 500 * 1024 * 1024 }
209fn default_max_cols() -> usize { 500 }
210
211impl Default for ShieldConfig {
214 fn default() -> Self {
215 Self {
216 block_threshold: 0.7,
217 warn_threshold: 0.4,
218 sql: SqlFirewallConfig::default(),
219 ssrf: SsrfConfig::default(),
220 rate: RateConfig::default(),
221 quarantine: QuarantineConfig::default(),
222 audit_max_events: 100_000,
223 email: EmailGuardConfig::default(),
224 api_token: None,
225 tls_cert: None,
226 tls_key: None,
227 webhook_urls: Vec::new(),
228 ferrum_mail: None,
229 signature_update: None,
230 }
231 }
232}
233
234impl Default for SqlFirewallConfig {
235 fn default() -> Self {
236 Self {
237 allow_comments: false,
238 max_query_length: 10_000,
239 max_subquery_depth: 3,
240 blocked_functions: Vec::new(),
241 blocked_schemas: Vec::new(),
242 }
243 }
244}
245
246impl Default for SsrfConfig {
247 fn default() -> Self {
248 Self {
249 block_private_ips: true,
250 block_loopback: true,
251 block_link_local: true,
252 block_metadata_endpoints: true,
253 allowed_schemes: default_allowed_schemes(),
254 allowlist: HashSet::new(),
255 blocklist: HashSet::new(),
256 blocked_ports: default_blocked_ports(),
257 }
258 }
259}
260
261impl Default for RateConfig {
262 fn default() -> Self {
263 Self {
264 requests_per_second: 50.0,
265 burst_capacity: 100.0,
266 warn_after: 3,
267 throttle_after: 8,
268 block_after: 15,
269 ban_after: 30,
270 ban_duration_secs: 300,
271 violation_decay_secs: 60,
272 }
273 }
274}
275
276impl Default for QuarantineConfig {
277 fn default() -> Self {
278 Self {
279 max_rows: 5_000_000,
280 max_size_bytes: 500 * 1024 * 1024,
281 max_columns: 500,
282 check_formula_injection: true,
283 check_embedded_scripts: true,
284 }
285 }
286}
287
288pub fn load_config(path: &Path) -> Result<ShieldConfig, String> {
290 let content = std::fs::read_to_string(path)
291 .map_err(|e| format!("Failed to read config file {}: {}", path.display(), e))?;
292 let config: ShieldConfig = toml::from_str(&content)
293 .map_err(|e| format!("Failed to parse config file {}: {}", path.display(), e))?;
294 Ok(config)
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn default_config() {
303 let config = ShieldConfig::default();
304 assert_eq!(config.block_threshold, 0.7);
305 assert_eq!(config.warn_threshold, 0.4);
306 assert!(config.api_token.is_none());
307 assert!(config.tls_cert.is_none());
308 assert!(config.webhook_urls.is_empty());
309 assert!(config.ferrum_mail.is_none());
310 assert!(config.signature_update.is_none());
311 }
312
313 #[test]
314 fn parse_minimal_toml() {
315 let toml = r#"
316block_threshold = 0.8
317warn_threshold = 0.5
318"#;
319 let config: ShieldConfig = toml::from_str(toml).unwrap();
320 assert_eq!(config.block_threshold, 0.8);
321 assert_eq!(config.warn_threshold, 0.5);
322 assert_eq!(config.rate.requests_per_second, 50.0); }
324
325 #[test]
326 fn parse_full_toml() {
327 let toml = r#"
328block_threshold = 0.9
329warn_threshold = 0.6
330api_token = "my-secret-token"
331tls_cert = "/etc/nexus-shield/cert.pem"
332tls_key = "/etc/nexus-shield/key.pem"
333
334[sql]
335allow_comments = true
336max_query_length = 20000
337
338[rate]
339requests_per_second = 100.0
340burst_capacity = 200.0
341ban_duration_secs = 600
342
343[[webhook_urls]]
344url = "https://hooks.slack.com/services/xxx"
345min_severity = "critical"
346webhook_type = "slack"
347
348[ferrum_mail]
349api_url = "http://localhost:3030"
350api_key = "fm-key-123"
351from_address = "shield@company.com"
352alert_recipients = ["admin@company.com", "security@company.com"]
353
354[signature_update]
355feed_url = "https://signatures.nexusshield.dev/v1/latest.ndjson"
356interval_secs = 1800
357"#;
358 let config: ShieldConfig = toml::from_str(toml).unwrap();
359 assert_eq!(config.block_threshold, 0.9);
360 assert_eq!(config.api_token, Some("my-secret-token".to_string()));
361 assert_eq!(config.sql.max_query_length, 20000);
362 assert_eq!(config.rate.requests_per_second, 100.0);
363 assert_eq!(config.webhook_urls.len(), 1);
364 assert_eq!(config.webhook_urls[0].webhook_type, "slack");
365 let fm = config.ferrum_mail.unwrap();
366 assert_eq!(fm.alert_recipients.len(), 2);
367 let su = config.signature_update.unwrap();
368 assert_eq!(su.interval_secs, 1800);
369 }
370
371 #[test]
372 fn parse_empty_toml() {
373 let config: ShieldConfig = toml::from_str("").unwrap();
374 assert_eq!(config.block_threshold, 0.7); }
376
377 #[test]
378 fn load_nonexistent_file() {
379 let result = load_config(Path::new("/nonexistent/config.toml"));
380 assert!(result.is_err());
381 }
382
383 #[test]
384 fn webhook_config_defaults() {
385 let toml = r#"
386[[webhook_urls]]
387url = "https://example.com/hook"
388"#;
389 let config: ShieldConfig = toml::from_str(toml).unwrap();
390 assert_eq!(config.webhook_urls[0].min_severity, "high");
391 assert_eq!(config.webhook_urls[0].webhook_type, "generic");
392 }
393}