Skip to main content

gateway_runtime/layers/
error.rs

1//! # Error Handling Layer
2//!
3//! This layer intercepts errors returned by the inner service and converts them into
4//! valid HTTP responses using a configured [ErrorHandler].
5//!
6//! This is crucial for returning user-friendly error messages (e.g., JSON) instead of
7//! raw server errors or dropped connections.
8
9use crate::alloc::boxed::Box;
10use crate::alloc::sync::Arc;
11use crate::{GatewayError, GatewayRequest, GatewayResponse};
12use core::task::{Context, Poll};
13use std::future::Future;
14use std::pin::Pin;
15use tower::Service;
16
17/// A handler for converting errors into HTTP responses.
18///
19/// Implementations of this function type are responsible for mapping domain-specific
20/// [GatewayError]s into user-facing [GatewayResponse]s (e.g., setting status codes, JSON bodies).
21pub type ErrorHandler = Arc<dyn Fn(&GatewayRequest, GatewayError) -> GatewayResponse + Send + Sync>;
22
23/// A Tower middleware that handles errors from the inner service.
24#[derive(Clone)]
25pub struct ErrorLayer<S> {
26    inner: S,
27    handler: Option<ErrorHandler>,
28}
29
30impl<S> ErrorLayer<S> {
31    /// Creates a new `ErrorLayer`.
32    ///
33    /// # Parameters
34    /// *   `inner`: The inner service to wrap.
35    /// *   `handler`: An optional custom error handler. If `None`, errors are propagated.
36    pub fn new(inner: S, handler: Option<ErrorHandler>) -> Self {
37        Self { inner, handler }
38    }
39}
40
41impl<S> Service<GatewayRequest> for ErrorLayer<S>
42where
43    S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>,
44    S::Future: Send + 'static,
45{
46    type Response = GatewayResponse;
47    type Error = GatewayError;
48    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
49
50    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51        self.inner.poll_ready(cx)
52    }
53
54    fn call(&mut self, req: GatewayRequest) -> Self::Future {
55        // Capture minimal request context (Method, URI, Headers) to pass to the error handler.
56        // The body is not cloned to avoid performance penalties.
57        let method = req.method().clone();
58        let uri = req.uri().clone();
59        let headers = req.headers().clone();
60
61        let handler_clone = self.handler.clone();
62
63        let fut = self.inner.call(req);
64
65        Box::pin(async move {
66            match fut.await {
67                Ok(resp) => Ok(resp),
68                Err(err) => {
69                    if let Some(h) = handler_clone {
70                        // Reconstruct a partial request for the handler context.
71                        // The body is empty since the original request has been consumed.
72                        let mut partial_req = http::Request::builder()
73                            .method(method)
74                            .uri(uri)
75                            .body(crate::alloc::vec::Vec::new())
76                            .unwrap();
77                        *partial_req.headers_mut() = headers;
78
79                        // Execute the custom error handler
80                        Ok(h(&partial_req, err))
81                    } else {
82                        // Propagate the error if no handler is configured
83                        Err(err)
84                    }
85                }
86            }
87        })
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::alloc::string::ToString;
95    use http::{Response, StatusCode};
96    use http_body_util::{BodyExt, Full};
97
98    #[tokio::test]
99    async fn test_error_layer_catches_error() {
100        // Mock service that always fails
101        let service =
102            tower::service_fn(|_req: GatewayRequest| async { Err(GatewayError::NotFound) });
103
104        // Handler that converts error to 404 response
105        let handler: ErrorHandler = Arc::new(|_, err| {
106            Response::builder()
107                .status(StatusCode::NOT_FOUND)
108                .body(BodyExt::boxed_unsync(
109                    Full::new(crate::bytes::Bytes::from(err.to_string()))
110                        .map_err(|_| unreachable!()),
111                ))
112                .unwrap()
113        });
114
115        let mut layer = ErrorLayer::new(service, Some(handler));
116        let req = http::Request::builder()
117            .body(crate::alloc::vec::Vec::new())
118            .unwrap();
119
120        let resp = layer.call(req).await.unwrap();
121        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
122    }
123
124    #[tokio::test]
125    async fn test_error_layer_propagates_ok() {
126        // Mock service that succeeds
127        let service = tower::service_fn(|_req: GatewayRequest| async {
128            Ok(Response::new(BodyExt::boxed_unsync(
129                Full::new(crate::bytes::Bytes::from("ok")).map_err(|_| unreachable!()),
130            )))
131        });
132
133        let handler: ErrorHandler = Arc::new(|_, _| {
134            panic!("Handler should not be called");
135        });
136
137        let mut layer = ErrorLayer::new(service, Some(handler));
138        let req = http::Request::builder()
139            .body(crate::alloc::vec::Vec::new())
140            .unwrap();
141
142        let resp = layer.call(req).await.unwrap();
143        assert_eq!(resp.status(), StatusCode::OK);
144    }
145
146    #[tokio::test]
147    async fn test_error_layer_propagates_error_without_handler() {
148        let service =
149            tower::service_fn(|_req: GatewayRequest| async { Err(GatewayError::MethodNotAllowed) });
150
151        let mut layer = ErrorLayer::new(service, None);
152        let req = http::Request::builder()
153            .body(crate::alloc::vec::Vec::new())
154            .unwrap();
155
156        let res = layer.call(req).await;
157        assert!(res.is_err());
158    }
159}