use crate::api::{ApiResponse, ApiState};
use crate::tokens::AuthToken;
use axum::{
extract::{Request, State},
middleware::Next,
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use tracing::{debug, info, warn};
pub async fn rbac_middleware(
State(state): State<ApiState>,
request: Request,
next: Next,
) -> Result<Response, Response> {
let path = request.uri().path();
if is_public_endpoint(path) {
return Ok(next.run(request).await);
}
let auth_token = match request.extensions().get::<AuthToken>() {
Some(token) => token.clone(),
None => {
let error_response = ApiResponse::<()>::unauthorized();
return Err(error_response.into_response());
}
};
let context = build_request_context(&request, &auth_token);
let authorized = match check_authorization(&state, &auth_token, &request, &context).await {
Ok(granted) => granted,
Err(e) => {
warn!("Authorization check failed: {}", e);
let error_response = ApiResponse::<()>::forbidden();
return Err(error_response.into_response());
}
};
if authorized {
debug!(
"Authorization granted for user '{}' on {}",
auth_token.user_id, path
);
Ok(next.run(request).await)
} else {
info!(
"Authorization denied for user '{}' on {}",
auth_token.user_id, path
);
let error_response = ApiResponse::<()>::forbidden();
Err(error_response.into_response())
}
}
pub async fn conditional_permission_middleware(
State(state): State<ApiState>,
request: Request,
next: Next,
) -> Result<Response, Response> {
let path = request.uri().path();
if is_sensitive_endpoint(path) {
let auth_token = match request.extensions().get::<AuthToken>() {
Some(token) => token,
None => {
let error_response = ApiResponse::<()>::unauthorized();
return Err(error_response.into_response());
}
};
let context = build_conditional_context(&request);
let has_conditional_access: bool = state
.authorization_service
.check_permission(&auth_token.user_id, "access", path, Some(&context))
.await
.unwrap_or_default();
if !has_conditional_access {
info!(
"Conditional access denied for user '{}' on {}",
auth_token.user_id, path
);
let error_response = ApiResponse::<()>::error(
"CONDITIONAL_ACCESS_DENIED",
"Access denied due to conditional permissions (time, location, etc.)",
);
return Err(error_response.into_response());
}
}
Ok(next.run(request).await)
}
pub async fn role_elevation_middleware(
State(state): State<ApiState>,
request: Request,
next: Next,
) -> Result<Response, Response> {
let path = request.uri().path();
if requires_role_elevation(path) {
let auth_token = match request.extensions().get::<AuthToken>() {
Some(token) => token,
None => {
let error_response = ApiResponse::<()>::unauthorized();
return Err(error_response.into_response());
}
};
let has_elevated_access: bool = state
.authorization_service
.check_permission(&auth_token.user_id, "elevated", "admin", None)
.await
.unwrap_or_default();
if !has_elevated_access {
info!(
"Elevated access required for user '{}' on {}",
auth_token.user_id, path
);
let error_response = ApiResponse::<()>::error(
"ELEVATION_REQUIRED",
"This action requires elevated permissions. Please request temporary role elevation.",
);
return Err(error_response.into_response());
}
}
Ok(next.run(request).await)
}
async fn check_authorization(
state: &ApiState,
auth_token: &AuthToken,
request: &Request,
context: &HashMap<String, String>,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
let method = request.method().as_str();
let path = request.uri().path();
state
.authorization_service
.check_api_permission(&auth_token.user_id, method, path, context)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
fn build_request_context(request: &Request, auth_token: &AuthToken) -> HashMap<String, String> {
let mut context = HashMap::new();
context.insert("user_id".to_string(), auth_token.user_id.clone());
context.insert("roles".to_string(), auth_token.roles.join(","));
if let Some(user_agent) = request.headers().get("user-agent")
&& let Ok(ua_str) = user_agent.to_str()
{
context.insert("user_agent".to_string(), ua_str.to_string());
}
if let Some(forwarded_for) = request.headers().get("x-forwarded-for")
&& let Ok(ip_str) = forwarded_for.to_str()
{
context.insert("ip_address".to_string(), ip_str.to_string());
}
let current_hour = chrono::Utc::now().format("%H").to_string();
let hour: u32 = current_hour.parse().unwrap_or(0);
if (9..=17).contains(&hour) {
context.insert("time".to_string(), "business_hours".to_string());
} else {
context.insert("time".to_string(), "after_hours".to_string());
}
let day_of_week = chrono::Utc::now().format("%u").to_string(); let weekday: u32 = day_of_week.parse().unwrap_or(1);
if (1..=5).contains(&weekday) {
context.insert("day_type".to_string(), "weekday".to_string());
} else {
context.insert("day_type".to_string(), "weekend".to_string());
}
context
}
fn build_conditional_context(request: &Request) -> HashMap<String, String> {
let mut context = HashMap::new();
if let Some(via) = request.headers().get("via")
&& let Ok(via_str) = via.to_str()
&& (via_str.contains("vpn") || via_str.contains("proxy"))
{
context.insert("connection_type".to_string(), "vpn".to_string());
}
if let Some(user_agent) = request.headers().get("user-agent")
&& let Ok(ua_str) = user_agent.to_str()
{
if ua_str.contains("Mobile") || ua_str.contains("Android") || ua_str.contains("iPhone") {
context.insert("device_type".to_string(), "mobile".to_string());
} else {
context.insert("device_type".to_string(), "desktop".to_string());
}
}
let path = request.uri().path();
if path.contains("/admin/") {
context.insert("security_level".to_string(), "high".to_string());
} else if path.contains("/api/") {
context.insert("security_level".to_string(), "medium".to_string());
} else {
context.insert("security_level".to_string(), "low".to_string());
}
context
}
fn is_public_endpoint(path: &str) -> bool {
match path {
"/health" | "/health/detailed" | "/metrics" | "/readiness" | "/liveness" => true,
"/auth/login" | "/auth/refresh" | "/auth/providers" => true,
"/oauth/authorize" | "/oauth/token" | "/oauth/.well-known/openid_configuration" => true,
_ if path.starts_with("/oauth/.well-known/") => true,
_ => false,
}
}
fn is_sensitive_endpoint(path: &str) -> bool {
match path {
_ if path.starts_with("/admin/") => true,
_ if path.contains("/secrets/") => true,
_ if path.contains("/config/") => true,
_ if path.contains("/keys/") => true,
"/auth/logout" => true, _ => false,
}
}
fn requires_role_elevation(path: &str) -> bool {
match path {
_ if path.starts_with("/admin/users/delete") => true,
_ if path.starts_with("/admin/system/") => true,
_ if path.contains("/sudo/") => true,
_ if path.contains("/elevate/") => true,
_ => false,
}
}
pub fn require_permission(
action: &str,
resource: &str,
) -> impl Fn(
Request,
Next,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Response>> + Send>>
+ Clone {
let action = action.to_string();
let resource = resource.to_string();
move |request: Request, next: Next| {
let action = action.clone();
let resource = resource.clone();
Box::pin(async move {
let auth_token = match request.extensions().get::<AuthToken>() {
Some(token) => token,
None => {
let error_response = ApiResponse::<()>::unauthorized();
return Err(error_response.into_response());
}
};
if check_token_permission(auth_token, &action, &resource) {
Ok(next.run(request).await)
} else {
let error_response = ApiResponse::<()>::forbidden();
Err(error_response.into_response())
}
})
}
}
fn check_token_permission(auth_token: &AuthToken, action: &str, resource: &str) -> bool {
if auth_token.roles.contains(&"admin".to_string()) {
return true;
}
let required_permission = format!("{}:{}", action, resource);
auth_token.permissions.iter().any(|perm| {
perm == &required_permission
|| perm == "*"
|| (perm.ends_with("*") && required_permission.starts_with(&perm[..perm.len() - 1]))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_public_endpoint_detection() {
assert!(is_public_endpoint("/health"));
assert!(is_public_endpoint("/auth/login"));
assert!(is_public_endpoint(
"/oauth/.well-known/openid_configuration"
));
assert!(!is_public_endpoint("/api/users"));
assert!(!is_public_endpoint("/admin/roles"));
}
#[test]
fn test_sensitive_endpoint_detection() {
assert!(is_sensitive_endpoint("/admin/users"));
assert!(is_sensitive_endpoint("/api/secrets/vault"));
assert!(is_sensitive_endpoint("/auth/logout"));
assert!(!is_sensitive_endpoint("/api/health"));
assert!(!is_sensitive_endpoint("/public/info"));
}
#[test]
fn test_elevation_requirement() {
assert!(requires_role_elevation("/admin/users/delete/123"));
assert!(requires_role_elevation("/admin/system/shutdown"));
assert!(requires_role_elevation("/api/sudo/execute"));
assert!(!requires_role_elevation("/admin/users"));
assert!(!requires_role_elevation("/api/profile"));
}
#[test]
fn test_context_building() {
let context: HashMap<String, String> = HashMap::new();
assert!(context.is_empty()); }
}