use crate::api::{ApiResponse, ApiState, extract_bearer_token, validate_api_token};
use axum::{
extract::{Request, State},
middleware::Next,
response::{IntoResponse, Response},
};
use std::time::{Duration, Instant};
pub async fn auth_middleware(
State(state): State<ApiState>,
mut request: Request,
next: Next,
) -> Result<Response, Response> {
let path = request.uri().path();
if is_public_endpoint(path) {
return Ok(next.run(request).await);
}
let headers = request.headers();
match extract_bearer_token(headers) {
Some(token) => {
match validate_api_token(&state.auth_framework, &token).await {
Ok(auth_token) => {
request.extensions_mut().insert(auth_token);
Ok(next.run(request).await)
}
Err(_) => {
let error_response = ApiResponse::<()>::unauthorized();
Err(error_response.into_response())
}
}
}
None => {
let error_response = ApiResponse::<()>::unauthorized();
Err(error_response.into_response())
}
}
}
pub async fn admin_middleware(
State(_state): State<ApiState>,
request: Request,
next: Next,
) -> Result<Response, Response> {
match request.extensions().get::<crate::tokens::AuthToken>() {
Some(auth_token) => {
if auth_token.roles.contains(&"admin".to_string()) {
Ok(next.run(request).await)
} else {
let error_response = ApiResponse::<()>::forbidden();
Err(error_response.into_response())
}
}
None => {
let error_response = ApiResponse::<()>::unauthorized();
Err(error_response.into_response())
}
}
}
pub async fn rate_limit_middleware(request: Request, next: Next) -> Result<Response, Response> {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert("X-RateLimit-Limit", "100".parse().unwrap());
headers.insert("X-RateLimit-Remaining", "95".parse().unwrap());
headers.insert("X-RateLimit-Reset", "1692278400".parse().unwrap());
Ok(response)
}
pub async fn cors_middleware(request: Request, next: Next) -> Response {
let response = next.run(request).await;
let mut response = response;
let headers = response.headers_mut();
headers.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
headers.insert(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS".parse().unwrap(),
);
headers.insert(
"Access-Control-Allow-Headers",
"Content-Type, Authorization".parse().unwrap(),
);
headers.insert("Access-Control-Max-Age", "3600".parse().unwrap());
response
}
pub async fn logging_middleware(request: Request, next: Next) -> Response {
let start = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let headers = request.headers().clone();
let user_agent = headers
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
let forwarded_for = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
tracing::info!(
"Request started: {} {} from {} ({})",
method,
uri,
forwarded_for,
user_agent
);
let response = next.run(request).await;
let duration = start.elapsed();
let status = response.status();
tracing::info!(
"Request completed: {} {} {} in {:?}",
method,
uri,
status,
duration
);
response
}
pub async fn security_headers_middleware(request: Request, next: Next) -> Response {
let response = next.run(request).await;
let mut response = response;
let headers = response.headers_mut();
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
headers.insert(
"Strict-Transport-Security",
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
headers.insert(
"Referrer-Policy",
"strict-origin-when-cross-origin".parse().unwrap(),
);
headers.insert(
"Permissions-Policy",
"camera=(), microphone=(), geolocation=()".parse().unwrap(),
);
response
}
pub async fn timeout_middleware(request: Request, next: Next) -> Result<Response, Response> {
match tokio::time::timeout(Duration::from_secs(30), next.run(request)).await {
Ok(response) => Ok(response),
Err(_) => {
let error_response =
ApiResponse::<()>::error("REQUEST_TIMEOUT", "Request timed out after 30 seconds");
Err(error_response.into_response())
}
}
}
fn is_public_endpoint(path: &str) -> bool {
match path {
"/health" | "/health/detailed" | "/metrics" | "/readiness" | "/liveness" => true,
"/auth/login" | "/auth/refresh" | "/auth/providers" => true,
"/oauth/authorize" | "/oauth/token" | "/oauth/.well-known/openid_configuration" => true,
_ if path.starts_with("/oauth/.well-known/") => true,
_ => false,
}
}
pub fn check_permission(auth_token: &crate::tokens::AuthToken, required_permission: &str) -> bool {
auth_token.permissions.iter().any(|perm| {
perm == required_permission
|| perm == "*"
|| (perm.ends_with("*") && required_permission.starts_with(&perm[..perm.len() - 1]))
})
}
pub fn check_role(auth_token: &crate::tokens::AuthToken, required_role: &str) -> bool {
auth_token.roles.contains(&required_role.to_string())
|| auth_token.roles.contains(&"admin".to_string()) }