use axum::body::Body;
use axum::http::{Request, Response, StatusCode};
use axum::middleware::Next;
use mockforge_core::security::{
emit_security_event, EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
};
use tracing::debug;
pub async fn security_middleware(req: Request<Body>, next: Next) -> Response<Body> {
let path = req.uri().path().to_string();
let method = req.method().clone();
let ip_address = req
.headers()
.get("x-forwarded-for")
.or_else(|| req.headers().get("x-real-ip"))
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.or_else(|| {
req.extensions()
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
.map(|addr| addr.ip().to_string())
});
let user_agent = req
.headers()
.get("user-agent")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let user_id: Option<String> = req
.extensions()
.get::<axum::extract::Extension<crate::auth::types::AuthClaims>>()
.and_then(|claims| claims.sub.clone());
let response = next.run(req).await;
let status = response.status();
let is_success = status.is_success();
let is_client_error = status.is_client_error();
let is_server_error = status.is_server_error();
if is_success {
let event = SecurityEvent::new(SecurityEventType::AuthzAccessGranted, None, None)
.with_actor(EventActor {
user_id: user_id.clone(),
username: user_id.clone(),
ip_address: ip_address.clone(),
user_agent: user_agent.clone(),
})
.with_target(EventTarget {
resource_type: Some("api".to_string()),
resource_id: Some(path.clone()),
method: Some(method.to_string()),
})
.with_outcome(EventOutcome {
success: true,
reason: None,
})
.with_metadata("status_code".to_string(), serde_json::json!(status.as_u16()));
emit_security_event(event).await;
} else if is_client_error && status == StatusCode::FORBIDDEN {
let event = SecurityEvent::new(SecurityEventType::AuthzAccessDenied, None, None)
.with_actor(EventActor {
user_id: user_id.clone(),
username: user_id.clone(),
ip_address: ip_address.clone(),
user_agent: user_agent.clone(),
})
.with_target(EventTarget {
resource_type: Some("api".to_string()),
resource_id: Some(path.clone()),
method: Some(method.to_string()),
})
.with_outcome(EventOutcome {
success: false,
reason: Some(format!("Access denied: {}", status)),
})
.with_metadata("status_code".to_string(), serde_json::json!(status.as_u16()));
emit_security_event(event).await;
} else if is_server_error {
debug!("Server error detected: {} for {}", status, path);
}
response
}