use std::sync::Arc;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use crate::config::error_sanitization::ErrorSanitizer;
const MAX_ERROR_BODY_BYTES: usize = 256 * 1024;
pub async fn error_sanitization_middleware(
State(sanitizer): State<Arc<ErrorSanitizer>>,
request: Request<Body>,
next: Next,
) -> Response {
if !sanitizer.is_enabled() {
return next.run(request).await;
}
let response = next.run(request).await;
if response.status().is_success() || response.status().is_redirection() {
return response;
}
let is_json = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|ct| ct.contains("application/json"));
if !is_json {
return response;
}
let (parts, body) = response.into_parts();
let body_bytes = match axum::body::to_bytes(body, MAX_ERROR_BODY_BYTES).await {
Ok(bytes) => bytes,
Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(),
};
let sanitized_bytes = match serde_json::from_slice::<serde_json::Value>(&body_bytes) {
Ok(mut json) => {
sanitize_json_error(&sanitizer, &mut json);
match serde_json::to_vec(&json) {
Ok(bytes) => bytes,
Err(_) => body_bytes.to_vec(),
}
},
Err(_) => body_bytes.to_vec(),
};
let body_len = sanitized_bytes.len();
let mut response = Response::from_parts(parts, Body::from(sanitized_bytes));
response.headers_mut().insert(header::CONTENT_LENGTH, body_len.into());
response
}
pub(crate) fn sanitize_json_error(sanitizer: &ErrorSanitizer, json: &mut serde_json::Value) {
if let Some(errors) = json.get_mut("errors").and_then(|e| e.as_array_mut()) {
for error in errors {
sanitize_single_error(sanitizer, error);
}
}
if json.get("errors").is_none() {
sanitize_single_error(sanitizer, json);
}
}
fn sanitize_single_error(sanitizer: &ErrorSanitizer, error: &mut serde_json::Value) {
let code = error.get("code").and_then(|c| c.as_str()).unwrap_or("");
let is_internal = matches!(code, "INTERNAL_SERVER_ERROR" | "DATABASE_ERROR");
if is_internal {
if let Some(message) = error.get("message").and_then(|m| m.as_str()) {
let code_enum = if code == "DATABASE_ERROR" {
crate::error::ErrorCode::DatabaseError
} else {
crate::error::ErrorCode::InternalServerError
};
let temp = crate::error::GraphQLError::new(message, code_enum);
let sanitized = sanitizer.sanitize(temp);
error["message"] = serde_json::Value::String(sanitized.message);
}
}
if let Some(extensions) = error.get_mut("extensions") {
if let Some(obj) = extensions.as_object_mut() {
obj.remove("detail");
}
}
}