auth_framework/api/
validation.rs1use axum::{
6 extract::Request,
7 http::{HeaderMap, StatusCode},
8 middleware::Next,
9 response::Response,
10};
11
12const REQUIRED_SECURITY_HEADERS: &[&str] = &["user-agent"];
16
17const SUSPICIOUS_PATTERNS: &[&str] = &[
18 "<script",
19 "javascript:",
20 "onload=",
21 "onerror=",
22 "eval(",
23 "alert(",
24];
25
26pub async fn validate_request_middleware(
28 request: Request,
29 next: Next,
30) -> Result<Response, StatusCode> {
31 let headers = request.headers();
32
33 validate_security_headers(headers)?;
35
36 if let Some(content_length) = headers.get("content-length")
38 && let Ok(length_str) = content_length.to_str()
39 && let Ok(length) = length_str.parse::<usize>()
40 && length > 10_000_000
41 {
42 return Err(StatusCode::PAYLOAD_TOO_LARGE);
44 }
45
46 for (name, value) in headers.iter() {
48 if let Ok(value_str) = value.to_str()
49 && contains_suspicious_content(value_str)
50 {
51 tracing::warn!(
52 "Suspicious content detected in header {}: {}",
53 name,
54 value_str
55 );
56 return Err(StatusCode::BAD_REQUEST);
57 }
58 }
59
60 Ok(next.run(request).await)
61}
62
63fn validate_security_headers(headers: &HeaderMap) -> Result<(), StatusCode> {
65 let missing_headers: Vec<&str> = REQUIRED_SECURITY_HEADERS
66 .iter()
67 .filter(|&&header| !headers.contains_key(header))
68 .copied()
69 .collect();
70
71 if !missing_headers.is_empty() {
72 tracing::warn!("Missing required headers: {:?}", missing_headers);
73 return Err(StatusCode::BAD_REQUEST);
74 }
75
76 Ok(())
77}
78
79fn contains_suspicious_content(content: &str) -> bool {
81 let content_lower = content.to_lowercase();
82 SUSPICIOUS_PATTERNS
83 .iter()
84 .any(|&pattern| content_lower.contains(pattern))
85}
86
87pub struct IpRateLimiter {
89 requests: std::sync::Mutex<std::collections::HashMap<String, (u32, std::time::Instant)>>,
90 max_requests: u32,
91 window_duration: std::time::Duration,
92}
93
94impl IpRateLimiter {
95 pub fn new(max_requests: u32, window_minutes: u64) -> Self {
96 Self {
97 requests: std::sync::Mutex::new(std::collections::HashMap::new()),
98 max_requests,
99 window_duration: std::time::Duration::from_secs(window_minutes * 60),
100 }
101 }
102
103 pub fn check_rate_limit(&self, ip: &str) -> Result<(), StatusCode> {
104 let mut requests = self
105 .requests
106 .lock()
107 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
108 let now = std::time::Instant::now();
109
110 requests.retain(|_, (_, timestamp)| now.duration_since(*timestamp) < self.window_duration);
112
113 match requests.get_mut(ip) {
115 Some((count, timestamp)) => {
116 if now.duration_since(*timestamp) < self.window_duration {
117 if *count >= self.max_requests {
118 return Err(StatusCode::TOO_MANY_REQUESTS);
119 }
120 *count += 1;
121 } else {
122 *count = 1;
123 *timestamp = now;
124 }
125 }
126 None => {
127 requests.insert(ip.to_string(), (1, now));
128 }
129 }
130
131 Ok(())
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_suspicious_content_detection() {
141 assert!(contains_suspicious_content("<script>alert('xss')</script>"));
142 assert!(contains_suspicious_content("javascript:void(0)"));
143 assert!(contains_suspicious_content("onload=malicious()"));
144 assert!(!contains_suspicious_content("normal content"));
145 assert!(!contains_suspicious_content("user@example.com"));
146 }
147
148 #[test]
149 fn test_rate_limiter() {
150 let limiter = IpRateLimiter::new(5, 1);
151
152 for _ in 0..5 {
154 assert!(limiter.check_rate_limit("192.168.1.1").is_ok());
155 }
156
157 assert!(limiter.check_rate_limit("192.168.1.1").is_err());
159 }
160}