Skip to main content

modo/middleware/
error_handler.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::response::Response;
6use http::request::Parts;
7use tower::{Layer, Service};
8
9/// Creates an error-handler layer that intercepts responses containing a
10/// [`crate::error::Error`] in their extensions and rewrites them through
11/// the supplied handler function.
12///
13/// Any middleware that stores a `modo::Error` in response extensions
14/// (`Error::into_response()`, `catch_panic`, `csrf`, `rate_limit`, etc.)
15/// will be caught by this layer, giving the application a single place to
16/// control the error response format (JSON, HTML, plain text, etc.).
17///
18/// The handler receives the error and the original request parts (method,
19/// URI, headers, extensions) by value.
20///
21/// # Example
22///
23/// ```
24/// use axum::{Router, routing::get};
25/// use axum::response::IntoResponse;
26///
27/// async fn render_error(err: modo::Error, parts: http::request::Parts) -> axum::response::Response {
28///     (err.status(), err.message().to_string()).into_response()
29/// }
30///
31/// let app: Router = Router::new()
32///     .route("/", get(|| async { "ok" }))
33///     .layer(modo::middleware::error_handler(render_error));
34/// ```
35pub fn error_handler<F, Fut>(handler: F) -> ErrorHandlerLayer<F>
36where
37    F: Fn(crate::error::Error, Parts) -> Fut + Clone + Send + Sync + 'static,
38    Fut: Future<Output = Response> + Send + 'static,
39{
40    ErrorHandlerLayer { handler }
41}
42
43/// Tower [`Layer`] produced by [`error_handler`].
44#[derive(Clone)]
45pub struct ErrorHandlerLayer<F> {
46    handler: F,
47}
48
49impl<S, F> Layer<S> for ErrorHandlerLayer<F>
50where
51    F: Clone,
52{
53    type Service = ErrorHandlerService<S, F>;
54
55    fn layer(&self, inner: S) -> Self::Service {
56        ErrorHandlerService {
57            inner,
58            handler: self.handler.clone(),
59        }
60    }
61}
62
63/// Tower [`Service`] that wraps an inner service and rewrites error responses
64/// through a user-provided handler.
65#[derive(Clone)]
66pub struct ErrorHandlerService<S, F> {
67    inner: S,
68    handler: F,
69}
70
71impl<S, F, Fut> Service<http::Request<axum::body::Body>> for ErrorHandlerService<S, F>
72where
73    S: Service<http::Request<axum::body::Body>, Response = Response> + Clone + Send + 'static,
74    S::Future: Send,
75    S::Error: Into<std::convert::Infallible>,
76    F: Fn(crate::error::Error, Parts) -> Fut + Clone + Send + Sync + 'static,
77    Fut: Future<Output = Response> + Send + 'static,
78{
79    type Response = Response;
80    type Error = S::Error;
81    type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send>>;
82
83    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        self.inner.poll_ready(cx)
85    }
86
87    fn call(&mut self, req: http::Request<axum::body::Body>) -> Self::Future {
88        // Clone parts before consuming the request so the error handler can
89        // inspect method, URI, headers, etc.
90        let (parts, body) = req.into_parts();
91        let saved_parts = parts.clone();
92        let req = http::Request::from_parts(parts, body);
93
94        let handler = self.handler.clone();
95        let future = self.inner.call(req);
96
97        Box::pin(async move {
98            let response = future.await?;
99
100            if let Some(error) = response.extensions().get::<crate::error::Error>() {
101                let error = error.clone();
102                let new_response = handler(error, saved_parts).await;
103                Ok(new_response)
104            } else {
105                Ok(response)
106            }
107        })
108    }
109}