rama_core/layer/
consume_err.rs

1use crate::{Context, Layer, Service, error::BoxError};
2use rama_utils::macros::define_inner_service_accessors;
3use std::{convert::Infallible, fmt};
4
5use sealed::{DefaulResponse, StaticResponse, Trace};
6
7/// Consumes this service's error value and returns [`Infallible`].
8#[derive(Clone)]
9pub struct ConsumeErr<S, F, R = DefaulResponse> {
10    inner: S,
11    f: F,
12    response: R,
13}
14
15impl<S, F, R> fmt::Debug for ConsumeErr<S, F, R>
16where
17    S: fmt::Debug,
18    R: fmt::Debug,
19{
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        f.debug_struct("ConsumeErr")
22            .field("inner", &self.inner)
23            .field("f", &format_args!("{}", std::any::type_name::<F>()))
24            .field("response", &self.response)
25            .finish()
26    }
27}
28
29/// A [`Layer`] that produces [`ConsumeErr`] services.
30///
31/// [`Layer`]: crate::Layer
32#[derive(Clone)]
33pub struct ConsumeErrLayer<F, R = DefaulResponse> {
34    f: F,
35    response: R,
36}
37
38impl<F, R: fmt::Debug> fmt::Debug for ConsumeErrLayer<F, R> {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        f.debug_struct("ConsumeErrLayer")
41            .field("f", &format_args!("{}", std::any::type_name::<F>()))
42            .field("response", &self.response)
43            .finish()
44    }
45}
46
47impl Default for ConsumeErrLayer<Trace> {
48    fn default() -> Self {
49        Self::trace(tracing::Level::ERROR)
50    }
51}
52
53impl<S, F> ConsumeErr<S, F, DefaulResponse> {
54    /// Creates a new [`ConsumeErr`] service.
55    pub const fn new(inner: S, f: F) -> Self {
56        ConsumeErr {
57            f,
58            inner,
59            response: DefaulResponse,
60        }
61    }
62
63    define_inner_service_accessors!();
64}
65
66impl<S, F> ConsumeErr<S, F, DefaulResponse> {
67    /// Set a response to be used in case of errors,
68    /// instead of requiring and using the [`Default::default`] implementation
69    /// of the inner service's response type.
70    pub fn with_response<R>(self, response: R) -> ConsumeErr<S, F, StaticResponse<R>> {
71        ConsumeErr {
72            f: self.f,
73            inner: self.inner,
74            response: StaticResponse(response),
75        }
76    }
77}
78
79impl<S> ConsumeErr<S, Trace, DefaulResponse> {
80    /// Trace the error passed to this [`ConsumeErr`] service for the provided trace level.
81    pub const fn trace(inner: S, level: tracing::Level) -> Self {
82        Self::new(inner, Trace(level))
83    }
84}
85
86impl<S, F, State, Request> Service<State, Request> for ConsumeErr<S, F, DefaulResponse>
87where
88    S: Service<State, Request, Response: Default>,
89    F: FnOnce(S::Error) + Clone + Send + Sync + 'static,
90    State: Clone + Send + Sync + 'static,
91    Request: Send + 'static,
92{
93    type Response = S::Response;
94    type Error = Infallible;
95
96    async fn serve(
97        &self,
98        ctx: Context<State>,
99        req: Request,
100    ) -> Result<Self::Response, Self::Error> {
101        match self.inner.serve(ctx, req).await {
102            Ok(resp) => Ok(resp),
103            Err(err) => {
104                (self.f.clone())(err);
105                Ok(S::Response::default())
106            }
107        }
108    }
109}
110
111impl<S, F, State, Request, R> Service<State, Request> for ConsumeErr<S, F, StaticResponse<R>>
112where
113    S: Service<State, Request>,
114    F: FnOnce(S::Error) + Clone + Send + Sync + 'static,
115    R: Into<S::Response> + Clone + Send + Sync + 'static,
116    State: Clone + Send + Sync + 'static,
117    Request: Send + 'static,
118{
119    type Response = S::Response;
120    type Error = Infallible;
121
122    async fn serve(
123        &self,
124        ctx: Context<State>,
125        req: Request,
126    ) -> Result<Self::Response, Self::Error> {
127        match self.inner.serve(ctx, req).await {
128            Ok(resp) => Ok(resp),
129            Err(err) => {
130                (self.f.clone())(err);
131                Ok(self.response.0.clone().into())
132            }
133        }
134    }
135}
136
137impl<S, State, Request> Service<State, Request> for ConsumeErr<S, Trace, DefaulResponse>
138where
139    S: Service<State, Request, Response: Default, Error: Into<BoxError>>,
140    State: Clone + Send + Sync + 'static,
141    Request: Send + 'static,
142{
143    type Response = S::Response;
144    type Error = Infallible;
145
146    async fn serve(
147        &self,
148        ctx: Context<State>,
149        req: Request,
150    ) -> Result<Self::Response, Self::Error> {
151        match self.inner.serve(ctx, req).await {
152            Ok(resp) => Ok(resp),
153            Err(err) => {
154                const MESSAGE: &str = "unhandled service error consumed";
155                match self.f.0 {
156                    tracing::Level::TRACE => {
157                        tracing::trace!(error = err.into(), MESSAGE);
158                    }
159                    tracing::Level::DEBUG => {
160                        tracing::debug!(error = err.into(), MESSAGE);
161                    }
162                    tracing::Level::INFO => {
163                        tracing::info!(error = err.into(), MESSAGE);
164                    }
165                    tracing::Level::WARN => {
166                        tracing::warn!(error = err.into(), MESSAGE);
167                    }
168                    tracing::Level::ERROR => {
169                        tracing::error!(error = err.into(), MESSAGE);
170                    }
171                }
172                Ok(S::Response::default())
173            }
174        }
175    }
176}
177
178impl<S, State, Request, R> Service<State, Request> for ConsumeErr<S, Trace, StaticResponse<R>>
179where
180    S: Service<State, Request, Error: Into<BoxError>>,
181    R: Into<S::Response> + Clone + Send + Sync + 'static,
182    State: Clone + Send + Sync + 'static,
183    Request: Send + 'static,
184{
185    type Response = S::Response;
186    type Error = Infallible;
187
188    async fn serve(
189        &self,
190        ctx: Context<State>,
191        req: Request,
192    ) -> Result<Self::Response, Self::Error> {
193        match self.inner.serve(ctx, req).await {
194            Ok(resp) => Ok(resp),
195            Err(err) => {
196                const MESSAGE: &str = "unhandled service error consumed";
197                match self.f.0 {
198                    tracing::Level::TRACE => {
199                        tracing::trace!(error = err.into(), MESSAGE);
200                    }
201                    tracing::Level::DEBUG => {
202                        tracing::debug!(error = err.into(), MESSAGE);
203                    }
204                    tracing::Level::INFO => {
205                        tracing::info!(error = err.into(), MESSAGE);
206                    }
207                    tracing::Level::WARN => {
208                        tracing::warn!(error = err.into(), MESSAGE);
209                    }
210                    tracing::Level::ERROR => {
211                        tracing::error!(error = err.into(), MESSAGE);
212                    }
213                }
214                Ok(self.response.0.clone().into())
215            }
216        }
217    }
218}
219
220impl<F> ConsumeErrLayer<F> {
221    /// Creates a new [`ConsumeErrLayer`].
222    pub const fn new(f: F) -> Self {
223        ConsumeErrLayer {
224            f,
225            response: DefaulResponse,
226        }
227    }
228}
229
230impl ConsumeErrLayer<Trace> {
231    /// Creates a new [`ConsumeErrLayer`] to trace the consumed error.
232    pub const fn trace(level: tracing::Level) -> Self {
233        Self::new(Trace(level))
234    }
235}
236
237impl<F> ConsumeErrLayer<F, DefaulResponse> {
238    /// Set a response to be used in case of errors,
239    /// instead of requiring and using the [`Default::default`] implementation
240    /// of the inner service's response type.
241    pub fn with_response<R>(self, response: R) -> ConsumeErrLayer<F, StaticResponse<R>> {
242        ConsumeErrLayer {
243            f: self.f,
244            response: StaticResponse(response),
245        }
246    }
247}
248
249impl<S, F, R> Layer<S> for ConsumeErrLayer<F, R>
250where
251    F: Clone,
252    R: Clone,
253{
254    type Service = ConsumeErr<S, F, R>;
255
256    fn layer(&self, inner: S) -> Self::Service {
257        ConsumeErr {
258            f: self.f.clone(),
259            inner,
260            response: self.response.clone(),
261        }
262    }
263
264    fn into_layer(self, inner: S) -> Self::Service {
265        ConsumeErr {
266            f: self.f,
267            inner,
268            response: self.response,
269        }
270    }
271}
272
273mod sealed {
274    #[derive(Debug, Clone)]
275    /// A sealed new type to prevent downstream users from
276    /// passing the trace level directly to the [`ConsumeErr::new`] method.
277    ///
278    /// [`ConsumeErr::new`]: crate::layer::ConsumeErr::new
279    pub struct Trace(pub tracing::Level);
280
281    #[derive(Debug, Clone)]
282    #[non_exhaustive]
283    /// A sealed type to indicate default response is to be used.
284    pub struct DefaulResponse;
285
286    #[derive(Debug, Clone)]
287    #[non_exhaustive]
288    /// A sealed type to indicate static response is to be used.
289    pub struct StaticResponse<R>(pub(super) R);
290}