Skip to main content

auth_framework/api/
validation.rs

1//! Request Validation Middleware
2//!
3//! Provides comprehensive request validation and sanitization
4
5use axum::{
6    extract::Request,
7    http::{HeaderMap, StatusCode},
8    middleware::Next,
9    response::Response,
10};
11
12/// Security headers to validate
13/// Note: Only user-agent is required. The accept header is omitted because
14/// many legitimate clients (curl, automated tools, health checks) don't send it.
15const REQUIRED_SECURITY_HEADERS: &[&str] = &["user-agent"];
16
17const SUSPICIOUS_PATTERNS: &[&str] = &[
18    "<script",
19    "javascript:",
20    "onload=",
21    "onerror=",
22    "eval(",
23    "alert(",
24];
25
26/// Request validation middleware
27pub async fn validate_request_middleware(
28    request: Request,
29    next: Next,
30) -> Result<Response, StatusCode> {
31    let headers = request.headers();
32
33    // Validate security headers
34    validate_security_headers(headers)?;
35
36    // Validate request size
37    if let Some(content_length) = headers.get("content-length")
38        && let Ok(length_str) = content_length.to_str()
39        && let Ok(length) = length_str.parse::<usize>()
40        && length > 10_000_000
41    {
42        // 10MB limit
43        return Err(StatusCode::PAYLOAD_TOO_LARGE);
44    }
45
46    // Check for suspicious patterns in headers
47    for (name, value) in headers.iter() {
48        if let Ok(value_str) = value.to_str()
49            && contains_suspicious_content(value_str)
50        {
51            tracing::warn!(
52                "Suspicious content detected in header {}: {}",
53                name,
54                value_str
55            );
56            return Err(StatusCode::BAD_REQUEST);
57        }
58    }
59
60    Ok(next.run(request).await)
61}
62
63/// Validate required security headers
64fn validate_security_headers(headers: &HeaderMap) -> Result<(), StatusCode> {
65    let missing_headers: Vec<&str> = REQUIRED_SECURITY_HEADERS
66        .iter()
67        .filter(|&&header| !headers.contains_key(header))
68        .copied()
69        .collect();
70
71    if !missing_headers.is_empty() {
72        tracing::warn!("Missing required headers: {:?}", missing_headers);
73        return Err(StatusCode::BAD_REQUEST);
74    }
75
76    Ok(())
77}
78
79/// Check for suspicious content patterns
80fn contains_suspicious_content(content: &str) -> bool {
81    let content_lower = content.to_lowercase();
82    SUSPICIOUS_PATTERNS
83        .iter()
84        .any(|&pattern| content_lower.contains(pattern))
85}
86
87/// Rate limiting by IP
88pub struct IpRateLimiter {
89    requests: std::sync::Mutex<std::collections::HashMap<String, (u32, std::time::Instant)>>,
90    max_requests: u32,
91    window_duration: std::time::Duration,
92}
93
94impl IpRateLimiter {
95    pub fn new(max_requests: u32, window_minutes: u64) -> Self {
96        Self {
97            requests: std::sync::Mutex::new(std::collections::HashMap::new()),
98            max_requests,
99            window_duration: std::time::Duration::from_secs(window_minutes * 60),
100        }
101    }
102
103    pub fn check_rate_limit(&self, ip: &str) -> Result<(), StatusCode> {
104        let mut requests = self
105            .requests
106            .lock()
107            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
108        let now = std::time::Instant::now();
109
110        // Clean expired entries
111        requests.retain(|_, (_, timestamp)| now.duration_since(*timestamp) < self.window_duration);
112
113        // Check current IP
114        match requests.get_mut(ip) {
115            Some((count, timestamp)) => {
116                if now.duration_since(*timestamp) < self.window_duration {
117                    if *count >= self.max_requests {
118                        return Err(StatusCode::TOO_MANY_REQUESTS);
119                    }
120                    *count += 1;
121                } else {
122                    *count = 1;
123                    *timestamp = now;
124                }
125            }
126            None => {
127                requests.insert(ip.to_string(), (1, now));
128            }
129        }
130
131        Ok(())
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn test_suspicious_content_detection() {
141        assert!(contains_suspicious_content("<script>alert('xss')</script>"));
142        assert!(contains_suspicious_content("javascript:void(0)"));
143        assert!(contains_suspicious_content("onload=malicious()"));
144        assert!(!contains_suspicious_content("normal content"));
145        assert!(!contains_suspicious_content("user@example.com"));
146    }
147
148    #[test]
149    fn test_rate_limiter() {
150        let limiter = IpRateLimiter::new(5, 1);
151
152        // Should allow first 5 requests
153        for _ in 0..5 {
154            assert!(limiter.check_rate_limit("192.168.1.1").is_ok());
155        }
156
157        // Should block 6th request
158        assert!(limiter.check_rate_limit("192.168.1.1").is_err());
159    }
160}