Skip to main content

nexus_shield/
config.rs

1// ============================================================================
2// File: config.rs
3// Description: Shield security engine configuration for all defense layers
4// Author: Andrew Jewell Sr. - AutomataNexus
5// Updated: March 26, 2026
6//
7// DISCLAIMER: This software is provided "as is", without warranty of any kind,
8// express or implied. Use at your own risk. AutomataNexus and the author assume
9// no liability for any damages arising from the use of this software.
10// ============================================================================
11use std::collections::HashSet;
12use std::path::Path;
13use serde::{Deserialize, Serialize};
14use crate::email_guard::EmailGuardConfig;
15
16/// Complete configuration for the Shield security engine.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ShieldConfig {
19    /// Threat score threshold above which requests are blocked (0.0–1.0).
20    #[serde(default = "default_block_threshold")]
21    pub block_threshold: f64,
22    /// Threat score threshold for logging warnings (0.0–1.0).
23    #[serde(default = "default_warn_threshold")]
24    pub warn_threshold: f64,
25    /// SQL firewall configuration.
26    #[serde(default)]
27    pub sql: SqlFirewallConfig,
28    /// SSRF guard configuration.
29    #[serde(default)]
30    pub ssrf: SsrfConfig,
31    /// Rate limiting configuration.
32    #[serde(default)]
33    pub rate: RateConfig,
34    /// Data quarantine configuration.
35    #[serde(default)]
36    pub quarantine: QuarantineConfig,
37    /// Maximum audit chain events to keep in memory before pruning.
38    #[serde(default = "default_audit_max")]
39    pub audit_max_events: usize,
40    /// Email guard configuration.
41    #[serde(default)]
42    pub email: EmailGuardConfig,
43    /// API authentication token (if set, all sensitive endpoints require Bearer auth).
44    #[serde(default)]
45    pub api_token: Option<String>,
46    /// TLS certificate path.
47    #[serde(default)]
48    pub tls_cert: Option<String>,
49    /// TLS private key path.
50    #[serde(default)]
51    pub tls_key: Option<String>,
52    /// Webhook alert URLs for critical/high detections.
53    #[serde(default)]
54    pub webhook_urls: Vec<WebhookConfig>,
55    /// Ferrum-Mail integration.
56    #[serde(default)]
57    pub ferrum_mail: Option<FerrumMailConfig>,
58    /// Signature auto-update configuration.
59    #[serde(default)]
60    pub signature_update: Option<SignatureUpdateConfig>,
61    /// NexusPulse SMS alert integration.
62    #[serde(default)]
63    pub nexus_pulse: Option<NexusPulseConfig>,
64}
65
66/// NexusPulse SMS alert configuration.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct NexusPulseConfig {
69    /// NexusPulse API base URL (e.g., http://localhost:8100).
70    pub api_url: String,
71    /// API key for NexusPulse authentication.
72    pub api_key: String,
73    /// Phone numbers to receive SMS alerts (E.164 format).
74    pub alert_recipients: Vec<String>,
75    /// Sender phone number (optional, uses NexusPulse default if omitted).
76    #[serde(default)]
77    pub from_number: Option<String>,
78    /// Minimum severity to trigger SMS (default: critical).
79    #[serde(default = "default_pulse_min_severity")]
80    pub min_severity: String,
81    /// Use the built-in "alert" template for formatted messages.
82    #[serde(default = "default_true")]
83    pub use_template: bool,
84}
85
86fn default_block_threshold() -> f64 { 0.7 }
87fn default_warn_threshold() -> f64 { 0.4 }
88fn default_audit_max() -> usize { 100_000 }
89
90/// Webhook alert configuration.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct WebhookConfig {
93    /// Webhook URL (Slack, Discord, PagerDuty, generic).
94    pub url: String,
95    /// Minimum severity to trigger (info, low, medium, high, critical).
96    #[serde(default = "default_webhook_min_severity")]
97    pub min_severity: String,
98    /// Optional custom headers.
99    #[serde(default)]
100    pub headers: Vec<(String, String)>,
101    /// Webhook type hint for formatting (slack, discord, generic).
102    #[serde(default = "default_webhook_type")]
103    pub webhook_type: String,
104}
105
106fn default_webhook_min_severity() -> String { "high".to_string() }
107fn default_webhook_type() -> String { "generic".to_string() }
108
109/// Ferrum-Mail integration for email alerts.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct FerrumMailConfig {
112    /// Ferrum-Mail API base URL.
113    pub api_url: String,
114    /// API key for Ferrum-Mail authentication.
115    pub api_key: String,
116    /// Sender address for alert emails.
117    pub from_address: String,
118    /// Recipient addresses for alerts.
119    pub alert_recipients: Vec<String>,
120    /// Minimum severity to send email (default: high).
121    #[serde(default = "default_webhook_min_severity")]
122    pub min_severity: String,
123    /// Include full event details in email body.
124    #[serde(default = "default_true")]
125    pub include_details: bool,
126}
127
128fn default_true() -> bool { true }
129
130/// Automatic signature update configuration.
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct SignatureUpdateConfig {
133    /// URL to fetch NDJSON signatures from.
134    pub feed_url: String,
135    /// Update interval in seconds (default: 3600 = 1 hour).
136    #[serde(default = "default_sig_interval")]
137    pub interval_secs: u64,
138    /// Optional authentication header for the feed.
139    #[serde(default)]
140    pub auth_header: Option<String>,
141}
142
143fn default_sig_interval() -> u64 { 3600 }
144fn default_pulse_min_severity() -> String { "critical".to_string() }
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct SqlFirewallConfig {
148    #[serde(default)]
149    pub allow_comments: bool,
150    #[serde(default = "default_max_query_length")]
151    pub max_query_length: usize,
152    #[serde(default = "default_max_subquery_depth")]
153    pub max_subquery_depth: u32,
154    #[serde(default)]
155    pub blocked_functions: Vec<String>,
156    #[serde(default)]
157    pub blocked_schemas: Vec<String>,
158}
159
160fn default_max_query_length() -> usize { 10_000 }
161fn default_max_subquery_depth() -> u32 { 3 }
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct SsrfConfig {
165    #[serde(default = "default_true")]
166    pub block_private_ips: bool,
167    #[serde(default = "default_true")]
168    pub block_loopback: bool,
169    #[serde(default = "default_true")]
170    pub block_link_local: bool,
171    #[serde(default = "default_true")]
172    pub block_metadata_endpoints: bool,
173    #[serde(default = "default_allowed_schemes")]
174    pub allowed_schemes: Vec<String>,
175    #[serde(default)]
176    pub allowlist: HashSet<String>,
177    #[serde(default)]
178    pub blocklist: HashSet<String>,
179    #[serde(default = "default_blocked_ports")]
180    pub blocked_ports: Vec<u16>,
181}
182
183fn default_allowed_schemes() -> Vec<String> { vec!["http".into(), "https".into()] }
184fn default_blocked_ports() -> Vec<u16> {
185    vec![22, 23, 25, 53, 111, 135, 139, 445, 514, 873, 2049, 3306, 5432, 6379, 6380, 9200, 9300, 11211, 27017, 27018, 50070]
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct RateConfig {
190    #[serde(default = "default_rps")]
191    pub requests_per_second: f64,
192    #[serde(default = "default_burst")]
193    pub burst_capacity: f64,
194    #[serde(default = "default_warn_after")]
195    pub warn_after: u32,
196    #[serde(default = "default_throttle_after")]
197    pub throttle_after: u32,
198    #[serde(default = "default_block_after")]
199    pub block_after: u32,
200    #[serde(default = "default_ban_after")]
201    pub ban_after: u32,
202    #[serde(default = "default_ban_duration")]
203    pub ban_duration_secs: u64,
204    #[serde(default = "default_decay")]
205    pub violation_decay_secs: u64,
206}
207
208fn default_rps() -> f64 { 50.0 }
209fn default_burst() -> f64 { 100.0 }
210fn default_warn_after() -> u32 { 3 }
211fn default_throttle_after() -> u32 { 8 }
212fn default_block_after() -> u32 { 15 }
213fn default_ban_after() -> u32 { 30 }
214fn default_ban_duration() -> u64 { 300 }
215fn default_decay() -> u64 { 60 }
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct QuarantineConfig {
219    #[serde(default = "default_max_rows")]
220    pub max_rows: usize,
221    #[serde(default = "default_max_size")]
222    pub max_size_bytes: usize,
223    #[serde(default = "default_max_cols")]
224    pub max_columns: usize,
225    #[serde(default = "default_true")]
226    pub check_formula_injection: bool,
227    #[serde(default = "default_true")]
228    pub check_embedded_scripts: bool,
229}
230
231fn default_max_rows() -> usize { 5_000_000 }
232fn default_max_size() -> usize { 500 * 1024 * 1024 }
233fn default_max_cols() -> usize { 500 }
234
235// === Default implementations ===
236
237impl Default for ShieldConfig {
238    fn default() -> Self {
239        Self {
240            block_threshold: 0.7,
241            warn_threshold: 0.4,
242            sql: SqlFirewallConfig::default(),
243            ssrf: SsrfConfig::default(),
244            rate: RateConfig::default(),
245            quarantine: QuarantineConfig::default(),
246            audit_max_events: 100_000,
247            email: EmailGuardConfig::default(),
248            api_token: None,
249            tls_cert: None,
250            tls_key: None,
251            webhook_urls: Vec::new(),
252            ferrum_mail: None,
253            signature_update: None,
254            nexus_pulse: None,
255        }
256    }
257}
258
259impl Default for SqlFirewallConfig {
260    fn default() -> Self {
261        Self {
262            allow_comments: false,
263            max_query_length: 10_000,
264            max_subquery_depth: 3,
265            blocked_functions: Vec::new(),
266            blocked_schemas: Vec::new(),
267        }
268    }
269}
270
271impl Default for SsrfConfig {
272    fn default() -> Self {
273        Self {
274            block_private_ips: true,
275            block_loopback: true,
276            block_link_local: true,
277            block_metadata_endpoints: true,
278            allowed_schemes: default_allowed_schemes(),
279            allowlist: HashSet::new(),
280            blocklist: HashSet::new(),
281            blocked_ports: default_blocked_ports(),
282        }
283    }
284}
285
286impl Default for RateConfig {
287    fn default() -> Self {
288        Self {
289            requests_per_second: 50.0,
290            burst_capacity: 100.0,
291            warn_after: 3,
292            throttle_after: 8,
293            block_after: 15,
294            ban_after: 30,
295            ban_duration_secs: 300,
296            violation_decay_secs: 60,
297        }
298    }
299}
300
301impl Default for QuarantineConfig {
302    fn default() -> Self {
303        Self {
304            max_rows: 5_000_000,
305            max_size_bytes: 500 * 1024 * 1024,
306            max_columns: 500,
307            check_formula_injection: true,
308            check_embedded_scripts: true,
309        }
310    }
311}
312
313/// Load configuration from a TOML file, falling back to defaults for missing fields.
314pub fn load_config(path: &Path) -> Result<ShieldConfig, String> {
315    let content = std::fs::read_to_string(path)
316        .map_err(|e| format!("Failed to read config file {}: {}", path.display(), e))?;
317    let config: ShieldConfig = toml::from_str(&content)
318        .map_err(|e| format!("Failed to parse config file {}: {}", path.display(), e))?;
319    Ok(config)
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn default_config() {
328        let config = ShieldConfig::default();
329        assert_eq!(config.block_threshold, 0.7);
330        assert_eq!(config.warn_threshold, 0.4);
331        assert!(config.api_token.is_none());
332        assert!(config.tls_cert.is_none());
333        assert!(config.webhook_urls.is_empty());
334        assert!(config.ferrum_mail.is_none());
335        assert!(config.signature_update.is_none());
336    }
337
338    #[test]
339    fn parse_minimal_toml() {
340        let toml = r#"
341block_threshold = 0.8
342warn_threshold = 0.5
343"#;
344        let config: ShieldConfig = toml::from_str(toml).unwrap();
345        assert_eq!(config.block_threshold, 0.8);
346        assert_eq!(config.warn_threshold, 0.5);
347        assert_eq!(config.rate.requests_per_second, 50.0); // default
348    }
349
350    #[test]
351    fn parse_full_toml() {
352        let toml = r#"
353block_threshold = 0.9
354warn_threshold = 0.6
355api_token = "my-secret-token"
356tls_cert = "/etc/nexus-shield/cert.pem"
357tls_key = "/etc/nexus-shield/key.pem"
358
359[sql]
360allow_comments = true
361max_query_length = 20000
362
363[rate]
364requests_per_second = 100.0
365burst_capacity = 200.0
366ban_duration_secs = 600
367
368[[webhook_urls]]
369url = "https://hooks.slack.com/services/xxx"
370min_severity = "critical"
371webhook_type = "slack"
372
373[ferrum_mail]
374api_url = "http://localhost:3030"
375api_key = "fm-key-123"
376from_address = "shield@company.com"
377alert_recipients = ["admin@company.com", "security@company.com"]
378
379[signature_update]
380feed_url = "https://signatures.nexusshield.dev/v1/latest.ndjson"
381interval_secs = 1800
382"#;
383        let config: ShieldConfig = toml::from_str(toml).unwrap();
384        assert_eq!(config.block_threshold, 0.9);
385        assert_eq!(config.api_token, Some("my-secret-token".to_string()));
386        assert_eq!(config.sql.max_query_length, 20000);
387        assert_eq!(config.rate.requests_per_second, 100.0);
388        assert_eq!(config.webhook_urls.len(), 1);
389        assert_eq!(config.webhook_urls[0].webhook_type, "slack");
390        let fm = config.ferrum_mail.unwrap();
391        assert_eq!(fm.alert_recipients.len(), 2);
392        let su = config.signature_update.unwrap();
393        assert_eq!(su.interval_secs, 1800);
394    }
395
396    #[test]
397    fn parse_empty_toml() {
398        let config: ShieldConfig = toml::from_str("").unwrap();
399        assert_eq!(config.block_threshold, 0.7); // all defaults
400    }
401
402    #[test]
403    fn load_nonexistent_file() {
404        let result = load_config(Path::new("/nonexistent/config.toml"));
405        assert!(result.is_err());
406    }
407
408    #[test]
409    fn webhook_config_defaults() {
410        let toml = r#"
411[[webhook_urls]]
412url = "https://example.com/hook"
413"#;
414        let config: ShieldConfig = toml::from_str(toml).unwrap();
415        assert_eq!(config.webhook_urls[0].min_severity, "high");
416        assert_eq!(config.webhook_urls[0].webhook_type, "generic");
417    }
418}