rama_http/layer/
error_handling.rs

1//! Middleware to turn [`Service`] errors into [`Response`]s.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_core::{
7//!     service::service_fn,
8//!     Context, Service, Layer,
9//! };
10//! use rama_http::{
11//!     service::client::HttpClientExt,
12//!     layer::{error_handling::ErrorHandlerLayer, timeout::TimeoutLayer},
13//!     service::web::WebService,
14//!     service::web::response::IntoResponse,
15//!     Body, Request, Response, StatusCode,
16//! };
17//! use std::time::Duration;
18//!
19//! # async fn some_expensive_io_operation() -> Result<(), std::io::Error> {
20//! #     Ok(())
21//! # }
22//!
23//! async fn handler<S>(_ctx: Context<S>, _req: Request) -> Result<Response, std::io::Error> {
24//!     some_expensive_io_operation().await?;
25//!     Ok(StatusCode::OK.into_response())
26//! }
27//!
28//! # #[tokio::main]
29//! # async fn main() {
30//!     let home_handler = (
31//!         ErrorHandlerLayer::new().error_mapper(|err| {
32//!             tracing::error!("Error: {:?}", err);
33//!             StatusCode::INTERNAL_SERVER_ERROR.into_response()
34//!         }),
35//!         TimeoutLayer::new(Duration::from_secs(5)),
36//!         ).into_layer(service_fn(handler));
37//!
38//!     let service = WebService::default().get("/", home_handler);
39//!
40//!     let _ = service.serve(Context::default(), Request::builder()
41//!         .method("GET")
42//!         .uri("/")
43//!         .body(Body::empty())
44//!         .unwrap()).await;
45//! # }
46//! ```
47
48use crate::service::web::response::IntoResponse;
49use crate::{Request, Response};
50use rama_core::{Context, Layer, Service};
51use rama_utils::macros::define_inner_service_accessors;
52use std::{convert::Infallible, fmt};
53
54/// A [`Layer`] that wraps a [`Service`] and converts errors into [`Response`]s.
55pub struct ErrorHandlerLayer<F = ()> {
56    error_mapper: F,
57}
58
59impl Default for ErrorHandlerLayer {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl<F: fmt::Debug> fmt::Debug for ErrorHandlerLayer<F> {
66    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67        f.debug_struct("ErrorHandlerLayer")
68            .field("error_mapper", &self.error_mapper)
69            .finish()
70    }
71}
72
73impl<F: Clone> Clone for ErrorHandlerLayer<F> {
74    fn clone(&self) -> Self {
75        Self {
76            error_mapper: self.error_mapper.clone(),
77        }
78    }
79}
80
81impl ErrorHandlerLayer {
82    /// Create a new [`ErrorHandlerLayer`].
83    pub const fn new() -> Self {
84        Self { error_mapper: () }
85    }
86
87    /// Set the error mapper function (not set by default).
88    ///
89    /// The error mapper function is called with the error,
90    /// and should return an [`IntoResponse`] implementation.
91    pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandlerLayer<F> {
92        ErrorHandlerLayer { error_mapper }
93    }
94}
95
96impl<S, F: Clone> Layer<S> for ErrorHandlerLayer<F> {
97    type Service = ErrorHandler<S, F>;
98
99    fn layer(&self, inner: S) -> Self::Service {
100        ErrorHandler::new(inner).error_mapper(self.error_mapper.clone())
101    }
102
103    fn into_layer(self, inner: S) -> Self::Service {
104        ErrorHandler::new(inner).error_mapper(self.error_mapper)
105    }
106}
107
108/// A [`Service`] adapter that handles errors by converting them into [`Response`]s.
109pub struct ErrorHandler<S, F = ()> {
110    inner: S,
111    error_mapper: F,
112}
113
114impl<S: fmt::Debug, F: fmt::Debug> fmt::Debug for ErrorHandler<S, F> {
115    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116        f.debug_struct("ErrorHandler")
117            .field("inner", &self.inner)
118            .field("error_mapper", &self.error_mapper)
119            .finish()
120    }
121}
122
123impl<S: Clone, F: Clone> Clone for ErrorHandler<S, F> {
124    fn clone(&self) -> Self {
125        Self {
126            inner: self.inner.clone(),
127            error_mapper: self.error_mapper.clone(),
128        }
129    }
130}
131
132impl<S> ErrorHandler<S> {
133    /// Create a new [`ErrorHandler`] wrapping the given service.
134    pub const fn new(inner: S) -> Self {
135        Self {
136            inner,
137            error_mapper: (),
138        }
139    }
140
141    define_inner_service_accessors!();
142
143    /// Set the error mapper function (not set by default).
144    ///
145    /// The error mapper function is called with the error,
146    /// and should return an [`IntoResponse`] implementation.
147    pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandler<S, F> {
148        ErrorHandler {
149            inner: self.inner,
150            error_mapper,
151        }
152    }
153}
154
155impl<S, State, Body> Service<State, Request<Body>> for ErrorHandler<S, ()>
156where
157    S: Service<State, Request<Body>, Response: IntoResponse, Error: IntoResponse>,
158    State: Clone + Send + Sync + 'static,
159    Body: Send + 'static,
160{
161    type Response = Response;
162    type Error = Infallible;
163
164    async fn serve(
165        &self,
166        ctx: Context<State>,
167        req: Request<Body>,
168    ) -> Result<Self::Response, Self::Error> {
169        match self.inner.serve(ctx, req).await {
170            Ok(response) => Ok(response.into_response()),
171            Err(error) => Ok(error.into_response()),
172        }
173    }
174}
175
176impl<S, F, R, State, Body> Service<State, Request<Body>> for ErrorHandler<S, F>
177where
178    S: Service<State, Request<Body>, Response: IntoResponse>,
179    F: Fn(S::Error) -> R + Clone + Send + Sync + 'static,
180    R: IntoResponse + 'static,
181    State: Clone + Send + Sync + 'static,
182    Body: Send + 'static,
183{
184    type Response = Response;
185    type Error = Infallible;
186
187    async fn serve(
188        &self,
189        ctx: Context<State>,
190        req: Request<Body>,
191    ) -> Result<Self::Response, Self::Error> {
192        match self.inner.serve(ctx, req).await {
193            Ok(response) => Ok(response.into_response()),
194            Err(error) => Ok((self.error_mapper)(error).into_response()),
195        }
196    }
197}