Skip to main content

ohttp_gateway/middleware/
security.rs

1use axum::{
2    body::Body,
3    extract::{ConnectInfo, Request, State},
4    http::{HeaderMap, StatusCode, header},
5    middleware::Next,
6    response::{IntoResponse, Response},
7};
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Instant;
12use tokio::sync::Mutex;
13use tracing::{info, warn};
14use uuid::Uuid;
15
16use crate::{config::RateLimitConfig, state::AppState};
17
18/// Rate limiter implementation
19pub struct RateLimiter {
20    config: RateLimitConfig,
21    buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
22}
23
24struct TokenBucket {
25    tokens: f64,
26    last_update: Instant,
27}
28
29impl RateLimiter {
30    pub fn new(config: RateLimitConfig) -> Self {
31        Self {
32            config,
33            buckets: Arc::new(Mutex::new(HashMap::new())),
34        }
35    }
36
37    pub async fn check_rate_limit(&self, key: &str) -> bool {
38        let mut buckets = self.buckets.lock().await;
39        let now = Instant::now();
40
41        let bucket = buckets
42            .entry(key.to_string())
43            .or_insert_with(|| TokenBucket {
44                tokens: self.config.burst_size as f64,
45                last_update: now,
46            });
47
48        // Calculate tokens to add based on time elapsed
49        let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
50        let tokens_to_add = elapsed * (self.config.requests_per_second as f64);
51
52        bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.burst_size as f64);
53        bucket.last_update = now;
54
55        // Check if we have tokens available
56        if bucket.tokens >= 1.0 {
57            bucket.tokens -= 1.0;
58            true
59        } else {
60            false
61        }
62    }
63}
64
65/// Security middleware that adds various security headers and checks
66pub async fn security_middleware(
67    State(state): State<AppState>,
68    ConnectInfo(addr): ConnectInfo<SocketAddr>,
69    request: Request<Body>,
70    next: Next,
71) -> Result<Response, StatusCode> {
72    // Generate request ID for tracing
73    let request_id = Uuid::new_v4();
74
75    // Add security headers to the request context
76    let mut request = request;
77    request
78        .headers_mut()
79        .insert("x-request-id", request_id.to_string().parse().unwrap());
80
81    let is_https = matches!(request.uri().scheme_str(), Some("https"));
82
83    // Apply rate limiting if configured
84    if let Some(rate_limit_config) = &state.config.rate_limit {
85        let rate_limiter = RateLimiter::new(rate_limit_config.clone());
86
87        let rate_limit_key = if rate_limit_config.by_ip {
88            addr.ip().to_string()
89        } else {
90            "global".to_string()
91        };
92
93        if !rate_limiter.check_rate_limit(&rate_limit_key).await {
94            warn!(
95                "Rate limit exceeded for key: {}, request_id: {}",
96                rate_limit_key, request_id
97            );
98
99            return Ok((
100                StatusCode::TOO_MANY_REQUESTS,
101                [
102                    (
103                        "X-RateLimit-Limit",
104                        rate_limit_config.requests_per_second.to_string(),
105                    ),
106                    ("X-RateLimit-Remaining", "0".to_string()),
107                    ("Retry-After", "1".to_string()),
108                ],
109                "Rate limit exceeded",
110            )
111                .into_response());
112        }
113    }
114
115    // Process the request
116    let mut response = next.run(request).await;
117
118    // Add security headers to the response
119    let headers = response.headers_mut();
120
121    // Security headers
122    headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
123    headers.insert("X-Frame-Options", "DENY".parse().unwrap());
124    headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
125    headers.insert("Referrer-Policy", "no-referrer".parse().unwrap());
126    headers.insert("X-Request-ID", request_id.to_string().parse().unwrap());
127
128    // HSTS header for HTTPS connections
129    if is_https {
130        headers.insert(
131            "Strict-Transport-Security",
132            "max-age=31536000; includeSubDomains".parse().unwrap(),
133        );
134    }
135
136    // Content Security Policy
137    headers.insert(
138        "Content-Security-Policy",
139        "default-src 'none'; frame-ancestors 'none';"
140            .parse()
141            .unwrap(),
142    );
143
144    // Remove sensitive headers
145    headers.remove("Server");
146    headers.remove("X-Powered-By");
147
148    Ok(response)
149}
150
151/// Middleware to validate and sanitize incoming requests
152pub async fn request_validation_middleware(
153    headers: HeaderMap,
154    request: Request<Body>,
155    next: Next,
156) -> Result<Response, StatusCode> {
157    // Check for required headers only on requests with bodies
158    if matches!(
159        request.method(),
160        &axum::http::Method::POST | &axum::http::Method::PUT | &axum::http::Method::PATCH
161    ) && !headers.contains_key(header::CONTENT_TYPE)
162    {
163        return Err(StatusCode::BAD_REQUEST);
164    }
165
166    // Validate User-Agent
167    if let Some(user_agent) = headers.get(header::USER_AGENT)
168        && let Ok(ua_str) = user_agent.to_str()
169    {
170        // Block known bad user agents
171        if ua_str.is_empty() || ua_str.contains("bot") || ua_str.contains("crawler") {
172            info!("Blocked suspicious user agent: {}", ua_str);
173            return Err(StatusCode::FORBIDDEN);
174        }
175    }
176
177    Ok(next.run(request).await)
178}