rama_http/layer/
error_handling.rs1use crate::service::web::response::IntoResponse;
49use crate::{Request, Response};
50use rama_core::{Context, Layer, Service};
51use rama_utils::macros::define_inner_service_accessors;
52use std::{convert::Infallible, fmt};
53
54pub struct ErrorHandlerLayer<F = ()> {
56 error_mapper: F,
57}
58
59impl Default for ErrorHandlerLayer {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl<F: fmt::Debug> fmt::Debug for ErrorHandlerLayer<F> {
66 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67 f.debug_struct("ErrorHandlerLayer")
68 .field("error_mapper", &self.error_mapper)
69 .finish()
70 }
71}
72
73impl<F: Clone> Clone for ErrorHandlerLayer<F> {
74 fn clone(&self) -> Self {
75 Self {
76 error_mapper: self.error_mapper.clone(),
77 }
78 }
79}
80
81impl ErrorHandlerLayer {
82 pub const fn new() -> Self {
84 Self { error_mapper: () }
85 }
86
87 pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandlerLayer<F> {
92 ErrorHandlerLayer { error_mapper }
93 }
94}
95
96impl<S, F: Clone> Layer<S> for ErrorHandlerLayer<F> {
97 type Service = ErrorHandler<S, F>;
98
99 fn layer(&self, inner: S) -> Self::Service {
100 ErrorHandler::new(inner).error_mapper(self.error_mapper.clone())
101 }
102
103 fn into_layer(self, inner: S) -> Self::Service {
104 ErrorHandler::new(inner).error_mapper(self.error_mapper)
105 }
106}
107
108pub struct ErrorHandler<S, F = ()> {
110 inner: S,
111 error_mapper: F,
112}
113
114impl<S: fmt::Debug, F: fmt::Debug> fmt::Debug for ErrorHandler<S, F> {
115 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116 f.debug_struct("ErrorHandler")
117 .field("inner", &self.inner)
118 .field("error_mapper", &self.error_mapper)
119 .finish()
120 }
121}
122
123impl<S: Clone, F: Clone> Clone for ErrorHandler<S, F> {
124 fn clone(&self) -> Self {
125 Self {
126 inner: self.inner.clone(),
127 error_mapper: self.error_mapper.clone(),
128 }
129 }
130}
131
132impl<S> ErrorHandler<S> {
133 pub const fn new(inner: S) -> Self {
135 Self {
136 inner,
137 error_mapper: (),
138 }
139 }
140
141 define_inner_service_accessors!();
142
143 pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandler<S, F> {
148 ErrorHandler {
149 inner: self.inner,
150 error_mapper,
151 }
152 }
153}
154
155impl<S, State, Body> Service<State, Request<Body>> for ErrorHandler<S, ()>
156where
157 S: Service<State, Request<Body>, Response: IntoResponse, Error: IntoResponse>,
158 State: Clone + Send + Sync + 'static,
159 Body: Send + 'static,
160{
161 type Response = Response;
162 type Error = Infallible;
163
164 async fn serve(
165 &self,
166 ctx: Context<State>,
167 req: Request<Body>,
168 ) -> Result<Self::Response, Self::Error> {
169 match self.inner.serve(ctx, req).await {
170 Ok(response) => Ok(response.into_response()),
171 Err(error) => Ok(error.into_response()),
172 }
173 }
174}
175
176impl<S, F, R, State, Body> Service<State, Request<Body>> for ErrorHandler<S, F>
177where
178 S: Service<State, Request<Body>, Response: IntoResponse>,
179 F: Fn(S::Error) -> R + Clone + Send + Sync + 'static,
180 R: IntoResponse + 'static,
181 State: Clone + Send + Sync + 'static,
182 Body: Send + 'static,
183{
184 type Response = Response;
185 type Error = Infallible;
186
187 async fn serve(
188 &self,
189 ctx: Context<State>,
190 req: Request<Body>,
191 ) -> Result<Self::Response, Self::Error> {
192 match self.inner.serve(ctx, req).await {
193 Ok(response) => Ok(response.into_response()),
194 Err(error) => Ok((self.error_mapper)(error).into_response()),
195 }
196 }
197}