auth_framework/api/
validation.rs1use axum::{
6 extract::Request,
7 http::{HeaderMap, StatusCode},
8 middleware::Next,
9 response::Response,
10};
11
12const REQUIRED_SECURITY_HEADERS: &[&str] = &["user-agent", "accept"];
14
15const SUSPICIOUS_PATTERNS: &[&str] = &[
16 "<script",
17 "javascript:",
18 "onload=",
19 "onerror=",
20 "eval(",
21 "alert(",
22];
23
24pub async fn validate_request_middleware(
26 request: Request,
27 next: Next,
28) -> Result<Response, StatusCode> {
29 let headers = request.headers();
30
31 validate_security_headers(headers)?;
33
34 if let Some(content_length) = headers.get("content-length")
36 && let Ok(length_str) = content_length.to_str()
37 && let Ok(length) = length_str.parse::<usize>()
38 && length > 10_000_000
39 {
40 return Err(StatusCode::PAYLOAD_TOO_LARGE);
42 }
43
44 for (name, value) in headers.iter() {
46 if let Ok(value_str) = value.to_str()
47 && contains_suspicious_content(value_str)
48 {
49 tracing::warn!(
50 "Suspicious content detected in header {}: {}",
51 name,
52 value_str
53 );
54 return Err(StatusCode::BAD_REQUEST);
55 }
56 }
57
58 Ok(next.run(request).await)
59}
60
61fn validate_security_headers(headers: &HeaderMap) -> Result<(), StatusCode> {
63 let missing_headers: Vec<&str> = REQUIRED_SECURITY_HEADERS
64 .iter()
65 .filter(|&&header| !headers.contains_key(header))
66 .copied()
67 .collect();
68
69 if !missing_headers.is_empty() {
70 tracing::warn!("Missing required headers: {:?}", missing_headers);
71 return Err(StatusCode::BAD_REQUEST);
72 }
73
74 Ok(())
75}
76
77fn contains_suspicious_content(content: &str) -> bool {
79 let content_lower = content.to_lowercase();
80 SUSPICIOUS_PATTERNS
81 .iter()
82 .any(|&pattern| content_lower.contains(pattern))
83}
84
85pub struct IpRateLimiter {
87 requests: std::sync::Mutex<std::collections::HashMap<String, (u32, std::time::Instant)>>,
88 max_requests: u32,
89 window_duration: std::time::Duration,
90}
91
92impl IpRateLimiter {
93 pub fn new(max_requests: u32, window_minutes: u64) -> Self {
94 Self {
95 requests: std::sync::Mutex::new(std::collections::HashMap::new()),
96 max_requests,
97 window_duration: std::time::Duration::from_secs(window_minutes * 60),
98 }
99 }
100
101 pub fn check_rate_limit(&self, ip: &str) -> Result<(), StatusCode> {
102 let mut requests = self.requests.lock().unwrap();
103 let now = std::time::Instant::now();
104
105 requests.retain(|_, (_, timestamp)| now.duration_since(*timestamp) < self.window_duration);
107
108 match requests.get_mut(ip) {
110 Some((count, timestamp)) => {
111 if now.duration_since(*timestamp) < self.window_duration {
112 if *count >= self.max_requests {
113 return Err(StatusCode::TOO_MANY_REQUESTS);
114 }
115 *count += 1;
116 } else {
117 *count = 1;
118 *timestamp = now;
119 }
120 }
121 None => {
122 requests.insert(ip.to_string(), (1, now));
123 }
124 }
125
126 Ok(())
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_suspicious_content_detection() {
136 assert!(contains_suspicious_content("<script>alert('xss')</script>"));
137 assert!(contains_suspicious_content("javascript:void(0)"));
138 assert!(contains_suspicious_content("onload=malicious()"));
139 assert!(!contains_suspicious_content("normal content"));
140 assert!(!contains_suspicious_content("user@example.com"));
141 }
142
143 #[test]
144 fn test_rate_limiter() {
145 let limiter = IpRateLimiter::new(5, 1);
146
147 for _ in 0..5 {
149 assert!(limiter.check_rate_limit("192.168.1.1").is_ok());
150 }
151
152 assert!(limiter.check_rate_limit("192.168.1.1").is_err());
154 }
155}
156
157