gateway_runtime/layers/
error.rs1use crate::alloc::boxed::Box;
10use crate::alloc::sync::Arc;
11use crate::{GatewayError, GatewayRequest, GatewayResponse};
12use core::task::{Context, Poll};
13use std::future::Future;
14use std::pin::Pin;
15use tower::Service;
16
17pub type ErrorHandler = Arc<dyn Fn(&GatewayRequest, GatewayError) -> GatewayResponse + Send + Sync>;
22
23#[derive(Clone)]
25pub struct ErrorLayer<S> {
26 inner: S,
27 handler: Option<ErrorHandler>,
28}
29
30impl<S> ErrorLayer<S> {
31 pub fn new(inner: S, handler: Option<ErrorHandler>) -> Self {
37 Self { inner, handler }
38 }
39}
40
41impl<S> Service<GatewayRequest> for ErrorLayer<S>
42where
43 S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>,
44 S::Future: Send + 'static,
45{
46 type Response = GatewayResponse;
47 type Error = GatewayError;
48 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
49
50 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51 self.inner.poll_ready(cx)
52 }
53
54 fn call(&mut self, req: GatewayRequest) -> Self::Future {
55 let method = req.method().clone();
58 let uri = req.uri().clone();
59 let headers = req.headers().clone();
60
61 let handler_clone = self.handler.clone();
62
63 let fut = self.inner.call(req);
64
65 Box::pin(async move {
66 match fut.await {
67 Ok(resp) => Ok(resp),
68 Err(err) => {
69 if let Some(h) = handler_clone {
70 let mut partial_req = http::Request::builder()
73 .method(method)
74 .uri(uri)
75 .body(crate::alloc::vec::Vec::new())
76 .unwrap();
77 *partial_req.headers_mut() = headers;
78
79 Ok(h(&partial_req, err))
81 } else {
82 Err(err)
84 }
85 }
86 }
87 })
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::alloc::string::ToString;
95 use http::{Response, StatusCode};
96 use http_body_util::{BodyExt, Full};
97
98 #[tokio::test]
99 async fn test_error_layer_catches_error() {
100 let service =
102 tower::service_fn(|_req: GatewayRequest| async { Err(GatewayError::NotFound) });
103
104 let handler: ErrorHandler = Arc::new(|_, err| {
106 Response::builder()
107 .status(StatusCode::NOT_FOUND)
108 .body(BodyExt::boxed_unsync(
109 Full::new(crate::bytes::Bytes::from(err.to_string()))
110 .map_err(|_| unreachable!()),
111 ))
112 .unwrap()
113 });
114
115 let mut layer = ErrorLayer::new(service, Some(handler));
116 let req = http::Request::builder()
117 .body(crate::alloc::vec::Vec::new())
118 .unwrap();
119
120 let resp = layer.call(req).await.unwrap();
121 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
122 }
123
124 #[tokio::test]
125 async fn test_error_layer_propagates_ok() {
126 let service = tower::service_fn(|_req: GatewayRequest| async {
128 Ok(Response::new(BodyExt::boxed_unsync(
129 Full::new(crate::bytes::Bytes::from("ok")).map_err(|_| unreachable!()),
130 )))
131 });
132
133 let handler: ErrorHandler = Arc::new(|_, _| {
134 panic!("Handler should not be called");
135 });
136
137 let mut layer = ErrorLayer::new(service, Some(handler));
138 let req = http::Request::builder()
139 .body(crate::alloc::vec::Vec::new())
140 .unwrap();
141
142 let resp = layer.call(req).await.unwrap();
143 assert_eq!(resp.status(), StatusCode::OK);
144 }
145
146 #[tokio::test]
147 async fn test_error_layer_propagates_error_without_handler() {
148 let service =
149 tower::service_fn(|_req: GatewayRequest| async { Err(GatewayError::MethodNotAllowed) });
150
151 let mut layer = ErrorLayer::new(service, None);
152 let req = http::Request::builder()
153 .body(crate::alloc::vec::Vec::new())
154 .unwrap();
155
156 let res = layer.call(req).await;
157 assert!(res.is_err());
158 }
159}