async_lsp/
router.rs

1//! Dispatch requests and notifications to individual handlers.
2use std::any::TypeId;
3use std::collections::HashMap;
4use std::future::{ready, Future};
5use std::ops::ControlFlow;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use lsp_types::notification::Notification;
10use lsp_types::request::Request;
11use tower_service::Service;
12
13use crate::{
14    AnyEvent, AnyNotification, AnyRequest, ErrorCode, JsonValue, LspService, ResponseError, Result,
15};
16
17/// A router dispatching requests and notifications to individual handlers.
18pub struct Router<St, Error = ResponseError> {
19    state: St,
20    req_handlers: HashMap<&'static str, BoxReqHandler<St, Error>>,
21    notif_handlers: HashMap<&'static str, BoxNotifHandler<St>>,
22    event_handlers: HashMap<TypeId, BoxEventHandler<St>>,
23    unhandled_req: BoxReqHandler<St, Error>,
24    unhandled_notif: BoxNotifHandler<St>,
25    unhandled_event: BoxEventHandler<St>,
26}
27
28type BoxReqFuture<Error> = Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>;
29type BoxReqHandler<St, Error> = Box<dyn Fn(&mut St, AnyRequest) -> BoxReqFuture<Error> + Send>;
30type BoxNotifHandler<St> = Box<dyn Fn(&mut St, AnyNotification) -> ControlFlow<Result<()>> + Send>;
31type BoxEventHandler<St> = Box<dyn Fn(&mut St, AnyEvent) -> ControlFlow<Result<()>> + Send>;
32
33impl<St, Error> Default for Router<St, Error>
34where
35    St: Default,
36    Error: From<ResponseError> + Send + 'static,
37{
38    fn default() -> Self {
39        Self::new(St::default())
40    }
41}
42
43// TODO: Make it possible to construct with arbitrary `Error`, with no default handlers.
44impl<St, Error> Router<St, Error>
45where
46    Error: From<ResponseError> + Send + 'static,
47{
48    /// Create a empty `Router`.
49    #[must_use]
50    pub fn new(state: St) -> Self {
51        Self {
52            state,
53            req_handlers: HashMap::new(),
54            notif_handlers: HashMap::new(),
55            event_handlers: HashMap::new(),
56            unhandled_req: Box::new(|_, req| {
57                Box::pin(ready(Err(ResponseError {
58                    code: ErrorCode::METHOD_NOT_FOUND,
59                    message: format!("No such method {}", req.method),
60                    data: None,
61                }
62                .into())))
63            }),
64            unhandled_notif: Box::new(|_, notif| {
65                if notif.method.starts_with("$/") {
66                    ControlFlow::Continue(())
67                } else {
68                    ControlFlow::Break(Err(crate::Error::Routing(format!(
69                        "Unhandled notification: {}",
70                        notif.method,
71                    ))))
72                }
73            }),
74            unhandled_event: Box::new(|_, event| {
75                ControlFlow::Break(Err(crate::Error::Routing(format!(
76                    "Unhandled event: {event:?}"
77                ))))
78            }),
79        }
80    }
81
82    /// Add an asynchronous request handler for a specific LSP request `R`.
83    ///
84    /// If handler for the method already exists, it replaces the old one.
85    pub fn request<R: Request, Fut>(
86        &mut self,
87        handler: impl Fn(&mut St, R::Params) -> Fut + Send + 'static,
88    ) -> &mut Self
89    where
90        Fut: Future<Output = Result<R::Result, Error>> + Send + 'static,
91    {
92        self.req_handlers.insert(
93            R::METHOD,
94            Box::new(
95                move |state, req| match serde_json::from_value::<R::Params>(req.params) {
96                    Ok(params) => {
97                        let fut = handler(state, params);
98                        Box::pin(async move {
99                            Ok(serde_json::to_value(fut.await?).expect("Serialization failed"))
100                        })
101                    }
102                    Err(err) => Box::pin(ready(Err(ResponseError {
103                        code: ErrorCode::INVALID_PARAMS,
104                        message: format!("Failed to deserialize parameters: {err}"),
105                        data: None,
106                    }
107                    .into()))),
108                },
109            ),
110        );
111        self
112    }
113
114    /// Add a synchronous request handler for a specific LSP notification `N`.
115    ///
116    /// If handler for the method already exists, it replaces the old one.
117    pub fn notification<N: Notification>(
118        &mut self,
119        handler: impl Fn(&mut St, N::Params) -> ControlFlow<Result<()>> + Send + 'static,
120    ) -> &mut Self {
121        self.notif_handlers.insert(
122            N::METHOD,
123            Box::new(
124                move |state, notif| match serde_json::from_value::<N::Params>(notif.params) {
125                    Ok(params) => handler(state, params),
126                    Err(err) => ControlFlow::Break(Err(err.into())),
127                },
128            ),
129        );
130        self
131    }
132
133    /// Add a synchronous event handler for event type `E`.
134    ///
135    /// If handler for the method already exists, it replaces the old one.
136    pub fn event<E: Send + 'static>(
137        &mut self,
138        handler: impl Fn(&mut St, E) -> ControlFlow<Result<()>> + Send + 'static,
139    ) -> &mut Self {
140        self.event_handlers.insert(
141            TypeId::of::<E>(),
142            Box::new(move |state, event| {
143                let event = event.downcast::<E>().expect("Checked TypeId");
144                handler(state, event)
145            }),
146        );
147        self
148    }
149
150    /// Set an asynchronous catch-all request handler for any requests with no corresponding handler
151    /// for its `method`.
152    ///
153    /// There can only be a single catch-all request handler. New ones replace old ones.
154    ///
155    /// The default handler is to respond a error response with code
156    /// [`ErrorCode::METHOD_NOT_FOUND`].
157    pub fn unhandled_request<Fut>(
158        &mut self,
159        handler: impl Fn(&mut St, AnyRequest) -> Fut + Send + 'static,
160    ) -> &mut Self
161    where
162        Fut: Future<Output = Result<JsonValue, Error>> + Send + 'static,
163    {
164        self.unhandled_req = Box::new(move |state, req| Box::pin(handler(state, req)));
165        self
166    }
167
168    /// Set a synchronous catch-all notification handler for any notifications with no
169    /// corresponding handler for its `method`.
170    ///
171    /// There can only be a single catch-all notification handler. New ones replace old ones.
172    ///
173    /// The default handler is to do nothing for methods starting with `$/`, and break the main
174    /// loop with [`Error::Routing`][crate::Error::Routing] for other methods. Typically
175    /// notifications are critical and
176    /// losing them can break state synchronization, easily leading to catastrophic failures after
177    /// incorrect incremental changes.
178    pub fn unhandled_notification(
179        &mut self,
180        handler: impl Fn(&mut St, AnyNotification) -> ControlFlow<Result<()>> + Send + 'static,
181    ) -> &mut Self {
182        self.unhandled_notif = Box::new(handler);
183        self
184    }
185
186    /// Set a synchronous catch-all event handler for any notifications with no
187    /// corresponding handler for its type.
188    ///
189    /// There can only be a single catch-all event handler. New ones replace old ones.
190    ///
191    /// The default handler is to break the main loop with
192    /// [`Error::Routing`][crate::Error::Routing]. Since events are
193    /// emitted internally, mishandling are typically logic errors.
194    pub fn unhandled_event(
195        &mut self,
196        handler: impl Fn(&mut St, AnyEvent) -> ControlFlow<Result<()>> + Send + 'static,
197    ) -> &mut Self {
198        self.unhandled_event = Box::new(handler);
199        self
200    }
201}
202
203impl<St, Error> Service<AnyRequest> for Router<St, Error> {
204    type Response = JsonValue;
205    type Error = Error;
206    type Future = BoxReqFuture<Error>;
207
208    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209        Poll::Ready(Ok(()))
210    }
211
212    fn call(&mut self, req: AnyRequest) -> Self::Future {
213        let h = self
214            .req_handlers
215            .get(&*req.method)
216            .unwrap_or(&self.unhandled_req);
217        h(&mut self.state, req)
218    }
219}
220
221impl<St> LspService for Router<St> {
222    fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
223        let h = self
224            .notif_handlers
225            .get(&*notif.method)
226            .unwrap_or(&self.unhandled_notif);
227        h(&mut self.state, notif)
228    }
229
230    fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
231        let h = self
232            .event_handlers
233            .get(&event.inner_type_id())
234            .unwrap_or(&self.unhandled_event);
235        h(&mut self.state, event)
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    fn _assert_send<St: Send>(router: Router<St>) -> impl Send {
244        router
245    }
246}