use crate::auth::{
permissions::PermissionChecker,
policy_engine::{AuthorizationContext, UnifiedPolicyEngine},
types::{Permission, User},
};
use crate::security_audit::{
AuditEventType, AuditLogEntry, AuditResult, SecurityAuditManager, Severity,
};
use axum::{
extract::Request,
http::{header, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, info, warn, Span};
use uuid::Uuid;
pub async fn security_audit_middleware(
security_auditor: Arc<SecurityAuditManager>,
request: Request,
next: Next,
) -> Response {
let method = request.method().clone();
let uri = request.uri().clone();
let path = uri.path().to_string();
let ip_address = request
.headers()
.get("x-forwarded-for")
.and_then(|h| h.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).trim().to_string())
.or_else(|| {
request
.headers()
.get("x-real-ip")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "unknown".to_string());
let response = next.run(request).await;
let status = response.status();
let event_type = match method {
Method::GET | Method::HEAD | Method::OPTIONS => AuditEventType::DataAccess,
Method::POST | Method::PUT | Method::PATCH | Method::DELETE => {
AuditEventType::DataModification
}
_ => AuditEventType::SecurityEvent,
};
let severity = if status.is_success() {
Severity::Info
} else if status.is_client_error() {
Severity::Low
} else if status.is_server_error() {
Severity::Medium
} else {
Severity::Info
};
let result = if status.is_success() {
AuditResult::Success
} else if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
AuditResult::Denied
} else if status.is_client_error() || status.is_server_error() {
AuditResult::Failure
} else {
AuditResult::Success
};
let entry = AuditLogEntry {
timestamp: chrono::Utc::now(),
event_type,
severity,
user: None, ip_address: Some(ip_address),
resource: path,
action: method.to_string(),
result,
details: Some(format!("Status: {}", status.as_u16())),
};
let auditor = security_auditor.clone();
tokio::spawn(async move {
if let Err(e) = auditor.log_event(entry).await {
warn!("Failed to log security audit event: {}", e);
}
});
response
}
pub async fn security_headers(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
header::HeaderName::from_static("x-frame-options"),
HeaderValue::from_static("DENY"),
);
headers.insert(
header::HeaderName::from_static("x-content-type-options"),
HeaderValue::from_static("nosniff"),
);
headers.insert(
header::HeaderName::from_static("x-xss-protection"),
HeaderValue::from_static("1; mode=block"),
);
headers.insert(
header::REFERRER_POLICY,
HeaderValue::from_static("strict-origin-when-cross-origin"),
);
headers.insert(
header::HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("geolocation=(), microphone=(), camera=()"),
);
headers.insert(
header::HeaderName::from_static("content-security-policy"),
HeaderValue::from_static(
"default-src 'self'; \
script-src 'self' 'unsafe-inline'; \
style-src 'self' 'unsafe-inline'; \
img-src 'self' data: https:; \
font-src 'self' data:; \
connect-src 'self'; \
frame-ancestors 'none'; \
base-uri 'self'; \
form-action 'self'",
),
);
response
}
pub async fn https_security_headers(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
headers.insert(
header::STRICT_TRANSPORT_SECURITY,
HeaderValue::from_static("max-age=31536000; includeSubDomains; preload"),
);
response
}
pub async fn request_correlation_id(mut request: Request, next: Next) -> Response {
let correlation_id = request
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| Uuid::new_v4().to_string());
Span::current().record("request_id", &correlation_id);
request
.extensions_mut()
.insert(CorrelationId(correlation_id.clone()));
debug!(correlation_id = %correlation_id, "Request received");
let mut response = next.run(request).await;
response.headers_mut().insert(
header::HeaderName::from_static("x-request-id"),
HeaderValue::from_str(&correlation_id)
.unwrap_or_else(|_| HeaderValue::from_static("invalid")),
);
response
}
#[derive(Clone, Debug)]
pub struct CorrelationId(pub String);
#[derive(Clone, Debug)]
pub struct AuthenticatedUser(pub Arc<User>);
pub async fn rbac_check(
permission: Permission,
) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
+ Clone {
move |request: Request, next: Next| {
let required_permission = permission.clone();
Box::pin(async move {
let user = request.extensions().get::<AuthenticatedUser>().cloned();
match user {
Some(AuthenticatedUser(user_arc)) => {
let user_ref = &*user_arc;
if PermissionChecker::has_permission(user_ref, &required_permission) {
debug!(
user = %user_ref.username,
permission = ?required_permission,
"Permission granted"
);
next.run(request).await
} else {
warn!(
user = %user_ref.username,
permission = ?required_permission,
"Permission denied"
);
(
StatusCode::FORBIDDEN,
format!(
"Access denied: User '{}' does not have required permission: {:?}",
user_ref.username, required_permission
),
)
.into_response()
}
}
None => {
warn!(
permission = ?required_permission,
"Authentication required but no user present"
);
(StatusCode::UNAUTHORIZED, "Authentication required").into_response()
}
}
})
}
}
pub async fn route_based_rbac(request: Request, next: Next) -> Response {
let path = request.uri().path();
let public_endpoints = [
"/health",
"/health/live",
"/health/ready",
"/metrics", ];
if public_endpoints.contains(&path) {
return next.run(request).await;
}
let user = request.extensions().get::<AuthenticatedUser>().cloned();
let user_arc = match user {
Some(AuthenticatedUser(user)) => user,
None => {
debug!(path = %path, "No authentication present, allowing request");
return next.run(request).await;
}
};
let method = request.method();
let required_permission = match (method, path) {
(_, "/sparql") if method == Method::GET || method == Method::POST => {
Some(Permission::QueryExecute)
}
(_, "/update") if method == Method::POST => Some(Permission::UpdateExecute),
(_, p) if p.starts_with("/graph") || p == "/data" => match *method {
Method::GET | Method::HEAD => Some(Permission::Read),
Method::PUT | Method::POST | Method::DELETE => Some(Permission::GraphStore),
_ => Some(Permission::Read),
},
(_, "/upload") if method == Method::POST => Some(Permission::Upload),
(_, "/shacl") if method == Method::POST => Some(Permission::QueryExecute),
(_, "/patch") if method == Method::POST => Some(Permission::Write),
(_, p) if p.starts_with("/$/datasets") => match *method {
Method::GET => Some(Permission::Read),
Method::POST => Some(Permission::DatasetCreate),
Method::DELETE => Some(Permission::DatasetDelete),
Method::PUT => Some(Permission::DatasetManage),
_ => Some(Permission::Admin),
},
(_, p) if p.starts_with("/$/admin") => Some(Permission::Admin),
(_, p) if p.starts_with("/$/stats") || p.starts_with("/$/logs") => {
Some(Permission::Monitor)
}
(_, p) if p.starts_with("/$/tasks") => match *method {
Method::GET => Some(Permission::Monitor),
_ => Some(Permission::Admin),
},
(_, p) if p.starts_with("/$/federation") => Some(Permission::FederationManage),
(_, p) if p.starts_with("/$/cluster") => Some(Permission::ClusterManage),
(_, p) if p.starts_with("/$/users") => Some(Permission::UserManage),
(_, p) if p.starts_with("/$/config") => Some(Permission::SystemConfig),
(_, "/$/backup") if method == Method::POST => Some(Permission::Backup),
(_, "/$/restore") if method == Method::POST => Some(Permission::Restore),
_ => Some(Permission::Read),
};
if let Some(permission) = required_permission {
let user_ref = &*user_arc;
if PermissionChecker::has_permission(user_ref, &permission) {
debug!(
user = %user_ref.username,
path = %path,
method = %method,
permission = ?permission,
"RBAC check passed"
);
next.run(request).await
} else {
warn!(
user = %user_ref.username,
path = %path,
method = %method,
permission = ?permission,
"RBAC check failed - permission denied"
);
(
StatusCode::FORBIDDEN,
format!(
"Access denied: User '{}' does not have required permission {:?} for {} {}",
user_ref.username, permission, method, path
),
)
.into_response()
}
} else {
next.run(request).await
}
}
pub async fn rebac_middleware(
axum::extract::State(policy_engine): axum::extract::State<Arc<UnifiedPolicyEngine>>,
request: Request,
next: Next,
) -> Response {
let path = request.uri().path();
let public_endpoints = ["/health", "/health/live", "/health/ready", "/metrics"];
if public_endpoints.contains(&path) {
return next.run(request).await;
}
let user = match request.extensions().get::<AuthenticatedUser>().cloned() {
Some(AuthenticatedUser(user)) => user,
None => {
return next.run(request).await;
}
};
let (action, resource) = extract_action_and_resource(&request);
let context = AuthorizationContext::new((*user).clone(), action.clone(), resource.clone());
match policy_engine.authorize(&context).await {
Ok(response) if response.allowed => {
debug!(
user = %user.username,
action = %action,
resource = %resource,
"ReBAC authorization granted"
);
next.run(request).await
}
Ok(response) => {
warn!(
user = %user.username,
action = %action,
resource = %resource,
reason = ?response.reason,
"ReBAC authorization denied"
);
(
StatusCode::FORBIDDEN,
format!(
"Access denied: {}",
response
.reason
.unwrap_or_else(|| "Insufficient permissions".to_string())
),
)
.into_response()
}
Err(e) => {
warn!(
user = %user.username,
action = %action,
resource = %resource,
error = %e,
"ReBAC authorization error"
);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Authorization error: {}", e),
)
.into_response()
}
}
}
fn extract_action_and_resource(request: &Request) -> (String, String) {
let method = request.method();
let path = request.uri().path();
let query = request.uri().query();
let dataset = if let Some(ds) = path.strip_prefix("/dataset/") {
let ds_name = ds.split('/').next().unwrap_or("default");
ds_name.to_string()
} else {
"default".to_string()
};
if let Some(query_str) = query {
if let Some(graph_uri) = extract_graph_from_query(query_str) {
let action = match method {
&Method::GET | &Method::HEAD => "can_read",
&Method::POST | &Method::PUT | &Method::DELETE => "can_write",
_ => "can_read",
};
return (action.to_string(), format!("graph:{}", graph_uri));
}
}
let action = match (method, path) {
(&Method::GET, _) | (&Method::HEAD, _) => "can_read",
(&Method::POST, p) if p.contains("/sparql") || p.contains("/query") => "can_execute_query",
(&Method::POST, p) if p.contains("/update") => "can_execute_update",
(&Method::POST, _) | (&Method::PUT, _) | (&Method::PATCH, _) | (&Method::DELETE, _) => {
"can_write"
}
_ => "can_read",
};
(action.to_string(), format!("dataset:{}", dataset))
}
fn extract_graph_from_query(query: &str) -> Option<String> {
for pair in query.split('&') {
if let Some((key, value)) = pair.split_once('=') {
if key == "graph" || key == "default" {
return Some(urlencoding::decode(value).ok()?.into_owned());
}
}
}
None
}
pub async fn request_timing(request: Request, next: Next) -> Response {
let start = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let response = next.run(request).await;
let duration = start.elapsed();
let duration_ms = duration.as_millis();
if duration_ms > 1000 {
info!(
method = %method,
uri = %uri,
duration_ms = %duration_ms,
"Slow request detected"
);
}
let mut response = response;
if let Ok(duration_value) = HeaderValue::from_str(&duration_ms.to_string()) {
response.headers_mut().insert(
header::HeaderName::from_static("x-response-time"),
duration_value,
);
}
debug!(
method = %method,
uri = %uri,
duration_ms = %duration_ms,
status = %response.status(),
"Request completed"
);
response
}
pub async fn health_check_bypass(request: Request, next: Next) -> Response {
let path = request.uri().path();
let health_endpoints = ["/health", "/health/live", "/health/ready", "/metrics"];
if health_endpoints.contains(&path) {
return next.run(request).await;
}
next.run(request).await
}
pub async fn request_size_limit(request: Request, next: Next, max_size_bytes: usize) -> Response {
if let Some(content_length) = request
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<usize>().ok())
{
if content_length > max_size_bytes {
return (
StatusCode::PAYLOAD_TOO_LARGE,
format!(
"Request body too large: {} bytes (max: {})",
content_length, max_size_bytes
),
)
.into_response();
}
}
next.run(request).await
}
pub async fn api_version(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
response.headers_mut().insert(
header::HeaderName::from_static("x-api-version"),
HeaderValue::from_static(env!("CARGO_PKG_VERSION")),
);
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, routing::get, Router};
use tower::ServiceExt;
#[tokio::test]
async fn test_security_headers() {
let app = Router::new()
.route("/", get(|| async { "Hello" }))
.layer(axum::middleware::from_fn(security_headers));
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert!(response.headers().contains_key("x-frame-options"));
assert!(response.headers().contains_key("x-content-type-options"));
assert!(response.headers().contains_key("x-xss-protection"));
assert!(response.headers().contains_key("referrer-policy"));
assert!(response.headers().contains_key("content-security-policy"));
}
#[tokio::test]
async fn test_correlation_id_generation() {
let app = Router::new()
.route("/", get(|| async { "Hello" }))
.layer(axum::middleware::from_fn(request_correlation_id));
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let correlation_id = response.headers().get("x-request-id");
assert!(correlation_id.is_some());
let id_str = correlation_id.unwrap().to_str().unwrap();
assert!(Uuid::parse_str(id_str).is_ok());
}
#[tokio::test]
async fn test_api_version() {
let app = Router::new()
.route("/", get(|| async { "Hello" }))
.layer(axum::middleware::from_fn(api_version));
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let version = response.headers().get("x-api-version");
assert!(version.is_some());
assert_eq!(version.unwrap(), env!("CARGO_PKG_VERSION"));
}
}