1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
//! Dispatch requests and notifications to individual handlers.
use std::any::TypeId;
use std::collections::HashMap;
use std::future::{ready, Future};
use std::ops::ControlFlow;
use std::pin::Pin;
use std::task::{Context, Poll};

use lsp_types::notification::Notification;
use lsp_types::request::Request;
use tower_service::Service;

use crate::{
    AnyEvent, AnyNotification, AnyRequest, ErrorCode, JsonValue, LspService, ResponseError, Result,
};

/// A router dispatching requests and notifications to individual handlers.
pub struct Router<St, Error = ResponseError> {
    state: St,
    req_handlers: HashMap<&'static str, BoxReqHandler<St, Error>>,
    notif_handlers: HashMap<&'static str, BoxNotifHandler<St>>,
    event_handlers: HashMap<TypeId, BoxEventHandler<St>>,
    unhandled_req: BoxReqHandler<St, Error>,
    unhandled_notif: BoxNotifHandler<St>,
    unhandled_event: BoxEventHandler<St>,
}

type BoxReqFuture<Error> = Pin<Box<dyn Future<Output = Result<JsonValue, Error>> + Send>>;
type BoxReqHandler<St, Error> = Box<dyn Fn(&mut St, AnyRequest) -> BoxReqFuture<Error> + Send>;
type BoxNotifHandler<St> = Box<dyn Fn(&mut St, AnyNotification) -> ControlFlow<Result<()>> + Send>;
type BoxEventHandler<St> = Box<dyn Fn(&mut St, AnyEvent) -> ControlFlow<Result<()>> + Send>;

impl<St, Error> Default for Router<St, Error>
where
    St: Default,
    Error: From<ResponseError> + Send + 'static,
{
    fn default() -> Self {
        Self::new(St::default())
    }
}

