Skip to main content

nexus_shield/
lib.rs

1// ============================================================================
2// File: lib.rs
3// Description: Adaptive zero-trust security engine with layered threat defense
4// Author: Andrew Jewell Sr. - AutomataNexus
5// Updated: March 24, 2026
6//
7// DISCLAIMER: This software is provided "as is", without warranty of any kind,
8// express or implied. Use at your own risk. AutomataNexus and the author assume
9// no liability for any damages arising from the use of this software.
10// ============================================================================
11//! # NexusShield
12//!
13//! Adaptive zero-trust security engine for the Nexus platform.
14//!
15//! Protects services from SQL injection, SSRF, command injection, path traversal,
16//! and automated attacks through a layered defense architecture:
17//!
18//! - **SQL Firewall** — AST-level SQL parsing (not regex) to detect injection
19//! - **SSRF Guard** — IP/DNS validation blocking internal network probing
20//! - **Rate Governor** — Adaptive rate limiting with behavioral escalation
21//! - **Request Fingerprinting** — Bot detection via header/behavioral analysis
22//! - **Data Quarantine** — Validates imported data for malicious payloads
23//! - **Audit Chain** — Hash-chained tamper-evident security event log
24//! - **Input Sanitizer** — Connection string and path traversal prevention
25//! - **Threat Scoring** — Multi-signal adaptive threat assessment (0.0–1.0)
26
27pub mod auth;
28pub mod audit_chain;
29pub mod compliance_report;
30pub mod config;
31pub mod credential_vault;
32pub mod endpoint;
33pub mod ferrum_integration;
34pub mod fingerprint;
35pub mod journal;
36pub mod metrics;
37pub mod nexuspulse_integration;
38pub mod quarantine;
39pub mod rate_governor;
40pub mod sanitizer;
41pub mod siem_export;
42pub mod signature_updater;
43pub mod sql_firewall;
44pub mod sse_events;
45pub mod ssrf_guard;
46pub mod email_guard;
47pub mod threat_score;
48pub mod webhook;
49
50use std::sync::Arc;
51
52use axum::extract::Request;
53use axum::http::StatusCode;
54use axum::middleware::Next;
55use axum::response::{IntoResponse, Response};
56use axum::Extension;
57
58pub use audit_chain::{AuditChain, AuditEvent, SecurityEventType};
59pub use config::ShieldConfig;
60pub use email_guard::{EmailGuardConfig, EmailRateLimiter};
61pub use threat_score::{ThreatAction, ThreatAssessment};
62
63/// Errors raised by the NexusShield security engine.
64#[derive(Debug, thiserror::Error)]
65pub enum ShieldError {
66    #[error("SQL injection detected: {0}")]
67    SqlInjectionDetected(String),
68
69    #[error("Request blocked by SSRF guard: {0}")]
70    SsrfBlocked(String),
71
72    #[error("Rate limit exceeded")]
73    RateLimitExceeded { retry_after: Option<u64> },
74
75    #[error("Request blocked (threat score: {0:.3})")]
76    ThreatScoreExceeded(f64),
77
78    #[error("Malicious input detected: {0}")]
79    MaliciousInput(String),
80
81    #[error("Path traversal blocked: {0}")]
82    PathTraversal(String),
83
84    #[error("Invalid connection configuration: {0}")]
85    InvalidConnectionString(String),
86
87    #[error("Data quarantine failed: {0}")]
88    QuarantineFailed(String),
89
90    #[error("Email security violation: {0}")]
91    EmailViolation(String),
92
93    #[error("Email rate limit exceeded for {0}")]
94    EmailBombing(String),
95
96    #[error("Malware detected: {0}")]
97    MalwareDetected(String),
98
99    #[error("Endpoint protection error: {0}")]
100    EndpointError(String),
101
102    #[error("Quarantine vault error: {0}")]
103    QuarantineVaultError(String),
104}
105
106impl IntoResponse for ShieldError {
107    fn into_response(self) -> Response {
108        // Deliberately vague error messages to avoid leaking security internals
109        let (status, message) = match &self {
110            Self::SqlInjectionDetected(_) => {
111                (StatusCode::FORBIDDEN, "Request blocked by security policy")
112            }
113            Self::SsrfBlocked(_) => {
114                (StatusCode::FORBIDDEN, "Request blocked by security policy")
115            }
116            Self::RateLimitExceeded { .. } => {
117                (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded")
118            }
119            Self::ThreatScoreExceeded(_) => {
120                (StatusCode::FORBIDDEN, "Request blocked by security policy")
121            }
122            Self::MaliciousInput(_) => (StatusCode::BAD_REQUEST, "Invalid input detected"),
123            Self::PathTraversal(_) => {
124                (StatusCode::FORBIDDEN, "Request blocked by security policy")
125            }
126            Self::InvalidConnectionString(_) => {
127                (StatusCode::BAD_REQUEST, "Invalid connection configuration")
128            }
129            Self::QuarantineFailed(_) => {
130                (StatusCode::BAD_REQUEST, "Data validation failed")
131            }
132            Self::EmailViolation(_) => {
133                (StatusCode::BAD_REQUEST, "Email validation failed")
134            }
135            Self::EmailBombing(_) => {
136                (StatusCode::TOO_MANY_REQUESTS, "Email rate limit exceeded")
137            }
138            Self::MalwareDetected(_) => {
139                (StatusCode::FORBIDDEN, "Request blocked by security policy")
140            }
141            Self::EndpointError(_) => {
142                (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
143            }
144            Self::QuarantineVaultError(_) => {
145                (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
146            }
147        };
148        (status, message).into_response()
149    }
150}
151
152/// Core security engine that orchestrates all NexusShield components.
153pub struct Shield {
154    pub config: ShieldConfig,
155    pub audit: Arc<AuditChain>,
156    pub rate_governor: Arc<rate_governor::RateGovernor>,
157    pub fingerprinter: Arc<fingerprint::Fingerprinter>,
158    pub email_limiter: Arc<EmailRateLimiter>,
159    pub endpoint: Option<Arc<endpoint::EndpointEngine>>,
160}
161
162impl Shield {
163    pub fn new(config: ShieldConfig) -> Self {
164        let audit = Arc::new(AuditChain::with_max_events(config.audit_max_events));
165        let rate_governor = Arc::new(rate_governor::RateGovernor::new(&config));
166        let fingerprinter = Arc::new(fingerprint::Fingerprinter::new());
167        let email_limiter = Arc::new(EmailRateLimiter::new(config.email.clone()));
168        Self {
169            config,
170            audit,
171            rate_governor,
172            fingerprinter,
173            email_limiter,
174            endpoint: None,
175        }
176    }
177
178    // --- Convenience methods for direct validation from API handlers ---
179
180    /// Validate a SQL query through the AST-based firewall.
181    pub fn validate_sql(&self, sql: &str) -> Result<(), ShieldError> {
182        let analysis = sql_firewall::analyze_query(sql, &self.config.sql);
183        if analysis.allowed {
184            Ok(())
185        } else {
186            let reason = analysis
187                .violations
188                .iter()
189                .map(|v| format!("{:?}", v))
190                .collect::<Vec<_>>()
191                .join(", ");
192            self.audit.record(
193                SecurityEventType::SqlInjectionAttempt,
194                "api",
195                &reason,
196                analysis.risk_score,
197            );
198            Err(ShieldError::SqlInjectionDetected(reason))
199        }
200    }
201
202    /// Validate a URL through the SSRF guard.
203    pub fn validate_url(&self, url: &str) -> Result<(), ShieldError> {
204        ssrf_guard::validate_url(url, &self.config.ssrf).map_err(|reason| {
205            self.audit.record(
206                SecurityEventType::SsrfAttempt,
207                "api",
208                &reason,
209                0.9,
210            );
211            ShieldError::SsrfBlocked(reason)
212        })
213    }
214
215    /// Validate an IP address through the SSRF guard.
216    pub fn validate_ip(&self, ip: &str) -> Result<(), ShieldError> {
217        ssrf_guard::validate_ip_str(ip, &self.config.ssrf).map_err(|reason| {
218            self.audit.record(
219                SecurityEventType::SsrfAttempt,
220                "api",
221                &reason,
222                0.9,
223            );
224            ShieldError::SsrfBlocked(reason)
225        })
226    }
227
228    /// Validate and sanitize a database connection string.
229    pub fn validate_connection_string(&self, conn_str: &str) -> Result<String, ShieldError> {
230        sanitizer::validate_connection_string(conn_str).map_err(|reason| {
231            self.audit.record(
232                SecurityEventType::MaliciousPayload,
233                "api",
234                &reason,
235                0.8,
236            );
237            ShieldError::InvalidConnectionString(reason)
238        })
239    }
240
241    /// Validate a file path (SQLite database path).
242    pub fn validate_file_path(&self, path: &str) -> Result<(), ShieldError> {
243        sanitizer::validate_file_path(path).map_err(|reason| {
244            self.audit.record(
245                SecurityEventType::PathTraversalAttempt,
246                "api",
247                &reason,
248                0.9,
249            );
250            ShieldError::PathTraversal(reason)
251        })
252    }
253
254    /// Run imported CSV data through quarantine validation.
255    pub fn quarantine_csv(&self, content: &str) -> Result<(), ShieldError> {
256        let result = quarantine::validate_csv(content, &self.config.quarantine);
257        if result.passed {
258            Ok(())
259        } else {
260            let reason = result
261                .violations
262                .iter()
263                .map(|v| format!("{:?}", v))
264                .collect::<Vec<_>>()
265                .join(", ");
266            self.audit.record(
267                SecurityEventType::DataQuarantined,
268                "api",
269                &reason,
270                0.7,
271            );
272            Err(ShieldError::QuarantineFailed(reason))
273        }
274    }
275
276    /// Validate a JSON response from an external source.
277    pub fn quarantine_json(&self, json: &str) -> Result<(), ShieldError> {
278        quarantine::validate_json_response(json, self.config.quarantine.max_size_bytes)
279            .map_err(|reason| {
280                self.audit.record(
281                    SecurityEventType::DataQuarantined,
282                    "api",
283                    &reason,
284                    0.6,
285                );
286                ShieldError::MaliciousInput(reason)
287            })
288    }
289
290    // --- Email security methods ---
291
292    /// Validate an email address for format, domain safety, and injection.
293    pub fn validate_email_address(&self, addr: &str) -> Result<(), ShieldError> {
294        let violations = email_guard::validate_email_address(addr, &self.config.email);
295        if violations.is_empty() {
296            Ok(())
297        } else {
298            let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
299            self.audit.record(
300                SecurityEventType::MaliciousPayload,
301                "email",
302                &reason,
303                0.7,
304            );
305            Err(ShieldError::EmailViolation(reason))
306        }
307    }
308
309    /// Validate a header field (subject, name, ticket_id) for injection.
310    pub fn validate_email_header(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
311        let max_len = match field_name {
312            "subject" => self.config.email.max_subject_len,
313            _ => self.config.email.max_name_len,
314        };
315        let violations = email_guard::validate_header_field(field_name, value, max_len);
316        if violations.is_empty() {
317            Ok(())
318        } else {
319            let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
320            self.audit.record(
321                SecurityEventType::MaliciousPayload,
322                "email",
323                &format!("header injection in {}: {}", field_name, reason),
324                0.8,
325            );
326            Err(ShieldError::EmailViolation(reason))
327        }
328    }
329
330    /// Validate content that will be interpolated into an HTML email template.
331    pub fn validate_email_content(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
332        let violations = email_guard::validate_template_content(
333            field_name, value, self.config.email.max_body_len,
334        );
335        if violations.is_empty() {
336            Ok(())
337        } else {
338            let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
339            self.audit.record(
340                SecurityEventType::MaliciousPayload,
341                "email",
342                &format!("content injection in {}: {}", field_name, reason),
343                0.8,
344            );
345            Err(ShieldError::EmailViolation(reason))
346        }
347    }
348
349    /// Check per-recipient email rate limit (anti-bombing).
350    pub fn check_email_rate(&self, recipient: &str) -> Result<(), ShieldError> {
351        if self.email_limiter.check_and_record(recipient) {
352            Ok(())
353        } else {
354            self.audit.record(
355                SecurityEventType::RateLimitHit,
356                "email",
357                &format!("email bombing attempt to {}", recipient),
358                0.9,
359            );
360            Err(ShieldError::EmailBombing(recipient.to_string()))
361        }
362    }
363
364    /// Full outbound email validation: addresses, headers, content, and rate limits.
365    pub fn validate_outbound_email(
366        &self,
367        to: &[&str],
368        subject: &str,
369        body_fields: &[(&str, &str)],
370    ) -> Result<(), ShieldError> {
371        // Validate all fields
372        email_guard::validate_outbound_email(to, subject, body_fields, &self.config.email)
373            .map_err(|reason| {
374                self.audit.record(
375                    SecurityEventType::MaliciousPayload,
376                    "email",
377                    &reason,
378                    0.7,
379                );
380                ShieldError::EmailViolation(reason)
381            })?;
382
383        // Check per-recipient rate limits
384        for addr in to {
385            self.check_email_rate(addr)?;
386        }
387
388        Ok(())
389    }
390
391    /// HTML-escape user content for safe template interpolation.
392    pub fn escape_email_content(value: &str) -> String {
393        email_guard::html_escape(value)
394    }
395}
396
397/// Axum middleware that performs per-request threat assessment.
398///
399/// Install via:
400/// ```ignore
401/// let shield = Arc::new(Shield::new(ShieldConfig::default()));
402/// let app = Router::new()
403///     .route(...)
404///     .layer(Extension(shield.clone()))
405///     .layer(axum::middleware::from_fn(shield_middleware));
406/// ```
407pub async fn shield_middleware(
408    shield: Option<Extension<Arc<Shield>>>,
409    request: Request,
410    next: Next,
411) -> Response {
412    let shield = match shield {
413        Some(Extension(s)) => s,
414        None => return next.run(request).await,
415    };
416
417    let client_ip = extract_client_ip(&request);
418
419    // 1. Rate limiting
420    let rate_result = shield.rate_governor.check(&client_ip);
421    if !rate_result.allowed {
422        shield.audit.record(
423            SecurityEventType::RateLimitHit,
424            &client_ip,
425            &format!(
426                "escalation={:?}, violations={}",
427                rate_result.escalation, rate_result.violations
428            ),
429            0.8,
430        );
431        let mut resp = (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
432        if let Some(retry_after) = rate_result.retry_after {
433            resp.headers_mut().insert(
434                "Retry-After",
435                retry_after.to_string().parse().unwrap(),
436            );
437        }
438        return resp;
439    }
440
441    // 2. Request fingerprinting
442    let fp = shield.fingerprinter.analyze(request.headers());
443
444    // 3. Behavioral score
445    let behavioral = shield.fingerprinter.behavioral_score(&client_ip);
446
447    // 4. Check for recent violations
448    let recent_violations = {
449        let since = chrono::Utc::now() - chrono::Duration::minutes(5);
450        shield
451            .audit
452            .count_since(&SecurityEventType::RequestBlocked, since)
453            > 0
454    };
455
456    // 5. Compute threat score
457    let assessment = threat_score::assess(
458        &fp,
459        &rate_result,
460        behavioral,
461        recent_violations,
462        shield.config.warn_threshold,
463        shield.config.block_threshold,
464    );
465
466    match assessment.action {
467        ThreatAction::Block => {
468            shield.audit.record(
469                SecurityEventType::RequestBlocked,
470                &client_ip,
471                &format!(
472                    "score={:.3}, fingerprint={:.3}, rate={:.3}, behavioral={:.3}",
473                    assessment.score,
474                    assessment.signals.fingerprint_anomaly,
475                    assessment.signals.rate_pressure,
476                    assessment.signals.behavioral_anomaly,
477                ),
478                assessment.score,
479            );
480            return (StatusCode::FORBIDDEN, "Request blocked by security policy").into_response();
481        }
482        ThreatAction::Warn => {
483            tracing::warn!(
484                ip = %client_ip,
485                score = assessment.score,
486                "NexusShield: elevated threat score"
487            );
488            shield.audit.record(
489                SecurityEventType::RequestAllowed,
490                &client_ip,
491                &format!("WARN: score={:.3}", assessment.score),
492                assessment.score,
493            );
494        }
495        ThreatAction::Allow => {
496            // Record silently for behavioral tracking
497            shield.fingerprinter.record_request(&client_ip);
498        }
499    }
500
501    // 6. Proceed to handler
502    let response = next.run(request).await;
503
504    // 7. Track errors for behavioral analysis
505    if response.status().is_client_error() || response.status().is_server_error() {
506        shield.fingerprinter.record_error(&client_ip);
507    }
508
509    response
510}
511
512/// Extract the client IP from the request, checking proxy headers first.
513fn extract_client_ip(req: &Request) -> String {
514    // X-Forwarded-For (first IP in chain)
515    if let Some(xff) = req.headers().get("x-forwarded-for") {
516        if let Ok(value) = xff.to_str() {
517            if let Some(first_ip) = value.split(',').next() {
518                let ip = first_ip.trim();
519                if !ip.is_empty() {
520                    return ip.to_string();
521                }
522            }
523        }
524    }
525
526    // X-Real-IP
527    if let Some(xri) = req.headers().get("x-real-ip") {
528        if let Ok(value) = xri.to_str() {
529            let ip = value.trim();
530            if !ip.is_empty() {
531                return ip.to_string();
532            }
533        }
534    }
535
536    "unknown".to_string()
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use axum::http::{HeaderValue, StatusCode};
543
544    // ── ShieldError display messages ───────────────────────
545
546    #[test]
547    fn shield_error_sql_injection_display() {
548        let err = ShieldError::SqlInjectionDetected("UNION attack".to_string());
549        assert_eq!(err.to_string(), "SQL injection detected: UNION attack");
550    }
551
552    #[test]
553    fn shield_error_ssrf_display() {
554        let err = ShieldError::SsrfBlocked("private IP".to_string());
555        assert_eq!(err.to_string(), "Request blocked by SSRF guard: private IP");
556    }
557
558    #[test]
559    fn shield_error_rate_limit_display() {
560        let err = ShieldError::RateLimitExceeded { retry_after: Some(60) };
561        assert_eq!(err.to_string(), "Rate limit exceeded");
562    }
563
564    #[test]
565    fn shield_error_threat_score_display() {
566        let err = ShieldError::ThreatScoreExceeded(0.85);
567        assert_eq!(err.to_string(), "Request blocked (threat score: 0.850)");
568    }
569
570    #[test]
571    fn shield_error_malicious_input_display() {
572        let err = ShieldError::MaliciousInput("script tag".to_string());
573        assert_eq!(err.to_string(), "Malicious input detected: script tag");
574    }
575
576    #[test]
577    fn shield_error_path_traversal_display() {
578        let err = ShieldError::PathTraversal("../../etc/passwd".to_string());
579        assert_eq!(err.to_string(), "Path traversal blocked: ../../etc/passwd");
580    }
581
582    #[test]
583    fn shield_error_invalid_connection_display() {
584        let err = ShieldError::InvalidConnectionString("bad string".to_string());
585        assert_eq!(err.to_string(), "Invalid connection configuration: bad string");
586    }
587
588    #[test]
589    fn shield_error_quarantine_display() {
590        let err = ShieldError::QuarantineFailed("oversized".to_string());
591        assert_eq!(err.to_string(), "Data quarantine failed: oversized");
592    }
593
594    #[test]
595    fn shield_error_email_violation_display() {
596        let err = ShieldError::EmailViolation("header injection".to_string());
597        assert_eq!(err.to_string(), "Email security violation: header injection");
598    }
599
600    #[test]
601    fn shield_error_email_bombing_display() {
602        let err = ShieldError::EmailBombing("test@example.com".to_string());
603        assert_eq!(err.to_string(), "Email rate limit exceeded for test@example.com");
604    }
605
606    // ── ShieldError -> HTTP status codes ───────────────────
607
608    #[test]
609    fn shield_error_sql_injection_returns_forbidden() {
610        let err = ShieldError::SqlInjectionDetected("test".to_string());
611        let resp = err.into_response();
612        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
613    }
614
615    #[test]
616    fn shield_error_ssrf_returns_forbidden() {
617        let err = ShieldError::SsrfBlocked("test".to_string());
618        let resp = err.into_response();
619        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
620    }
621
622    #[test]
623    fn shield_error_rate_limit_returns_429() {
624        let err = ShieldError::RateLimitExceeded { retry_after: None };
625        let resp = err.into_response();
626        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
627    }
628
629    #[test]
630    fn shield_error_threat_score_returns_forbidden() {
631        let err = ShieldError::ThreatScoreExceeded(0.9);
632        let resp = err.into_response();
633        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
634    }
635
636    #[test]
637    fn shield_error_malicious_input_returns_bad_request() {
638        let err = ShieldError::MaliciousInput("test".to_string());
639        let resp = err.into_response();
640        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
641    }
642
643    #[test]
644    fn shield_error_path_traversal_returns_forbidden() {
645        let err = ShieldError::PathTraversal("test".to_string());
646        let resp = err.into_response();
647        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
648    }
649
650    #[test]
651    fn shield_error_invalid_conn_returns_bad_request() {
652        let err = ShieldError::InvalidConnectionString("test".to_string());
653        let resp = err.into_response();
654        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
655    }
656
657    #[test]
658    fn shield_error_quarantine_returns_bad_request() {
659        let err = ShieldError::QuarantineFailed("test".to_string());
660        let resp = err.into_response();
661        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
662    }
663
664    #[test]
665    fn shield_error_email_violation_returns_bad_request() {
666        let err = ShieldError::EmailViolation("test".to_string());
667        let resp = err.into_response();
668        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
669    }
670
671    #[test]
672    fn shield_error_email_bombing_returns_429() {
673        let err = ShieldError::EmailBombing("test@test.com".to_string());
674        let resp = err.into_response();
675        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
676    }
677
678    // ── Shield construction ────────────────────────────────
679
680    #[test]
681    fn shield_new_with_default_config() {
682        let shield = Shield::new(ShieldConfig::default());
683        assert!(shield.config.block_threshold > 0.0);
684        assert!(shield.config.warn_threshold > 0.0);
685    }
686
687    #[test]
688    fn shield_html_escape() {
689        let escaped = Shield::escape_email_content("<script>alert('xss')</script>");
690        assert!(!escaped.contains("<script>"));
691        assert!(escaped.contains("&lt;script&gt;"));
692    }
693
694    // ── extract_client_ip ──────────────────────────────────
695
696    #[test]
697    fn extract_ip_from_x_forwarded_for() {
698        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
699        req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4, 5.6.7.8"));
700        let ip = extract_client_ip(&req);
701        assert_eq!(ip, "1.2.3.4");
702    }
703
704    #[test]
705    fn extract_ip_from_x_real_ip() {
706        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
707        req.headers_mut().insert("x-real-ip", HeaderValue::from_static("10.0.0.1"));
708        let ip = extract_client_ip(&req);
709        assert_eq!(ip, "10.0.0.1");
710    }
711
712    #[test]
713    fn extract_ip_xff_takes_precedence_over_xri() {
714        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
715        req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.1.1.1"));
716        req.headers_mut().insert("x-real-ip", HeaderValue::from_static("2.2.2.2"));
717        let ip = extract_client_ip(&req);
718        assert_eq!(ip, "1.1.1.1");
719    }
720
721    #[test]
722    fn extract_ip_unknown_when_no_headers() {
723        let req = Request::builder().body(axum::body::Body::empty()).unwrap();
724        let ip = extract_client_ip(&req);
725        assert_eq!(ip, "unknown");
726    }
727
728    #[test]
729    fn extract_ip_xff_trims_whitespace() {
730        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
731        req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("  3.3.3.3  , 4.4.4.4"));
732        let ip = extract_client_ip(&req);
733        assert_eq!(ip, "3.3.3.3");
734    }
735
736    #[test]
737    fn extract_ip_xri_trims_whitespace() {
738        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
739        req.headers_mut().insert("x-real-ip", HeaderValue::from_static("  5.5.5.5  "));
740        let ip = extract_client_ip(&req);
741        assert_eq!(ip, "5.5.5.5");
742    }
743}