auth_framework/api/
validation.rs

1//! Request Validation Middleware
2//!
3//! Provides comprehensive request validation and sanitization
4
5use axum::{
6    extract::Request,
7    http::{HeaderMap, StatusCode},
8    middleware::Next,
9    response::Response,
10};
11
12/// Security headers to validate
13const REQUIRED_SECURITY_HEADERS: &[&str] = &["user-agent", "accept"];
14
15const SUSPICIOUS_PATTERNS: &[&str] = &[
16    "<script",
17    "javascript:",
18    "onload=",
19    "onerror=",
20    "eval(",
21    "alert(",
22];
23
24/// Request validation middleware
25pub async fn validate_request_middleware(
26    request: Request,
27    next: Next,
28) -> Result<Response, StatusCode> {
29    let headers = request.headers();
30
31    // Validate security headers
32    validate_security_headers(headers)?;
33
34    // Validate request size
35    if let Some(content_length) = headers.get("content-length")
36        && let Ok(length_str) = content_length.to_str()
37        && let Ok(length) = length_str.parse::<usize>()
38        && length > 10_000_000
39    {
40        // 10MB limit
41        return Err(StatusCode::PAYLOAD_TOO_LARGE);
42    }
43
44    // Check for suspicious patterns in headers
45    for (name, value) in headers.iter() {
46        if let Ok(value_str) = value.to_str()
47            && contains_suspicious_content(value_str)
48        {
49            tracing::warn!(
50                "Suspicious content detected in header {}: {}",
51                name,
52                value_str
53            );
54            return Err(StatusCode::BAD_REQUEST);
55        }
56    }
57
58    Ok(next.run(request).await)
59}
60
61/// Validate required security headers
62fn validate_security_headers(headers: &HeaderMap) -> Result<(), StatusCode> {
63    let missing_headers: Vec<&str> = REQUIRED_SECURITY_HEADERS
64        .iter()
65        .filter(|&&header| !headers.contains_key(header))
66        .copied()
67        .collect();
68
69    if !missing_headers.is_empty() {
70        tracing::warn!("Missing required headers: {:?}", missing_headers);
71        return Err(StatusCode::BAD_REQUEST);
72    }
73
74    Ok(())
75}
76
77/// Check for suspicious content patterns
78fn contains_suspicious_content(content: &str) -> bool {
79    let content_lower = content.to_lowercase();
80    SUSPICIOUS_PATTERNS
81        .iter()
82        .any(|&pattern| content_lower.contains(pattern))
83}
84
85/// Rate limiting by IP
86pub struct IpRateLimiter {
87    requests: std::sync::Mutex<std::collections::HashMap<String, (u32, std::time::Instant)>>,
88    max_requests: u32,
89    window_duration: std::time::Duration,
90}
91
92impl IpRateLimiter {
93    pub fn new(max_requests: u32, window_minutes: u64) -> Self {
94        Self {
95            requests: std::sync::Mutex::new(std::collections::HashMap::new()),
96            max_requests,
97            window_duration: std::time::Duration::from_secs(window_minutes * 60),
98        }
99    }
100
101    pub fn check_rate_limit(&self, ip: &str) -> Result<(), StatusCode> {
102        let mut requests = self.requests.lock().unwrap();
103        let now = std::time::Instant::now();
104
105        // Clean expired entries
106        requests.retain(|_, (_, timestamp)| now.duration_since(*timestamp) < self.window_duration);
107
108        // Check current IP
109        match requests.get_mut(ip) {
110            Some((count, timestamp)) => {
111                if now.duration_since(*timestamp) < self.window_duration {
112                    if *count >= self.max_requests {
113                        return Err(StatusCode::TOO_MANY_REQUESTS);
114                    }
115                    *count += 1;
116                } else {
117                    *count = 1;
118                    *timestamp = now;
119                }
120            }
121            None => {
122                requests.insert(ip.to_string(), (1, now));
123            }
124        }
125
126        Ok(())
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_suspicious_content_detection() {
136        assert!(contains_suspicious_content("<script>alert('xss')</script>"));
137        assert!(contains_suspicious_content("javascript:void(0)"));
138        assert!(contains_suspicious_content("onload=malicious()"));
139        assert!(!contains_suspicious_content("normal content"));
140        assert!(!contains_suspicious_content("user@example.com"));
141    }
142
143    #[test]
144    fn test_rate_limiter() {
145        let limiter = IpRateLimiter::new(5, 1);
146
147        // Should allow first 5 requests
148        for _ in 0..5 {
149            assert!(limiter.check_rate_limit("192.168.1.1").is_ok());
150        }
151
152        // Should block 6th request
153        assert!(limiter.check_rate_limit("192.168.1.1").is_err());
154    }
155}
156
157