Skip to main content

modo/middleware/
error_handler.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::response::{IntoResponse, Response};
6use http::request::Parts;
7use tower::{Layer, Service};
8
9use crate::error::render_error_body;
10
11/// Creates an error-handler layer that intercepts responses containing a
12/// [`crate::error::Error`] in their extensions and rewrites them through
13/// the supplied handler function.
14///
15/// Any middleware that stores a `modo::Error` in response extensions
16/// (`Error::into_response()`, `catch_panic`, `csrf`, `rate_limit`, etc.)
17/// will be caught by this layer, giving the application a single place to
18/// control the error response format (JSON, HTML, plain text, etc.).
19///
20/// The handler receives the error and the original request parts (method,
21/// URI, headers, extensions) by value.
22///
23/// # Example
24///
25/// ```rust,no_run
26/// use axum::{Router, routing::get};
27/// use axum::response::IntoResponse;
28///
29/// async fn render_error(err: modo::Error, parts: http::request::Parts) -> axum::response::Response {
30///     (err.status(), err.message().to_string()).into_response()
31/// }
32///
33/// let app: Router = Router::new()
34///     .route("/", get(|| async { "ok" }))
35///     .layer(modo::middleware::error_handler(render_error));
36/// ```
37pub fn error_handler<F, Fut>(handler: F) -> ErrorHandlerLayer<F>
38where
39    F: Fn(crate::error::Error, Parts) -> Fut + Clone + Send + Sync + 'static,
40    Fut: Future<Output = Response> + Send + 'static,
41{
42    ErrorHandlerLayer { handler }
43}
44
45/// Default error responder suitable for passing directly to [`error_handler`].
46///
47/// Produces the same JSON shape as [`crate::Error::into_response`]:
48///
49/// ```json
50/// { "error": { "status": 404, "message": "..." } }
51/// ```
52///
53/// When the error carries a translation key (via
54/// [`Error::localized`](crate::Error::localized) or
55/// [`Error::with_locale_key`](crate::Error::with_locale_key)) **and** the
56/// request has a [`Translator`](crate::i18n::Translator) in its extensions
57/// (typically injected by [`I18nLayer`](crate::i18n::I18nLayer)), the key is
58/// resolved at response-build time and the translated string is used as the
59/// response `message`. Otherwise the error's stored `message` is used
60/// unchanged.
61///
62/// # Layer ordering
63///
64/// When pairing with [`I18nLayer`](crate::i18n::I18nLayer), install `I18nLayer`
65/// **outside** [`error_handler`] (apply `i18n.layer()` *after* `error_handler`
66/// in `.layer(...)` calls) so the [`Translator`](crate::i18n::Translator) is
67/// inserted into the request extensions before `error_handler` clones the
68/// request parts. Reversed ordering silently falls back to the raw key.
69///
70/// # Example
71///
72/// ```rust,no_run
73/// use axum::{Router, routing::get};
74/// use modo::middleware::{default_error_handler, error_handler};
75///
76/// # fn wire(i18n: modo::i18n::I18n) {
77/// let app: Router = Router::new()
78///     .route("/", get(|| async { "ok" }))
79///     .layer(error_handler(default_error_handler))
80///     .layer(i18n.layer());  // outer — must run before error_handler
81/// # }
82/// ```
83pub async fn default_error_handler(err: crate::error::Error, parts: Parts) -> Response {
84    let status = err.status();
85    let details = err.details().cloned();
86
87    let message = match (
88        err.locale_key(),
89        parts.extensions.get::<crate::i18n::Translator>(),
90    ) {
91        (Some(key), Some(tr)) => tr.t(key, &[]),
92        _ => err.message().to_string(),
93    };
94
95    let body = render_error_body(status, &message, details.as_ref());
96    (status, axum::Json(body)).into_response()
97}
98
99/// Tower [`Layer`] produced by [`error_handler`].
100#[derive(Clone)]
101pub struct ErrorHandlerLayer<F> {
102    handler: F,
103}
104
105impl<S, F> Layer<S> for ErrorHandlerLayer<F>
106where
107    F: Clone,
108{
109    type Service = ErrorHandlerService<S, F>;
110
111    fn layer(&self, inner: S) -> Self::Service {
112        ErrorHandlerService {
113            inner,
114            handler: self.handler.clone(),
115        }
116    }
117}
118
119/// Tower [`Service`] that wraps an inner service and rewrites error responses
120/// through a user-provided handler.
121#[derive(Clone)]
122pub struct ErrorHandlerService<S, F> {
123    inner: S,
124    handler: F,
125}
126
127impl<S, F, Fut> Service<http::Request<axum::body::Body>> for ErrorHandlerService<S, F>
128where
129    S: Service<http::Request<axum::body::Body>, Response = Response> + Clone + Send + 'static,
130    S::Future: Send,
131    S::Error: Into<std::convert::Infallible>,
132    F: Fn(crate::error::Error, Parts) -> Fut + Clone + Send + Sync + 'static,
133    Fut: Future<Output = Response> + Send + 'static,
134{
135    type Response = Response;
136    type Error = S::Error;
137    type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send>>;
138
139    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140        self.inner.poll_ready(cx)
141    }
142
143    fn call(&mut self, req: http::Request<axum::body::Body>) -> Self::Future {
144        // Clone parts before consuming the request so the error handler can
145        // inspect method, URI, headers, etc.
146        let (parts, body) = req.into_parts();
147        // NOTE: parts are cloned on every request so the handler callback can
148        // read headers / extensions on the error path. For 2xx responses this
149        // clone is unused — if this becomes a hot-path bottleneck, wrap parts
150        // in an Arc (copy-on-write) or redesign the handler to take only the
151        // extensions slice that default_error_handler actually reads.
152        let saved_parts = parts.clone();
153        let req = http::Request::from_parts(parts, body);
154
155        let handler = self.handler.clone();
156        let future = self.inner.call(req);
157
158        Box::pin(async move {
159            let response = future.await?;
160
161            if let Some(error) = response.extensions().get::<crate::error::Error>() {
162                let error = error.clone();
163                let new_response = handler(error, saved_parts).await;
164                Ok(new_response)
165            } else {
166                Ok(response)
167            }
168        })
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::error::Error as ModoError;
176    use crate::i18n::{I18n, I18nConfig};
177    use axum::body::Body;
178    use axum::{Router, routing::get};
179    use http::{Request, StatusCode};
180    use tower::ServiceExt;
181
182    fn test_i18n(dir: &std::path::Path) -> I18n {
183        let en_dir = dir.join("en");
184        let uk_dir = dir.join("uk");
185        std::fs::create_dir_all(&en_dir).unwrap();
186        std::fs::create_dir_all(&uk_dir).unwrap();
187        std::fs::write(
188            en_dir.join("errors.yaml"),
189            "user:\n  not_found: User not found\n",
190        )
191        .unwrap();
192        std::fs::write(
193            uk_dir.join("errors.yaml"),
194            "user:\n  not_found: Користувача не знайдено\n",
195        )
196        .unwrap();
197
198        let config = I18nConfig {
199            locales_path: dir.to_str().unwrap().to_string(),
200            default_locale: "en".into(),
201            ..I18nConfig::default()
202        };
203        I18n::new(&config).unwrap()
204    }
205
206    async fn decode_json(resp: axum::response::Response) -> serde_json::Value {
207        let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
208            .await
209            .unwrap();
210        serde_json::from_slice(&bytes).unwrap()
211    }
212
213    async fn localized_handler() -> Result<&'static str, ModoError> {
214        Err(ModoError::localized(
215            StatusCode::NOT_FOUND,
216            "errors.user.not_found",
217        ))
218    }
219
220    async fn plain_handler() -> Result<&'static str, ModoError> {
221        Err(ModoError::bad_request("boom"))
222    }
223
224    #[tokio::test]
225    async fn default_handler_uses_translator_when_present() {
226        let dir = tempfile::tempdir().unwrap();
227        let i18n = test_i18n(dir.path());
228
229        let app = Router::new()
230            .route("/", get(localized_handler))
231            .layer(error_handler(default_error_handler))
232            .layer(i18n.layer());
233
234        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
235        let resp = app.oneshot(req).await.unwrap();
236        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
237
238        let body = decode_json(resp).await;
239        assert_eq!(body["error"]["status"], 404);
240        assert_eq!(body["error"]["message"], "User not found");
241    }
242
243    #[tokio::test]
244    async fn default_handler_translates_using_resolved_locale() {
245        let dir = tempfile::tempdir().unwrap();
246        let i18n = test_i18n(dir.path());
247
248        let app = Router::new()
249            .route("/", get(localized_handler))
250            .layer(error_handler(default_error_handler))
251            .layer(i18n.layer());
252
253        let req = Request::builder()
254            .uri("/?lang=uk")
255            .body(Body::empty())
256            .unwrap();
257        let resp = app.oneshot(req).await.unwrap();
258        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
259
260        let body = decode_json(resp).await;
261        assert_eq!(body["error"]["message"], "Користувача не знайдено");
262    }
263
264    #[tokio::test]
265    async fn default_handler_falls_back_to_key_without_translator() {
266        // No I18nLayer is installed, so no Translator exists in the extensions.
267        let app = Router::new()
268            .route("/", get(localized_handler))
269            .layer(error_handler(default_error_handler));
270
271        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
272        let resp = app.oneshot(req).await.unwrap();
273        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
274
275        let body = decode_json(resp).await;
276        // Fallback is the raw translation key.
277        assert_eq!(body["error"]["message"], "errors.user.not_found");
278    }
279
280    #[tokio::test]
281    async fn default_handler_passes_through_plain_errors() {
282        // With a Translator installed.
283        let dir = tempfile::tempdir().unwrap();
284        let i18n = test_i18n(dir.path());
285
286        let app = Router::new()
287            .route("/", get(plain_handler))
288            .layer(error_handler(default_error_handler))
289            .layer(i18n.layer());
290
291        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
292        let resp = app.oneshot(req).await.unwrap();
293        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
294        let body = decode_json(resp).await;
295        assert_eq!(body["error"]["message"], "boom");
296
297        // And without one.
298        let app = Router::new()
299            .route("/", get(plain_handler))
300            .layer(error_handler(default_error_handler));
301        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
302        let resp = app.oneshot(req).await.unwrap();
303        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
304        let body = decode_json(resp).await;
305        assert_eq!(body["error"]["message"], "boom");
306    }
307}