use crate::api::{ApiResponse, ApiState};
use crate::distributed::rate_limiting::RateLimitResult;
use axum::{
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::time::{Duration, Instant};
fn sanitize_header_for_log(value: &str) -> String {
value
.chars()
.filter(|c| !c.is_control() || *c == ' ')
.take(200)
.collect()
}
pub async fn rate_limit_middleware_with_state(
state: ApiState,
request: Request,
next: Next,
) -> Result<Response, Response> {
let client_key = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(str::trim)
.and_then(|s| s.parse::<std::net::IpAddr>().ok().map(|ip| ip.to_string()))
.or_else(|| {
request
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.and_then(|s| s.parse::<std::net::IpAddr>().ok().map(|ip| ip.to_string()))
})
.unwrap_or_else(|| {
tracing::warn!(
"Rate limiter: no identifiable client IP from X-Forwarded-For or X-Real-IP; \
falling back to shared 'unidentified' bucket"
);
"unidentified".to_string()
});
match state.rate_limiter.check_rate_limit(&client_key).await {
Ok(RateLimitResult::Allowed {
remaining,
reset_at,
}) => {
let mut response = next.run(request).await;
let headers = response.headers_mut();
let reset_secs = reset_at
.checked_duration_since(Instant::now())
.unwrap_or(Duration::ZERO)
.as_secs();
if let Ok(v) = remaining.to_string().parse() {
headers.insert("X-RateLimit-Remaining", v);
}
if let Ok(v) = reset_secs.to_string().parse() {
headers.insert("X-RateLimit-Reset", v);
}
Ok(response)
}
Ok(RateLimitResult::Denied { retry_after, .. }) => {
tracing::warn!(client = %client_key, "Rate limit exceeded");
let mut response = ApiResponse::<()>::error(
"RATE_LIMIT_EXCEEDED",
"Too many requests — please retry after the indicated delay",
)
.into_response();
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
let headers = response.headers_mut();
if let Ok(v) = retry_after.as_secs().to_string().parse() {
headers.insert("Retry-After", v);
}
Err(response)
}
Ok(RateLimitResult::Blocked { unblock_at, reason }) => {
tracing::warn!(client = %client_key, reason = %reason, "Client is blocked");
let mut response = ApiResponse::<()>::error(
"CLIENT_BLOCKED",
"Access temporarily blocked due to repeated rate limit violations",
)
.into_response();
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
let unblock_secs = unblock_at
.checked_duration_since(Instant::now())
.unwrap_or(Duration::ZERO)
.as_secs();
if let Ok(v) = unblock_secs.to_string().parse() {
let headers = response.headers_mut();
headers.insert("Retry-After", v);
}
Err(response)
}
Err(e) => {
tracing::error!(error = %e, "Rate limiter error — rejecting request");
let mut response = ApiResponse::<()>::error(
"RATE_LIMIT_UNAVAILABLE",
"Rate limiting is temporarily unavailable; request rejected to protect the service",
)
.into_response();
*response.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
Err(response)
}
}
}
pub async fn logging_middleware(request: Request, next: Next) -> Response {
let start = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let headers = request.headers().clone();
let user_agent = headers
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
let user_agent = sanitize_header_for_log(user_agent);
let forwarded_for = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
let forwarded_for = sanitize_header_for_log(forwarded_for);
tracing::info!(
"Request started: {} {} from {} ({})",
method,
uri,
forwarded_for,
user_agent
);
let response = next.run(request).await;
let duration = start.elapsed();
let status = response.status();
tracing::info!(
"Request completed: {} {} {} in {:?}",
method,
uri,
status,
duration
);
response
}
pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
let response = next.run(request).await;
let mut response = response;
let headers = response.headers_mut();
headers.insert(
"X-Content-Type-Options",
axum::http::HeaderValue::from_static("nosniff"),
);
headers.insert(
"X-Frame-Options",
axum::http::HeaderValue::from_static("DENY"),
);
headers.insert(
"X-XSS-Protection",
axum::http::HeaderValue::from_static("1; mode=block"),
);
headers.insert(
"Strict-Transport-Security",
axum::http::HeaderValue::from_static("max-age=31536000; includeSubDomains"),
);
headers.insert(
"Referrer-Policy",
axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
);
headers.insert(
"Permissions-Policy",
axum::http::HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
);
headers.insert(
"Content-Security-Policy",
axum::http::HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
);
response
}
pub async fn timeout_middleware(request: Request, next: Next) -> Result<Response, Response> {
match tokio::time::timeout(Duration::from_secs(30), next.run(request)).await {
Ok(response) => Ok(response),
Err(_) => {
let error_response =
ApiResponse::<()>::error("REQUEST_TIMEOUT", "Request timed out after 30 seconds");
Err(error_response.into_response())
}
}
}
pub fn check_permission(auth_token: &crate::tokens::AuthToken, required_permission: &str) -> bool {
auth_token.permissions.iter().any(|perm| {
perm == required_permission
|| perm == "*"
|| (perm.ends_with("*") && required_permission.starts_with(&perm[..perm.len() - 1]))
})
}
pub fn check_role(auth_token: &crate::tokens::AuthToken, required_role: &str) -> bool {
auth_token.roles.contains(&required_role.to_string())
|| auth_token.roles.contains(&"admin".to_string()) }
#[cfg(test)]
mod tests {
use super::*;
use crate::tokens::{AuthToken, TokenMetadata};
fn make_token(permissions: Vec<&str>, roles: Vec<&str>) -> AuthToken {
AuthToken {
token_id: "tid".into(),
user_id: "uid".into(),
access_token: "at".into(),
token_type: Some("Bearer".into()),
subject: Some("uid".into()),
issuer: Some("iss".into()),
refresh_token: None,
issued_at: chrono::Utc::now(),
expires_at: chrono::Utc::now(),
scopes: vec![].into(),
auth_method: "jwt".into(),
client_id: None,
user_profile: None,
permissions: permissions
.into_iter()
.map(String::from)
.collect::<Vec<_>>()
.into(),
roles: roles
.into_iter()
.map(String::from)
.collect::<Vec<_>>()
.into(),
metadata: TokenMetadata::default(),
}
}
#[test]
fn test_check_permission_exact_match() {
let token = make_token(vec!["users:read"], vec![]);
assert!(check_permission(&token, "users:read"));
}
#[test]
fn test_check_permission_no_match() {
let token = make_token(vec!["users:read"], vec![]);
assert!(!check_permission(&token, "users:write"));
}
#[test]
fn test_check_permission_wildcard_all() {
let token = make_token(vec!["*"], vec![]);
assert!(check_permission(&token, "anything:at:all"));
}
#[test]
fn test_check_permission_wildcard_prefix() {
let token = make_token(vec!["users:*"], vec![]);
assert!(check_permission(&token, "users:read"));
assert!(check_permission(&token, "users:write"));
assert!(!check_permission(&token, "admin:read"));
}
#[test]
fn test_check_permission_empty() {
let token = make_token(vec![], vec![]);
assert!(!check_permission(&token, "anything"));
}
#[test]
fn test_check_role_exact_match() {
let token = make_token(vec![], vec!["editor"]);
assert!(check_role(&token, "editor"));
}
#[test]
fn test_check_role_no_match() {
let token = make_token(vec![], vec!["editor"]);
assert!(!check_role(&token, "moderator"));
}
#[test]
fn test_check_role_admin_has_all_roles() {
let token = make_token(vec![], vec!["admin"]);
assert!(check_role(&token, "anything"));
assert!(check_role(&token, "editor"));
}
#[test]
fn test_check_role_empty() {
let token = make_token(vec![], vec![]);
assert!(!check_role(&token, "user"));
}
}