ohttp_gateway/middleware/
security.rs1use axum::{
2 body::Body,
3 extract::{ConnectInfo, Request, State},
4 http::{HeaderMap, StatusCode, header},
5 middleware::Next,
6 response::{IntoResponse, Response},
7};
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Instant;
12use tokio::sync::Mutex;
13use tracing::{info, warn};
14use uuid::Uuid;
15
16use crate::{config::RateLimitConfig, state::AppState};
17
18pub struct RateLimiter {
20 config: RateLimitConfig,
21 buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
22}
23
24struct TokenBucket {
25 tokens: f64,
26 last_update: Instant,
27}
28
29impl RateLimiter {
30 pub fn new(config: RateLimitConfig) -> Self {
31 Self {
32 config,
33 buckets: Arc::new(Mutex::new(HashMap::new())),
34 }
35 }
36
37 pub async fn check_rate_limit(&self, key: &str) -> bool {
38 let mut buckets = self.buckets.lock().await;
39 let now = Instant::now();
40
41 let bucket = buckets
42 .entry(key.to_string())
43 .or_insert_with(|| TokenBucket {
44 tokens: self.config.burst_size as f64,
45 last_update: now,
46 });
47
48 let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
50 let tokens_to_add = elapsed * (self.config.requests_per_second as f64);
51
52 bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.burst_size as f64);
53 bucket.last_update = now;
54
55 if bucket.tokens >= 1.0 {
57 bucket.tokens -= 1.0;
58 true
59 } else {
60 false
61 }
62 }
63}
64
65pub async fn security_middleware(
67 State(state): State<AppState>,
68 ConnectInfo(addr): ConnectInfo<SocketAddr>,
69 request: Request<Body>,
70 next: Next,
71) -> Result<Response, StatusCode> {
72 let request_id = Uuid::new_v4();
74
75 let mut request = request;
77 request
78 .headers_mut()
79 .insert("x-request-id", request_id.to_string().parse().unwrap());
80
81 let is_https = matches!(request.uri().scheme_str(), Some("https"));
82
83 if let Some(rate_limit_config) = &state.config.rate_limit {
85 let rate_limiter = RateLimiter::new(rate_limit_config.clone());
86
87 let rate_limit_key = if rate_limit_config.by_ip {
88 addr.ip().to_string()
89 } else {
90 "global".to_string()
91 };
92
93 if !rate_limiter.check_rate_limit(&rate_limit_key).await {
94 warn!(
95 "Rate limit exceeded for key: {}, request_id: {}",
96 rate_limit_key, request_id
97 );
98
99 return Ok((
100 StatusCode::TOO_MANY_REQUESTS,
101 [
102 (
103 "X-RateLimit-Limit",
104 rate_limit_config.requests_per_second.to_string(),
105 ),
106 ("X-RateLimit-Remaining", "0".to_string()),
107 ("Retry-After", "1".to_string()),
108 ],
109 "Rate limit exceeded",
110 )
111 .into_response());
112 }
113 }
114
115 let mut response = next.run(request).await;
117
118 let headers = response.headers_mut();
120
121 headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
123 headers.insert("X-Frame-Options", "DENY".parse().unwrap());
124 headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
125 headers.insert("Referrer-Policy", "no-referrer".parse().unwrap());
126 headers.insert("X-Request-ID", request_id.to_string().parse().unwrap());
127
128 if is_https {
130 headers.insert(
131 "Strict-Transport-Security",
132 "max-age=31536000; includeSubDomains".parse().unwrap(),
133 );
134 }
135
136 headers.insert(
138 "Content-Security-Policy",
139 "default-src 'none'; frame-ancestors 'none';"
140 .parse()
141 .unwrap(),
142 );
143
144 headers.remove("Server");
146 headers.remove("X-Powered-By");
147
148 Ok(response)
149}
150
151pub async fn request_validation_middleware(
153 headers: HeaderMap,
154 request: Request<Body>,
155 next: Next,
156) -> Result<Response, StatusCode> {
157 if matches!(
159 request.method(),
160 &axum::http::Method::POST | &axum::http::Method::PUT | &axum::http::Method::PATCH
161 ) && !headers.contains_key(header::CONTENT_TYPE)
162 {
163 return Err(StatusCode::BAD_REQUEST);
164 }
165
166 if let Some(user_agent) = headers.get(header::USER_AGENT)
168 && let Ok(ua_str) = user_agent.to_str()
169 {
170 if ua_str.is_empty() || ua_str.contains("bot") || ua_str.contains("crawler") {
172 info!("Blocked suspicious user agent: {}", ua_str);
173 return Err(StatusCode::FORBIDDEN);
174 }
175 }
176
177 Ok(next.run(request).await)
178}