rama_http/layer/
error_handling.rs

1//! Middleware to turn [`Service`] errors into [`Response`]s.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_core::{
7//!     service::service_fn,
8//!     Context, Service, Layer,
9//! };
10//! use rama_http::{
11//!     service::client::HttpClientExt,
12//!     layer::{error_handling::ErrorHandlerLayer, timeout::TimeoutLayer},
13//!     service::web::WebService,
14//!     Body, IntoResponse, Request, Response, StatusCode,
15//! };
16//! use std::time::Duration;
17//!
18//! # async fn some_expensive_io_operation() -> Result<(), std::io::Error> {
19//! #     Ok(())
20//! # }
21//!
22//! async fn handler<S>(_ctx: Context<S>, _req: Request) -> Result<Response, std::io::Error> {
23//!     some_expensive_io_operation().await?;
24//!     Ok(StatusCode::OK.into_response())
25//! }
26//!
27//! # #[tokio::main]
28//! # async fn main() {
29//!     let home_handler = (
30//!         ErrorHandlerLayer::new().error_mapper(|err| {
31//!             tracing::error!("Error: {:?}", err);
32//!             StatusCode::INTERNAL_SERVER_ERROR.into_response()
33//!         }),
34//!         TimeoutLayer::new(Duration::from_secs(5)),
35//!         ).into_layer(service_fn(handler));
36//!
37//!     let service = WebService::default().get("/", home_handler);
38//!
39//!     let _ = service.serve(Context::default(), Request::builder()
40//!         .method("GET")
41//!         .uri("/")
42//!         .body(Body::empty())
43//!         .unwrap()).await;
44//! # }
45//! ```
46
47use crate::{IntoResponse, Request, Response};
48use rama_core::{Context, Layer, Service};
49use rama_utils::macros::define_inner_service_accessors;
50use std::{convert::Infallible, fmt};
51
52/// A [`Layer`] that wraps a [`Service`] and converts errors into [`Response`]s.
53pub struct ErrorHandlerLayer<F = ()> {
54    error_mapper: F,
55}
56
57impl Default for ErrorHandlerLayer {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl<F: fmt::Debug> fmt::Debug for ErrorHandlerLayer<F> {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        f.debug_struct("ErrorHandlerLayer")
66            .field("error_mapper", &self.error_mapper)
67            .finish()
68    }
69}
70
71impl<F: Clone> Clone for ErrorHandlerLayer<F> {
72    fn clone(&self) -> Self {
73        Self {
74            error_mapper: self.error_mapper.clone(),
75        }
76    }
77}
78
79impl ErrorHandlerLayer {
80    /// Create a new [`ErrorHandlerLayer`].
81    pub const fn new() -> Self {
82        Self { error_mapper: () }
83    }
84
85    /// Set the error mapper function (not set by default).
86    ///
87    /// The error mapper function is called with the error,
88    /// and should return an [`IntoResponse`] implementation.
89    pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandlerLayer<F> {
90        ErrorHandlerLayer { error_mapper }
91    }
92}
93
94impl<S, F: Clone> Layer<S> for ErrorHandlerLayer<F> {
95    type Service = ErrorHandler<S, F>;
96
97    fn layer(&self, inner: S) -> Self::Service {
98        ErrorHandler::new(inner).error_mapper(self.error_mapper.clone())
99    }
100
101    fn into_layer(self, inner: S) -> Self::Service {
102        ErrorHandler::new(inner).error_mapper(self.error_mapper)
103    }
104}
105
106/// A [`Service`] adapter that handles errors by converting them into [`Response`]s.
107pub struct ErrorHandler<S, F = ()> {
108    inner: S,
109    error_mapper: F,
110}
111
112impl<S: fmt::Debug, F: fmt::Debug> fmt::Debug for ErrorHandler<S, F> {
113    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114        f.debug_struct("ErrorHandler")
115            .field("inner", &self.inner)
116            .field("error_mapper", &self.error_mapper)
117            .finish()
118    }
119}
120
121impl<S: Clone, F: Clone> Clone for ErrorHandler<S, F> {
122    fn clone(&self) -> Self {
123        Self {
124            inner: self.inner.clone(),
125            error_mapper: self.error_mapper.clone(),
126        }
127    }
128}
129
130impl<S> ErrorHandler<S> {
131    /// Create a new [`ErrorHandler`] wrapping the given service.
132    pub const fn new(inner: S) -> Self {
133        Self {
134            inner,
135            error_mapper: (),
136        }
137    }
138
139    define_inner_service_accessors!();
140
141    /// Set the error mapper function (not set by default).
142    ///
143    /// The error mapper function is called with the error,
144    /// and should return an [`IntoResponse`] implementation.
145    pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandler<S, F> {
146        ErrorHandler {
147            inner: self.inner,
148            error_mapper,
149        }
150    }
151}
152
153impl<S, State, Body> Service<State, Request<Body>> for ErrorHandler<S, ()>
154where
155    S: Service<State, Request<Body>, Response: IntoResponse, Error: IntoResponse>,
156    State: Clone + Send + Sync + 'static,
157    Body: Send + 'static,
158{
159    type Response = Response;
160    type Error = Infallible;
161
162    async fn serve(
163        &self,
164        ctx: Context<State>,
165        req: Request<Body>,
166    ) -> Result<Self::Response, Self::Error> {
167        match self.inner.serve(ctx, req).await {
168            Ok(response) => Ok(response.into_response()),
169            Err(error) => Ok(error.into_response()),
170        }
171    }
172}
173
174impl<S, F, R, State, Body> Service<State, Request<Body>> for ErrorHandler<S, F>
175where
176    S: Service<State, Request<Body>, Response: IntoResponse>,
177    F: Fn(S::Error) -> R + Clone + Send + Sync + 'static,
178    R: IntoResponse + 'static,
179    State: Clone + Send + Sync + 'static,
180    Body: Send + 'static,
181{
182    type Response = Response;
183    type Error = Infallible;
184
185    async fn serve(
186        &self,
187        ctx: Context<State>,
188        req: Request<Body>,
189    ) -> Result<Self::Response, Self::Error> {
190        match self.inner.serve(ctx, req).await {
191            Ok(response) => Ok(response.into_response()),
192            Err(error) => Ok((self.error_mapper)(error).into_response()),
193        }
194    }
195}