Skip to main content

nexus_shield/
lib.rs

1// ============================================================================
2// File: lib.rs
3// Description: Adaptive zero-trust security engine with layered threat defense
4// Author: Andrew Jewell Sr. - AutomataNexus
5// Updated: March 24, 2026
6//
7// DISCLAIMER: This software is provided "as is", without warranty of any kind,
8// express or implied. Use at your own risk. AutomataNexus and the author assume
9// no liability for any damages arising from the use of this software.
10// ============================================================================
11//! # NexusShield
12//!
13//! Adaptive zero-trust security engine for the Nexus platform.
14//!
15//! Protects services from SQL injection, SSRF, command injection, path traversal,
16//! and automated attacks through a layered defense architecture:
17//!
18//! - **SQL Firewall** — AST-level SQL parsing (not regex) to detect injection
19//! - **SSRF Guard** — IP/DNS validation blocking internal network probing
20//! - **Rate Governor** — Adaptive rate limiting with behavioral escalation
21//! - **Request Fingerprinting** — Bot detection via header/behavioral analysis
22//! - **Data Quarantine** — Validates imported data for malicious payloads
23//! - **Audit Chain** — Hash-chained tamper-evident security event log
24//! - **Input Sanitizer** — Connection string and path traversal prevention
25//! - **Threat Scoring** — Multi-signal adaptive threat assessment (0.0–1.0)
26
27pub mod auth;
28pub mod audit_chain;
29pub mod compliance_report;
30pub mod config;
31pub mod credential_vault;
32pub mod endpoint;
33pub mod ferrum_integration;
34pub mod fingerprint;
35pub mod journal;
36pub mod metrics;
37pub mod quarantine;
38pub mod rate_governor;
39pub mod sanitizer;
40pub mod siem_export;
41pub mod signature_updater;
42pub mod sql_firewall;
43pub mod sse_events;
44pub mod ssrf_guard;
45pub mod email_guard;
46pub mod threat_score;
47pub mod webhook;
48
49use std::sync::Arc;
50
51use axum::extract::Request;
52use axum::http::StatusCode;
53use axum::middleware::Next;
54use axum::response::{IntoResponse, Response};
55use axum::Extension;
56
57pub use audit_chain::{AuditChain, AuditEvent, SecurityEventType};
58pub use config::ShieldConfig;
59pub use email_guard::{EmailGuardConfig, EmailRateLimiter};
60pub use threat_score::{ThreatAction, ThreatAssessment};
61
62/// Errors raised by the NexusShield security engine.
63#[derive(Debug, thiserror::Error)]
64pub enum ShieldError {
65    #[error("SQL injection detected: {0}")]
66    SqlInjectionDetected(String),
67
68    #[error("Request blocked by SSRF guard: {0}")]
69    SsrfBlocked(String),
70
71    #[error("Rate limit exceeded")]
72    RateLimitExceeded { retry_after: Option<u64> },
73
74    #[error("Request blocked (threat score: {0:.3})")]
75    ThreatScoreExceeded(f64),
76
77    #[error("Malicious input detected: {0}")]
78    MaliciousInput(String),
79
80    #[error("Path traversal blocked: {0}")]
81    PathTraversal(String),
82
83    #[error("Invalid connection configuration: {0}")]
84    InvalidConnectionString(String),
85
86    #[error("Data quarantine failed: {0}")]
87    QuarantineFailed(String),
88
89    #[error("Email security violation: {0}")]
90    EmailViolation(String),
91
92    #[error("Email rate limit exceeded for {0}")]
93    EmailBombing(String),
94
95    #[error("Malware detected: {0}")]
96    MalwareDetected(String),
97
98    #[error("Endpoint protection error: {0}")]
99    EndpointError(String),
100
101    #[error("Quarantine vault error: {0}")]
102    QuarantineVaultError(String),
103}
104
105impl IntoResponse for ShieldError {
106    fn into_response(self) -> Response {
107        // Deliberately vague error messages to avoid leaking security internals
108        let (status, message) = match &self {
109            Self::SqlInjectionDetected(_) => {
110                (StatusCode::FORBIDDEN, "Request blocked by security policy")
111            }
112            Self::SsrfBlocked(_) => {
113                (StatusCode::FORBIDDEN, "Request blocked by security policy")
114            }
115            Self::RateLimitExceeded { .. } => {
116                (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded")
117            }
118            Self::ThreatScoreExceeded(_) => {
119                (StatusCode::FORBIDDEN, "Request blocked by security policy")
120            }
121            Self::MaliciousInput(_) => (StatusCode::BAD_REQUEST, "Invalid input detected"),
122            Self::PathTraversal(_) => {
123                (StatusCode::FORBIDDEN, "Request blocked by security policy")
124            }
125            Self::InvalidConnectionString(_) => {
126                (StatusCode::BAD_REQUEST, "Invalid connection configuration")
127            }
128            Self::QuarantineFailed(_) => {
129                (StatusCode::BAD_REQUEST, "Data validation failed")
130            }
131            Self::EmailViolation(_) => {
132                (StatusCode::BAD_REQUEST, "Email validation failed")
133            }
134            Self::EmailBombing(_) => {
135                (StatusCode::TOO_MANY_REQUESTS, "Email rate limit exceeded")
136            }
137            Self::MalwareDetected(_) => {
138                (StatusCode::FORBIDDEN, "Request blocked by security policy")
139            }
140            Self::EndpointError(_) => {
141                (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
142            }
143            Self::QuarantineVaultError(_) => {
144                (StatusCode::INTERNAL_SERVER_ERROR, "Security engine error")
145            }
146        };
147        (status, message).into_response()
148    }
149}
150
151/// Core security engine that orchestrates all NexusShield components.
152pub struct Shield {
153    pub config: ShieldConfig,
154    pub audit: Arc<AuditChain>,
155    pub rate_governor: Arc<rate_governor::RateGovernor>,
156    pub fingerprinter: Arc<fingerprint::Fingerprinter>,
157    pub email_limiter: Arc<EmailRateLimiter>,
158    pub endpoint: Option<Arc<endpoint::EndpointEngine>>,
159}
160
161impl Shield {
162    pub fn new(config: ShieldConfig) -> Self {
163        let audit = Arc::new(AuditChain::with_max_events(config.audit_max_events));
164        let rate_governor = Arc::new(rate_governor::RateGovernor::new(&config));
165        let fingerprinter = Arc::new(fingerprint::Fingerprinter::new());
166        let email_limiter = Arc::new(EmailRateLimiter::new(config.email.clone()));
167        Self {
168            config,
169            audit,
170            rate_governor,
171            fingerprinter,
172            email_limiter,
173            endpoint: None,
174        }
175    }
176
177    // --- Convenience methods for direct validation from API handlers ---
178
179    /// Validate a SQL query through the AST-based firewall.
180    pub fn validate_sql(&self, sql: &str) -> Result<(), ShieldError> {
181        let analysis = sql_firewall::analyze_query(sql, &self.config.sql);
182        if analysis.allowed {
183            Ok(())
184        } else {
185            let reason = analysis
186                .violations
187                .iter()
188                .map(|v| format!("{:?}", v))
189                .collect::<Vec<_>>()
190                .join(", ");
191            self.audit.record(
192                SecurityEventType::SqlInjectionAttempt,
193                "api",
194                &reason,
195                analysis.risk_score,
196            );
197            Err(ShieldError::SqlInjectionDetected(reason))
198        }
199    }
200
201    /// Validate a URL through the SSRF guard.
202    pub fn validate_url(&self, url: &str) -> Result<(), ShieldError> {
203        ssrf_guard::validate_url(url, &self.config.ssrf).map_err(|reason| {
204            self.audit.record(
205                SecurityEventType::SsrfAttempt,
206                "api",
207                &reason,
208                0.9,
209            );
210            ShieldError::SsrfBlocked(reason)
211        })
212    }
213
214    /// Validate an IP address through the SSRF guard.
215    pub fn validate_ip(&self, ip: &str) -> Result<(), ShieldError> {
216        ssrf_guard::validate_ip_str(ip, &self.config.ssrf).map_err(|reason| {
217            self.audit.record(
218                SecurityEventType::SsrfAttempt,
219                "api",
220                &reason,
221                0.9,
222            );
223            ShieldError::SsrfBlocked(reason)
224        })
225    }
226
227    /// Validate and sanitize a database connection string.
228    pub fn validate_connection_string(&self, conn_str: &str) -> Result<String, ShieldError> {
229        sanitizer::validate_connection_string(conn_str).map_err(|reason| {
230            self.audit.record(
231                SecurityEventType::MaliciousPayload,
232                "api",
233                &reason,
234                0.8,
235            );
236            ShieldError::InvalidConnectionString(reason)
237        })
238    }
239
240    /// Validate a file path (SQLite database path).
241    pub fn validate_file_path(&self, path: &str) -> Result<(), ShieldError> {
242        sanitizer::validate_file_path(path).map_err(|reason| {
243            self.audit.record(
244                SecurityEventType::PathTraversalAttempt,
245                "api",
246                &reason,
247                0.9,
248            );
249            ShieldError::PathTraversal(reason)
250        })
251    }
252
253    /// Run imported CSV data through quarantine validation.
254    pub fn quarantine_csv(&self, content: &str) -> Result<(), ShieldError> {
255        let result = quarantine::validate_csv(content, &self.config.quarantine);
256        if result.passed {
257            Ok(())
258        } else {
259            let reason = result
260                .violations
261                .iter()
262                .map(|v| format!("{:?}", v))
263                .collect::<Vec<_>>()
264                .join(", ");
265            self.audit.record(
266                SecurityEventType::DataQuarantined,
267                "api",
268                &reason,
269                0.7,
270            );
271            Err(ShieldError::QuarantineFailed(reason))
272        }
273    }
274
275    /// Validate a JSON response from an external source.
276    pub fn quarantine_json(&self, json: &str) -> Result<(), ShieldError> {
277        quarantine::validate_json_response(json, self.config.quarantine.max_size_bytes)
278            .map_err(|reason| {
279                self.audit.record(
280                    SecurityEventType::DataQuarantined,
281                    "api",
282                    &reason,
283                    0.6,
284                );
285                ShieldError::MaliciousInput(reason)
286            })
287    }
288
289    // --- Email security methods ---
290
291    /// Validate an email address for format, domain safety, and injection.
292    pub fn validate_email_address(&self, addr: &str) -> Result<(), ShieldError> {
293        let violations = email_guard::validate_email_address(addr, &self.config.email);
294        if violations.is_empty() {
295            Ok(())
296        } else {
297            let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
298            self.audit.record(
299                SecurityEventType::MaliciousPayload,
300                "email",
301                &reason,
302                0.7,
303            );
304            Err(ShieldError::EmailViolation(reason))
305        }
306    }
307
308    /// Validate a header field (subject, name, ticket_id) for injection.
309    pub fn validate_email_header(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
310        let max_len = match field_name {
311            "subject" => self.config.email.max_subject_len,
312            _ => self.config.email.max_name_len,
313        };
314        let violations = email_guard::validate_header_field(field_name, value, max_len);
315        if violations.is_empty() {
316            Ok(())
317        } else {
318            let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
319            self.audit.record(
320                SecurityEventType::MaliciousPayload,
321                "email",
322                &format!("header injection in {}: {}", field_name, reason),
323                0.8,
324            );
325            Err(ShieldError::EmailViolation(reason))
326        }
327    }
328
329    /// Validate content that will be interpolated into an HTML email template.
330    pub fn validate_email_content(&self, field_name: &str, value: &str) -> Result<(), ShieldError> {
331        let violations = email_guard::validate_template_content(
332            field_name, value, self.config.email.max_body_len,
333        );
334        if violations.is_empty() {
335            Ok(())
336        } else {
337            let reason = violations.iter().map(|v| format!("{:?}", v)).collect::<Vec<_>>().join(", ");
338            self.audit.record(
339                SecurityEventType::MaliciousPayload,
340                "email",
341                &format!("content injection in {}: {}", field_name, reason),
342                0.8,
343            );
344            Err(ShieldError::EmailViolation(reason))
345        }
346    }
347
348    /// Check per-recipient email rate limit (anti-bombing).
349    pub fn check_email_rate(&self, recipient: &str) -> Result<(), ShieldError> {
350        if self.email_limiter.check_and_record(recipient) {
351            Ok(())
352        } else {
353            self.audit.record(
354                SecurityEventType::RateLimitHit,
355                "email",
356                &format!("email bombing attempt to {}", recipient),
357                0.9,
358            );
359            Err(ShieldError::EmailBombing(recipient.to_string()))
360        }
361    }
362
363    /// Full outbound email validation: addresses, headers, content, and rate limits.
364    pub fn validate_outbound_email(
365        &self,
366        to: &[&str],
367        subject: &str,
368        body_fields: &[(&str, &str)],
369    ) -> Result<(), ShieldError> {
370        // Validate all fields
371        email_guard::validate_outbound_email(to, subject, body_fields, &self.config.email)
372            .map_err(|reason| {
373                self.audit.record(
374                    SecurityEventType::MaliciousPayload,
375                    "email",
376                    &reason,
377                    0.7,
378                );
379                ShieldError::EmailViolation(reason)
380            })?;
381
382        // Check per-recipient rate limits
383        for addr in to {
384            self.check_email_rate(addr)?;
385        }
386
387        Ok(())
388    }
389
390    /// HTML-escape user content for safe template interpolation.
391    pub fn escape_email_content(value: &str) -> String {
392        email_guard::html_escape(value)
393    }
394}
395
396/// Axum middleware that performs per-request threat assessment.
397///
398/// Install via:
399/// ```ignore
400/// let shield = Arc::new(Shield::new(ShieldConfig::default()));
401/// let app = Router::new()
402///     .route(...)
403///     .layer(Extension(shield.clone()))
404///     .layer(axum::middleware::from_fn(shield_middleware));
405/// ```
406pub async fn shield_middleware(
407    shield: Option<Extension<Arc<Shield>>>,
408    request: Request,
409    next: Next,
410) -> Response {
411    let shield = match shield {
412        Some(Extension(s)) => s,
413        None => return next.run(request).await,
414    };
415
416    let client_ip = extract_client_ip(&request);
417
418    // 1. Rate limiting
419    let rate_result = shield.rate_governor.check(&client_ip);
420    if !rate_result.allowed {
421        shield.audit.record(
422            SecurityEventType::RateLimitHit,
423            &client_ip,
424            &format!(
425                "escalation={:?}, violations={}",
426                rate_result.escalation, rate_result.violations
427            ),
428            0.8,
429        );
430        let mut resp = (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
431        if let Some(retry_after) = rate_result.retry_after {
432            resp.headers_mut().insert(
433                "Retry-After",
434                retry_after.to_string().parse().unwrap(),
435            );
436        }
437        return resp;
438    }
439
440    // 2. Request fingerprinting
441    let fp = shield.fingerprinter.analyze(request.headers());
442
443    // 3. Behavioral score
444    let behavioral = shield.fingerprinter.behavioral_score(&client_ip);
445
446    // 4. Check for recent violations
447    let recent_violations = {
448        let since = chrono::Utc::now() - chrono::Duration::minutes(5);
449        shield
450            .audit
451            .count_since(&SecurityEventType::RequestBlocked, since)
452            > 0
453    };
454
455    // 5. Compute threat score
456    let assessment = threat_score::assess(
457        &fp,
458        &rate_result,
459        behavioral,
460        recent_violations,
461        shield.config.warn_threshold,
462        shield.config.block_threshold,
463    );
464
465    match assessment.action {
466        ThreatAction::Block => {
467            shield.audit.record(
468                SecurityEventType::RequestBlocked,
469                &client_ip,
470                &format!(
471                    "score={:.3}, fingerprint={:.3}, rate={:.3}, behavioral={:.3}",
472                    assessment.score,
473                    assessment.signals.fingerprint_anomaly,
474                    assessment.signals.rate_pressure,
475                    assessment.signals.behavioral_anomaly,
476                ),
477                assessment.score,
478            );
479            return (StatusCode::FORBIDDEN, "Request blocked by security policy").into_response();
480        }
481        ThreatAction::Warn => {
482            tracing::warn!(
483                ip = %client_ip,
484                score = assessment.score,
485                "NexusShield: elevated threat score"
486            );
487            shield.audit.record(
488                SecurityEventType::RequestAllowed,
489                &client_ip,
490                &format!("WARN: score={:.3}", assessment.score),
491                assessment.score,
492            );
493        }
494        ThreatAction::Allow => {
495            // Record silently for behavioral tracking
496            shield.fingerprinter.record_request(&client_ip);
497        }
498    }
499
500    // 6. Proceed to handler
501    let response = next.run(request).await;
502
503    // 7. Track errors for behavioral analysis
504    if response.status().is_client_error() || response.status().is_server_error() {
505        shield.fingerprinter.record_error(&client_ip);
506    }
507
508    response
509}
510
511/// Extract the client IP from the request, checking proxy headers first.
512fn extract_client_ip(req: &Request) -> String {
513    // X-Forwarded-For (first IP in chain)
514    if let Some(xff) = req.headers().get("x-forwarded-for") {
515        if let Ok(value) = xff.to_str() {
516            if let Some(first_ip) = value.split(',').next() {
517                let ip = first_ip.trim();
518                if !ip.is_empty() {
519                    return ip.to_string();
520                }
521            }
522        }
523    }
524
525    // X-Real-IP
526    if let Some(xri) = req.headers().get("x-real-ip") {
527        if let Ok(value) = xri.to_str() {
528            let ip = value.trim();
529            if !ip.is_empty() {
530                return ip.to_string();
531            }
532        }
533    }
534
535    "unknown".to_string()
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use axum::http::{HeaderValue, StatusCode};
542
543    // ── ShieldError display messages ───────────────────────
544
545    #[test]
546    fn shield_error_sql_injection_display() {
547        let err = ShieldError::SqlInjectionDetected("UNION attack".to_string());
548        assert_eq!(err.to_string(), "SQL injection detected: UNION attack");
549    }
550
551    #[test]
552    fn shield_error_ssrf_display() {
553        let err = ShieldError::SsrfBlocked("private IP".to_string());
554        assert_eq!(err.to_string(), "Request blocked by SSRF guard: private IP");
555    }
556
557    #[test]
558    fn shield_error_rate_limit_display() {
559        let err = ShieldError::RateLimitExceeded { retry_after: Some(60) };
560        assert_eq!(err.to_string(), "Rate limit exceeded");
561    }
562
563    #[test]
564    fn shield_error_threat_score_display() {
565        let err = ShieldError::ThreatScoreExceeded(0.85);
566        assert_eq!(err.to_string(), "Request blocked (threat score: 0.850)");
567    }
568
569    #[test]
570    fn shield_error_malicious_input_display() {
571        let err = ShieldError::MaliciousInput("script tag".to_string());
572        assert_eq!(err.to_string(), "Malicious input detected: script tag");
573    }
574
575    #[test]
576    fn shield_error_path_traversal_display() {
577        let err = ShieldError::PathTraversal("../../etc/passwd".to_string());
578        assert_eq!(err.to_string(), "Path traversal blocked: ../../etc/passwd");
579    }
580
581    #[test]
582    fn shield_error_invalid_connection_display() {
583        let err = ShieldError::InvalidConnectionString("bad string".to_string());
584        assert_eq!(err.to_string(), "Invalid connection configuration: bad string");
585    }
586
587    #[test]
588    fn shield_error_quarantine_display() {
589        let err = ShieldError::QuarantineFailed("oversized".to_string());
590        assert_eq!(err.to_string(), "Data quarantine failed: oversized");
591    }
592
593    #[test]
594    fn shield_error_email_violation_display() {
595        let err = ShieldError::EmailViolation("header injection".to_string());
596        assert_eq!(err.to_string(), "Email security violation: header injection");
597    }
598
599    #[test]
600    fn shield_error_email_bombing_display() {
601        let err = ShieldError::EmailBombing("test@example.com".to_string());
602        assert_eq!(err.to_string(), "Email rate limit exceeded for test@example.com");
603    }
604
605    // ── ShieldError -> HTTP status codes ───────────────────
606
607    #[test]
608    fn shield_error_sql_injection_returns_forbidden() {
609        let err = ShieldError::SqlInjectionDetected("test".to_string());
610        let resp = err.into_response();
611        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
612    }
613
614    #[test]
615    fn shield_error_ssrf_returns_forbidden() {
616        let err = ShieldError::SsrfBlocked("test".to_string());
617        let resp = err.into_response();
618        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
619    }
620
621    #[test]
622    fn shield_error_rate_limit_returns_429() {
623        let err = ShieldError::RateLimitExceeded { retry_after: None };
624        let resp = err.into_response();
625        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
626    }
627
628    #[test]
629    fn shield_error_threat_score_returns_forbidden() {
630        let err = ShieldError::ThreatScoreExceeded(0.9);
631        let resp = err.into_response();
632        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
633    }
634
635    #[test]
636    fn shield_error_malicious_input_returns_bad_request() {
637        let err = ShieldError::MaliciousInput("test".to_string());
638        let resp = err.into_response();
639        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
640    }
641
642    #[test]
643    fn shield_error_path_traversal_returns_forbidden() {
644        let err = ShieldError::PathTraversal("test".to_string());
645        let resp = err.into_response();
646        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
647    }
648
649    #[test]
650    fn shield_error_invalid_conn_returns_bad_request() {
651        let err = ShieldError::InvalidConnectionString("test".to_string());
652        let resp = err.into_response();
653        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
654    }
655
656    #[test]
657    fn shield_error_quarantine_returns_bad_request() {
658        let err = ShieldError::QuarantineFailed("test".to_string());
659        let resp = err.into_response();
660        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
661    }
662
663    #[test]
664    fn shield_error_email_violation_returns_bad_request() {
665        let err = ShieldError::EmailViolation("test".to_string());
666        let resp = err.into_response();
667        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
668    }
669
670    #[test]
671    fn shield_error_email_bombing_returns_429() {
672        let err = ShieldError::EmailBombing("test@test.com".to_string());
673        let resp = err.into_response();
674        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
675    }
676
677    // ── Shield construction ────────────────────────────────
678
679    #[test]
680    fn shield_new_with_default_config() {
681        let shield = Shield::new(ShieldConfig::default());
682        assert!(shield.config.block_threshold > 0.0);
683        assert!(shield.config.warn_threshold > 0.0);
684    }
685
686    #[test]
687    fn shield_html_escape() {
688        let escaped = Shield::escape_email_content("<script>alert('xss')</script>");
689        assert!(!escaped.contains("<script>"));
690        assert!(escaped.contains("&lt;script&gt;"));
691    }
692
693    // ── extract_client_ip ──────────────────────────────────
694
695    #[test]
696    fn extract_ip_from_x_forwarded_for() {
697        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
698        req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.2.3.4, 5.6.7.8"));
699        let ip = extract_client_ip(&req);
700        assert_eq!(ip, "1.2.3.4");
701    }
702
703    #[test]
704    fn extract_ip_from_x_real_ip() {
705        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
706        req.headers_mut().insert("x-real-ip", HeaderValue::from_static("10.0.0.1"));
707        let ip = extract_client_ip(&req);
708        assert_eq!(ip, "10.0.0.1");
709    }
710
711    #[test]
712    fn extract_ip_xff_takes_precedence_over_xri() {
713        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
714        req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("1.1.1.1"));
715        req.headers_mut().insert("x-real-ip", HeaderValue::from_static("2.2.2.2"));
716        let ip = extract_client_ip(&req);
717        assert_eq!(ip, "1.1.1.1");
718    }
719
720    #[test]
721    fn extract_ip_unknown_when_no_headers() {
722        let req = Request::builder().body(axum::body::Body::empty()).unwrap();
723        let ip = extract_client_ip(&req);
724        assert_eq!(ip, "unknown");
725    }
726
727    #[test]
728    fn extract_ip_xff_trims_whitespace() {
729        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
730        req.headers_mut().insert("x-forwarded-for", HeaderValue::from_static("  3.3.3.3  , 4.4.4.4"));
731        let ip = extract_client_ip(&req);
732        assert_eq!(ip, "3.3.3.3");
733    }
734
735    #[test]
736    fn extract_ip_xri_trims_whitespace() {
737        let mut req = Request::builder().body(axum::body::Body::empty()).unwrap();
738        req.headers_mut().insert("x-real-ip", HeaderValue::from_static("  5.5.5.5  "));
739        let ip = extract_client_ip(&req);
740        assert_eq!(ip, "5.5.5.5");
741    }
742}