use crate::http_server_security::{
AuthValidator, InputValidator, RateLimiter, RequestSizeLimiter, SecurityLogger,
};
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
use std::net::IpAddr;
use std::sync::Arc;
pub async fn rate_limit_middleware(request: Request, next: Next) -> Response {
let rate_limiter = request
.extensions()
.get::<Arc<RateLimiter>>()
.cloned()
.unwrap_or_else(|| Arc::new(RateLimiter::new(100, 60)));
let ip = extract_ip(&request);
match rate_limiter.check_rate_limit(ip).await {
Ok(_) => next.run(request).await,
Err(status) => {
SecurityLogger::log_rate_limit(&ip.to_string());
Response::builder()
.status(status)
.body(axum::body::Body::from("Too many requests"))
.unwrap()
}
}
}
pub async fn request_size_middleware(request: Request, next: Next) -> Response {
let size_limiter = request
.extensions()
.get::<Arc<RequestSizeLimiter>>()
.cloned()
.unwrap_or_else(|| Arc::new(RequestSizeLimiter::default()));
let headers = request.headers();
let uri = request.uri();
let url_length = uri.to_string().len();
let body_size = headers
.get("Content-Length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);
match size_limiter.validate_request(headers, body_size, url_length) {
Ok(_) => next.run(request).await,
Err(status) => {
let ip = extract_ip(&request);
SecurityLogger::log_event(
"REQUEST_SIZE_LIMIT",
&format!("Body: {} bytes, URL: {} chars", body_size, url_length),
Some(&ip.to_string()),
);
Response::builder()
.status(status)
.body(axum::body::Body::from("Request too large"))
.unwrap()
}
}
}
pub async fn auth_middleware(request: Request, next: Next) -> Response {
let headers = request.headers();
if let Some(token) = AuthValidator::extract_token(headers) {
let validator = AuthValidator::default();
match validator.validate_api_key(&token) {
Ok(_claims) => {
next.run(request).await
}
Err(e) => {
let ip = extract_ip(&request);
SecurityLogger::log_auth_failure(&ip.to_string(), &e);
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::from("Unauthorized"))
.unwrap()
}
}
} else {
let ip = extract_ip(&request);
SecurityLogger::log_auth_failure(&ip.to_string(), "Missing Authorization header");
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::from("Unauthorized"))
.unwrap()
}
}
pub async fn input_validation_middleware(request: Request, next: Next) -> Response {
let headers = request.headers();
let uri = request.uri();
let ip = extract_ip(&request);
if let Some(query) = uri.query() {
for param in query.split('&') {
if let Some((key, value)) = param.split_once('=') {
if let Err(e) = InputValidator::validate_string(key, 100) {
SecurityLogger::log_invalid_input(
&ip.to_string(),
&format!("Query key: {}", e),
);
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(axum::body::Body::from("Invalid query parameter"))
.unwrap();
}
if let Err(e) = InputValidator::validate_string(value, 1000) {
SecurityLogger::log_invalid_input(
&ip.to_string(),
&format!("Query value: {}", e),
);
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(axum::body::Body::from("Invalid query parameter"))
.unwrap();
}
}
}
}
for (name, value) in headers.iter() {
if let Ok(value_str) = value.to_str() {
if let Err(e) = InputValidator::validate_string(value_str, 1000) {
SecurityLogger::log_invalid_input(
&ip.to_string(),
&format!("Header {}: {}", name, e),
);
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(axum::body::Body::from("Invalid header"))
.unwrap();
}
}
}
next.run(request).await
}
fn extract_ip(request: &Request) -> IpAddr {
if let Some(forwarded) = request.headers().get("X-Forwarded-For") {
if let Ok(forwarded_str) = forwarded.to_str() {
if let Some(ip_str) = forwarded_str.split(',').next() {
if let Ok(ip) = ip_str.trim().parse::<IpAddr>() {
return ip;
}
}
}
}
if let Some(real_ip) = request.headers().get("X-Real-IP") {
if let Ok(ip_str) = real_ip.to_str() {
if let Ok(ip) = ip_str.parse::<IpAddr>() {
return ip;
}
}
}
"127.0.0.1".parse().unwrap()
}