use crate::{AuditEvent, AuditLogger, AuditSeverity, AuditStatus};
use armature_core::{Error, HttpRequest, HttpResponse, Middleware};
use std::sync::Arc;
use std::time::Instant;
pub struct AuditMiddleware {
logger: Arc<AuditLogger>,
log_request_body: bool,
log_response_body: bool,
max_body_size: usize,
}
impl AuditMiddleware {
pub fn new(logger: Arc<AuditLogger>) -> Self {
Self {
logger,
log_request_body: true,
log_response_body: true,
max_body_size: 10_000, }
}
pub fn log_request_body(mut self, log: bool) -> Self {
self.log_request_body = log;
self
}
pub fn log_response_body(mut self, log: bool) -> Self {
self.log_response_body = log;
self
}
pub fn max_body_size(mut self, size: usize) -> Self {
self.max_body_size = size;
self
}
fn extract_user_id(&self, request: &HttpRequest) -> Option<String> {
request.headers.get("authorization").and_then(|auth| {
if auth.starts_with("Bearer ") {
Some("authenticated_user".to_string())
} else {
None
}
})
}
fn extract_ip(&self, request: &HttpRequest) -> Option<String> {
if let Some(forwarded) = request.headers.get("x-forwarded-for") {
return Some(
forwarded
.split(',')
.next()
.unwrap_or(forwarded)
.trim()
.to_string(),
);
}
if let Some(real_ip) = request.headers.get("x-real-ip") {
return Some(real_ip.clone());
}
None
}
fn extract_user_agent(&self, request: &HttpRequest) -> Option<String> {
request.headers.get("user-agent").cloned()
}
fn truncate_body(&self, body: &[u8]) -> Option<String> {
if body.is_empty() {
return None;
}
if body.len() > self.max_body_size {
let truncated = &body[..self.max_body_size];
let mut text = String::from_utf8_lossy(truncated).to_string();
text.push_str("... [TRUNCATED]");
Some(text)
} else {
Some(String::from_utf8_lossy(body).to_string())
}
}
}
#[async_trait::async_trait]
impl Middleware for AuditMiddleware {
async fn handle(
&self,
request: HttpRequest,
next: armature_core::middleware::Next,
) -> Result<HttpResponse, Error> {
let start = Instant::now();
let method = request.method.clone();
let path = request.path.clone();
let user_id = self.extract_user_id(&request);
let ip_address = self.extract_ip(&request);
let user_agent = self.extract_user_agent(&request);
let request_body = if self.log_request_body {
self.truncate_body(&request.body)
} else {
None
};
let result = next(request).await;
let duration_ms = start.elapsed().as_millis() as u64;
let event = match &result {
Ok(response) => {
let status_code = response.status;
let response_body = if self.log_response_body {
self.truncate_body(&response.body)
} else {
None
};
let status = if status_code < 400 {
AuditStatus::Success
} else if status_code < 500 {
AuditStatus::Denied
} else {
AuditStatus::Error
};
let severity = if status_code < 400 {
AuditSeverity::Info
} else if status_code < 500 {
AuditSeverity::Warning
} else {
AuditSeverity::Error
};
let mut event = AuditEvent::new("http.request")
.action("http_request")
.method(method)
.path(path)
.status_code(status_code)
.status(status)
.severity(severity)
.duration_ms(duration_ms);
if let Some(user) = user_id {
event = event.user(user);
}
if let Some(ip) = ip_address {
event = event.ip(ip);
}
if let Some(ua) = user_agent {
event = event.user_agent(ua);
}
if let Some(body) = request_body {
event = event.request_body(body);
}
if let Some(body) = response_body {
event = event.response_body(body);
}
event
}
Err(err) => {
let status_code = err.status_code();
let mut event = AuditEvent::new("http.request")
.action("http_request")
.method(method)
.path(path)
.status_code(status_code)
.status(AuditStatus::Error)
.severity(AuditSeverity::Error)
.duration_ms(duration_ms)
.error(err.to_string());
if let Some(user) = user_id {
event = event.user(user);
}
if let Some(ip) = ip_address {
event = event.ip(ip);
}
if let Some(ua) = user_agent {
event = event.user_agent(ua);
}
if let Some(body) = request_body {
event = event.request_body(body);
}
event
}
};
if let Err(e) = self.logger.log(event).await {
tracing::error!("Failed to log audit event: {}", e);
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{AuditLogger, MemoryBackend};
#[test]
fn test_audit_middleware_creation() {
let logger = Arc::new(AuditLogger::builder().backend(MemoryBackend::new()).build());
let middleware = AuditMiddleware::new(logger);
assert!(middleware.log_request_body);
assert!(middleware.log_response_body);
}
#[test]
fn test_audit_middleware_configuration() {
let logger = Arc::new(AuditLogger::builder().backend(MemoryBackend::new()).build());
let middleware = AuditMiddleware::new(logger)
.log_request_body(false)
.log_response_body(false)
.max_body_size(5000);
assert!(!middleware.log_request_body);
assert!(!middleware.log_response_body);
assert_eq!(middleware.max_body_size, 5000);
}
#[test]
fn test_truncate_body() {
let logger = Arc::new(AuditLogger::builder().backend(MemoryBackend::new()).build());
let middleware = AuditMiddleware::new(logger).max_body_size(10);
let body = b"This is a very long body that should be truncated";
let truncated = middleware.truncate_body(body).unwrap();
assert!(truncated.len() <= 30); assert!(truncated.contains("[TRUNCATED]"));
}
}