// TODO: Make it possible to construct with arbitrary `Error`, with no default handlers.
impl<St, Error> Router<St, Error>
where
    Error: From<ResponseError> + Send + 'static,
{
    /// Create a empty `Router`.
    #[must_use]
    pub fn new(state: St) -> Self {
        Self {
            state,
            req_handlers: HashMap::new(),
            notif_handlers: HashMap::new(),
            event_handlers: HashMap::new(),
            unhandled_req: Box::new(|_, req| {
                Box::pin(ready(Err(ResponseError {
                    code: ErrorCode::METHOD_NOT_FOUND,
                    message: format!("No such method {}", req.method),
                    data: None,
                }
                .into())))
            }),
            unhandled_notif: Box::new(|_, notif| {
                if notif.method.starts_with("$/") {
                    ControlFlow::Continue(())
                } else {
                    ControlFlow::Break(Err(crate::Error::Routing(format!(
                        "Unhandled notification: {}",
                        notif.method,
                    ))))
                }
            }),
            unhandled_event: Box::new(|_, event| {
                ControlFlow::Break(Err(crate::Error::Routing(format!(
                    "Unhandled event: {event:?}"
                ))))
            }),
        }
    }

    /// Add an asynchronous request handler for a specific LSP request `R`.
    ///
    /// If handler for the method already exists, it replaces the old one.
    pub fn request<R: Request, Fut>(
        &mut self,
        handler: impl Fn(&mut St, R::Params) -> Fut + Send + 'static,
    ) -> &mut Self
    where
        Fut: Future<Output = Result<R::Result, Error>> + Send + 'static,
    {
        self.req_handlers.insert(
            R::METHOD,
            Box::new(
                move |state, req| match serde_json::from_value::<R::Params>(req.params) {
                    Ok(params) => {
                        let fut = handler(state, params);
                        Box::pin(async move {
                            Ok(serde_json::to_value(fut.await?).expect("Serialization failed"))
                        })
                    }
                    Err(err) => Box::pin(ready(Err(ResponseError {
                        code: ErrorCode::INVALID_PARAMS,
                        message: format!("Failed to deserialize parameters: {err}"),
                        data: None,
                    }
                    .into()))),
                },
            ),
        );
        self
    }

    /// Add a synchronous request handler for a specific LSP notification `N`.
    ///
    /// If handler for the method already exists, it replaces the old one.
    pub fn notification<N: Notification>(
        &mut self,
        handler: impl Fn(&mut St, N::Params) -> ControlFlow<Result<()>> + Send + 'static,
    ) -> &mut Self {
        self.notif_handlers.insert(
            N::METHOD,
            Box::new(
                move |state, notif| match serde_json::from_value::<N::Params>(notif.params) {
                    Ok(params) => handler(state, params),
                    Err(err) => ControlFlow::Break(Err(err.into())),
                },
            ),
        );
        self
    }

    /// Add a synchronous event handler for event type `E`.
    ///
    /// If handler for the method already exists, it replaces the old one.
    pub fn event<E: Send + 'static>(
        &mut self,
        handler: impl Fn(&mut St, E) -> ControlFlow<Result<()>> + Send + 'static,
    ) -> &mut Self {
        self.event_handlers.insert(
            TypeId::of::<E>(),
            Box::new(move |state, event| {
                let event = event.downcast::<E>().expect("Checked TypeId");
                handler(state, event)
            }),
        );
        self
    }

    /// Set an asynchronous catch-all request handler for any requests with no corresponding handler
    /// for its `method`.
    ///
    /// There can only be a single catch-all request handler. New ones replace old ones.
    ///
    /// The default handler is to respond a error response with code
    /// [`ErrorCode::METHOD_NOT_FOUND`].
    pub fn unhandled_request<Fut>(
        &mut self,
        handler: impl Fn(&mut St, AnyRequest) -> Fut + Send + 'static,
    ) -> &mut Self
    where
        Fut: Future<Output = Result<JsonValue, Error>> + Send + 'static,
    {
        self.unhandled_req = Box::new(move |state, req| Box::pin(handler(state, req)));
        self
    }

    /// Set a synchronous catch-all notification handler for any notifications with no
    /// corresponding handler for its `method`.
    ///
    /// There can only be a single catch-all notification handler. New ones replace old ones.
    ///
    /// The default handler is to do nothing for methods starting with `$/`, and break the main
    /// loop with [`Error::Routing`][crate::Error::Routing] for other methods. Typically
    /// notifications are critical and
    /// losing them can break state synchronization, easily leading to catastrophic failures after
    /// incorrect incremental changes.
    pub fn unhandled_notification(
        &mut self,
        handler: impl Fn(&mut St, AnyNotification) -> ControlFlow<Result<()>> + Send + 'static,
    ) -> &mut Self {
        self.unhandled_notif = Box::new(handler);
        self
    }

    /// Set a synchronous catch-all event handler for any notifications with no
    /// corresponding handler for its type.
    ///
    /// There can only be a single catch-all event handler. New ones replace old ones.
    ///
    /// The default handler is to break the main loop with
    /// [`Error::Routing`][crate::Error::Routing]. Since events are
    /// emitted internally, mishandling are typically logic errors.
    pub fn unhandled_event(
        &mut self,
        handler: impl Fn(&mut St, AnyEvent) -> ControlFlow<Result<()>> + Send + 'static,
    ) -> &mut Self {
        self.unhandled_event = Box::new(handler);
        self
    }
}

impl<St, Error> Service<AnyRequest> for Router<St, Error> {
    type Response = JsonValue;
    type Error = Error;
    type Future = BoxReqFuture<Error>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: AnyRequest) -> Self::Future {
        let h = self
            .req_handlers
            .get(&*req.method)
            .unwrap_or(&self.unhandled_req);
        h(&mut self.state, req)
    }
}

impl<St> LspService for Router<St> {
    fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
        let h = self
            .notif_handlers
            .get(&*notif.method)
            .unwrap_or(&self.unhandled_notif);
        h(&mut self.state, notif)
    }

    fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
        let h = self
            .event_handlers
            .get(&event.inner_type_id())
            .unwrap_or(&self.unhandled_event);
        h(&mut self.state, event)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn _assert_send<St: Send>(router: Router<St>) -> impl Send {
        router
    }
}