1pub mod auth;
28pub mod audit_chain;
29pub mod compliance_report;
30pub mod config;
31pub mod credential_vault;
32pub mod endpoint;
33pub mod ferrum_integration;
34pub mod fingerprint;
35pub mod journal;
36pub mod metrics;
37pub mod quarantine;
38pub mod rate_governor;
39pub mod sanitizer;
40pub mod siem_export;
41pub mod signature_updater;
42pub mod sql_firewall;
43pub mod sse_events;
44pub mod ssrf_guard;
45pub mod email_guard;
46pub mod threat_score;
47pub mod webhook;
48
49use std::sync::Arc;
50
51use axum::extract::Request;
52use axum::http::StatusCode;
53use axum::middleware::Next;
54use axum::response::{IntoResponse, Response};
55use axum::Extension;
56
57pub use audit_chain::{AuditChain, AuditEvent, SecurityEventType};
58pub use config::ShieldConfig;
59pub use email_guard::{EmailGuardConfig, EmailRateLimiter};
60pub use threat_score::{ThreatAction, ThreatAssessment};
61
62#[derive(Debug, thiserror::Error)]
64pub enum ShieldError {
65 #[error("SQL injection detected: {0}")]
66 SqlInjectionDetected(String),
67
68 #[error("Request blocked by SSRF guard: {0}")]
69 SsrfBlocked(String),
70
71 #[error("Rate limit exceeded")]
72 RateLimitExceeded { retry_after: Option<u64> },
73
74 #[error("Request blocked (threat score: {0:.3})")]
75 ThreatScoreExceeded(f64),
76
77 #[error("Malicious input detected: {0}")]
78 MaliciousInput(String),
79
80 #[error("Path traversal blocked: {0}")]
81 PathTraversal(String),
82
83 #[error("Invalid connection configuration: {0}")]
84 InvalidConnectionString(String),
85
86 #[error("Data quarantine failed: {0}")]
87 QuarantineFailed(String),
88
89 #[error("Email security violation: {0}")]
90 EmailViolation(String),
91
92 #[error("Email rate limit exceeded for {0}")]
93 EmailBombing(String),
94
95 #[error("Malware detected: {0}")]
96 MalwareDetected(String),
97
98 #[error("Endpoint protection error: {0}")]
99 EndpointError(String),
100
101 #[error("Quarantine vault error: {0}")]
102 QuarantineVaultError(String),
103}
104
105impl IntoResponse for ShieldError {
106 fn into_response(self) -> Response {
107 let (status, message) = match &self {
109 Self::SqlInjectionDetected(_) => {
110 (StatusCode::FORBIDDEN, "Request blocked by security policy")
111 }
112 Self::SsrfBlocked(_) => {
113 (StatusCode::FORBIDDEN, "Request blocked by security policy")
114 }
115 Self::RateLimitExceeded { .. } => {
116 (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded")
117 }
118 Self::ThreatScoreExceeded(_) => {
119 (StatusCode::FORBIDDEN, "Request blocked by security policy")
120 }
121 Self::MaliciousInput(_) => (StatusCode::BAD_REQUEST, "Invalid input detected"),
122 Self::PathTraversal(_) => {
123 (StatusCode::FORBIDDEN, "Request blocked by security policy")
124 }
125 Self::InvalidConnectionString(_) => {
126 (StatusCode::BAD_REQUEST, "Invalid connection configuration")
127 }
128 Self::QuarantineFailed(_) => {
129 (StatusCode::BAD_REQUEST, "Data validation failed")
130 }
131 Self::EmailViolation(_) => {
132 (StatusCode::BAD_REQUEST, "Email validation failed")
133 }
134 Self::EmailBombing(_) => {
135 (StatusCode::TOO_MANY_REQUESTS, "Email rate limit exceeded")
136 }
137 Self::MalwareDetected(_) => {
138 (StatusCode::FORBIDDEN, "Request blocked by security policy")
139 }
140 Self::EndpointError(_) => {
141 (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
142 }
143 Self::QuarantineVaultError(_) => {
144 (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
145 }
146 };
147 (status, message).into_response()
148 }
149}
150
151pub struct Shield {
153 pub config: ShieldConfig,
154 pub audit: Arc<AuditChain>,
155 pub rate_governor: Arc<rate_governor::RateGovernor>,
156 pub fingerprinter: Arc<fingerprint::Fingerprinter>,
157 pub email_limiter: Arc<EmailRateLimiter>,
158 pub endpoint: Option<Arc<endpoint::EndpointEngine>>,
159}
160
161impl Shield {
162 pub fn new(config: ShieldConfig) -> Self {
163 let audit = Arc::new(AuditChain::with_max_events(config.audit_max_events));
164 let rate_governor = Arc::new(rate_governor::RateGovernor::new(&config));
165 let fingerprinter = Arc::new(fingerprint::Fingerprinter::new());
166 let email_limiter = Arc::new(EmailRateLimiter::new(config.email.clone()));
167 Self {
168 config,
169 audit,
170 rate_governor,
171 fingerprinter,
172 email_limiter,
173 endpoint: None,
174 }
175 }
176
177 pub fn validate_sql(&self, sql: &str) -> Result<(), ShieldError> {
181 let analysis = sql_firewall::analyze_query(sql, &self.config.sql);
182 if analysis.allowed {
183 Ok(())
184 } else {
185 let reason = analysis
186 .violations
187 .iter()
188 .map(|v| format!("{:?}", v))
189 .collect::<Vec<_>>()
190 .join(", ");
191 self.audit.record(
192 SecurityEventType::SqlInjectionAttempt,
193 "api",
194 &reason,
195 analysis.risk_score,
196 );
197 Err(ShieldError::SqlInjectionDetected(reason))
198 }
199 }
200
201 pub fn validate_url(&self, url: &str) -> Result<(), ShieldError> {
203 ssrf_guard::validate_url(url, &self.config.ssrf).map_err(|reason| {
204 self.audit.record(
205 SecurityEventType::SsrfAttempt,
206 "api",
207 &reason,
208 0.9,
209 );
210 ShieldError::SsrfBlocked(reason)
211 })
212 }
213
214 pub fn validate_ip(&self, ip: &str) -> Result<(), ShieldError> {
216 ssrf_guard::validate_ip_str(ip, &self.config.ssrf).map_err(|reason| {
217 self.audit.record(
218 SecurityEventType::SsrfAttempt,
219 "api",
220 &reason,
221 0.9,
222 );
223 ShieldError::SsrfBlocked(reason)
224 })
225 }
226
227 pub fn validate_connection_string(&self, conn_str: &str) -> Result<String, ShieldError> {
229 sanitizer::validate_connection_string(conn_str).map_err(|reason| {
230 self.audit.record(
231 SecurityEventType::MaliciousPayload,
232 "api",
233 &reason,
234 0.8,
235 );
236 ShieldError::InvalidConnectionString(reason)
237 })
238 }
239
240 pub fn validate_file_path(&self, path: &str) -> Result<(), ShieldError> {
242 sanitizer::validate_file_path(path).map_err(|reason| {
243 self.audit.record(
244 SecurityEventType::PathTraversalAttempt,
245 "api",
246 &reason,
247 0.9,
248 );
249 ShieldError::PathTraversal(reason)
250 })
251 }
252
253 pub fn quarantine_csv(&self, content: &str) -> Result<(), ShieldError> {
255 let result = quarantine::validate_csv(content, &self.config.quarantine);
256 if result.passed {
257 Ok(())
258 } else {
259 let reason = result
260 .violations
261 .iter()
262 .map(|v| format!("{:?}", v))
263 .collect::<Vec<_>>()
264 .join(", ");
265 self.audit.record(
266 SecurityEventType::DataQuarantined,
267 "api",
268 &reason,
269 0.7,
270 );
271 Err(ShieldError::QuarantineFailed(reason))
272 }
273 }
274
275 pub fn quarantine_json(&self, json: &str) -> Result<(), ShieldError> {
277 quarantine::validate_json_response(json, self.config.quarantine.max_size_bytes)
278 .map_err(|reason| {
279 self.audit.record(
280 SecurityEventType::DataQuarantined,
281 "api",
282 &reason,
283 0.6,
284 );
285 ShieldError::MaliciousInput(reason)
286 })
287 }
288
289 pub fn validate_email_address(&self, addr: &str) -> Result<(), ShieldError> {
293 let violations = email_guard::validate_email_address(addr, &self.config.email);
294 if violations.is_empty() {
295 Ok(())
296 } else {
297 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
298 self.audit.record(
299 SecurityEventType::MaliciousPayload,
300 "email",
301 &reason,
302 0.7,
303 );
304 Err(ShieldError::EmailViolation(reason))
305 }
306 }
307
308 pub fn validate_email_header(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
310 let max_len = match field_name {
311 "subject" => self.config.email.max_subject_len,
312 _ => self.config.email.max_name_len,
313 };
314 let violations = email_guard::validate_header_field(field_name, value, max_len);
315 if violations.is_empty() {
316 Ok(())
317 } else {
318 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
319 self.audit.record(
320 SecurityEventType::MaliciousPayload,
321 "email",
322 &format!("header injection in {}: {}", field_name, reason),
323 0.8,
324 );
325 Err(ShieldError::EmailViolation(reason))
326 }
327 }
328
329 pub fn validate_email_content(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
331 let violations = email_guard::validate_template_content(
332 field_name, value, self.config.email.max_body_len,
333 );
334 if violations.is_empty() {
335 Ok(())
336 } else {
337 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
338 self.audit.record(
339 SecurityEventType::MaliciousPayload,
340 "email",
341 &format!("content injection in {}: {}", field_name, reason),
342 0.8,
343 );
344 Err(ShieldError::EmailViolation(reason))
345 }
346 }
347
348 pub fn check_email_rate(&self, recipient: &str) -> Result<(), ShieldError> {
350 if self.email_limiter.check_and_record(recipient) {
351 Ok(())
352 } else {
353 self.audit.record(
354 SecurityEventType::RateLimitHit,
355 "email",
356 &format!("email bombing attempt to {}", recipient),
357 0.9,
358 );
359 Err(ShieldError::EmailBombing(recipient.to_string()))
360 }
361 }
362
363 pub fn validate_outbound_email(
365 &self,
366 to: &[&str],
367 subject: &str,
368 body_fields: &[(&str, &str)],
369 ) -> Result<(), ShieldError> {
370 email_guard::validate_outbound_email(to, subject, body_fields, &self.config.email)
372 .map_err(|reason| {
373 self.audit.record(
374 SecurityEventType::MaliciousPayload,
375 "email",
376 &reason,
377 0.7,
378 );
379 ShieldError::EmailViolation(reason)
380 })?;
381
382 for addr in to {
384 self.check_email_rate(addr)?;
385 }
386
387 Ok(())
388 }
389
390 pub fn escape_email_content(value: &str) -> String {
392 email_guard::html_escape(value)
393 }
394}
395
396pub async fn shield_middleware(
407 shield: Option<Extension<Arc<Shield>>>,
408 request: Request,
409 next: Next,
410) -> Response {
411 let shield = match shield {
412 Some(Extension(s)) => s,
413 None => return next.run(request).await,
414 };
415
416 let client_ip = extract_client_ip(&request);
417
418 let rate_result = shield.rate_governor.check(&client_ip);
420 if !rate_result.allowed {
421 shield.audit.record(
422 SecurityEventType::RateLimitHit,
423 &client_ip,
424 &format!(
425 "escalation={:?}, violations={}",
426 rate_result.escalation, rate_result.violations
427 ),
428 0.8,
429 );
430 let mut resp = (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
431 if let Some(retry_after) = rate_result.retry_after {
432 resp.headers_mut().insert(
433 "Retry-After",
434 retry_after.to_string().parse().unwrap(),
435 );
436 }
437 return resp;
438 }
439
440 let fp = shield.fingerprinter.analyze(request.headers());
442
443 let behavioral = shield.fingerprinter.behavioral_score(&client_ip);
445
446 let recent_violations = {
448 let since = chrono::Utc::now() - chrono::Duration::minutes(5);
449 shield
450 .audit
451 .count_since(&SecurityEventType::RequestBlocked, since)
452 > 0
453 };
454
455 let assessment = threat_score::assess(
457 &fp,
458 &rate_result,
459 behavioral,
460 recent_violations,
461 shield.config.warn_threshold,
462 shield.config.block_threshold,
463 );
464
465 match assessment.action {
466 ThreatAction::Block => {
467 shield.audit.record(
468 SecurityEventType::RequestBlocked,
469 &client_ip,
470 &format!(
471 "score={:.3}, fingerprint={:.3}, rate={:.3}, behavioral={:.3}",
472 assessment.score,
473 assessment.signals.fingerprint_anomaly,
474 assessment.signals.rate_pressure,
475 assessment.signals.behavioral_anomaly,
476 ),
477 assessment.score,
478 );
479 return (StatusCode::FORBIDDEN, "Request blocked by security policy").into_response();
480 }
481 ThreatAction::Warn => {
482 tracing::warn!(
483 ip = %client_ip,
484 score = assessment.score,
485 "NexusShield: elevated threat score"
486 );
487 shield.audit.record(
488 SecurityEventType::RequestAllowed,
489 &client_ip,
490 &format!("WARN: score={:.3}", assessment.score),
491 assessment.score,
492 );
493 }
494 ThreatAction::Allow => {
495 shield.fingerprinter.record_request(&client_ip);
497 }
498 }
499
500 let response = next.run(request).await;
502
503 if response.status().is_client_error() || response.status().is_server_error() {
505 shield.fingerprinter.record_error(&client_ip);
506 }
507
508 response
509}
510
511fn extract_client_ip(req: &Request) -> String {
513 if let Some(xff) = req.headers().get("x-forwarded-for") {
515 if let Ok(value) = xff.to_str() {
516 if let Some(first_ip) = value.split(',').next() {
517 let ip = first_ip.trim();
518 if !ip.is_empty() {
519 return ip.to_string();
520 }
521 }
522 }
523 }
524
525 if let Some(xri) = req.headers().get("x-real-ip") {
527 if let Ok(value) = xri.to_str() {
528 let ip = value.trim();
529 if !ip.is_empty() {
530 return ip.to_string();
531 }
532 }
533 }
534
535 "unknown".to_string()
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use axum::http::{HeaderValue, StatusCode};
542
543 #[test]
546 fn shield_error_sql_injection_display() {
547 let err = ShieldError::SqlInjectionDetected("UNION attack".to_string());
548 assert_eq!(err.to_string(), "SQL injection detected: UNION attack");
549 }
550
551 #[test]
552 fn shield_error_ssrf_display() {
553 let err = ShieldError::SsrfBlocked("private IP".to_string());
554 assert_eq!(err.to_string(), "Request blocked by SSRF guard: private IP");
555 }
556
557 #[test]
558 fn shield_error_rate_limit_display() {
559 let err = ShieldError::RateLimitExceeded { retry_after: Some(60) };
560 assert_eq!(err.to_string(), "Rate limit exceeded");
561 }
562
563 #[test]
564 fn shield_error_threat_score_display() {
565 let err = ShieldError::ThreatScoreExceeded(0.85);
566 assert_eq!(err.to_string(), "Request blocked (threat score: 0.850)");
567 }
568
569 #[test]
570 fn shield_error_malicious_input_display() {
571 let err = ShieldError::MaliciousInput("script tag".to_string());
572 assert_eq!(err.to_string(), "Malicious input detected: script tag");
573 }
574
575 #[test]
576 fn shield_error_path_traversal_display() {
577 let err = ShieldError::PathTraversal("../../etc/passwd".to_string());
578 assert_eq!(err.to_string(), "Path traversal blocked: ../../etc/passwd");
579 }
580
581 #[test]
582 fn shield_error_invalid_connection_display() {
583 let err = ShieldError::InvalidConnectionString("bad string".to_string());
584 assert_eq!(err.to_string(), "Invalid connection configuration: bad string");
585 }
586
587 #[test]
588 fn shield_error_quarantine_display() {
589 let err = ShieldError::QuarantineFailed("oversized".to_string());
590 assert_eq!(err.to_string(), "Data quarantine failed: oversized");
591 }
592
593 #[test]
594 fn shield_error_email_violation_display() {
595 let err = ShieldError::EmailViolation("header injection".to_string());
596 assert_eq!(err.to_string(), "Email security violation: header injection");
597 }
598
599 #[test]
600 fn shield_error_email_bombing_display() {
601 let err = ShieldError::EmailBombing("test@example.com".to_string());
602 assert_eq!(err.to_string(), "Email rate limit exceeded for test@example.com");
603 }
604
605 #[test]
608 fn shield_error_sql_injection_returns_forbidden() {
609 let err = ShieldError::SqlInjectionDetected("test".to_string());
610 let resp = err.into_response();
611 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
612 }
613
614 #[test]
615 fn shield_error_ssrf_returns_forbidden() {
616 let err = ShieldError::SsrfBlocked("test".to_string());
617 let resp = err.into_response();
618 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
619 }
620
621 #[test]
622 fn shield_error_rate_limit_returns_429() {
623 let err = ShieldError::RateLimitExceeded { retry_after: None };
624 let resp = err.into_response();
625 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
626 }
627
628 #[test]
629 fn shield_error_threat_score_returns_forbidden() {
630 let err = ShieldError::ThreatScoreExceeded(0.9);
631 let resp = err.into_response();
632 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
633 }
634
635 #[test]
636 fn shield_error_malicious_input_returns_bad_request() {
637 let err = ShieldError::MaliciousInput("test".to_string());
638 let resp = err.into_response();
639 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
640 }
641
642 #[test]
643 fn shield_error_path_traversal_returns_forbidden() {
644 let err = ShieldError::PathTraversal("test".to_string());
645 let resp = err.into_response();
646 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
647 }
648
649 #[test]
650 fn shield_error_invalid_conn_returns_bad_request() {
651 let err = ShieldError::InvalidConnectionString("test".to_string());
652 let resp = err.into_response();
653 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
654 }
655
656 #[test]
657 fn shield_error_quarantine_returns_bad_request() {
658 let err = ShieldError::QuarantineFailed("test".to_string());
659 let resp = err.into_response();
660 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
661 }
662
663 #[test]
664 fn shield_error_email_violation_returns_bad_request() {
665 let err = ShieldError::EmailViolation("test".to_string());
666 let resp = err.into_response();
667 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
668 }
669
670 #[test]
671 fn shield_error_email_bombing_returns_429() {
672 let err = ShieldError::EmailBombing("test@test.com".to_string());
673 let resp = err.into_response();
674 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
675 }
676
677 #[test]
680 fn shield_new_with_default_config() {
681 let shield = Shield::new(ShieldConfig::default());
682 assert!(shield.config.block_threshold > 0.0);
683 assert!(shield.config.warn_threshold > 0.0);
684 }
685
686 #[test]
687 fn shield_html_escape() {
688 let escaped = Shield::escape_email_content("<script>alert('xss')</script>");
689 assert!(!escaped.contains("<script>"));
690 assert!(escaped.contains("<script>"));
691 }
692
693 #[test]
696 fn extract_ip_from_x_forwarded_for() {
697 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
698 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4, 5.6.7.8"));
699 let ip = extract_client_ip(&req);
700 assert_eq!(ip, "1.2.3.4");
701 }
702
703 #[test]
704 fn extract_ip_from_x_real_ip() {
705 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
706 req.headers_mut().insert("x-real-ip", HeaderValue::from_static("10.0.0.1"));
707 let ip = extract_client_ip(&req);
708 assert_eq!(ip, "10.0.0.1");
709 }
710
711 #[test]
712 fn extract_ip_xff_takes_precedence_over_xri() {
713 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
714 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.1.1.1"));
715 req.headers_mut().insert("x-real-ip", HeaderValue::from_static("2.2.2.2"));
716 let ip = extract_client_ip(&req);
717 assert_eq!(ip, "1.1.1.1");
718 }
719
720 #[test]
721 fn extract_ip_unknown_when_no_headers() {
722 let req = Request::builder().body(axum::body::Body::empty()).unwrap();
723 let ip = extract_client_ip(&req);
724 assert_eq!(ip, "unknown");
725 }
726
727 #[test]
728 fn extract_ip_xff_trims_whitespace() {
729 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
730 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static(" 3.3.3.3 , 4.4.4.4"));
731 let ip = extract_client_ip(&req);
732 assert_eq!(ip, "3.3.3.3");
733 }
734
735 #[test]
736 fn extract_ip_xri_trims_whitespace() {
737 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
738 req.headers_mut().insert("x-real-ip", HeaderValue::from_static(" 5.5.5.5 "));
739 let ip = extract_client_ip(&req);
740 assert_eq!(ip, "5.5.5.5");
741 }
742}