use std::sync::Arc;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use crate::config::error_sanitization::ErrorSanitizer;
const MAX_ERROR_BODY_BYTES: usize = 256 * 1024;
pub async fn error_sanitization_middleware(
State(sanitizer): State<Arc<ErrorSanitizer>>,
request: Request<Body>,
next: Next,
) -> Response {
if !sanitizer.is_enabled() {
return next.run(request).await;
}
let response = next.run(request).await;
if response.status().is_success() || response.status().is_redirection() {
return response;
}
let is_json = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|ct| ct.contains("application/json"));
if !is_json {
return response;
}
let (parts, body) = response.into_parts();
let body_bytes = match axum::body::to_bytes(body, MAX_ERROR_BODY_BYTES).await {
Ok(bytes) => bytes,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let sanitized_bytes = match serde_json::from_slice::<serde_json::Value>(&body_bytes) {
Ok(mut json) => {
sanitize_json_error(&sanitizer, &mut json);
match serde_json::to_vec(&json) {
Ok(bytes) => bytes,
Err(_) => body_bytes.to_vec(),
}
},
Err(_) => body_bytes.to_vec(),
};
let body_len = sanitized_bytes.len();
let mut response = Response::from_parts(parts, Body::from(sanitized_bytes));
response.headers_mut().insert(header::CONTENT_LENGTH, body_len.into());
response
}
fn sanitize_json_error(sanitizer: &ErrorSanitizer, json: &mut serde_json::Value) {
if let Some(errors) = json.get_mut("errors").and_then(|e| e.as_array_mut()) {
for error in errors {
sanitize_single_error(sanitizer, error);
}
}
if json.get("errors").is_none() {
sanitize_single_error(sanitizer, json);
}
}
fn sanitize_single_error(sanitizer: &ErrorSanitizer, error: &mut serde_json::Value) {
let code = error.get("code").and_then(|c| c.as_str()).unwrap_or("");
let is_internal = matches!(code, "INTERNAL_SERVER_ERROR" | "DATABASE_ERROR");
if is_internal {
if let Some(message) = error.get("message").and_then(|m| m.as_str()) {
let code_enum = if code == "DATABASE_ERROR" {
crate::error::ErrorCode::DatabaseError
} else {
crate::error::ErrorCode::InternalServerError
};
let temp = crate::error::GraphQLError::new(message, code_enum);
let sanitized = sanitizer.sanitize(temp);
error["message"] = serde_json::Value::String(sanitized.message);
}
}
if let Some(extensions) = error.get_mut("extensions") {
if let Some(obj) = extensions.as_object_mut() {
obj.remove("detail");
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)]
use super::*;
use crate::config::error_sanitization::ErrorSanitizationConfig;
fn test_sanitizer() -> ErrorSanitizer {
ErrorSanitizer::new(ErrorSanitizationConfig {
enabled: true,
hide_implementation_details: true,
sanitize_database_errors: true,
custom_error_message: None,
})
}
#[test]
fn test_sanitize_graphql_db_error() {
let sanitizer = test_sanitizer();
let mut json = serde_json::json!({
"errors": [{
"message": "ERROR: relation \"tb_users\" does not exist",
"code": "DATABASE_ERROR",
"extensions": {
"detail": "at line 42 in query.rs"
}
}]
});
sanitize_json_error(&sanitizer, &mut json);
let error = &json["errors"][0];
assert_eq!(error["message"], "An internal error occurred");
assert!(error["extensions"].get("detail").is_none());
}
#[test]
fn test_sanitize_preserves_validation_error() {
let sanitizer = test_sanitizer();
let mut json = serde_json::json!({
"errors": [{
"message": "Field 'email' is required",
"code": "VALIDATION_ERROR"
}]
});
sanitize_json_error(&sanitizer, &mut json);
assert_eq!(json["errors"][0]["message"], "Field 'email' is required");
}
#[test]
fn test_sanitize_rest_internal_error() {
let sanitizer = test_sanitizer();
let mut json = serde_json::json!({
"message": "connection refused: postgres://user:pass@host/db",
"code": "INTERNAL_SERVER_ERROR",
"extensions": {
"detail": "panic at src/db.rs:123"
}
});
sanitize_json_error(&sanitizer, &mut json);
assert_eq!(json["message"], "An internal error occurred");
assert!(json["extensions"].get("detail").is_none());
}
#[test]
fn test_disabled_sanitizer_passes_through() {
let sanitizer = ErrorSanitizer::disabled();
let mut json = serde_json::json!({
"errors": [{
"message": "ERROR: relation \"tb_users\" does not exist",
"code": "DATABASE_ERROR"
}]
});
sanitize_json_error(&sanitizer, &mut json);
assert_eq!(json["errors"][0]["message"], "ERROR: relation \"tb_users\" does not exist");
}
}