rama_http/layer/
error_handling.rs1use crate::{IntoResponse, Request, Response};
48use rama_core::{Context, Layer, Service};
49use rama_utils::macros::define_inner_service_accessors;
50use std::{convert::Infallible, fmt};
51
52pub struct ErrorHandlerLayer<F = ()> {
54 error_mapper: F,
55}
56
57impl Default for ErrorHandlerLayer {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl<F: fmt::Debug> fmt::Debug for ErrorHandlerLayer<F> {
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 f.debug_struct("ErrorHandlerLayer")
66 .field("error_mapper", &self.error_mapper)
67 .finish()
68 }
69}
70
71impl<F: Clone> Clone for ErrorHandlerLayer<F> {
72 fn clone(&self) -> Self {
73 Self {
74 error_mapper: self.error_mapper.clone(),
75 }
76 }
77}
78
79impl ErrorHandlerLayer {
80 pub const fn new() -> Self {
82 Self { error_mapper: () }
83 }
84
85 pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandlerLayer<F> {
90 ErrorHandlerLayer { error_mapper }
91 }
92}
93
94impl<S, F: Clone> Layer<S> for ErrorHandlerLayer<F> {
95 type Service = ErrorHandler<S, F>;
96
97 fn layer(&self, inner: S) -> Self::Service {
98 ErrorHandler::new(inner).error_mapper(self.error_mapper.clone())
99 }
100}
101
102pub struct ErrorHandler<S, F = ()> {
104 inner: S,
105 error_mapper: F,
106}
107
108impl<S: fmt::Debug, F: fmt::Debug> fmt::Debug for ErrorHandler<S, F> {
109 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110 f.debug_struct("ErrorHandler")
111 .field("inner", &self.inner)
112 .field("error_mapper", &self.error_mapper)
113 .finish()
114 }
115}
116
117impl<S: Clone, F: Clone> Clone for ErrorHandler<S, F> {
118 fn clone(&self) -> Self {
119 Self {
120 inner: self.inner.clone(),
121 error_mapper: self.error_mapper.clone(),
122 }
123 }
124}
125
126impl<S> ErrorHandler<S> {
127 pub const fn new(inner: S) -> Self {
129 Self {
130 inner,
131 error_mapper: (),
132 }
133 }
134
135 define_inner_service_accessors!();
136
137 pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandler<S, F> {
142 ErrorHandler {
143 inner: self.inner,
144 error_mapper,
145 }
146 }
147}
148
149impl<S, State, Body> Service<State, Request<Body>> for ErrorHandler<S, ()>
150where
151 S: Service<State, Request<Body>, Response: IntoResponse, Error: IntoResponse>,
152 State: Clone + Send + Sync + 'static,
153 Body: Send + 'static,
154{
155 type Response = Response;
156 type Error = Infallible;
157
158 async fn serve(
159 &self,
160 ctx: Context<State>,
161 req: Request<Body>,
162 ) -> Result<Self::Response, Self::Error> {
163 match self.inner.serve(ctx, req).await {
164 Ok(response) => Ok(response.into_response()),
165 Err(error) => Ok(error.into_response()),
166 }
167 }
168}
169
170impl<S, F, R, State, Body> Service<State, Request<Body>> for ErrorHandler<S, F>
171where
172 S: Service<State, Request<Body>, Response: IntoResponse>,
173 F: Fn(S::Error) -> R + Clone + Send + Sync + 'static,
174 R: IntoResponse + 'static,
175 State: Clone + Send + Sync + 'static,
176 Body: Send + 'static,
177{
178 type Response = Response;
179 type Error = Infallible;
180
181 async fn serve(
182 &self,
183 ctx: Context<State>,
184 req: Request<Body>,
185 ) -> Result<Self::Response, Self::Error> {
186 match self.inner.serve(ctx, req).await {
187 Ok(response) => Ok(response.into_response()),
188 Err(error) => Ok((self.error_mapper)(error).into_response()),
189 }
190 }
191}