use axum::{extract::Request, http::HeaderMap, middleware::Next, response::Response};
use http::StatusCode;
use std::any::Any;
use crate::api::problem::Problem;
use crate::config::ConfigError;
use modkit_odata::Error as ODataError;
pub async fn error_mapping_middleware(request: Request, next: Next) -> Response {
let _uri = request.uri().clone();
let _headers = request.headers().clone();
let response = next.run(request).await;
if response.status().is_success() || is_problem_response(&response) {
return response;
}
response
}
fn is_problem_response(response: &Response) -> bool {
response
.headers()
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|ct| ct.contains("application/problem+json"))
}
pub fn extract_trace_id(headers: &HeaderMap) -> Option<String> {
headers
.get("x-trace-id")
.or_else(|| headers.get("x-request-id"))
.or_else(|| headers.get("traceparent"))
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
.or_else(|| {
tracing::Span::current()
.id()
.map(|id| id.into_u64().to_string())
})
}
pub fn map_error_to_problem(error: &dyn Any, instance: &str, trace_id: Option<String>) -> Problem {
if let Some(odata_err) = error.downcast_ref::<ODataError>() {
return crate::api::odata::error::odata_error_to_problem(odata_err, instance, trace_id);
}
if let Some(config_err) = error.downcast_ref::<ConfigError>() {
let mut problem = match config_err {
ConfigError::ModuleNotFound { module } => Problem::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Configuration Error",
format!("Module '{module}' configuration not found"),
)
.with_code("CONFIG_MODULE_NOT_FOUND")
.with_type("https://errors.example.com/CONFIG_MODULE_NOT_FOUND"),
ConfigError::InvalidModuleStructure { module } => Problem::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Configuration Error",
format!("Module '{module}' has invalid configuration structure"),
)
.with_code("CONFIG_INVALID_STRUCTURE")
.with_type("https://errors.example.com/CONFIG_INVALID_STRUCTURE"),
ConfigError::MissingConfigSection { module } => Problem::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Configuration Error",
format!("Module '{module}' is missing required config section"),
)
.with_code("CONFIG_MISSING_SECTION")
.with_type("https://errors.example.com/CONFIG_MISSING_SECTION"),
ConfigError::InvalidConfig { module, .. } => Problem::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Configuration Error",
format!("Module '{module}' has invalid configuration"),
)
.with_code("CONFIG_INVALID")
.with_type("https://errors.example.com/CONFIG_INVALID"),
ConfigError::VarExpand { module, source } => {
tracing::error!(
module = %module,
error = %source,
"Environment variable expansion failed in module config"
);
Problem::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Configuration Error",
format!("Module '{module}' has invalid environment-backed configuration"),
)
.with_code("CONFIG_ENV_EXPAND")
.with_type("https://errors.example.com/CONFIG_ENV_EXPAND")
}
};
problem = problem.with_instance(instance);
if let Some(tid) = trace_id {
problem = problem.with_trace_id(tid);
}
return problem;
}
if let Some(anyhow_err) = error.downcast_ref::<anyhow::Error>() {
let mut problem = Problem::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal Server Error",
"An internal error occurred",
)
.with_code("INTERNAL_ERROR")
.with_type("https://errors.example.com/INTERNAL_ERROR");
problem = problem.with_instance(instance);
if let Some(tid) = trace_id {
problem = problem.with_trace_id(tid);
}
tracing::error!(error = %anyhow_err, "Internal server error");
return problem;
}
let mut problem = Problem::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Unknown Error",
"An unknown error occurred",
)
.with_code("UNKNOWN_ERROR")
.with_type("https://errors.example.com/UNKNOWN_ERROR");
problem = problem.with_instance(instance);
if let Some(tid) = trace_id {
problem = problem.with_trace_id(tid);
}
tracing::error!("Unknown error type in error mapping layer");
problem
}
pub trait IntoProblem {
fn into_problem(self, instance: &str, trace_id: Option<String>) -> Problem;
}
impl IntoProblem for ODataError {
fn into_problem(self, instance: &str, trace_id: Option<String>) -> Problem {
crate::api::odata::error::odata_error_to_problem(&self, instance, trace_id)
}
}
impl IntoProblem for ConfigError {
fn into_problem(self, instance: &str, trace_id: Option<String>) -> Problem {
map_error_to_problem(&self as &dyn Any, instance, trace_id)
}
}
impl IntoProblem for anyhow::Error {
fn into_problem(self, instance: &str, trace_id: Option<String>) -> Problem {
map_error_to_problem(&self as &dyn Any, instance, trace_id)
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_odata_error_mapping() {
let error = ODataError::InvalidFilter("malformed".to_owned());
let problem = error.into_problem("/tests/v1/test", Some("trace123".to_owned()));
assert_eq!(problem.status, StatusCode::UNPROCESSABLE_ENTITY);
assert!(problem.code.contains("invalid_filter"));
assert_eq!(problem.instance, "/tests/v1/test");
assert_eq!(problem.trace_id, Some("trace123".to_owned()));
}
#[test]
fn test_config_error_mapping() {
let error = ConfigError::ModuleNotFound {
module: "test_module".to_owned(),
};
let problem = error.into_problem("/tests/v1/test", None);
assert_eq!(problem.status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(problem.code, "CONFIG_MODULE_NOT_FOUND");
assert_eq!(problem.instance, "/tests/v1/test");
assert!(problem.detail.contains("test_module"));
}
#[test]
fn test_anyhow_error_mapping() {
let error = anyhow::anyhow!("Something went wrong");
let problem = error.into_problem("/tests/v1/test", Some("trace456".to_owned()));
assert_eq!(problem.status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(problem.code, "INTERNAL_ERROR");
assert_eq!(problem.instance, "/tests/v1/test");
assert_eq!(problem.trace_id, Some("trace456".to_owned()));
}
#[test]
fn test_config_var_expand_error_sanitizes_detail() {
let source = modkit_utils::var_expand::ExpandVarsError::Var {
name: "SECRET_API_KEY".to_owned(),
source: std::env::VarError::NotPresent,
};
let error = ConfigError::VarExpand {
module: "my_mod".to_owned(),
source,
};
let problem = error.into_problem("/tests/v1/test", Some("trace789".to_owned()));
assert_eq!(problem.status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(problem.code, "CONFIG_ENV_EXPAND");
assert_eq!(
problem.type_url,
"https://errors.example.com/CONFIG_ENV_EXPAND"
);
assert_eq!(problem.instance, "/tests/v1/test");
assert_eq!(problem.trace_id, Some("trace789".to_owned()));
assert!(
!problem.detail.contains("SECRET_API_KEY"),
"detail must not contain env var name, got: {}",
problem.detail,
);
assert!(
!problem.detail.contains("not present"),
"detail must not contain source error text, got: {}",
problem.detail,
);
assert!(problem.detail.contains("my_mod"));
}
#[test]
fn test_extract_trace_id_from_headers() {
let mut headers = HeaderMap::new();
headers.insert("x-trace-id", "test-trace-123".parse().unwrap());
let trace_id = extract_trace_id(&headers);
assert_eq!(trace_id, Some("test-trace-123".to_owned()));
}
}