1use crate::{Body, HeaderValue, Request, Response, StatusCode};
90use futures_lite::future::FutureExt;
91use rama_core::{Context, Layer, Service};
92use rama_utils::macros::define_inner_service_accessors;
93use std::fmt;
94use std::{any::Any, panic::AssertUnwindSafe};
95
96pub struct CatchPanicLayer<T> {
101 panic_handler: T,
102}
103
104impl<T: fmt::Debug> fmt::Debug for CatchPanicLayer<T> {
105 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106 f.debug_struct("CatchPanicLayer")
107 .field("panic_handler", &self.panic_handler)
108 .finish()
109 }
110}
111
112impl<T: Clone> Clone for CatchPanicLayer<T> {
113 fn clone(&self) -> Self {
114 Self {
115 panic_handler: self.panic_handler.clone(),
116 }
117 }
118}
119
120impl Default for CatchPanicLayer<DefaultResponseForPanic> {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl CatchPanicLayer<DefaultResponseForPanic> {
127 pub const fn new() -> Self {
129 CatchPanicLayer {
130 panic_handler: DefaultResponseForPanic,
131 }
132 }
133}
134
135impl<T> CatchPanicLayer<T> {
136 pub fn custom(panic_handler: T) -> Self
138 where
139 T: ResponseForPanic,
140 {
141 Self { panic_handler }
142 }
143}
144
145impl<T, S> Layer<S> for CatchPanicLayer<T>
146where
147 T: Clone,
148{
149 type Service = CatchPanic<S, T>;
150
151 fn layer(&self, inner: S) -> Self::Service {
152 CatchPanic {
153 inner,
154 panic_handler: self.panic_handler.clone(),
155 }
156 }
157
158 fn into_layer(self, inner: S) -> Self::Service {
159 CatchPanic {
160 inner,
161 panic_handler: self.panic_handler,
162 }
163 }
164}
165
166pub struct CatchPanic<S, T> {
170 inner: S,
171 panic_handler: T,
172}
173
174impl<S> CatchPanic<S, DefaultResponseForPanic> {
175 pub const fn new(inner: S) -> Self {
177 Self {
178 inner,
179 panic_handler: DefaultResponseForPanic,
180 }
181 }
182}
183
184impl<S, T> CatchPanic<S, T> {
185 define_inner_service_accessors!();
186
187 pub const fn custom(inner: S, panic_handler: T) -> Self
189 where
190 T: ResponseForPanic,
191 {
192 Self {
193 inner,
194 panic_handler,
195 }
196 }
197}
198
199impl<S: fmt::Debug, T: fmt::Debug> fmt::Debug for CatchPanic<S, T> {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_struct("CatchPanic")
202 .field("inner", &self.inner)
203 .field("panic_handler", &self.panic_handler)
204 .finish()
205 }
206}
207
208impl<S: Clone, T: Clone> Clone for CatchPanic<S, T> {
209 fn clone(&self) -> Self {
210 CatchPanic {
211 inner: self.inner.clone(),
212 panic_handler: self.panic_handler.clone(),
213 }
214 }
215}
216
217impl<State, S, T, ReqBody, ResBody> Service<State, Request<ReqBody>> for CatchPanic<S, T>
218where
219 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
220 ResBody: Into<Body> + Send + 'static,
221 T: ResponseForPanic + Clone + Send + Sync + 'static,
222 ReqBody: Send + 'static,
223 ResBody: Send + 'static,
224 State: Clone + Send + Sync + 'static,
225{
226 type Response = Response;
227 type Error = S::Error;
228
229 async fn serve(
230 &self,
231 ctx: Context<State>,
232 req: Request<ReqBody>,
233 ) -> Result<Self::Response, Self::Error> {
234 let future = match std::panic::catch_unwind(AssertUnwindSafe(|| self.inner.serve(ctx, req)))
235 {
236 Ok(future) => future,
237 Err(panic_err) => return Ok(self.panic_handler.response_for_panic(panic_err)),
238 };
239 match AssertUnwindSafe(future).catch_unwind().await {
240 Ok(res) => match res {
241 Ok(res) => Ok(res.map(Into::into)),
242 Err(err) => Err(err),
243 },
244 Err(panic_err) => Ok(self.panic_handler.response_for_panic(panic_err)),
245 }
246 }
247}
248
249pub trait ResponseForPanic: Clone {
251 fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response<Body>;
253}
254
255impl<F> ResponseForPanic for F
256where
257 F: Fn(Box<dyn Any + Send + 'static>) -> Response + Clone,
258{
259 fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response {
260 self(err)
261 }
262}
263
264#[derive(Debug, Default, Clone)]
269#[non_exhaustive]
270pub struct DefaultResponseForPanic;
271
272impl ResponseForPanic for DefaultResponseForPanic {
273 fn response_for_panic(&self, err: Box<dyn Any + Send + 'static>) -> Response {
274 if let Some(s) = err.downcast_ref::<String>() {
275 tracing::error!("Service panicked: {}", s);
276 } else if let Some(s) = err.downcast_ref::<&str>() {
277 tracing::error!("Service panicked: {}", s);
278 } else {
279 tracing::error!(
280 "Service panicked but `CatchPanic` was unable to downcast the panic info"
281 );
282 };
283
284 let mut res = Response::new(Body::from("Service panicked"));
285 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
286
287 #[allow(clippy::declare_interior_mutable_const)]
288 const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
289 res.headers_mut()
290 .insert(rama_http_types::header::CONTENT_TYPE, TEXT_PLAIN);
291
292 res
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 #![allow(unreachable_code)]
299
300 use super::*;
301
302 use crate::dep::http_body_util::BodyExt;
303 use crate::{Body, Response};
304 use rama_core::service::service_fn;
305 use rama_core::{Context, Service};
306 use std::convert::Infallible;
307
308 #[tokio::test]
309 async fn panic_before_returning_future() {
310 let svc = CatchPanicLayer::new().into_layer(service_fn(|_: Request| {
311 panic!("service panic");
312 async { Ok::<_, Infallible>(Response::new(Body::empty())) }
313 }));
314
315 let req = Request::new(Body::empty());
316
317 let res = svc.serve(Context::default(), req).await.unwrap();
318
319 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
320 let body = res.into_body().collect().await.unwrap().to_bytes();
321 assert_eq!(&body[..], b"Service panicked");
322 }
323
324 #[tokio::test]
325 async fn panic_in_future() {
326 let svc = CatchPanicLayer::new().into_layer(service_fn(async |_: Request<Body>| {
327 panic!("future panic");
328 Ok::<_, Infallible>(Response::new(Body::empty()))
329 }));
330
331 let req = Request::new(Body::empty());
332
333 let res = svc.serve(Context::default(), req).await.unwrap();
334
335 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
336 let body = res.into_body().collect().await.unwrap().to_bytes();
337 assert_eq!(&body[..], b"Service panicked");
338 }
339}