Skip to main content

modkit/api/
error_layer.rs

1//! Centralized error mapping for Axum
2//!
3//! This module provides utilities for automatically converting all framework
4//! and module errors into consistent RFC 9457 Problem+JSON responses, eliminating
5//! per-route boilerplate.
6
7use axum::{extract::Request, http::HeaderMap, middleware::Next, response::Response};
8use http::StatusCode;
9use std::any::Any;
10
11use crate::api::problem::Problem;
12use crate::config::ConfigError;
13use modkit_odata::Error as ODataError;
14
15/// Middleware function that provides centralized error mapping
16///
17/// This middleware can be applied to routes to automatically extract request context
18/// and provide it to error handlers. The actual error conversion happens in the
19/// `IntoProblem` trait implementations and `map_error_to_problem` function.
20pub async fn error_mapping_middleware(request: Request, next: Next) -> Response {
21    let _uri = request.uri().clone();
22    let _headers = request.headers().clone();
23
24    let response = next.run(request).await;
25
26    // If the response is already successful or is already a Problem response, pass it through
27    if response.status().is_success() || is_problem_response(&response) {
28        return response;
29    }
30
31    // For error responses, the actual error conversion should happen in the handlers
32    // using the IntoProblem trait or map_error_to_problem function
33    // This middleware provides the infrastructure for extracting request context
34    response
35}
36
37/// Check if a response is already a Problem+JSON response
38fn is_problem_response(response: &Response) -> bool {
39    response
40        .headers()
41        .get(axum::http::header::CONTENT_TYPE)
42        .and_then(|v| v.to_str().ok())
43        .is_some_and(|ct| ct.contains("application/problem+json"))
44}
45
46/// Extract trace ID from headers or generate one
47pub fn extract_trace_id(headers: &HeaderMap) -> Option<String> {
48    // Try to get trace ID from various common headers
49    headers
50        .get("x-trace-id")
51        .or_else(|| headers.get("x-request-id"))
52        .or_else(|| headers.get("traceparent"))
53        .and_then(|v| v.to_str().ok())
54        .map(ToString::to_string)
55        .or_else(|| {
56            // Try to get from current tracing span
57            tracing::Span::current()
58                .id()
59                .map(|id| id.into_u64().to_string())
60        })
61}
62
63/// Centralized error mapping function
64///
65/// This function provides a single place to convert all framework and module errors
66/// into consistent Problem responses with proper trace IDs and instance paths.
67pub fn map_error_to_problem(error: &dyn Any, instance: &str, trace_id: Option<String>) -> Problem {
68    // Try to downcast to known error types
69    if let Some(odata_err) = error.downcast_ref::<ODataError>() {
70        return crate::api::odata::error::odata_error_to_problem(odata_err, instance, trace_id);
71    }
72
73    if let Some(config_err) = error.downcast_ref::<ConfigError>() {
74        let mut problem = match config_err {
75            ConfigError::ModuleNotFound { module } => Problem::new(
76                StatusCode::INTERNAL_SERVER_ERROR,
77                "Configuration Error",
78                format!("Module '{module}' configuration not found"),
79            )
80            .with_code("CONFIG_MODULE_NOT_FOUND")
81            .with_type("https://errors.example.com/CONFIG_MODULE_NOT_FOUND"),
82
83            ConfigError::InvalidModuleStructure { module } => Problem::new(
84                StatusCode::INTERNAL_SERVER_ERROR,
85                "Configuration Error",
86                format!("Module '{module}' has invalid configuration structure"),
87            )
88            .with_code("CONFIG_INVALID_STRUCTURE")
89            .with_type("https://errors.example.com/CONFIG_INVALID_STRUCTURE"),
90
91            ConfigError::MissingConfigSection { module } => Problem::new(
92                StatusCode::INTERNAL_SERVER_ERROR,
93                "Configuration Error",
94                format!("Module '{module}' is missing required config section"),
95            )
96            .with_code("CONFIG_MISSING_SECTION")
97            .with_type("https://errors.example.com/CONFIG_MISSING_SECTION"),
98
99            ConfigError::InvalidConfig { module, .. } => Problem::new(
100                StatusCode::INTERNAL_SERVER_ERROR,
101                "Configuration Error",
102                format!("Module '{module}' has invalid configuration"),
103            )
104            .with_code("CONFIG_INVALID")
105            .with_type("https://errors.example.com/CONFIG_INVALID"),
106
107            ConfigError::VarExpand { module, source } => {
108                tracing::error!(
109                    module = %module,
110                    error = %source,
111                    "Environment variable expansion failed in module config"
112                );
113                Problem::new(
114                    StatusCode::INTERNAL_SERVER_ERROR,
115                    "Configuration Error",
116                    format!("Module '{module}' has invalid environment-backed configuration"),
117                )
118                .with_code("CONFIG_ENV_EXPAND")
119                .with_type("https://errors.example.com/CONFIG_ENV_EXPAND")
120            }
121        };
122
123        problem = problem.with_instance(instance);
124        if let Some(tid) = trace_id {
125            problem = problem.with_trace_id(tid);
126        }
127        return problem;
128    }
129
130    // Handle anyhow::Error
131    if let Some(anyhow_err) = error.downcast_ref::<anyhow::Error>() {
132        let mut problem = Problem::new(
133            StatusCode::INTERNAL_SERVER_ERROR,
134            "Internal Server Error",
135            "An internal error occurred",
136        )
137        .with_code("INTERNAL_ERROR")
138        .with_type("https://errors.example.com/INTERNAL_ERROR");
139
140        problem = problem.with_instance(instance);
141        if let Some(tid) = trace_id {
142            problem = problem.with_trace_id(tid);
143        }
144
145        // Log the full error for debugging
146        tracing::error!(error = %anyhow_err, "Internal server error");
147        return problem;
148    }
149
150    // Fallback for unknown error types
151    let mut problem = Problem::new(
152        StatusCode::INTERNAL_SERVER_ERROR,
153        "Unknown Error",
154        "An unknown error occurred",
155    )
156    .with_code("UNKNOWN_ERROR")
157    .with_type("https://errors.example.com/UNKNOWN_ERROR");
158
159    problem = problem.with_instance(instance);
160    if let Some(tid) = trace_id {
161        problem = problem.with_trace_id(tid);
162    }
163
164    tracing::error!("Unknown error type in error mapping layer");
165    problem
166}
167
168/// Helper trait for converting errors to Problem responses with context
169pub trait IntoProblem {
170    fn into_problem(self, instance: &str, trace_id: Option<String>) -> Problem;
171}
172
173impl IntoProblem for ODataError {
174    fn into_problem(self, instance: &str, trace_id: Option<String>) -> Problem {
175        crate::api::odata::error::odata_error_to_problem(&self, instance, trace_id)
176    }
177}
178
179impl IntoProblem for ConfigError {
180    fn into_problem(self, instance: &str, trace_id: Option<String>) -> Problem {
181        map_error_to_problem(&self as &dyn Any, instance, trace_id)
182    }
183}
184
185impl IntoProblem for anyhow::Error {
186    fn into_problem(self, instance: &str, trace_id: Option<String>) -> Problem {
187        map_error_to_problem(&self as &dyn Any, instance, trace_id)
188    }
189}
190
191#[cfg(test)]
192#[cfg_attr(coverage_nightly, coverage(off))]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_odata_error_mapping() {
198        let error = ODataError::InvalidFilter("malformed".to_owned());
199        let problem = error.into_problem("/tests/v1/test", Some("trace123".to_owned()));
200
201        assert_eq!(problem.status, StatusCode::UNPROCESSABLE_ENTITY);
202        assert!(problem.code.contains("invalid_filter"));
203        assert_eq!(problem.instance, "/tests/v1/test");
204        assert_eq!(problem.trace_id, Some("trace123".to_owned()));
205    }
206
207    #[test]
208    fn test_config_error_mapping() {
209        let error = ConfigError::ModuleNotFound {
210            module: "test_module".to_owned(),
211        };
212        let problem = error.into_problem("/tests/v1/test", None);
213
214        assert_eq!(problem.status, StatusCode::INTERNAL_SERVER_ERROR);
215        assert_eq!(problem.code, "CONFIG_MODULE_NOT_FOUND");
216        assert_eq!(problem.instance, "/tests/v1/test");
217        assert!(problem.detail.contains("test_module"));
218    }
219
220    #[test]
221    fn test_anyhow_error_mapping() {
222        let error = anyhow::anyhow!("Something went wrong");
223        let problem = error.into_problem("/tests/v1/test", Some("trace456".to_owned()));
224
225        assert_eq!(problem.status, StatusCode::INTERNAL_SERVER_ERROR);
226        assert_eq!(problem.code, "INTERNAL_ERROR");
227        assert_eq!(problem.instance, "/tests/v1/test");
228        assert_eq!(problem.trace_id, Some("trace456".to_owned()));
229    }
230
231    #[test]
232    fn test_config_var_expand_error_sanitizes_detail() {
233        let source = modkit_utils::var_expand::ExpandVarsError::Var {
234            name: "SECRET_API_KEY".to_owned(),
235            source: std::env::VarError::NotPresent,
236        };
237        let error = ConfigError::VarExpand {
238            module: "my_mod".to_owned(),
239            source,
240        };
241        let problem = error.into_problem("/tests/v1/test", Some("trace789".to_owned()));
242
243        assert_eq!(problem.status, StatusCode::INTERNAL_SERVER_ERROR);
244        assert_eq!(problem.code, "CONFIG_ENV_EXPAND");
245        assert_eq!(
246            problem.type_url,
247            "https://errors.example.com/CONFIG_ENV_EXPAND"
248        );
249        assert_eq!(problem.instance, "/tests/v1/test");
250        assert_eq!(problem.trace_id, Some("trace789".to_owned()));
251
252        // Detail MUST NOT leak the env var name or the underlying error message.
253        assert!(
254            !problem.detail.contains("SECRET_API_KEY"),
255            "detail must not contain env var name, got: {}",
256            problem.detail,
257        );
258        assert!(
259            !problem.detail.contains("not present"),
260            "detail must not contain source error text, got: {}",
261            problem.detail,
262        );
263        // It should still mention the module name (non-sensitive).
264        assert!(problem.detail.contains("my_mod"));
265    }
266
267    #[test]
268    fn test_extract_trace_id_from_headers() {
269        let mut headers = HeaderMap::new();
270        headers.insert("x-trace-id", "test-trace-123".parse().unwrap());
271
272        let trace_id = extract_trace_id(&headers);
273        assert_eq!(trace_id, Some("test-trace-123".to_owned()));
274    }
275}