lsp_async_stub/
lib.rs

1pub mod rpc;
2pub mod util;
3
4use async_trait::async_trait;
5use futures::{
6    channel::oneshot, future::FusedFuture, lock::Mutex as AsyncMutex, sink::Sink, Future,
7    FutureExt, SinkExt,
8};
9use handler::Handler;
10use lsp_types::{
11    notification::{self, Notification},
12    request as req,
13    request::Request,
14    NumberOrString,
15};
16use serde::{de::DeserializeOwned, Serialize};
17use std::{
18    collections::HashMap,
19    io, mem,
20    pin::Pin,
21    sync::{
22        atomic::{AtomicBool, Ordering},
23        Arc, Mutex,
24    },
25    task::{Poll, Waker},
26};
27use tracing::Instrument;
28
29mod handler;
30
31#[cfg(any(feature = "tokio-stdio", feature = "tokio-tcp"))]
32pub mod listen;
33
34#[derive(Debug, Clone, Default)]
35struct Cancellation {
36    cancelled: Arc<AtomicBool>,
37    waker: Arc<Mutex<Option<Waker>>>,
38}
39
40impl Cancellation {
41    pub fn token(&self) -> CancelToken {
42        CancelToken {
43            cancelled: self.cancelled.clone(),
44            waker_set: Arc::new(AtomicBool::new(false)),
45            waker: self.waker.clone(),
46        }
47    }
48
49    pub fn cancel(&mut self) {
50        self.cancelled.store(true, Ordering::SeqCst);
51
52        if let Some(w) = (*self.waker.lock().unwrap()).take() {
53            w.wake();
54        }
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct CancelToken {
60    cancelled: Arc<AtomicBool>,
61    waker_set: Arc<AtomicBool>,
62    waker: Arc<Mutex<Option<Waker>>>,
63}
64
65impl CancelToken {
66    pub fn is_cancelled(&self) -> bool {
67        self.cancelled.load(Ordering::SeqCst)
68    }
69
70    pub fn as_err(&mut self) -> CancelTokenErr {
71        CancelTokenErr(self)
72    }
73}
74
75impl Future for CancelToken {
76    type Output = ();
77
78    #[allow(unused_mut)] // differs between compiler versions
79    fn poll(
80        mut self: std::pin::Pin<&mut Self>,
81        cx: &mut std::task::Context<'_>,
82    ) -> Poll<Self::Output> {
83        if self.cancelled.load(Ordering::SeqCst) {
84            Poll::Ready(())
85        } else {
86            if !self.waker_set.load(Ordering::SeqCst) {
87                *self.waker.lock().unwrap() = Some(cx.waker().clone());
88            }
89
90            Poll::Pending
91        }
92    }
93}
94
95impl FusedFuture for CancelToken {
96    fn is_terminated(&self) -> bool {
97        false
98    }
99}
100
101pub struct CancelTokenErr<'t>(&'t mut CancelToken);
102
103impl Future for CancelTokenErr<'_> {
104    type Output = Result<(), rpc::Error>;
105    fn poll(
106        mut self: std::pin::Pin<&mut Self>,
107        cx: &mut std::task::Context<'_>,
108    ) -> Poll<Self::Output> {
109        match self.0.poll_unpin(cx) {
110            Poll::Ready(_) => Poll::Ready(Err(rpc::Error::request_cancelled())),
111            Poll::Pending => Poll::Pending,
112        }
113    }
114}
115
116impl FusedFuture for CancelTokenErr<'_> {
117    fn is_terminated(&self) -> bool {
118        false
119    }
120}
121
122#[async_trait(?Send)]
123pub trait ResponseWriter: Sized {
124    async fn write_response<R: Serialize>(
125        mut self,
126        response: &rpc::Response<R>,
127    ) -> Result<(), io::Error>;
128}
129
130#[async_trait(?Send)]
131pub trait RequestWriter {
132    async fn write_request<
133        R: Request<Params = P>,
134        P: Serialize + DeserializeOwned + core::fmt::Debug,
135    >(
136        &mut self,
137        params: Option<R::Params>,
138    ) -> Result<rpc::Response<R::Result>, io::Error>;
139
140    async fn write_notification<
141        N: Notification<Params = P>,
142        P: Serialize + DeserializeOwned + core::fmt::Debug,
143    >(
144        &mut self,
145        params: Option<N::Params>,
146    ) -> Result<(), io::Error>;
147
148    async fn cancel(&mut self) -> Result<(), io::Error>;
149}
150
151trait NewTrait: Future<Output = ()> {}
152impl<T> NewTrait for T where T: Future<Output = ()> {}
153
154type DeferredTasks = Arc<AsyncMutex<Vec<Pin<Box<dyn NewTrait>>>>>;
155
156#[derive(Clone)]
157pub struct Context<W: Clone> {
158    inner: Arc<AsyncMutex<Inner<W>>>,
159    cancel_token: CancelToken,
160    last_req_id: Option<rpc::RequestId>, // For cancellation
161    rw: Arc<AsyncMutex<Box<dyn MessageWriter>>>,
162    world: W,
163    deferred: DeferredTasks,
164}
165
166impl<W: Clone> std::ops::Deref for Context<W> {
167    type Target = W;
168
169    fn deref(&self) -> &Self::Target {
170        &self.world
171    }
172}
173
174impl<W: Clone> Context<W> {
175    pub async fn is_initialized(&self) -> bool {
176        self.inner.lock().await.initialized
177    }
178
179    pub async fn is_shutting_down(&self) -> bool {
180        self.inner.lock().await.shutting_down
181    }
182
183    pub fn world(&self) -> &W {
184        &self.world
185    }
186
187    pub fn cancel_token(&mut self) -> &mut CancelToken {
188        &mut self.cancel_token
189    }
190
191    /// Defer the execution of the future until after the
192    /// handler returned (and response was sent if applicable).
193    ///
194    /// If sending a response fails, deferred futures
195    /// won't be executed.
196    pub async fn defer<F: Future<Output = ()> + 'static>(&self, fut: F) {
197        self.deferred.lock().await.push(Box::pin(fut));
198    }
199}
200
201#[async_trait(?Send)]
202impl<W: Clone> RequestWriter for Context<W> {
203    #[tracing::instrument(level = tracing::Level::TRACE, skip(self))]
204    async fn write_request<
205        R: Request<Params = P>,
206        P: Serialize + DeserializeOwned + core::fmt::Debug,
207    >(
208        &mut self,
209        params: Option<R::Params>,
210    ) -> Result<rpc::Response<R::Result>, io::Error> {
211        let mut inner = self.inner.lock().await;
212        let req_id = inner.next_request_id;
213        inner.next_request_id += 1;
214
215        let mut rw = self.rw.lock().await;
216
217        let id = NumberOrString::Number(req_id);
218
219        let request = rpc::Request::new()
220            .with_id(id.clone().into())
221            .with_method(R::METHOD)
222            .with_params(params);
223
224        let span = tracing::debug_span!("sending request", ?request);
225
226        rw.send(request.into_message()).instrument(span).await?;
227
228        self.last_req_id = Some(id.clone());
229
230        let (send, recv) = oneshot::channel();
231        inner.requests.insert(id, send);
232
233        drop(inner);
234
235        let res = recv.await.unwrap();
236
237        tracing::trace!(response = ?res, "received response");
238
239        self.last_req_id = None;
240
241        Ok(res.into_params())
242    }
243
244    #[tracing::instrument(level = tracing::Level::TRACE, skip(self))]
245    async fn write_notification<
246        N: Notification<Params = P>,
247        P: Serialize + DeserializeOwned + core::fmt::Debug,
248    >(
249        &mut self,
250        params: Option<N::Params>,
251    ) -> Result<(), io::Error> {
252        let mut rw = self.rw.lock().await;
253        rw.send(
254            rpc::Request::new()
255                .with_method(N::METHOD)
256                .with_params(params)
257                .into_message(),
258        )
259        .await
260    }
261
262    async fn cancel(&mut self) -> Result<(), io::Error> {
263        if let Some(id) = Option::take(&mut self.last_req_id) {
264            self.write_notification::<notification::Cancel, _>(Some(lsp_types::CancelParams { id }))
265                .await
266        } else {
267            Ok(())
268        }
269    }
270}
271
272pub trait MessageWriter: Sink<rpc::Message, Error = io::Error> + Unpin {}
273impl<T: Sink<rpc::Message, Error = io::Error> + Unpin> MessageWriter for T {}
274
275struct Inner<W: Clone> {
276    next_request_id: i32,
277    initialized: bool,
278    shutting_down: bool,
279    handlers: HashMap<String, Box<dyn Handler<W>>>,
280    tasks: HashMap<rpc::RequestId, Cancellation>,
281    requests: HashMap<rpc::RequestId, oneshot::Sender<rpc::Response<serde_json::Value>>>,
282}
283
284impl<W: Clone> Inner<W> {
285    fn task_done(&mut self, id: &rpc::RequestId) {
286        if let Some(mut t) = self.tasks.remove(id) {
287            t.cancel();
288            tracing::trace!(?id, "task completed");
289        }
290    }
291}
292
293pub struct Server<W: Clone> {
294    inner: Arc<AsyncMutex<Inner<W>>>,
295}
296
297impl<W: Clone> Server<W> {
298    #[allow(clippy::new_ret_no_self)]
299    pub fn new() -> ServerBuilder<W> {
300        ServerBuilder {
301            inner: Inner {
302                next_request_id: 0,
303                initialized: false,
304                shutting_down: false,
305                handlers: HashMap::new(),
306                tasks: HashMap::new(),
307                requests: HashMap::new(),
308            },
309        }
310    }
311
312    pub fn handle_message(
313        &self,
314        world: W,
315        message: rpc::Message,
316        writer: impl MessageWriter + Clone + 'static,
317    ) -> impl Future<Output = Result<(), io::Error>> {
318        let inner = self.inner.clone();
319
320        async move {
321            if message.is_response() {
322                Server::handle_response(inner, message.into_response()).await;
323                Ok(())
324            } else {
325                Server::handle_request(inner, world, message.into_request(), writer).await
326            }
327        }
328    }
329
330    pub async fn is_shutting_down(&self) -> bool {
331        self.inner.lock().await.shutting_down
332    }
333
334    #[tracing::instrument(level = tracing::Level::TRACE)]
335    async fn handle_response(
336        inner: Arc<AsyncMutex<Inner<W>>>,
337        response: rpc::Response<serde_json::Value>,
338    ) {
339        if let Some(sender) = inner.lock().await.requests.remove(&response.id) {
340            sender.send(response).ok();
341        } else {
342            tracing::error!(?response, "unexpected response")
343        }
344    }
345
346    #[tracing::instrument(level = tracing::Level::TRACE, skip(data, writer))]
347    async fn handle_request(
348        inner: Arc<AsyncMutex<Inner<W>>>,
349        data: W,
350        request: rpc::Request<serde_json::Value>,
351        mut writer: impl MessageWriter + Clone + 'static,
352    ) -> Result<(), io::Error> {
353        if &request.jsonrpc != "2.0" {
354            tracing::error!("JSON-RPC message version is not 2.0");
355            return writer
356                .send(
357                    rpc::Response::error(
358                        rpc::Error::invalid_request()
359                            .with_data("only JSON-RPC version 2.0 is accepted"),
360                    )
361                    .into_message(),
362                )
363                .await;
364        }
365
366        if request.id.is_some() {
367            tracing::debug!(
368                id = ?request.id.as_ref().unwrap(),
369                method = %request.method,
370                "request received"
371            );
372            let mut s = inner.lock().await;
373
374            if s.shutting_down {
375                tracing::warn!(
376                    id = ?request.id.as_ref().unwrap(),
377                    method = %request.method,
378                    "received request while shutting down"
379                );
380
381                writer
382                    .send(
383                        rpc::Response::error(
384                            rpc::Error::invalid_request().with_data("server is shutting down"),
385                        )
386                        .into_message(),
387                    )
388                    .await?;
389                return Ok(());
390            }
391
392            if request.method == req::Shutdown::METHOD {
393                tracing::info!(
394                    id = ?request.id.as_ref().unwrap(),
395                    method = %request.method,
396                    "received shutdown request"
397                );
398
399                s.shutting_down = true;
400            }
401
402            let is_initialize = request.method == req::Initialize::METHOD;
403
404            if !s.initialized && !is_initialize {
405                tracing::error!(
406                    id = ?request.id.as_ref().unwrap(),
407                    method = %request.method,
408                    "server not yet initialized"
409                );
410
411                writer
412                    .send(rpc::Response::error(rpc::Error::server_not_initialized()).into_message())
413                    .await?;
414                return Ok(());
415            }
416
417            if s.handlers.contains_key(&request.method) {
418                let mut handler = s.handlers.get_mut(&request.method).unwrap().clone();
419
420                let id = request.id.clone().unwrap();
421
422                // We expect the handler to run for a longer time
423                drop(s);
424
425                let ctx = Server::create_context(
426                    inner.clone(),
427                    Arc::new(AsyncMutex::new(Box::new(writer.clone()))),
428                    data,
429                    &request,
430                )
431                .await;
432
433                let handler_span = tracing::trace_span!(
434                    "request handler",
435                    method = %request.method,
436                );
437
438                let method = request.method.clone();
439
440                handler
441                    .handle(ctx.clone(), request, Some(&mut writer))
442                    .instrument(handler_span)
443                    .await;
444
445                let deferred = mem::take(&mut (*ctx.deferred.lock().await));
446
447                for d in deferred {
448                    let deferred_span = tracing::trace_span!(
449                        "deferred task",
450                        %method,
451                    );
452
453                    d.instrument(deferred_span).await
454                }
455
456                let mut s = inner.lock().await;
457
458                s.task_done(&id);
459                if is_initialize {
460                    s.initialized = true;
461                }
462                drop(s);
463
464                Ok(())
465            } else if request.method == req::Shutdown::METHOD {
466                // Shutting down without handler, everything should be OK.
467                writer
468                    .send(
469                        rpc::Response::success(())
470                            .with_request_id(request.id.unwrap())
471                            .into_message(),
472                    )
473                    .await
474            } else {
475                tracing::error!(
476                    method = %request.method,
477                    "no request handler registered"
478                );
479
480                writer
481                    .send(
482                        rpc::Response::error(rpc::Error::method_not_found())
483                            .with_request_id(request.id.unwrap())
484                            .into_message(),
485                    )
486                    .await
487            }
488        } else {
489            tracing::debug!(
490                method = %request.method,
491                "notification received"
492            );
493
494            if request.method == lsp_types::notification::Cancel::METHOD {
495                if let Some(p) = request.params {
496                    if let Ok(c) = serde_json::from_value::<lsp_types::CancelParams>(p) {
497                        inner.lock().await.task_done(&c.id);
498                    }
499                }
500                return Ok(());
501            }
502
503            let mut s = inner.lock().await;
504
505            if s.handlers.contains_key(&request.method) {
506                let mut handler = s.handlers.get_mut(&request.method).unwrap().clone();
507                drop(s);
508
509                let ctx = Server::create_context(
510                    inner,
511                    Arc::new(AsyncMutex::new(Box::new(writer))),
512                    data,
513                    &request,
514                )
515                .await;
516
517                let handler_span = tracing::trace_span!(
518                    "notification handler",
519                    method = %request.method,
520                );
521
522                let method = request.method.clone();
523
524                handler
525                    .handle(ctx.clone(), request, None)
526                    .instrument(handler_span)
527                    .await;
528
529                let deferred = mem::take(&mut (*ctx.deferred.lock().await));
530
531                for d in deferred {
532                    let deferred_span = tracing::trace_span!(
533                        "deferred task",
534                        %method,
535                    );
536
537                    d.instrument(deferred_span).await
538                }
539            } else {
540                tracing::warn!(
541                    method = %request.method,
542                    "no notification handler registered"
543                );
544            }
545
546            Ok(())
547        }
548    }
549
550    async fn create_context<D>(
551        inner: Arc<AsyncMutex<Inner<W>>>,
552        rw: Arc<AsyncMutex<Box<dyn MessageWriter>>>,
553        world: W,
554        req: &rpc::Request<D>,
555    ) -> Context<W> {
556        let cancel = Cancellation::default();
557        let cancel_token = cancel.token();
558
559        if let Some(id) = &req.id {
560            inner.lock().await.tasks.insert(id.clone(), cancel);
561        }
562
563        Context {
564            cancel_token,
565            world,
566            inner,
567            last_req_id: None,
568            rw,
569            deferred: Default::default(),
570        }
571    }
572}
573
574pub struct ServerBuilder<W: Clone + 'static> {
575    inner: Inner<W>,
576}
577
578impl<W: Clone + 'static> ServerBuilder<W> {
579    pub fn on_notification<N, F>(mut self, handler: fn(Context<W>, Params<N::Params>) -> F) -> Self
580    where
581        N: Notification + 'static,
582        F: Future<Output = ()> + 'static,
583    {
584        self.inner.handlers.insert(
585            N::METHOD.into(),
586            Box::new(handler::NotificationHandler::<N, _, _>::new(handler)),
587        );
588        tracing::info!(method = N::METHOD, "registered notification handler");
589        self
590    }
591
592    pub fn on_request<R, F>(mut self, handler: fn(Context<W>, Params<R::Params>) -> F) -> Self
593    where
594        R: Request + 'static,
595        F: Future<Output = Result<R::Result, rpc::Error>> + 'static,
596    {
597        self.inner.handlers.insert(
598            R::METHOD.into(),
599            Box::new(handler::RequestHandler::<R, _, _>::new(handler)),
600        );
601        tracing::info!(method = R::METHOD, "registered request handler");
602        self
603    }
604
605    pub fn build(self) -> Server<W> {
606        Server {
607            inner: Arc::new(AsyncMutex::new(self.inner)),
608        }
609    }
610}
611
612pub struct Params<P>(Option<P>);
613
614impl<P> Params<P> {
615    pub fn optional(self) -> Option<P> {
616        self.0
617    }
618
619    pub fn required(self) -> Result<P, rpc::Error> {
620        match self.0 {
621            None => Err(rpc::Error::invalid_params().with_data("params are required")),
622            Some(p) => Ok(p),
623        }
624    }
625}
626
627impl<P> From<Option<P>> for Params<P> {
628    fn from(p: Option<P>) -> Self {
629        Self(p)
630    }
631}