async_lsp/
panic.rs

1//! Catch panics of underlying handlers and turn them into error responses.
2//!
3//! *Applies to both Language Servers and Language Clients.*
4use std::any::Any;
5use std::future::Future;
6use std::ops::ControlFlow;
7use std::panic::{catch_unwind, AssertUnwindSafe};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use pin_project_lite::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::{AnyEvent, AnyNotification, AnyRequest, ErrorCode, LspService, ResponseError, Result};
16
17/// The middleware catching panics of underlying handlers and turn them into error responses.
18///
19/// See [module level documentations](self) for details.
20pub struct CatchUnwind<S: LspService> {
21    service: S,
22    handler: Handler<S::Error>,
23}
24
25define_getters!(impl[S: LspService] CatchUnwind<S>, service: S);
26
27type Handler<E> = fn(method: &str, payload: Box<dyn Any + Send>) -> E;
28
29fn default_handler(method: &str, payload: Box<dyn Any + Send>) -> ResponseError {
30    let msg = match payload.downcast::<String>() {
31        Ok(msg) => *msg,
32        Err(payload) => match payload.downcast::<&'static str>() {
33            Ok(msg) => (*msg).into(),
34            Err(_payload) => "unknown".into(),
35        },
36    };
37    ResponseError {
38        code: ErrorCode::INTERNAL_ERROR,
39        message: format!("Request handler of {method} panicked: {msg}"),
40        data: None,
41    }
42}
43
44impl<S: LspService> Service<AnyRequest> for CatchUnwind<S> {
45    type Response = S::Response;
46    type Error = S::Error;
47    type Future = ResponseFuture<S::Future, S::Error>;
48
49    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50        self.service.poll_ready(cx)
51    }
52
53    fn call(&mut self, req: AnyRequest) -> Self::Future {
54        let method = req.method.clone();
55        // FIXME: Clarify conditions of UnwindSafe.
56        match catch_unwind(AssertUnwindSafe(|| self.service.call(req)))
57            .map_err(|err| (self.handler)(&method, err))
58        {
59            Ok(fut) => ResponseFuture {
60                inner: ResponseFutureInner::Future {
61                    fut,
62                    method,
63                    handler: self.handler,
64                },
65            },
66            Err(err) => ResponseFuture {
67                inner: ResponseFutureInner::Ready { err: Some(err) },
68            },
69        }
70    }
71}
72
73pin_project! {
74    /// The [`Future`] type used by the [`CatchUnwind`] middleware.
75    pub struct ResponseFuture<Fut, Error> {
76        #[pin]
77        inner: ResponseFutureInner<Fut, Error>,
78    }
79}
80
81pin_project! {
82    #[project = ResponseFutureProj]
83    enum ResponseFutureInner<Fut, Error> {
84        Future {
85            #[pin]
86            fut: Fut,
87            method: String,
88            handler: Handler<Error>,
89        },
90        Ready {
91            err: Option<Error>,
92        },
93    }
94}
95
96impl<Response, Fut, Error> Future for ResponseFuture<Fut, Error>
97where
98    Fut: Future<Output = Result<Response, Error>>,
99{
100    type Output = Fut::Output;
101
102    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
103        match self.project().inner.project() {
104            ResponseFutureProj::Future {
105                fut,
106                method,
107                handler,
108            } => {
109                // FIXME: Clarify conditions of UnwindSafe.
110                match catch_unwind(AssertUnwindSafe(|| fut.poll(cx))) {
111                    Ok(poll) => poll,
112                    Err(payload) => Poll::Ready(Err(handler(method, payload))),
113                }
114            }
115            ResponseFutureProj::Ready { err } => Poll::Ready(Err(err.take().expect("Completed"))),
116        }
117    }
118}
119
120impl<S: LspService> LspService for CatchUnwind<S> {
121    fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
122        self.service.notify(notif)
123    }
124
125    fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
126        self.service.emit(event)
127    }
128}
129
130/// The builder of [`CatchUnwind`] middleware.
131///
132/// It's [`Default`] configuration tries to downcast the panic payload into `String` or `&str`, and
133/// fallback to format it via [`std::fmt::Display`], as the error message.
134/// The error code is set to [`ErrorCode::INTERNAL_ERROR`].
135#[derive(Clone)]
136#[must_use]
137pub struct CatchUnwindBuilder<Error = ResponseError> {
138    handler: Handler<Error>,
139}
140
141impl Default for CatchUnwindBuilder<ResponseError> {
142    fn default() -> Self {
143        Self::new_with_handler(default_handler)
144    }
145}
146
147impl<Error> CatchUnwindBuilder<Error> {
148    /// Create the builder of [`CatchUnwind`] middleware with a custom handler converting panic
149    /// payloads into [`ResponseError`].
150    pub fn new_with_handler(handler: Handler<Error>) -> Self {
151        Self { handler }
152    }
153}
154
155/// A type alias of [`CatchUnwindBuilder`] conforming to the naming convention of [`tower_layer`].
156pub type CatchUnwindLayer<Error = ResponseError> = CatchUnwindBuilder<Error>;
157
158impl<S: LspService> Layer<S> for CatchUnwindBuilder<S::Error> {
159    type Service = CatchUnwind<S>;
160
161    fn layer(&self, inner: S) -> Self::Service {
162        CatchUnwind {
163            service: inner,
164            handler: self.handler,
165        }
166    }
167}