rama_http/layer/
error_handling.rs

1//! Middleware to turn [`Service`] errors into [`Response`]s.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_core::{
7//!     service::service_fn,
8//!     Context, Service, Layer,
9//! };
10//! use rama_http::{
11//!     service::client::HttpClientExt,
12//!     layer::{error_handling::ErrorHandlerLayer, timeout::TimeoutLayer},
13//!     service::web::WebService,
14//!     Body, IntoResponse, Request, Response, StatusCode,
15//! };
16//! use std::time::Duration;
17//!
18//! # async fn some_expensive_io_operation() -> Result<(), std::io::Error> {
19//! #     Ok(())
20//! # }
21//!
22//! async fn handler<S>(_ctx: Context<S>, _req: Request) -> Result<Response, std::io::Error> {
23//!     some_expensive_io_operation().await?;
24//!     Ok(StatusCode::OK.into_response())
25//! }
26//!
27//! # #[tokio::main]
28//! # async fn main() {
29//!     let home_handler = (
30//!         ErrorHandlerLayer::new().error_mapper(|err| {
31//!             tracing::error!("Error: {:?}", err);
32//!             StatusCode::INTERNAL_SERVER_ERROR.into_response()
33//!         }),
34//!         TimeoutLayer::new(Duration::from_secs(5)),
35//!         ).layer(service_fn(handler));
36//!
37//!     let service = WebService::default().get("/", home_handler);
38//!
39//!     let _ = service.serve(Context::default(), Request::builder()
40//!         .method("GET")
41//!         .uri("/")
42//!         .body(Body::empty())
43//!         .unwrap()).await;
44//! # }
45//! ```
46
47use crate::{IntoResponse, Request, Response};
48use rama_core::{Context, Layer, Service};
49use rama_utils::macros::define_inner_service_accessors;
50use std::{convert::Infallible, fmt};
51
52/// A [`Layer`] that wraps a [`Service`] and converts errors into [`Response`]s.
53pub struct ErrorHandlerLayer<F = ()> {
54    error_mapper: F,
55}
56
57impl Default for ErrorHandlerLayer {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl<F: fmt::Debug> fmt::Debug for ErrorHandlerLayer<F> {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        f.debug_struct("ErrorHandlerLayer")
66            .field("error_mapper", &self.error_mapper)
67            .finish()
68    }
69}
70
71impl<F: Clone> Clone for ErrorHandlerLayer<F> {
72    fn clone(&self) -> Self {
73        Self {
74            error_mapper: self.error_mapper.clone(),
75        }
76    }
77}
78
79impl ErrorHandlerLayer {
80    /// Create a new [`ErrorHandlerLayer`].
81    pub const fn new() -> Self {
82        Self { error_mapper: () }
83    }
84
85    /// Set the error mapper function (not set by default).
86    ///
87    /// The error mapper function is called with the error,
88    /// and should return an [`IntoResponse`] implementation.
89    pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandlerLayer<F> {
90        ErrorHandlerLayer { error_mapper }
91    }
92}
93
94impl<S, F: Clone> Layer<S> for ErrorHandlerLayer<F> {
95    type Service = ErrorHandler<S, F>;
96
97    fn layer(&self, inner: S) -> Self::Service {
98        ErrorHandler::new(inner).error_mapper(self.error_mapper.clone())
99    }
100}
101
102/// A [`Service`] adapter that handles errors by converting them into [`Response`]s.
103pub struct ErrorHandler<S, F = ()> {
104    inner: S,
105    error_mapper: F,
106}
107
108impl<S: fmt::Debug, F: fmt::Debug> fmt::Debug for ErrorHandler<S, F> {
109    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110        f.debug_struct("ErrorHandler")
111            .field("inner", &self.inner)
112            .field("error_mapper", &self.error_mapper)
113            .finish()
114    }
115}
116
117impl<S: Clone, F: Clone> Clone for ErrorHandler<S, F> {
118    fn clone(&self) -> Self {
119        Self {
120            inner: self.inner.clone(),
121            error_mapper: self.error_mapper.clone(),
122        }
123    }
124}
125
126impl<S> ErrorHandler<S> {
127    /// Create a new [`ErrorHandler`] wrapping the given service.
128    pub const fn new(inner: S) -> Self {
129        Self {
130            inner,
131            error_mapper: (),
132        }
133    }
134
135    define_inner_service_accessors!();
136
137    /// Set the error mapper function (not set by default).
138    ///
139    /// The error mapper function is called with the error,
140    /// and should return an [`IntoResponse`] implementation.
141    pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandler<S, F> {
142        ErrorHandler {
143            inner: self.inner,
144            error_mapper,
145        }
146    }
147}
148
149impl<S, State, Body> Service<State, Request<Body>> for ErrorHandler<S, ()>
150where
151    S: Service<State, Request<Body>, Response: IntoResponse, Error: IntoResponse>,
152    State: Clone + Send + Sync + 'static,
153    Body: Send + 'static,
154{
155    type Response = Response;
156    type Error = Infallible;
157
158    async fn serve(
159        &self,
160        ctx: Context<State>,
161        req: Request<Body>,
162    ) -> Result<Self::Response, Self::Error> {
163        match self.inner.serve(ctx, req).await {
164            Ok(response) => Ok(response.into_response()),
165            Err(error) => Ok(error.into_response()),
166        }
167    }
168}
169
170impl<S, F, R, State, Body> Service<State, Request<Body>> for ErrorHandler<S, F>
171where
172    S: Service<State, Request<Body>, Response: IntoResponse>,
173    F: Fn(S::Error) -> R + Clone + Send + Sync + 'static,
174    R: IntoResponse + 'static,
175    State: Clone + Send + Sync + 'static,
176    Body: Send + 'static,
177{
178    type Response = Response;
179    type Error = Infallible;
180
181    async fn serve(
182        &self,
183        ctx: Context<State>,
184        req: Request<Body>,
185    ) -> Result<Self::Response, Self::Error> {
186        match self.inner.serve(ctx, req).await {
187            Ok(response) => Ok(response.into_response()),
188            Err(error) => Ok((self.error_mapper)(error).into_response()),
189        }
190    }
191}