rama_http/layer/
error_handling.rs1use crate::{IntoResponse, Request, Response};
48use rama_core::{Context, Layer, Service};
49use rama_utils::macros::define_inner_service_accessors;
50use std::{convert::Infallible, fmt};
51
52pub struct ErrorHandlerLayer<F = ()> {
54 error_mapper: F,
55}
56
57impl Default for ErrorHandlerLayer {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl<F: fmt::Debug> fmt::Debug for ErrorHandlerLayer<F> {
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 f.debug_struct("ErrorHandlerLayer")
66 .field("error_mapper", &self.error_mapper)
67 .finish()
68 }
69}
70
71impl<F: Clone> Clone for ErrorHandlerLayer<F> {
72 fn clone(&self) -> Self {
73 Self {
74 error_mapper: self.error_mapper.clone(),
75 }
76 }
77}
78
79impl ErrorHandlerLayer {
80 pub const fn new() -> Self {
82 Self { error_mapper: () }
83 }
84
85 pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandlerLayer<F> {
90 ErrorHandlerLayer { error_mapper }
91 }
92}
93
94impl<S, F: Clone> Layer<S> for ErrorHandlerLayer<F> {
95 type Service = ErrorHandler<S, F>;
96
97 fn layer(&self, inner: S) -> Self::Service {
98 ErrorHandler::new(inner).error_mapper(self.error_mapper.clone())
99 }
100
101 fn into_layer(self, inner: S) -> Self::Service {
102 ErrorHandler::new(inner).error_mapper(self.error_mapper)
103 }
104}
105
106pub struct ErrorHandler<S, F = ()> {
108 inner: S,
109 error_mapper: F,
110}
111
112impl<S: fmt::Debug, F: fmt::Debug> fmt::Debug for ErrorHandler<S, F> {
113 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
114 f.debug_struct("ErrorHandler")
115 .field("inner", &self.inner)
116 .field("error_mapper", &self.error_mapper)
117 .finish()
118 }
119}
120
121impl<S: Clone, F: Clone> Clone for ErrorHandler<S, F> {
122 fn clone(&self) -> Self {
123 Self {
124 inner: self.inner.clone(),
125 error_mapper: self.error_mapper.clone(),
126 }
127 }
128}
129
130impl<S> ErrorHandler<S> {
131 pub const fn new(inner: S) -> Self {
133 Self {
134 inner,
135 error_mapper: (),
136 }
137 }
138
139 define_inner_service_accessors!();
140
141 pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandler<S, F> {
146 ErrorHandler {
147 inner: self.inner,
148 error_mapper,
149 }
150 }
151}
152
153impl<S, State, Body> Service<State, Request<Body>> for ErrorHandler<S, ()>
154where
155 S: Service<State, Request<Body>, Response: IntoResponse, Error: IntoResponse>,
156 State: Clone + Send + Sync + 'static,
157 Body: Send + 'static,
158{
159 type Response = Response;
160 type Error = Infallible;
161
162 async fn serve(
163 &self,
164 ctx: Context<State>,
165 req: Request<Body>,
166 ) -> Result<Self::Response, Self::Error> {
167 match self.inner.serve(ctx, req).await {
168 Ok(response) => Ok(response.into_response()),
169 Err(error) => Ok(error.into_response()),
170 }
171 }
172}
173
174impl<S, F, R, State, Body> Service<State, Request<Body>> for ErrorHandler<S, F>
175where
176 S: Service<State, Request<Body>, Response: IntoResponse>,
177 F: Fn(S::Error) -> R + Clone + Send + Sync + 'static,
178 R: IntoResponse + 'static,
179 State: Clone + Send + Sync + 'static,
180 Body: Send + 'static,
181{
182 type Response = Response;
183 type Error = Infallible;
184
185 async fn serve(
186 &self,
187 ctx: Context<State>,
188 req: Request<Body>,
189 ) -> Result<Self::Response, Self::Error> {
190 match self.inner.serve(ctx, req).await {
191 Ok(response) => Ok(response.into_response()),
192 Err(error) => Ok((self.error_mapper)(error).into_response()),
193 }
194 }
195}