1pub mod auth;
28pub mod audit_chain;
29pub mod compliance_report;
30pub mod config;
31pub mod credential_vault;
32pub mod endpoint;
33pub mod ferrum_integration;
34pub mod fingerprint;
35pub mod journal;
36pub mod metrics;
37pub mod nexuspulse_integration;
38pub mod quarantine;
39pub mod rate_governor;
40pub mod sanitizer;
41pub mod siem_export;
42pub mod signature_updater;
43pub mod sql_firewall;
44pub mod sse_events;
45pub mod ssrf_guard;
46pub mod email_guard;
47pub mod threat_score;
48pub mod webhook;
49
50use std::sync::Arc;
51
52use axum::extract::Request;
53use axum::http::StatusCode;
54use axum::middleware::Next;
55use axum::response::{IntoResponse, Response};
56use axum::Extension;
57
58pub use audit_chain::{AuditChain, AuditEvent, SecurityEventType};
59pub use config::ShieldConfig;
60pub use email_guard::{EmailGuardConfig, EmailRateLimiter};
61pub use threat_score::{ThreatAction, ThreatAssessment};
62
63#[derive(Debug, thiserror::Error)]
65pub enum ShieldError {
66 #[error("SQL injection detected: {0}")]
67 SqlInjectionDetected(String),
68
69 #[error("Request blocked by SSRF guard: {0}")]
70 SsrfBlocked(String),
71
72 #[error("Rate limit exceeded")]
73 RateLimitExceeded { retry_after: Option<u64> },
74
75 #[error("Request blocked (threat score: {0:.3})")]
76 ThreatScoreExceeded(f64),
77
78 #[error("Malicious input detected: {0}")]
79 MaliciousInput(String),
80
81 #[error("Path traversal blocked: {0}")]
82 PathTraversal(String),
83
84 #[error("Invalid connection configuration: {0}")]
85 InvalidConnectionString(String),
86
87 #[error("Data quarantine failed: {0}")]
88 QuarantineFailed(String),
89
90 #[error("Email security violation: {0}")]
91 EmailViolation(String),
92
93 #[error("Email rate limit exceeded for {0}")]
94 EmailBombing(String),
95
96 #[error("Malware detected: {0}")]
97 MalwareDetected(String),
98
99 #[error("Endpoint protection error: {0}")]
100 EndpointError(String),
101
102 #[error("Quarantine vault error: {0}")]
103 QuarantineVaultError(String),
104}
105
106impl IntoResponse for ShieldError {
107 fn into_response(self) -> Response {
108 let (status, message) = match &self {
110 Self::SqlInjectionDetected(_) => {
111 (StatusCode::FORBIDDEN, "Request blocked by security policy")
112 }
113 Self::SsrfBlocked(_) => {
114 (StatusCode::FORBIDDEN, "Request blocked by security policy")
115 }
116 Self::RateLimitExceeded { .. } => {
117 (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded")
118 }
119 Self::ThreatScoreExceeded(_) => {
120 (StatusCode::FORBIDDEN, "Request blocked by security policy")
121 }
122 Self::MaliciousInput(_) => (StatusCode::BAD_REQUEST, "Invalid input detected"),
123 Self::PathTraversal(_) => {
124 (StatusCode::FORBIDDEN, "Request blocked by security policy")
125 }
126 Self::InvalidConnectionString(_) => {
127 (StatusCode::BAD_REQUEST, "Invalid connection configuration")
128 }
129 Self::QuarantineFailed(_) => {
130 (StatusCode::BAD_REQUEST, "Data validation failed")
131 }
132 Self::EmailViolation(_) => {
133 (StatusCode::BAD_REQUEST, "Email validation failed")
134 }
135 Self::EmailBombing(_) => {
136 (StatusCode::TOO_MANY_REQUESTS, "Email rate limit exceeded")
137 }
138 Self::MalwareDetected(_) => {
139 (StatusCode::FORBIDDEN, "Request blocked by security policy")
140 }
141 Self::EndpointError(_) => {
142 (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
143 }
144 Self::QuarantineVaultError(_) => {
145 (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
146 }
147 };
148 (status, message).into_response()
149 }
150}
151
152pub struct Shield {
154 pub config: ShieldConfig,
155 pub audit: Arc<AuditChain>,
156 pub rate_governor: Arc<rate_governor::RateGovernor>,
157 pub fingerprinter: Arc<fingerprint::Fingerprinter>,
158 pub email_limiter: Arc<EmailRateLimiter>,
159 pub endpoint: Option<Arc<endpoint::EndpointEngine>>,
160}
161
162impl Shield {
163 pub fn new(config: ShieldConfig) -> Self {
164 let audit = Arc::new(AuditChain::with_max_events(config.audit_max_events));
165 let rate_governor = Arc::new(rate_governor::RateGovernor::new(&config));
166 let fingerprinter = Arc::new(fingerprint::Fingerprinter::new());
167 let email_limiter = Arc::new(EmailRateLimiter::new(config.email.clone()));
168 Self {
169 config,
170 audit,
171 rate_governor,
172 fingerprinter,
173 email_limiter,
174 endpoint: None,
175 }
176 }
177
178 pub fn validate_sql(&self, sql: &str) -> Result<(), ShieldError> {
182 let analysis = sql_firewall::analyze_query(sql, &self.config.sql);
183 if analysis.allowed {
184 Ok(())
185 } else {
186 let reason = analysis
187 .violations
188 .iter()
189 .map(|v| format!("{:?}", v))
190 .collect::<Vec<_>>()
191 .join(", ");
192 self.audit.record(
193 SecurityEventType::SqlInjectionAttempt,
194 "api",
195 &reason,
196 analysis.risk_score,
197 );
198 Err(ShieldError::SqlInjectionDetected(reason))
199 }
200 }
201
202 pub fn validate_url(&self, url: &str) -> Result<(), ShieldError> {
204 ssrf_guard::validate_url(url, &self.config.ssrf).map_err(|reason| {
205 self.audit.record(
206 SecurityEventType::SsrfAttempt,
207 "api",
208 &reason,
209 0.9,
210 );
211 ShieldError::SsrfBlocked(reason)
212 })
213 }
214
215 pub fn validate_ip(&self, ip: &str) -> Result<(), ShieldError> {
217 ssrf_guard::validate_ip_str(ip, &self.config.ssrf).map_err(|reason| {
218 self.audit.record(
219 SecurityEventType::SsrfAttempt,
220 "api",
221 &reason,
222 0.9,
223 );
224 ShieldError::SsrfBlocked(reason)
225 })
226 }
227
228 pub fn validate_connection_string(&self, conn_str: &str) -> Result<String, ShieldError> {
230 sanitizer::validate_connection_string(conn_str).map_err(|reason| {
231 self.audit.record(
232 SecurityEventType::MaliciousPayload,
233 "api",
234 &reason,
235 0.8,
236 );
237 ShieldError::InvalidConnectionString(reason)
238 })
239 }
240
241 pub fn validate_file_path(&self, path: &str) -> Result<(), ShieldError> {
243 sanitizer::validate_file_path(path).map_err(|reason| {
244 self.audit.record(
245 SecurityEventType::PathTraversalAttempt,
246 "api",
247 &reason,
248 0.9,
249 );
250 ShieldError::PathTraversal(reason)
251 })
252 }
253
254 pub fn quarantine_csv(&self, content: &str) -> Result<(), ShieldError> {
256 let result = quarantine::validate_csv(content, &self.config.quarantine);
257 if result.passed {
258 Ok(())
259 } else {
260 let reason = result
261 .violations
262 .iter()
263 .map(|v| format!("{:?}", v))
264 .collect::<Vec<_>>()
265 .join(", ");
266 self.audit.record(
267 SecurityEventType::DataQuarantined,
268 "api",
269 &reason,
270 0.7,
271 );
272 Err(ShieldError::QuarantineFailed(reason))
273 }
274 }
275
276 pub fn quarantine_json(&self, json: &str) -> Result<(), ShieldError> {
278 quarantine::validate_json_response(json, self.config.quarantine.max_size_bytes)
279 .map_err(|reason| {
280 self.audit.record(
281 SecurityEventType::DataQuarantined,
282 "api",
283 &reason,
284 0.6,
285 );
286 ShieldError::MaliciousInput(reason)
287 })
288 }
289
290 pub fn validate_email_address(&self, addr: &str) -> Result<(), ShieldError> {
294 let violations = email_guard::validate_email_address(addr, &self.config.email);
295 if violations.is_empty() {
296 Ok(())
297 } else {
298 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
299 self.audit.record(
300 SecurityEventType::MaliciousPayload,
301 "email",
302 &reason,
303 0.7,
304 );
305 Err(ShieldError::EmailViolation(reason))
306 }
307 }
308
309 pub fn validate_email_header(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
311 let max_len = match field_name {
312 "subject" => self.config.email.max_subject_len,
313 _ => self.config.email.max_name_len,
314 };
315 let violations = email_guard::validate_header_field(field_name, value, max_len);
316 if violations.is_empty() {
317 Ok(())
318 } else {
319 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
320 self.audit.record(
321 SecurityEventType::MaliciousPayload,
322 "email",
323 &format!("header injection in {}: {}", field_name, reason),
324 0.8,
325 );
326 Err(ShieldError::EmailViolation(reason))
327 }
328 }
329
330 pub fn validate_email_content(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
332 let violations = email_guard::validate_template_content(
333 field_name, value, self.config.email.max_body_len,
334 );
335 if violations.is_empty() {
336 Ok(())
337 } else {
338 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
339 self.audit.record(
340 SecurityEventType::MaliciousPayload,
341 "email",
342 &format!("content injection in {}: {}", field_name, reason),
343 0.8,
344 );
345 Err(ShieldError::EmailViolation(reason))
346 }
347 }
348
349 pub fn check_email_rate(&self, recipient: &str) -> Result<(), ShieldError> {
351 if self.email_limiter.check_and_record(recipient) {
352 Ok(())
353 } else {
354 self.audit.record(
355 SecurityEventType::RateLimitHit,
356 "email",
357 &format!("email bombing attempt to {}", recipient),
358 0.9,
359 );
360 Err(ShieldError::EmailBombing(recipient.to_string()))
361 }
362 }
363
364 pub fn validate_outbound_email(
366 &self,
367 to: &[&str],
368 subject: &str,
369 body_fields: &[(&str, &str)],
370 ) -> Result<(), ShieldError> {
371 email_guard::validate_outbound_email(to, subject, body_fields, &self.config.email)
373 .map_err(|reason| {
374 self.audit.record(
375 SecurityEventType::MaliciousPayload,
376 "email",
377 &reason,
378 0.7,
379 );
380 ShieldError::EmailViolation(reason)
381 })?;
382
383 for addr in to {
385 self.check_email_rate(addr)?;
386 }
387
388 Ok(())
389 }
390
391 pub fn escape_email_content(value: &str) -> String {
393 email_guard::html_escape(value)
394 }
395}
396
397pub async fn shield_middleware(
408 shield: Option<Extension<Arc<Shield>>>,
409 request: Request,
410 next: Next,
411) -> Response {
412 let shield = match shield {
413 Some(Extension(s)) => s,
414 None => return next.run(request).await,
415 };
416
417 let client_ip = extract_client_ip(&request);
418
419 let rate_result = shield.rate_governor.check(&client_ip);
421 if !rate_result.allowed {
422 shield.audit.record(
423 SecurityEventType::RateLimitHit,
424 &client_ip,
425 &format!(
426 "escalation={:?}, violations={}",
427 rate_result.escalation, rate_result.violations
428 ),
429 0.8,
430 );
431 let mut resp = (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
432 if let Some(retry_after) = rate_result.retry_after {
433 resp.headers_mut().insert(
434 "Retry-After",
435 retry_after.to_string().parse().unwrap(),
436 );
437 }
438 return resp;
439 }
440
441 let fp = shield.fingerprinter.analyze(request.headers());
443
444 let behavioral = shield.fingerprinter.behavioral_score(&client_ip);
446
447 let recent_violations = {
449 let since = chrono::Utc::now() - chrono::Duration::minutes(5);
450 shield
451 .audit
452 .count_since(&SecurityEventType::RequestBlocked, since)
453 > 0
454 };
455
456 let assessment = threat_score::assess(
458 &fp,
459 &rate_result,
460 behavioral,
461 recent_violations,
462 shield.config.warn_threshold,
463 shield.config.block_threshold,
464 );
465
466 match assessment.action {
467 ThreatAction::Block => {
468 shield.audit.record(
469 SecurityEventType::RequestBlocked,
470 &client_ip,
471 &format!(
472 "score={:.3}, fingerprint={:.3}, rate={:.3}, behavioral={:.3}",
473 assessment.score,
474 assessment.signals.fingerprint_anomaly,
475 assessment.signals.rate_pressure,
476 assessment.signals.behavioral_anomaly,
477 ),
478 assessment.score,
479 );
480 return (StatusCode::FORBIDDEN, "Request blocked by security policy").into_response();
481 }
482 ThreatAction::Warn => {
483 tracing::warn!(
484 ip = %client_ip,
485 score = assessment.score,
486 "NexusShield: elevated threat score"
487 );
488 shield.audit.record(
489 SecurityEventType::RequestAllowed,
490 &client_ip,
491 &format!("WARN: score={:.3}", assessment.score),
492 assessment.score,
493 );
494 }
495 ThreatAction::Allow => {
496 shield.fingerprinter.record_request(&client_ip);
498 }
499 }
500
501 let response = next.run(request).await;
503
504 if response.status().is_client_error() || response.status().is_server_error() {
506 shield.fingerprinter.record_error(&client_ip);
507 }
508
509 response
510}
511
512fn extract_client_ip(req: &Request) -> String {
514 if let Some(xff) = req.headers().get("x-forwarded-for") {
516 if let Ok(value) = xff.to_str() {
517 if let Some(first_ip) = value.split(',').next() {
518 let ip = first_ip.trim();
519 if !ip.is_empty() {
520 return ip.to_string();
521 }
522 }
523 }
524 }
525
526 if let Some(xri) = req.headers().get("x-real-ip") {
528 if let Ok(value) = xri.to_str() {
529 let ip = value.trim();
530 if !ip.is_empty() {
531 return ip.to_string();
532 }
533 }
534 }
535
536 "unknown".to_string()
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use axum::http::{HeaderValue, StatusCode};
543
544 #[test]
547 fn shield_error_sql_injection_display() {
548 let err = ShieldError::SqlInjectionDetected("UNION attack".to_string());
549 assert_eq!(err.to_string(), "SQL injection detected: UNION attack");
550 }
551
552 #[test]
553 fn shield_error_ssrf_display() {
554 let err = ShieldError::SsrfBlocked("private IP".to_string());
555 assert_eq!(err.to_string(), "Request blocked by SSRF guard: private IP");
556 }
557
558 #[test]
559 fn shield_error_rate_limit_display() {
560 let err = ShieldError::RateLimitExceeded { retry_after: Some(60) };
561 assert_eq!(err.to_string(), "Rate limit exceeded");
562 }
563
564 #[test]
565 fn shield_error_threat_score_display() {
566 let err = ShieldError::ThreatScoreExceeded(0.85);
567 assert_eq!(err.to_string(), "Request blocked (threat score: 0.850)");
568 }
569
570 #[test]
571 fn shield_error_malicious_input_display() {
572 let err = ShieldError::MaliciousInput("script tag".to_string());
573 assert_eq!(err.to_string(), "Malicious input detected: script tag");
574 }
575
576 #[test]
577 fn shield_error_path_traversal_display() {
578 let err = ShieldError::PathTraversal("../../etc/passwd".to_string());
579 assert_eq!(err.to_string(), "Path traversal blocked: ../../etc/passwd");
580 }
581
582 #[test]
583 fn shield_error_invalid_connection_display() {
584 let err = ShieldError::InvalidConnectionString("bad string".to_string());
585 assert_eq!(err.to_string(), "Invalid connection configuration: bad string");
586 }
587
588 #[test]
589 fn shield_error_quarantine_display() {
590 let err = ShieldError::QuarantineFailed("oversized".to_string());
591 assert_eq!(err.to_string(), "Data quarantine failed: oversized");
592 }
593
594 #[test]
595 fn shield_error_email_violation_display() {
596 let err = ShieldError::EmailViolation("header injection".to_string());
597 assert_eq!(err.to_string(), "Email security violation: header injection");
598 }
599
600 #[test]
601 fn shield_error_email_bombing_display() {
602 let err = ShieldError::EmailBombing("test@example.com".to_string());
603 assert_eq!(err.to_string(), "Email rate limit exceeded for test@example.com");
604 }
605
606 #[test]
609 fn shield_error_sql_injection_returns_forbidden() {
610 let err = ShieldError::SqlInjectionDetected("test".to_string());
611 let resp = err.into_response();
612 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
613 }
614
615 #[test]
616 fn shield_error_ssrf_returns_forbidden() {
617 let err = ShieldError::SsrfBlocked("test".to_string());
618 let resp = err.into_response();
619 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
620 }
621
622 #[test]
623 fn shield_error_rate_limit_returns_429() {
624 let err = ShieldError::RateLimitExceeded { retry_after: None };
625 let resp = err.into_response();
626 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
627 }
628
629 #[test]
630 fn shield_error_threat_score_returns_forbidden() {
631 let err = ShieldError::ThreatScoreExceeded(0.9);
632 let resp = err.into_response();
633 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
634 }
635
636 #[test]
637 fn shield_error_malicious_input_returns_bad_request() {
638 let err = ShieldError::MaliciousInput("test".to_string());
639 let resp = err.into_response();
640 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
641 }
642
643 #[test]
644 fn shield_error_path_traversal_returns_forbidden() {
645 let err = ShieldError::PathTraversal("test".to_string());
646 let resp = err.into_response();
647 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
648 }
649
650 #[test]
651 fn shield_error_invalid_conn_returns_bad_request() {
652 let err = ShieldError::InvalidConnectionString("test".to_string());
653 let resp = err.into_response();
654 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
655 }
656
657 #[test]
658 fn shield_error_quarantine_returns_bad_request() {
659 let err = ShieldError::QuarantineFailed("test".to_string());
660 let resp = err.into_response();
661 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
662 }
663
664 #[test]
665 fn shield_error_email_violation_returns_bad_request() {
666 let err = ShieldError::EmailViolation("test".to_string());
667 let resp = err.into_response();
668 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
669 }
670
671 #[test]
672 fn shield_error_email_bombing_returns_429() {
673 let err = ShieldError::EmailBombing("test@test.com".to_string());
674 let resp = err.into_response();
675 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
676 }
677
678 #[test]
681 fn shield_new_with_default_config() {
682 let shield = Shield::new(ShieldConfig::default());
683 assert!(shield.config.block_threshold > 0.0);
684 assert!(shield.config.warn_threshold > 0.0);
685 }
686
687 #[test]
688 fn shield_html_escape() {
689 let escaped = Shield::escape_email_content("<script>alert('xss')</script>");
690 assert!(!escaped.contains("<script>"));
691 assert!(escaped.contains("<script>"));
692 }
693
694 #[test]
697 fn extract_ip_from_x_forwarded_for() {
698 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
699 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4, 5.6.7.8"));
700 let ip = extract_client_ip(&req);
701 assert_eq!(ip, "1.2.3.4");
702 }
703
704 #[test]
705 fn extract_ip_from_x_real_ip() {
706 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
707 req.headers_mut().insert("x-real-ip", HeaderValue::from_static("10.0.0.1"));
708 let ip = extract_client_ip(&req);
709 assert_eq!(ip, "10.0.0.1");
710 }
711
712 #[test]
713 fn extract_ip_xff_takes_precedence_over_xri() {
714 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
715 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.1.1.1"));
716 req.headers_mut().insert("x-real-ip", HeaderValue::from_static("2.2.2.2"));
717 let ip = extract_client_ip(&req);
718 assert_eq!(ip, "1.1.1.1");
719 }
720
721 #[test]
722 fn extract_ip_unknown_when_no_headers() {
723 let req = Request::builder().body(axum::body::Body::empty()).unwrap();
724 let ip = extract_client_ip(&req);
725 assert_eq!(ip, "unknown");
726 }
727
728 #[test]
729 fn extract_ip_xff_trims_whitespace() {
730 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
731 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static(" 3.3.3.3 , 4.4.4.4"));
732 let ip = extract_client_ip(&req);
733 assert_eq!(ip, "3.3.3.3");
734 }
735
736 #[test]
737 fn extract_ip_xri_trims_whitespace() {
738 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
739 req.headers_mut().insert("x-real-ip", HeaderValue::from_static(" 5.5.5.5 "));
740 let ip = extract_client_ip(&req);
741 assert_eq!(ip, "5.5.5.5");
742 }
743}