1use std::any::TypeId;
3use std::collections::HashMap;
4use std::future::{ready, Future};
5use std::ops::ControlFlow;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use lsp_types::notification::Notification;
10use lsp_types::request::Request;
11use tower_service::Service;
12
13use crate::{
14 AnyEvent, AnyNotification, AnyRequest, ErrorCode, JsonValue, LspService, ResponseError, Result,
15};
16
17pub struct Router<St, Error = ResponseError> {
19 state: St,
20 req_handlers: HashMap<&'static str, BoxReqHandler<St, Error>>,
21 notif_handlers: HashMap<&'static str, BoxNotifHandler<St>>,
22 event_handlers: HashMap<TypeId, BoxEventHandler<St>>,
23 unhandled_req: BoxReqHandler<St, Error>,
24 unhandled_notif: BoxNotifHandler<St>,
25 unhandled_event: BoxEventHandler<St>,
26}
27
28type BoxReqFuture<Error> = Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>;
29type BoxReqHandler<St, Error> = Box<dyn Fn(&mut St, AnyRequest) -> BoxReqFuture<Error> + Send>;
30type BoxNotifHandler<St> = Box<dyn Fn(&mut St, AnyNotification) -> ControlFlow<Result<()>> + Send>;
31type BoxEventHandler<St> = Box<dyn Fn(&mut St, AnyEvent) -> ControlFlow<Result<()>> + Send>;
32
33impl<St, Error> Default for Router<St, Error>
34where
35 St: Default,
36 Error: From<ResponseError> + Send + 'static,
37{
38 fn default() -> Self {
39 Self::new(St::default())
40 }
41}
42
43impl<St, Error> Router<St, Error>
45where
46 Error: From<ResponseError> + Send + 'static,
47{
48 #[must_use]
50 pub fn new(state: St) -> Self {
51 Self {
52 state,
53 req_handlers: HashMap::new(),
54 notif_handlers: HashMap::new(),
55 event_handlers: HashMap::new(),
56 unhandled_req: Box::new(|_, req| {
57 Box::pin(ready(Err(ResponseError {
58 code: ErrorCode::METHOD_NOT_FOUND,
59 message: format!("No such method {}", req.method),
60 data: None,
61 }
62 .into())))
63 }),
64 unhandled_notif: Box::new(|_, notif| {
65 if notif.method.starts_with("$/") {
66 ControlFlow::Continue(())
67 } else {
68 ControlFlow::Break(Err(crate::Error::Routing(format!(
69 "Unhandled notification: {}",
70 notif.method,
71 ))))
72 }
73 }),
74 unhandled_event: Box::new(|_, event| {
75 ControlFlow::Break(Err(crate::Error::Routing(format!(
76 "Unhandled event: {event:?}"
77 ))))
78 }),
79 }
80 }
81
82 pub fn request<R: Request, Fut>(
86 &mut self,
87 handler: impl Fn(&mut St, R::Params) -> Fut + Send + 'static,
88 ) -> &mut Self
89 where
90 Fut: Future<Output = Result<R::Result, Error>> + Send + 'static,
91 {
92 self.req_handlers.insert(
93 R::METHOD,
94 Box::new(
95 move |state, req| match serde_json::from_value::<R::Params>(req.params) {
96 Ok(params) => {
97 let fut = handler(state, params);
98 Box::pin(async move {
99 Ok(serde_json::to_value(fut.await?).expect("Serialization failed"))
100 })
101 }
102 Err(err) => Box::pin(ready(Err(ResponseError {
103 code: ErrorCode::INVALID_PARAMS,
104 message: format!("Failed to deserialize parameters: {err}"),
105 data: None,
106 }
107 .into()))),
108 },
109 ),
110 );
111 self
112 }
113
114 pub fn notification<N: Notification>(
118 &mut self,
119 handler: impl Fn(&mut St, N::Params) -> ControlFlow<Result<()>> + Send + 'static,
120 ) -> &mut Self {
121 self.notif_handlers.insert(
122 N::METHOD,
123 Box::new(
124 move |state, notif| match serde_json::from_value::<N::Params>(notif.params) {
125 Ok(params) => handler(state, params),
126 Err(err) => ControlFlow::Break(Err(err.into())),
127 },
128 ),
129 );
130 self
131 }
132
133 pub fn event<E: Send + 'static>(
137 &mut self,
138 handler: impl Fn(&mut St, E) -> ControlFlow<Result<()>> + Send + 'static,
139 ) -> &mut Self {
140 self.event_handlers.insert(
141 TypeId::of::<E>(),
142 Box::new(move |state, event| {
143 let event = event.downcast::<E>().expect("Checked TypeId");
144 handler(state, event)
145 }),
146 );
147 self
148 }
149
150 pub fn unhandled_request<Fut>(
158 &mut self,
159 handler: impl Fn(&mut St, AnyRequest) -> Fut + Send + 'static,
160 ) -> &mut Self
161 where
162 Fut: Future<Output = Result<JsonValue, Error>> + Send + 'static,
163 {
164 self.unhandled_req = Box::new(move |state, req| Box::pin(handler(state, req)));
165 self
166 }
167
168 pub fn unhandled_notification(
179 &mut self,
180 handler: impl Fn(&mut St, AnyNotification) -> ControlFlow<Result<()>> + Send + 'static,
181 ) -> &mut Self {
182 self.unhandled_notif = Box::new(handler);
183 self
184 }
185
186 pub fn unhandled_event(
195 &mut self,
196 handler: impl Fn(&mut St, AnyEvent) -> ControlFlow<Result<()>> + Send + 'static,
197 ) -> &mut Self {
198 self.unhandled_event = Box::new(handler);
199 self
200 }
201}
202
203impl<St, Error> Service<AnyRequest> for Router<St, Error> {
204 type Response = JsonValue;
205 type Error = Error;
206 type Future = BoxReqFuture<Error>;
207
208 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
209 Poll::Ready(Ok(()))
210 }
211
212 fn call(&mut self, req: AnyRequest) -> Self::Future {
213 let h = self
214 .req_handlers
215 .get(&*req.method)
216 .unwrap_or(&self.unhandled_req);
217 h(&mut self.state, req)
218 }
219}
220
221impl<St> LspService for Router<St> {
222 fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
223 let h = self
224 .notif_handlers
225 .get(&*notif.method)
226 .unwrap_or(&self.unhandled_notif);
227 h(&mut self.state, notif)
228 }
229
230 fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
231 let h = self
232 .event_handlers
233 .get(&event.inner_type_id())
234 .unwrap_or(&self.unhandled_event);
235 h(&mut self.state, event)
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 fn _assert_send<St: Send>(router: Router<St>) -> impl Send {
244 router
245 }
246}