1use crate::{Context, Layer, Service, error::BoxError};
2use rama_utils::macros::define_inner_service_accessors;
3use std::{convert::Infallible, fmt};
4
5use sealed::{DefaulResponse, StaticResponse, Trace};
6
7#[derive(Clone)]
9pub struct ConsumeErr<S, F, R = DefaulResponse> {
10 inner: S,
11 f: F,
12 response: R,
13}
14
15impl<S, F, R> fmt::Debug for ConsumeErr<S, F, R>
16where
17 S: fmt::Debug,
18 R: fmt::Debug,
19{
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 f.debug_struct("ConsumeErr")
22 .field("inner", &self.inner)
23 .field("f", &format_args!("{}", std::any::type_name::<F>()))
24 .field("response", &self.response)
25 .finish()
26 }
27}
28
29#[derive(Clone)]
33pub struct ConsumeErrLayer<F, R = DefaulResponse> {
34 f: F,
35 response: R,
36}
37
38impl<F, R: fmt::Debug> fmt::Debug for ConsumeErrLayer<F, R> {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 f.debug_struct("ConsumeErrLayer")
41 .field("f", &format_args!("{}", std::any::type_name::<F>()))
42 .field("response", &self.response)
43 .finish()
44 }
45}
46
47impl Default for ConsumeErrLayer<Trace> {
48 fn default() -> Self {
49 Self::trace(tracing::Level::ERROR)
50 }
51}
52
53impl<S, F> ConsumeErr<S, F, DefaulResponse> {
54 pub const fn new(inner: S, f: F) -> Self {
56 ConsumeErr {
57 f,
58 inner,
59 response: DefaulResponse,
60 }
61 }
62
63 define_inner_service_accessors!();
64}
65
66impl<S, F> ConsumeErr<S, F, DefaulResponse> {
67 pub fn with_response<R>(self, response: R) -> ConsumeErr<S, F, StaticResponse<R>> {
71 ConsumeErr {
72 f: self.f,
73 inner: self.inner,
74 response: StaticResponse(response),
75 }
76 }
77}
78
79impl<S> ConsumeErr<S, Trace, DefaulResponse> {
80 pub const fn trace(inner: S, level: tracing::Level) -> Self {
82 Self::new(inner, Trace(level))
83 }
84}
85
86impl<S, F, State, Request> Service<State, Request> for ConsumeErr<S, F, DefaulResponse>
87where
88 S: Service<State, Request, Response: Default>,
89 F: FnOnce(S::Error) + Clone + Send + Sync + 'static,
90 State: Clone + Send + Sync + 'static,
91 Request: Send + 'static,
92{
93 type Response = S::Response;
94 type Error = Infallible;
95
96 async fn serve(
97 &self,
98 ctx: Context<State>,
99 req: Request,
100 ) -> Result<Self::Response, Self::Error> {
101 match self.inner.serve(ctx, req).await {
102 Ok(resp) => Ok(resp),
103 Err(err) => {
104 (self.f.clone())(err);
105 Ok(S::Response::default())
106 }
107 }
108 }
109}
110
111impl<S, F, State, Request, R> Service<State, Request> for ConsumeErr<S, F, StaticResponse<R>>
112where
113 S: Service<State, Request>,
114 F: FnOnce(S::Error) + Clone + Send + Sync + 'static,
115 R: Into<S::Response> + Clone + Send + Sync + 'static,
116 State: Clone + Send + Sync + 'static,
117 Request: Send + 'static,
118{
119 type Response = S::Response;
120 type Error = Infallible;
121
122 async fn serve(
123 &self,
124 ctx: Context<State>,
125 req: Request,
126 ) -> Result<Self::Response, Self::Error> {
127 match self.inner.serve(ctx, req).await {
128 Ok(resp) => Ok(resp),
129 Err(err) => {
130 (self.f.clone())(err);
131 Ok(self.response.0.clone().into())
132 }
133 }
134 }
135}
136
137impl<S, State, Request> Service<State, Request> for ConsumeErr<S, Trace, DefaulResponse>
138where
139 S: Service<State, Request, Response: Default, Error: Into<BoxError>>,
140 State: Clone + Send + Sync + 'static,
141 Request: Send + 'static,
142{
143 type Response = S::Response;
144 type Error = Infallible;
145
146 async fn serve(
147 &self,
148 ctx: Context<State>,
149 req: Request,
150 ) -> Result<Self::Response, Self::Error> {
151 match self.inner.serve(ctx, req).await {
152 Ok(resp) => Ok(resp),
153 Err(err) => {
154 const MESSAGE: &str = "unhandled service error consumed";
155 match self.f.0 {
156 tracing::Level::TRACE => {
157 tracing::trace!(error = err.into(), MESSAGE);
158 }
159 tracing::Level::DEBUG => {
160 tracing::debug!(error = err.into(), MESSAGE);
161 }
162 tracing::Level::INFO => {
163 tracing::info!(error = err.into(), MESSAGE);
164 }
165 tracing::Level::WARN => {
166 tracing::warn!(error = err.into(), MESSAGE);
167 }
168 tracing::Level::ERROR => {
169 tracing::error!(error = err.into(), MESSAGE);
170 }
171 }
172 Ok(S::Response::default())
173 }
174 }
175 }
176}
177
178impl<S, State, Request, R> Service<State, Request> for ConsumeErr<S, Trace, StaticResponse<R>>
179where
180 S: Service<State, Request, Error: Into<BoxError>>,
181 R: Into<S::Response> + Clone + Send + Sync + 'static,
182 State: Clone + Send + Sync + 'static,
183 Request: Send + 'static,
184{
185 type Response = S::Response;
186 type Error = Infallible;
187
188 async fn serve(
189 &self,
190 ctx: Context<State>,
191 req: Request,
192 ) -> Result<Self::Response, Self::Error> {
193 match self.inner.serve(ctx, req).await {
194 Ok(resp) => Ok(resp),
195 Err(err) => {
196 const MESSAGE: &str = "unhandled service error consumed";
197 match self.f.0 {
198 tracing::Level::TRACE => {
199 tracing::trace!(error = err.into(), MESSAGE);
200 }
201 tracing::Level::DEBUG => {
202 tracing::debug!(error = err.into(), MESSAGE);
203 }
204 tracing::Level::INFO => {
205 tracing::info!(error = err.into(), MESSAGE);
206 }
207 tracing::Level::WARN => {
208 tracing::warn!(error = err.into(), MESSAGE);
209 }
210 tracing::Level::ERROR => {
211 tracing::error!(error = err.into(), MESSAGE);
212 }
213 }
214 Ok(self.response.0.clone().into())
215 }
216 }
217 }
218}
219
220impl<F> ConsumeErrLayer<F> {
221 pub const fn new(f: F) -> Self {
223 ConsumeErrLayer {
224 f,
225 response: DefaulResponse,
226 }
227 }
228}
229
230impl ConsumeErrLayer<Trace> {
231 pub const fn trace(level: tracing::Level) -> Self {
233 Self::new(Trace(level))
234 }
235}
236
237impl<F> ConsumeErrLayer<F, DefaulResponse> {
238 pub fn with_response<R>(self, response: R) -> ConsumeErrLayer<F, StaticResponse<R>> {
242 ConsumeErrLayer {
243 f: self.f,
244 response: StaticResponse(response),
245 }
246 }
247}
248
249impl<S, F, R> Layer<S> for ConsumeErrLayer<F, R>
250where
251 F: Clone,
252 R: Clone,
253{
254 type Service = ConsumeErr<S, F, R>;
255
256 fn layer(&self, inner: S) -> Self::Service {
257 ConsumeErr {
258 f: self.f.clone(),
259 inner,
260 response: self.response.clone(),
261 }
262 }
263
264 fn into_layer(self, inner: S) -> Self::Service {
265 ConsumeErr {
266 f: self.f,
267 inner,
268 response: self.response,
269 }
270 }
271}
272
273mod sealed {
274 #[derive(Debug, Clone)]
275 pub struct Trace(pub tracing::Level);
280
281 #[derive(Debug, Clone)]
282 #[non_exhaustive]
283 pub struct DefaulResponse;
285
286 #[derive(Debug, Clone)]
287 #[non_exhaustive]
288 pub struct StaticResponse<R>(pub(super) R);
290}