1pub mod audit_chain;
28pub mod compliance_report;
29pub mod config;
30pub mod credential_vault;
31pub mod endpoint;
32pub mod fingerprint;
33pub mod journal;
34pub mod quarantine;
35pub mod rate_governor;
36pub mod sanitizer;
37pub mod siem_export;
38pub mod sql_firewall;
39pub mod sse_events;
40pub mod ssrf_guard;
41pub mod email_guard;
42pub mod threat_score;
43
44use std::sync::Arc;
45
46use axum::extract::Request;
47use axum::http::StatusCode;
48use axum::middleware::Next;
49use axum::response::{IntoResponse, Response};
50use axum::Extension;
51
52pub use audit_chain::{AuditChain, AuditEvent, SecurityEventType};
53pub use config::ShieldConfig;
54pub use email_guard::{EmailGuardConfig, EmailRateLimiter};
55pub use threat_score::{ThreatAction, ThreatAssessment};
56
57#[derive(Debug, thiserror::Error)]
59pub enum ShieldError {
60 #[error("SQL injection detected: {0}")]
61 SqlInjectionDetected(String),
62
63 #[error("Request blocked by SSRF guard: {0}")]
64 SsrfBlocked(String),
65
66 #[error("Rate limit exceeded")]
67 RateLimitExceeded { retry_after: Option<u64> },
68
69 #[error("Request blocked (threat score: {0:.3})")]
70 ThreatScoreExceeded(f64),
71
72 #[error("Malicious input detected: {0}")]
73 MaliciousInput(String),
74
75 #[error("Path traversal blocked: {0}")]
76 PathTraversal(String),
77
78 #[error("Invalid connection configuration: {0}")]
79 InvalidConnectionString(String),
80
81 #[error("Data quarantine failed: {0}")]
82 QuarantineFailed(String),
83
84 #[error("Email security violation: {0}")]
85 EmailViolation(String),
86
87 #[error("Email rate limit exceeded for {0}")]
88 EmailBombing(String),
89
90 #[error("Malware detected: {0}")]
91 MalwareDetected(String),
92
93 #[error("Endpoint protection error: {0}")]
94 EndpointError(String),
95
96 #[error("Quarantine vault error: {0}")]
97 QuarantineVaultError(String),
98}
99
100impl IntoResponse for ShieldError {
101 fn into_response(self) -> Response {
102 let (status, message) = match &self {
104 Self::SqlInjectionDetected(_) => {
105 (StatusCode::FORBIDDEN, "Request blocked by security policy")
106 }
107 Self::SsrfBlocked(_) => {
108 (StatusCode::FORBIDDEN, "Request blocked by security policy")
109 }
110 Self::RateLimitExceeded { .. } => {
111 (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded")
112 }
113 Self::ThreatScoreExceeded(_) => {
114 (StatusCode::FORBIDDEN, "Request blocked by security policy")
115 }
116 Self::MaliciousInput(_) => (StatusCode::BAD_REQUEST, "Invalid input detected"),
117 Self::PathTraversal(_) => {
118 (StatusCode::FORBIDDEN, "Request blocked by security policy")
119 }
120 Self::InvalidConnectionString(_) => {
121 (StatusCode::BAD_REQUEST, "Invalid connection configuration")
122 }
123 Self::QuarantineFailed(_) => {
124 (StatusCode::BAD_REQUEST, "Data validation failed")
125 }
126 Self::EmailViolation(_) => {
127 (StatusCode::BAD_REQUEST, "Email validation failed")
128 }
129 Self::EmailBombing(_) => {
130 (StatusCode::TOO_MANY_REQUESTS, "Email rate limit exceeded")
131 }
132 Self::MalwareDetected(_) => {
133 (StatusCode::FORBIDDEN, "Request blocked by security policy")
134 }
135 Self::EndpointError(_) => {
136 (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
137 }
138 Self::QuarantineVaultError(_) => {
139 (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
140 }
141 };
142 (status, message).into_response()
143 }
144}
145
146pub struct Shield {
148 pub config: ShieldConfig,
149 pub audit: Arc<AuditChain>,
150 pub rate_governor: Arc<rate_governor::RateGovernor>,
151 pub fingerprinter: Arc<fingerprint::Fingerprinter>,
152 pub email_limiter: Arc<EmailRateLimiter>,
153 pub endpoint: Option<Arc<endpoint::EndpointEngine>>,
154}
155
156impl Shield {
157 pub fn new(config: ShieldConfig) -> Self {
158 let audit = Arc::new(AuditChain::with_max_events(config.audit_max_events));
159 let rate_governor = Arc::new(rate_governor::RateGovernor::new(&config));
160 let fingerprinter = Arc::new(fingerprint::Fingerprinter::new());
161 let email_limiter = Arc::new(EmailRateLimiter::new(config.email.clone()));
162 Self {
163 config,
164 audit,
165 rate_governor,
166 fingerprinter,
167 email_limiter,
168 endpoint: None,
169 }
170 }
171
172 pub fn validate_sql(&self, sql: &str) -> Result<(), ShieldError> {
176 let analysis = sql_firewall::analyze_query(sql, &self.config.sql);
177 if analysis.allowed {
178 Ok(())
179 } else {
180 let reason = analysis
181 .violations
182 .iter()
183 .map(|v| format!("{:?}", v))
184 .collect::<Vec<_>>()
185 .join(", ");
186 self.audit.record(
187 SecurityEventType::SqlInjectionAttempt,
188 "api",
189 &reason,
190 analysis.risk_score,
191 );
192 Err(ShieldError::SqlInjectionDetected(reason))
193 }
194 }
195
196 pub fn validate_url(&self, url: &str) -> Result<(), ShieldError> {
198 ssrf_guard::validate_url(url, &self.config.ssrf).map_err(|reason| {
199 self.audit.record(
200 SecurityEventType::SsrfAttempt,
201 "api",
202 &reason,
203 0.9,
204 );
205 ShieldError::SsrfBlocked(reason)
206 })
207 }
208
209 pub fn validate_ip(&self, ip: &str) -> Result<(), ShieldError> {
211 ssrf_guard::validate_ip_str(ip, &self.config.ssrf).map_err(|reason| {
212 self.audit.record(
213 SecurityEventType::SsrfAttempt,
214 "api",
215 &reason,
216 0.9,
217 );
218 ShieldError::SsrfBlocked(reason)
219 })
220 }
221
222 pub fn validate_connection_string(&self, conn_str: &str) -> Result<String, ShieldError> {
224 sanitizer::validate_connection_string(conn_str).map_err(|reason| {
225 self.audit.record(
226 SecurityEventType::MaliciousPayload,
227 "api",
228 &reason,
229 0.8,
230 );
231 ShieldError::InvalidConnectionString(reason)
232 })
233 }
234
235 pub fn validate_file_path(&self, path: &str) -> Result<(), ShieldError> {
237 sanitizer::validate_file_path(path).map_err(|reason| {
238 self.audit.record(
239 SecurityEventType::PathTraversalAttempt,
240 "api",
241 &reason,
242 0.9,
243 );
244 ShieldError::PathTraversal(reason)
245 })
246 }
247
248 pub fn quarantine_csv(&self, content: &str) -> Result<(), ShieldError> {
250 let result = quarantine::validate_csv(content, &self.config.quarantine);
251 if result.passed {
252 Ok(())
253 } else {
254 let reason = result
255 .violations
256 .iter()
257 .map(|v| format!("{:?}", v))
258 .collect::<Vec<_>>()
259 .join(", ");
260 self.audit.record(
261 SecurityEventType::DataQuarantined,
262 "api",
263 &reason,
264 0.7,
265 );
266 Err(ShieldError::QuarantineFailed(reason))
267 }
268 }
269
270 pub fn quarantine_json(&self, json: &str) -> Result<(), ShieldError> {
272 quarantine::validate_json_response(json, self.config.quarantine.max_size_bytes)
273 .map_err(|reason| {
274 self.audit.record(
275 SecurityEventType::DataQuarantined,
276 "api",
277 &reason,
278 0.6,
279 );
280 ShieldError::MaliciousInput(reason)
281 })
282 }
283
284 pub fn validate_email_address(&self, addr: &str) -> Result<(), ShieldError> {
288 let violations = email_guard::validate_email_address(addr, &self.config.email);
289 if violations.is_empty() {
290 Ok(())
291 } else {
292 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
293 self.audit.record(
294 SecurityEventType::MaliciousPayload,
295 "email",
296 &reason,
297 0.7,
298 );
299 Err(ShieldError::EmailViolation(reason))
300 }
301 }
302
303 pub fn validate_email_header(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
305 let max_len = match field_name {
306 "subject" => self.config.email.max_subject_len,
307 _ => self.config.email.max_name_len,
308 };
309 let violations = email_guard::validate_header_field(field_name, value, max_len);
310 if violations.is_empty() {
311 Ok(())
312 } else {
313 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
314 self.audit.record(
315 SecurityEventType::MaliciousPayload,
316 "email",
317 &format!("header injection in {}: {}", field_name, reason),
318 0.8,
319 );
320 Err(ShieldError::EmailViolation(reason))
321 }
322 }
323
324 pub fn validate_email_content(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
326 let violations = email_guard::validate_template_content(
327 field_name, value, self.config.email.max_body_len,
328 );
329 if violations.is_empty() {
330 Ok(())
331 } else {
332 let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
333 self.audit.record(
334 SecurityEventType::MaliciousPayload,
335 "email",
336 &format!("content injection in {}: {}", field_name, reason),
337 0.8,
338 );
339 Err(ShieldError::EmailViolation(reason))
340 }
341 }
342
343 pub fn check_email_rate(&self, recipient: &str) -> Result<(), ShieldError> {
345 if self.email_limiter.check_and_record(recipient) {
346 Ok(())
347 } else {
348 self.audit.record(
349 SecurityEventType::RateLimitHit,
350 "email",
351 &format!("email bombing attempt to {}", recipient),
352 0.9,
353 );
354 Err(ShieldError::EmailBombing(recipient.to_string()))
355 }
356 }
357
358 pub fn validate_outbound_email(
360 &self,
361 to: &[&str],
362 subject: &str,
363 body_fields: &[(&str, &str)],
364 ) -> Result<(), ShieldError> {
365 email_guard::validate_outbound_email(to, subject, body_fields, &self.config.email)
367 .map_err(|reason| {
368 self.audit.record(
369 SecurityEventType::MaliciousPayload,
370 "email",
371 &reason,
372 0.7,
373 );
374 ShieldError::EmailViolation(reason)
375 })?;
376
377 for addr in to {
379 self.check_email_rate(addr)?;
380 }
381
382 Ok(())
383 }
384
385 pub fn escape_email_content(value: &str) -> String {
387 email_guard::html_escape(value)
388 }
389}
390
391pub async fn shield_middleware(
402 shield: Option<Extension<Arc<Shield>>>,
403 request: Request,
404 next: Next,
405) -> Response {
406 let shield = match shield {
407 Some(Extension(s)) => s,
408 None => return next.run(request).await,
409 };
410
411 let client_ip = extract_client_ip(&request);
412
413 let rate_result = shield.rate_governor.check(&client_ip);
415 if !rate_result.allowed {
416 shield.audit.record(
417 SecurityEventType::RateLimitHit,
418 &client_ip,
419 &format!(
420 "escalation={:?}, violations={}",
421 rate_result.escalation, rate_result.violations
422 ),
423 0.8,
424 );
425 let mut resp = (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
426 if let Some(retry_after) = rate_result.retry_after {
427 resp.headers_mut().insert(
428 "Retry-After",
429 retry_after.to_string().parse().unwrap(),
430 );
431 }
432 return resp;
433 }
434
435 let fp = shield.fingerprinter.analyze(request.headers());
437
438 let behavioral = shield.fingerprinter.behavioral_score(&client_ip);
440
441 let recent_violations = {
443 let since = chrono::Utc::now() - chrono::Duration::minutes(5);
444 shield
445 .audit
446 .count_since(&SecurityEventType::RequestBlocked, since)
447 > 0
448 };
449
450 let assessment = threat_score::assess(
452 &fp,
453 &rate_result,
454 behavioral,
455 recent_violations,
456 shield.config.warn_threshold,
457 shield.config.block_threshold,
458 );
459
460 match assessment.action {
461 ThreatAction::Block => {
462 shield.audit.record(
463 SecurityEventType::RequestBlocked,
464 &client_ip,
465 &format!(
466 "score={:.3}, fingerprint={:.3}, rate={:.3}, behavioral={:.3}",
467 assessment.score,
468 assessment.signals.fingerprint_anomaly,
469 assessment.signals.rate_pressure,
470 assessment.signals.behavioral_anomaly,
471 ),
472 assessment.score,
473 );
474 return (StatusCode::FORBIDDEN, "Request blocked by security policy").into_response();
475 }
476 ThreatAction::Warn => {
477 tracing::warn!(
478 ip = %client_ip,
479 score = assessment.score,
480 "NexusShield: elevated threat score"
481 );
482 shield.audit.record(
483 SecurityEventType::RequestAllowed,
484 &client_ip,
485 &format!("WARN: score={:.3}", assessment.score),
486 assessment.score,
487 );
488 }
489 ThreatAction::Allow => {
490 shield.fingerprinter.record_request(&client_ip);
492 }
493 }
494
495 let response = next.run(request).await;
497
498 if response.status().is_client_error() || response.status().is_server_error() {
500 shield.fingerprinter.record_error(&client_ip);
501 }
502
503 response
504}
505
506fn extract_client_ip(req: &Request) -> String {
508 if let Some(xff) = req.headers().get("x-forwarded-for") {
510 if let Ok(value) = xff.to_str() {
511 if let Some(first_ip) = value.split(',').next() {
512 let ip = first_ip.trim();
513 if !ip.is_empty() {
514 return ip.to_string();
515 }
516 }
517 }
518 }
519
520 if let Some(xri) = req.headers().get("x-real-ip") {
522 if let Ok(value) = xri.to_str() {
523 let ip = value.trim();
524 if !ip.is_empty() {
525 return ip.to_string();
526 }
527 }
528 }
529
530 "unknown".to_string()
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use axum::http::{HeaderValue, StatusCode};
537
538 #[test]
541 fn shield_error_sql_injection_display() {
542 let err = ShieldError::SqlInjectionDetected("UNION attack".to_string());
543 assert_eq!(err.to_string(), "SQL injection detected: UNION attack");
544 }
545
546 #[test]
547 fn shield_error_ssrf_display() {
548 let err = ShieldError::SsrfBlocked("private IP".to_string());
549 assert_eq!(err.to_string(), "Request blocked by SSRF guard: private IP");
550 }
551
552 #[test]
553 fn shield_error_rate_limit_display() {
554 let err = ShieldError::RateLimitExceeded { retry_after: Some(60) };
555 assert_eq!(err.to_string(), "Rate limit exceeded");
556 }
557
558 #[test]
559 fn shield_error_threat_score_display() {
560 let err = ShieldError::ThreatScoreExceeded(0.85);
561 assert_eq!(err.to_string(), "Request blocked (threat score: 0.850)");
562 }
563
564 #[test]
565 fn shield_error_malicious_input_display() {
566 let err = ShieldError::MaliciousInput("script tag".to_string());
567 assert_eq!(err.to_string(), "Malicious input detected: script tag");
568 }
569
570 #[test]
571 fn shield_error_path_traversal_display() {
572 let err = ShieldError::PathTraversal("../../etc/passwd".to_string());
573 assert_eq!(err.to_string(), "Path traversal blocked: ../../etc/passwd");
574 }
575
576 #[test]
577 fn shield_error_invalid_connection_display() {
578 let err = ShieldError::InvalidConnectionString("bad string".to_string());
579 assert_eq!(err.to_string(), "Invalid connection configuration: bad string");
580 }
581
582 #[test]
583 fn shield_error_quarantine_display() {
584 let err = ShieldError::QuarantineFailed("oversized".to_string());
585 assert_eq!(err.to_string(), "Data quarantine failed: oversized");
586 }
587
588 #[test]
589 fn shield_error_email_violation_display() {
590 let err = ShieldError::EmailViolation("header injection".to_string());
591 assert_eq!(err.to_string(), "Email security violation: header injection");
592 }
593
594 #[test]
595 fn shield_error_email_bombing_display() {
596 let err = ShieldError::EmailBombing("test@example.com".to_string());
597 assert_eq!(err.to_string(), "Email rate limit exceeded for test@example.com");
598 }
599
600 #[test]
603 fn shield_error_sql_injection_returns_forbidden() {
604 let err = ShieldError::SqlInjectionDetected("test".to_string());
605 let resp = err.into_response();
606 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
607 }
608
609 #[test]
610 fn shield_error_ssrf_returns_forbidden() {
611 let err = ShieldError::SsrfBlocked("test".to_string());
612 let resp = err.into_response();
613 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
614 }
615
616 #[test]
617 fn shield_error_rate_limit_returns_429() {
618 let err = ShieldError::RateLimitExceeded { retry_after: None };
619 let resp = err.into_response();
620 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
621 }
622
623 #[test]
624 fn shield_error_threat_score_returns_forbidden() {
625 let err = ShieldError::ThreatScoreExceeded(0.9);
626 let resp = err.into_response();
627 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
628 }
629
630 #[test]
631 fn shield_error_malicious_input_returns_bad_request() {
632 let err = ShieldError::MaliciousInput("test".to_string());
633 let resp = err.into_response();
634 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
635 }
636
637 #[test]
638 fn shield_error_path_traversal_returns_forbidden() {
639 let err = ShieldError::PathTraversal("test".to_string());
640 let resp = err.into_response();
641 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
642 }
643
644 #[test]
645 fn shield_error_invalid_conn_returns_bad_request() {
646 let err = ShieldError::InvalidConnectionString("test".to_string());
647 let resp = err.into_response();
648 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
649 }
650
651 #[test]
652 fn shield_error_quarantine_returns_bad_request() {
653 let err = ShieldError::QuarantineFailed("test".to_string());
654 let resp = err.into_response();
655 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
656 }
657
658 #[test]
659 fn shield_error_email_violation_returns_bad_request() {
660 let err = ShieldError::EmailViolation("test".to_string());
661 let resp = err.into_response();
662 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
663 }
664
665 #[test]
666 fn shield_error_email_bombing_returns_429() {
667 let err = ShieldError::EmailBombing("test@test.com".to_string());
668 let resp = err.into_response();
669 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
670 }
671
672 #[test]
675 fn shield_new_with_default_config() {
676 let shield = Shield::new(ShieldConfig::default());
677 assert!(shield.config.block_threshold > 0.0);
678 assert!(shield.config.warn_threshold > 0.0);
679 }
680
681 #[test]
682 fn shield_html_escape() {
683 let escaped = Shield::escape_email_content("<script>alert('xss')</script>");
684 assert!(!escaped.contains("<script>"));
685 assert!(escaped.contains("<script>"));
686 }
687
688 #[test]
691 fn extract_ip_from_x_forwarded_for() {
692 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
693 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4, 5.6.7.8"));
694 let ip = extract_client_ip(&req);
695 assert_eq!(ip, "1.2.3.4");
696 }
697
698 #[test]
699 fn extract_ip_from_x_real_ip() {
700 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
701 req.headers_mut().insert("x-real-ip", HeaderValue::from_static("10.0.0.1"));
702 let ip = extract_client_ip(&req);
703 assert_eq!(ip, "10.0.0.1");
704 }
705
706 #[test]
707 fn extract_ip_xff_takes_precedence_over_xri() {
708 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
709 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.1.1.1"));
710 req.headers_mut().insert("x-real-ip", HeaderValue::from_static("2.2.2.2"));
711 let ip = extract_client_ip(&req);
712 assert_eq!(ip, "1.1.1.1");
713 }
714
715 #[test]
716 fn extract_ip_unknown_when_no_headers() {
717 let req = Request::builder().body(axum::body::Body::empty()).unwrap();
718 let ip = extract_client_ip(&req);
719 assert_eq!(ip, "unknown");
720 }
721
722 #[test]
723 fn extract_ip_xff_trims_whitespace() {
724 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
725 req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static(" 3.3.3.3 , 4.4.4.4"));
726 let ip = extract_client_ip(&req);
727 assert_eq!(ip, "3.3.3.3");
728 }
729
730 #[test]
731 fn extract_ip_xri_trims_whitespace() {
732 let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
733 req.headers_mut().insert("x-real-ip", HeaderValue::from_static(" 5.5.5.5 "));
734 let ip = extract_client_ip(&req);
735 assert_eq!(ip, "5.5.5.5");
736 }
737}