1use std::any::Any;
5use std::future::Future;
6use std::ops::ControlFlow;
7use std::panic::{catch_unwind, AssertUnwindSafe};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11use pin_project_lite::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::{AnyEvent, AnyNotification, AnyRequest, ErrorCode, LspService, ResponseError, Result};
16
17pub struct CatchUnwind<S: LspService> {
21 service: S,
22 handler: Handler<S::Error>,
23}
24
25define_getters!(impl[S: LspService] CatchUnwind<S>, service: S);
26
27type Handler<E> = fn(method: &str, payload: Box<dyn Any + Send>) -> E;
28
29fn default_handler(method: &str, payload: Box<dyn Any + Send>) -> ResponseError {
30 let msg = match payload.downcast::<String>() {
31 Ok(msg) => *msg,
32 Err(payload) => match payload.downcast::<&'static str>() {
33 Ok(msg) => (*msg).into(),
34 Err(_payload) => "unknown".into(),
35 },
36 };
37 ResponseError {
38 code: ErrorCode::INTERNAL_ERROR,
39 message: format!("Request handler of {method} panicked: {msg}"),
40 data: None,
41 }
42}
43
44impl<S: LspService> Service<AnyRequest> for CatchUnwind<S> {
45 type Response = S::Response;
46 type Error = S::Error;
47 type Future = ResponseFuture<S::Future, S::Error>;
48
49 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50 self.service.poll_ready(cx)
51 }
52
53 fn call(&mut self, req: AnyRequest) -> Self::Future {
54 let method = req.method.clone();
55 match catch_unwind(AssertUnwindSafe(|| self.service.call(req)))
57 .map_err(|err| (self.handler)(&method, err))
58 {
59 Ok(fut) => ResponseFuture {
60 inner: ResponseFutureInner::Future {
61 fut,
62 method,
63 handler: self.handler,
64 },
65 },
66 Err(err) => ResponseFuture {
67 inner: ResponseFutureInner::Ready { err: Some(err) },
68 },
69 }
70 }
71}
72
73pin_project! {
74 pub struct ResponseFuture<Fut, Error> {
76 #[pin]
77 inner: ResponseFutureInner<Fut, Error>,
78 }
79}
80
81pin_project! {
82 #[project = ResponseFutureProj]
83 enum ResponseFutureInner<Fut, Error> {
84 Future {
85 #[pin]
86 fut: Fut,
87 method: String,
88 handler: Handler<Error>,
89 },
90 Ready {
91 err: Option<Error>,
92 },
93 }
94}
95
96impl<Response, Fut, Error> Future for ResponseFuture<Fut, Error>
97where
98 Fut: Future<Output = Result<Response, Error>>,
99{
100 type Output = Fut::Output;
101
102 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
103 match self.project().inner.project() {
104 ResponseFutureProj::Future {
105 fut,
106 method,
107 handler,
108 } => {
109 match catch_unwind(AssertUnwindSafe(|| fut.poll(cx))) {
111 Ok(poll) => poll,
112 Err(payload) => Poll::Ready(Err(handler(method, payload))),
113 }
114 }
115 ResponseFutureProj::Ready { err } => Poll::Ready(Err(err.take().expect("Completed"))),
116 }
117 }
118}
119
120impl<S: LspService> LspService for CatchUnwind<S> {
121 fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
122 self.service.notify(notif)
123 }
124
125 fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
126 self.service.emit(event)
127 }
128}
129
130#[derive(Clone)]
136#[must_use]
137pub struct CatchUnwindBuilder<Error = ResponseError> {
138 handler: Handler<Error>,
139}
140
141impl Default for CatchUnwindBuilder<ResponseError> {
142 fn default() -> Self {
143 Self::new_with_handler(default_handler)
144 }
145}
146
147impl<Error> CatchUnwindBuilder<Error> {
148 pub fn new_with_handler(handler: Handler<Error>) -> Self {
151 Self { handler }
152 }
153}
154
155pub type CatchUnwindLayer<Error = ResponseError> = CatchUnwindBuilder<Error>;
157
158impl<S: LspService> Layer<S> for CatchUnwindBuilder<S::Error> {
159 type Service = CatchUnwind<S>;
160
161 fn layer(&self, inner: S) -> Self::Service {
162 CatchUnwind {
163 service: inner,
164 handler: self.handler,
165 }
166 }
167}