Skip to main content

nexus_shield/
config.rs

1// ============================================================================
2// File: config.rs
3// Description: Shield security engine configuration for all defense layers
4// Author: Andrew Jewell Sr. - AutomataNexus
5// Updated: March 26, 2026
6//
7// DISCLAIMER: This software is provided "as is", without warranty of any kind,
8// express or implied. Use at your own risk. AutomataNexus and the author assume
9// no liability for any damages arising from the use of this software.
10// ============================================================================
11use std::collections::HashSet;
12use std::path::Path;
13use serde::{Deserialize, Serialize};
14use crate::email_guard::EmailGuardConfig;
15
16/// Complete configuration for the Shield security engine.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ShieldConfig {
19    /// Threat score threshold above which requests are blocked (0.0–1.0).
20    #[serde(default = "default_block_threshold")]
21    pub block_threshold: f64,
22    /// Threat score threshold for logging warnings (0.0–1.0).
23    #[serde(default = "default_warn_threshold")]
24    pub warn_threshold: f64,
25    /// SQL firewall configuration.
26    #[serde(default)]
27    pub sql: SqlFirewallConfig,
28    /// SSRF guard configuration.
29    #[serde(default)]
30    pub ssrf: SsrfConfig,
31    /// Rate limiting configuration.
32    #[serde(default)]
33    pub rate: RateConfig,
34    /// Data quarantine configuration.
35    #[serde(default)]
36    pub quarantine: QuarantineConfig,
37    /// Maximum audit chain events to keep in memory before pruning.
38    #[serde(default = "default_audit_max")]
39    pub audit_max_events: usize,
40    /// Email guard configuration.
41    #[serde(default)]
42    pub email: EmailGuardConfig,
43    /// API authentication token (if set, all sensitive endpoints require Bearer auth).
44    #[serde(default)]
45    pub api_token: Option<String>,
46    /// TLS certificate path.
47    #[serde(default)]
48    pub tls_cert: Option<String>,
49    /// TLS private key path.
50    #[serde(default)]
51    pub tls_key: Option<String>,
52    /// Webhook alert URLs for critical/high detections.
53    #[serde(default)]
54    pub webhook_urls: Vec<WebhookConfig>,
55    /// Ferrum-Mail integration.
56    #[serde(default)]
57    pub ferrum_mail: Option<FerrumMailConfig>,
58    /// Signature auto-update configuration.
59    #[serde(default)]
60    pub signature_update: Option<SignatureUpdateConfig>,
61}
62
63fn default_block_threshold() -> f64 { 0.7 }
64fn default_warn_threshold() -> f64 { 0.4 }
65fn default_audit_max() -> usize { 100_000 }
66
67/// Webhook alert configuration.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct WebhookConfig {
70    /// Webhook URL (Slack, Discord, PagerDuty, generic).
71    pub url: String,
72    /// Minimum severity to trigger (info, low, medium, high, critical).
73    #[serde(default = "default_webhook_min_severity")]
74    pub min_severity: String,
75    /// Optional custom headers.
76    #[serde(default)]
77    pub headers: Vec<(String, String)>,
78    /// Webhook type hint for formatting (slack, discord, generic).
79    #[serde(default = "default_webhook_type")]
80    pub webhook_type: String,
81}
82
83fn default_webhook_min_severity() -> String { "high".to_string() }
84fn default_webhook_type() -> String { "generic".to_string() }
85
86/// Ferrum-Mail integration for email alerts.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct FerrumMailConfig {
89    /// Ferrum-Mail API base URL.
90    pub api_url: String,
91    /// API key for Ferrum-Mail authentication.
92    pub api_key: String,
93    /// Sender address for alert emails.
94    pub from_address: String,
95    /// Recipient addresses for alerts.
96    pub alert_recipients: Vec<String>,
97    /// Minimum severity to send email (default: high).
98    #[serde(default = "default_webhook_min_severity")]
99    pub min_severity: String,
100    /// Include full event details in email body.
101    #[serde(default = "default_true")]
102    pub include_details: bool,
103}
104
105fn default_true() -> bool { true }
106
107/// Automatic signature update configuration.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct SignatureUpdateConfig {
110    /// URL to fetch NDJSON signatures from.
111    pub feed_url: String,
112    /// Update interval in seconds (default: 3600 = 1 hour).
113    #[serde(default = "default_sig_interval")]
114    pub interval_secs: u64,
115    /// Optional authentication header for the feed.
116    #[serde(default)]
117    pub auth_header: Option<String>,
118}
119
120fn default_sig_interval() -> u64 { 3600 }
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct SqlFirewallConfig {
124    #[serde(default)]
125    pub allow_comments: bool,
126    #[serde(default = "default_max_query_length")]
127    pub max_query_length: usize,
128    #[serde(default = "default_max_subquery_depth")]
129    pub max_subquery_depth: u32,
130    #[serde(default)]
131    pub blocked_functions: Vec<String>,
132    #[serde(default)]
133    pub blocked_schemas: Vec<String>,
134}
135
136fn default_max_query_length() -> usize { 10_000 }
137fn default_max_subquery_depth() -> u32 { 3 }
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct SsrfConfig {
141    #[serde(default = "default_true")]
142    pub block_private_ips: bool,
143    #[serde(default = "default_true")]
144    pub block_loopback: bool,
145    #[serde(default = "default_true")]
146    pub block_link_local: bool,
147    #[serde(default = "default_true")]
148    pub block_metadata_endpoints: bool,
149    #[serde(default = "default_allowed_schemes")]
150    pub allowed_schemes: Vec<String>,
151    #[serde(default)]
152    pub allowlist: HashSet<String>,
153    #[serde(default)]
154    pub blocklist: HashSet<String>,
155    #[serde(default = "default_blocked_ports")]
156    pub blocked_ports: Vec<u16>,
157}
158
159fn default_allowed_schemes() -> Vec<String> { vec!["http".into(), "https".into()] }
160fn default_blocked_ports() -> Vec<u16> {
161    vec![22, 23, 25, 53, 111, 135, 139, 445, 514, 873, 2049, 3306, 5432, 6379, 6380, 9200, 9300, 11211, 27017, 27018, 50070]
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct RateConfig {
166    #[serde(default = "default_rps")]
167    pub requests_per_second: f64,
168    #[serde(default = "default_burst")]
169    pub burst_capacity: f64,
170    #[serde(default = "default_warn_after")]
171    pub warn_after: u32,
172    #[serde(default = "default_throttle_after")]
173    pub throttle_after: u32,
174    #[serde(default = "default_block_after")]
175    pub block_after: u32,
176    #[serde(default = "default_ban_after")]
177    pub ban_after: u32,
178    #[serde(default = "default_ban_duration")]
179    pub ban_duration_secs: u64,
180    #[serde(default = "default_decay")]
181    pub violation_decay_secs: u64,
182}
183
184fn default_rps() -> f64 { 50.0 }
185fn default_burst() -> f64 { 100.0 }
186fn default_warn_after() -> u32 { 3 }
187fn default_throttle_after() -> u32 { 8 }
188fn default_block_after() -> u32 { 15 }
189fn default_ban_after() -> u32 { 30 }
190fn default_ban_duration() -> u64 { 300 }
191fn default_decay() -> u64 { 60 }
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct QuarantineConfig {
195    #[serde(default = "default_max_rows")]
196    pub max_rows: usize,
197    #[serde(default = "default_max_size")]
198    pub max_size_bytes: usize,
199    #[serde(default = "default_max_cols")]
200    pub max_columns: usize,
201    #[serde(default = "default_true")]
202    pub check_formula_injection: bool,
203    #[serde(default = "default_true")]
204    pub check_embedded_scripts: bool,
205}
206
207fn default_max_rows() -> usize { 5_000_000 }
208fn default_max_size() -> usize { 500 * 1024 * 1024 }
209fn default_max_cols() -> usize { 500 }
210
211// === Default implementations ===
212
213impl Default for ShieldConfig {
214    fn default() -> Self {
215        Self {
216            block_threshold: 0.7,
217            warn_threshold: 0.4,
218            sql: SqlFirewallConfig::default(),
219            ssrf: SsrfConfig::default(),
220            rate: RateConfig::default(),
221            quarantine: QuarantineConfig::default(),
222            audit_max_events: 100_000,
223            email: EmailGuardConfig::default(),
224            api_token: None,
225            tls_cert: None,
226            tls_key: None,
227            webhook_urls: Vec::new(),
228            ferrum_mail: None,
229            signature_update: None,
230        }
231    }
232}
233
234impl Default for SqlFirewallConfig {
235    fn default() -> Self {
236        Self {
237            allow_comments: false,
238            max_query_length: 10_000,
239            max_subquery_depth: 3,
240            blocked_functions: Vec::new(),
241            blocked_schemas: Vec::new(),
242        }
243    }
244}
245
246impl Default for SsrfConfig {
247    fn default() -> Self {
248        Self {
249            block_private_ips: true,
250            block_loopback: true,
251            block_link_local: true,
252            block_metadata_endpoints: true,
253            allowed_schemes: default_allowed_schemes(),
254            allowlist: HashSet::new(),
255            blocklist: HashSet::new(),
256            blocked_ports: default_blocked_ports(),
257        }
258    }
259}
260
261impl Default for RateConfig {
262    fn default() -> Self {
263        Self {
264            requests_per_second: 50.0,
265            burst_capacity: 100.0,
266            warn_after: 3,
267            throttle_after: 8,
268            block_after: 15,
269            ban_after: 30,
270            ban_duration_secs: 300,
271            violation_decay_secs: 60,
272        }
273    }
274}
275
276impl Default for QuarantineConfig {
277    fn default() -> Self {
278        Self {
279            max_rows: 5_000_000,
280            max_size_bytes: 500 * 1024 * 1024,
281            max_columns: 500,
282            check_formula_injection: true,
283            check_embedded_scripts: true,
284        }
285    }
286}
287
288/// Load configuration from a TOML file, falling back to defaults for missing fields.
289pub fn load_config(path: &Path) -> Result<ShieldConfig, String> {
290    let content = std::fs::read_to_string(path)
291        .map_err(|e| format!("Failed to read config file {}: {}", path.display(), e))?;
292    let config: ShieldConfig = toml::from_str(&content)
293        .map_err(|e| format!("Failed to parse config file {}: {}", path.display(), e))?;
294    Ok(config)
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn default_config() {
303        let config = ShieldConfig::default();
304        assert_eq!(config.block_threshold, 0.7);
305        assert_eq!(config.warn_threshold, 0.4);
306        assert!(config.api_token.is_none());
307        assert!(config.tls_cert.is_none());
308        assert!(config.webhook_urls.is_empty());
309        assert!(config.ferrum_mail.is_none());
310        assert!(config.signature_update.is_none());
311    }
312
313    #[test]
314    fn parse_minimal_toml() {
315        let toml = r#"
316block_threshold = 0.8
317warn_threshold = 0.5
318"#;
319        let config: ShieldConfig = toml::from_str(toml).unwrap();
320        assert_eq!(config.block_threshold, 0.8);
321        assert_eq!(config.warn_threshold, 0.5);
322        assert_eq!(config.rate.requests_per_second, 50.0); // default
323    }
324
325    #[test]
326    fn parse_full_toml() {
327        let toml = r#"
328block_threshold = 0.9
329warn_threshold = 0.6
330api_token = "my-secret-token"
331tls_cert = "/etc/nexus-shield/cert.pem"
332tls_key = "/etc/nexus-shield/key.pem"
333
334[sql]
335allow_comments = true
336max_query_length = 20000
337
338[rate]
339requests_per_second = 100.0
340burst_capacity = 200.0
341ban_duration_secs = 600
342
343[[webhook_urls]]
344url = "https://hooks.slack.com/services/xxx"
345min_severity = "critical"
346webhook_type = "slack"
347
348[ferrum_mail]
349api_url = "http://localhost:3030"
350api_key = "fm-key-123"
351from_address = "shield@company.com"
352alert_recipients = ["admin@company.com", "security@company.com"]
353
354[signature_update]
355feed_url = "https://signatures.nexusshield.dev/v1/latest.ndjson"
356interval_secs = 1800
357"#;
358        let config: ShieldConfig = toml::from_str(toml).unwrap();
359        assert_eq!(config.block_threshold, 0.9);
360        assert_eq!(config.api_token, Some("my-secret-token".to_string()));
361        assert_eq!(config.sql.max_query_length, 20000);
362        assert_eq!(config.rate.requests_per_second, 100.0);
363        assert_eq!(config.webhook_urls.len(), 1);
364        assert_eq!(config.webhook_urls[0].webhook_type, "slack");
365        let fm = config.ferrum_mail.unwrap();
366        assert_eq!(fm.alert_recipients.len(), 2);
367        let su = config.signature_update.unwrap();
368        assert_eq!(su.interval_secs, 1800);
369    }
370
371    #[test]
372    fn parse_empty_toml() {
373        let config: ShieldConfig = toml::from_str("").unwrap();
374        assert_eq!(config.block_threshold, 0.7); // all defaults
375    }
376
377    #[test]
378    fn load_nonexistent_file() {
379        let result = load_config(Path::new("/nonexistent/config.toml"));
380        assert!(result.is_err());
381    }
382
383    #[test]
384    fn webhook_config_defaults() {
385        let toml = r#"
386[[webhook_urls]]
387url = "https://example.com/hook"
388"#;
389        let config: ShieldConfig = toml::from_str(toml).unwrap();
390        assert_eq!(config.webhook_urls[0].min_severity, "high");
391        assert_eq!(config.webhook_urls[0].webhook_type, "generic");
392    }
393}