async_lsp/
lib.rs

1//! Asynchronous [Language Server Protocol (LSP)][lsp] framework based on [tower].
2//!
3//! See project [README] for a general overview.
4//!
5//! [README]: https://github.com/oxalica/async-lsp#readme
6//! [lsp]: https://microsoft.github.io/language-server-protocol/overviews/lsp/overview/
7//! [tower]: https://github.com/tower-rs/tower
8//!
9//! This project is centered at a core service trait [`LspService`] for either Language Servers or
10//! Language Clients. The main loop driver [`MainLoop`] executes the service. The additional
11//! features, called middleware, are pluggable can be layered using the [`tower_layer`]
12//! abstraction. This crate defines several common middlewares for various mandatory or optional
13//! LSP functionalities, see their documentations for details.
14//! - [`concurrency::Concurrency`]: Incoming request multiplexing and cancellation.
15//! - [`panic::CatchUnwind`]: Turn panics into errors.
16//! - [`tracing::Tracing`]: Logger spans with methods instrumenting handlers.
17//! - [`server::Lifecycle`]: Server initialization, shutting down, and exit handling.
18//! - [`client_monitor::ClientProcessMonitor`]: Client process monitor.
19//! - [`router::Router`]: "Root" service to dispatch requests, notifications and events.
20//!
21//! Users are free to select and layer middlewares to run a Language Server or Language Client.
22//! They can also implement their own middlewares for like timeout, metering, request
23//! transformation and etc.
24//!
25//! ## Usages
26//!
27//! There are two main ways to define a [`Router`](router::Router) root service: one is via its
28//! builder API, and the other is to construct via implementing the omnitrait [`LanguageServer`] or
29//! [`LanguageClient`] for a state struct. The former is more flexible, while the latter has a
30//! more similar API as [`tower-lsp`](https://crates.io/crates/tower-lsp).
31//!
32//! The examples for both builder-API and omnitrait, cross Language Server and Language Client, can
33//! be seen under
34#![doc = concat!("[`examples`](https://github.com/oxalica/async-lsp/tree/v", env!("CARGO_PKG_VERSION") , "/examples)")]
35//! directory.
36//!
37//! ## Cargo features
38//!
39//! - `client-monitor`: Client process monitor middleware [`client_monitor`].
40//!   *Enabled by default.*
41//! - `omni-trait`: Mega traits of all standard requests and notifications, namely
42//!   [`LanguageServer`] and [`LanguageClient`].
43//!   *Enabled by default.*
44//! - `stdio`: Utilities to deal with pipe-like stdin/stdout communication channel for Language
45//!   Servers.
46//!   *Enabled by default.*
47//! - `tracing`: Integration with crate [`tracing`][::tracing] and the [`tracing`] middleware.
48//!   *Enabled by default.*
49//! - `forward`: Impl [`LspService`] for `{Client,Server}Socket`. This collides some method names
50//!   but allows easy service forwarding. See `examples/inspector.rs` for a possible use case.
51//!   *Disabled by default.*
52//! - `tokio`: Enable compatible methods for [`tokio`](https://crates.io/crates/tokio) runtime.
53//!   *Disabled by default.*
54#![cfg_attr(docsrs, feature(doc_cfg))]
55#![warn(missing_docs)]
56use std::any::{type_name, Any, TypeId};
57use std::collections::HashMap;
58use std::future::{poll_fn, Future};
59use std::marker::PhantomData;
60use std::ops::ControlFlow;
61use std::pin::Pin;
62use std::task::{ready, Context, Poll};
63use std::{fmt, io};
64
65use futures::channel::{mpsc, oneshot};
66use futures::io::BufReader;
67use futures::stream::FuturesUnordered;
68use futures::{
69    pin_mut, select_biased, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite,
70    AsyncWriteExt, FutureExt, SinkExt, StreamExt,
71};
72use lsp_types::notification::Notification;
73use lsp_types::request::Request;
74use lsp_types::NumberOrString;
75use pin_project_lite::pin_project;
76use serde::de::DeserializeOwned;
77use serde::{Deserialize, Serialize};
78use serde_json::Value as JsonValue;
79use thiserror::Error;
80use tower_service::Service;
81
82/// Re-export of the [`lsp_types`] dependency of this crate.
83pub use lsp_types;
84
85macro_rules! define_getters {
86    (impl[$($generic:tt)*] $ty:ty, $field:ident : $field_ty:ty) => {
87        impl<$($generic)*> $ty {
88            /// Get a reference to the inner service.
89            #[must_use]
90            pub fn get_ref(&self) -> &$field_ty {
91                &self.$field
92            }
93
94            /// Get a mutable reference to the inner service.
95            #[must_use]
96            pub fn get_mut(&mut self) -> &mut $field_ty {
97                &mut self.$field
98            }
99
100            /// Consume self, returning the inner service.
101            #[must_use]
102            pub fn into_inner(self) -> $field_ty {
103                self.$field
104            }
105        }
106    };
107}
108
109pub mod concurrency;
110pub mod panic;
111pub mod router;
112pub mod server;
113
114#[cfg(feature = "forward")]
115#[cfg_attr(docsrs, doc(cfg(feature = "forward")))]
116mod forward;
117
118#[cfg(feature = "client-monitor")]
119#[cfg_attr(docsrs, doc(cfg(feature = "client-monitor")))]
120pub mod client_monitor;
121
122#[cfg(all(feature = "stdio", unix))]
123#[cfg_attr(docsrs, doc(cfg(all(feature = "stdio", unix))))]
124pub mod stdio;
125
126#[cfg(feature = "tracing")]
127#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
128pub mod tracing;
129
130#[cfg(feature = "omni-trait")]
131mod omni_trait;
132#[cfg(feature = "omni-trait")]
133#[cfg_attr(docsrs, doc(cfg(feature = "omni-trait")))]
134pub use omni_trait::{LanguageClient, LanguageServer};
135
136/// A convenient type alias for `Result` with `E` = [`enum@crate::Error`].
137pub type Result<T, E = Error> = std::result::Result<T, E>;
138
139/// Possible errors.
140#[derive(Debug, thiserror::Error)]
141#[non_exhaustive]
142pub enum Error {
143    /// The service main loop stopped.
144    #[error("service stopped")]
145    ServiceStopped,
146    /// The peer replies undecodable or invalid responses.
147    #[error("deserialization failed: {0}")]
148    Deserialize(#[from] serde_json::Error),
149    /// The peer replies an error.
150    #[error("{0}")]
151    Response(#[from] ResponseError),
152    /// The peer violates the Language Server Protocol.
153    #[error("protocol error: {0}")]
154    Protocol(String),
155    /// Input/output errors from the underlying channels.
156    #[error("{0}")]
157    Io(#[from] io::Error),
158    /// The underlying channel reached EOF (end of file).
159    #[error("the underlying channel reached EOF")]
160    Eof,
161    /// No handlers for events or mandatory notifications (not starting with `$/`).
162    ///
163    /// Will not occur when catch-all handlers ([`router::Router::unhandled_event`] and
164    /// [`router::Router::unhandled_notification`]) are installed.
165    #[error("{0}")]
166    Routing(String),
167}
168
169/// The core service abstraction, representing either a Language Server or Language Client.
170pub trait LspService: Service<AnyRequest> {
171    /// The handler of [LSP notifications](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage).
172    ///
173    /// Notifications are delivered in order and synchronously. This is mandatory since they can
174    /// change the interpretation of later notifications or requests.
175    ///
176    /// # Return
177    ///
178    /// The return value decides the action to either break or continue the main loop.
179    fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>>;
180
181    /// The handler of an arbitrary [`AnyEvent`].
182    ///
183    /// Events are emitted by users or middlewares via [`ClientSocket::emit`] or
184    /// [`ServerSocket::emit`], for user-defined purposes. Events are delivered in order and
185    /// synchronously.
186    ///
187    /// # Return
188    ///
189    /// The return value decides the action to either break or continue the main loop.
190    fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>>;
191}
192
193/// A JSON-RPC error code.
194///
195/// Codes defined and/or used by LSP are defined as associated constants, eg.
196/// [`ErrorCode::REQUEST_FAILED`].
197///
198/// See:
199/// <https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#errorCodes>
200#[derive(
201    Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Error,
202)]
203#[error("jsonrpc error {0}")]
204pub struct ErrorCode(pub i32);
205
206impl From<i32> for ErrorCode {
207    fn from(i: i32) -> Self {
208        Self(i)
209    }
210}
211
212impl ErrorCode {
213    /// Invalid JSON was received by the server. An error occurred on the server while parsing the
214    /// JSON text.
215    ///
216    /// Defined by [JSON-RPC](https://www.jsonrpc.org/specification#error_object).
217    pub const PARSE_ERROR: Self = Self(-32700);
218
219    /// The JSON sent is not a valid Request object.
220    ///
221    /// Defined by [JSON-RPC](https://www.jsonrpc.org/specification#error_object).
222    pub const INVALID_REQUEST: Self = Self(-32600);
223
224    /// The method does not exist / is not available.
225    ///
226    /// Defined by [JSON-RPC](https://www.jsonrpc.org/specification#error_object).
227    pub const METHOD_NOT_FOUND: Self = Self(-32601);
228
229    /// Invalid method parameter(s).
230    ///
231    /// Defined by [JSON-RPC](https://www.jsonrpc.org/specification#error_object).
232    pub const INVALID_PARAMS: Self = Self(-32602);
233
234    /// Internal JSON-RPC error.
235    ///
236    /// Defined by [JSON-RPC](https://www.jsonrpc.org/specification#error_object).
237    pub const INTERNAL_ERROR: Self = Self(-32603);
238
239    /// This is the start range of JSON-RPC reserved error codes.
240    /// It doesn't denote a real error code. No LSP error codes should
241    /// be defined between the start and end range. For backwards
242    /// compatibility the `ServerNotInitialized` and the `UnknownErrorCode`
243    /// are left in the range.
244    ///
245    /// @since 3.16.0
246    pub const JSONRPC_RESERVED_ERROR_RANGE_START: Self = Self(-32099);
247
248    /// Error code indicating that a server received a notification or
249    /// request before the server has received the `initialize` request.
250    pub const SERVER_NOT_INITIALIZED: Self = Self(-32002);
251
252    /// (Defined by LSP specification without description)
253    pub const UNKNOWN_ERROR_CODE: Self = Self(-32001);
254
255    /// This is the end range of JSON-RPC reserved error codes.
256    /// It doesn't denote a real error code.
257    ///
258    /// @since 3.16.0
259    pub const JSONRPC_RESERVED_ERROR_RANGE_END: Self = Self(-32000);
260
261    /// This is the start range of LSP reserved error codes.
262    /// It doesn't denote a real error code.
263    ///
264    /// @since 3.16.0
265    pub const LSP_RESERVED_ERROR_RANGE_START: Self = Self(-32899);
266
267    /// A request failed but it was syntactically correct, e.g the
268    /// method name was known and the parameters were valid. The error
269    /// message should contain human readable information about why
270    /// the request failed.
271    ///
272    /// @since 3.17.0
273    pub const REQUEST_FAILED: Self = Self(-32803);
274
275    /// The server cancelled the request. This error code should
276    /// only be used for requests that explicitly support being
277    /// server cancellable.
278    ///
279    /// @since 3.17.0
280    pub const SERVER_CANCELLED: Self = Self(-32802);
281
282    /// The server detected that the content of a document got
283    /// modified outside normal conditions. A server should
284    /// NOT send this error code if it detects a content change
285    /// in it unprocessed messages. The result even computed
286    /// on an older state might still be useful for the client.
287    ///
288    /// If a client decides that a result is not of any use anymore
289    /// the client should cancel the request.
290    pub const CONTENT_MODIFIED: Self = Self(-32801);
291
292    /// The client has canceled a request and a server as detected
293    /// the cancel.
294    pub const REQUEST_CANCELLED: Self = Self(-32800);
295
296    /// This is the end range of LSP reserved error codes.
297    /// It doesn't denote a real error code.
298    ///
299    /// @since 3.16.0
300    pub const LSP_RESERVED_ERROR_RANGE_END: Self = Self(-32800);
301}
302
303/// The identifier of requests and responses.
304///
305/// Though `null` is technically a valid id for responses, we reject it since it hardly makes sense
306/// for valid communication.
307pub type RequestId = NumberOrString;
308
309#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
310struct RawMessage<T> {
311    jsonrpc: RpcVersion,
312    #[serde(flatten)]
313    inner: T,
314}
315
316impl<T> RawMessage<T> {
317    fn new(inner: T) -> Self {
318        Self {
319            jsonrpc: RpcVersion::V2,
320            inner,
321        }
322    }
323}
324
325#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
326enum RpcVersion {
327    #[serde(rename = "2.0")]
328    V2,
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
332#[serde(untagged)]
333enum Message {
334    Request(AnyRequest),
335    Response(AnyResponse),
336    Notification(AnyNotification),
337}
338
339/// A dynamic runtime [LSP request](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#requestMessage).
340#[derive(Debug, Clone, Serialize, Deserialize)]
341#[non_exhaustive]
342pub struct AnyRequest {
343    /// The request id.
344    pub id: RequestId,
345    /// The method to be invoked.
346    pub method: String,
347    /// The method's params.
348    #[serde(default)]
349    #[serde(skip_serializing_if = "serde_json::Value::is_null")]
350    pub params: serde_json::Value,
351}
352
353/// A dynamic runtime [LSP notification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#notificationMessage).
354#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
355#[non_exhaustive]
356pub struct AnyNotification {
357    /// The method to be invoked.
358    pub method: String,
359    /// The notification's params.
360    #[serde(default)]
361    #[serde(skip_serializing_if = "serde_json::Value::is_null")]
362    pub params: JsonValue,
363}
364
365/// A dynamic runtime response.
366#[derive(Debug, Clone, Serialize, Deserialize)]
367#[non_exhaustive]
368struct AnyResponse {
369    id: RequestId,
370    #[serde(skip_serializing_if = "Option::is_none")]
371    result: Option<JsonValue>,
372    #[serde(skip_serializing_if = "Option::is_none")]
373    error: Option<ResponseError>,
374}
375
376/// The error object in case a request fails.
377///
378/// See:
379/// <https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#responseError>
380#[derive(Debug, Clone, Serialize, Deserialize, Error)]
381#[non_exhaustive]
382#[error("{message} ({code})")]
383pub struct ResponseError {
384    /// A number indicating the error type that occurred.
385    pub code: ErrorCode,
386    /// A string providing a short description of the error.
387    pub message: String,
388    /// A primitive or structured value that contains additional
389    /// information about the error. Can be omitted.
390    pub data: Option<JsonValue>,
391}
392
393impl ResponseError {
394    /// Create a new error object with a JSON-RPC error code and a message.
395    #[must_use]
396    pub fn new(code: ErrorCode, message: impl fmt::Display) -> Self {
397        Self {
398            code,
399            message: message.to_string(),
400            data: None,
401        }
402    }
403
404    /// Create a new error object with a JSON-RPC error code, a message, and any additional data.
405    #[must_use]
406    pub fn new_with_data(code: ErrorCode, message: impl fmt::Display, data: JsonValue) -> Self {
407        Self {
408            code,
409            message: message.to_string(),
410            data: Some(data),
411        }
412    }
413}
414
415impl Message {
416    const CONTENT_LENGTH: &'static str = "Content-Length";
417
418    async fn read(mut reader: impl AsyncBufRead + Unpin) -> Result<Self> {
419        let mut line = String::new();
420        let mut content_len = None;
421        loop {
422            line.clear();
423            reader.read_line(&mut line).await?;
424            if line.is_empty() {
425                return Err(Error::Eof);
426            }
427            if line == "\r\n" {
428                break;
429            }
430            // NB. LSP spec is stricter than HTTP spec, the spaces here is required and it's not
431            // explicitly permitted to include extra spaces. We reject them here.
432            let (name, value) = line
433                .strip_suffix("\r\n")
434                .and_then(|line| line.split_once(": "))
435                .ok_or_else(|| Error::Protocol(format!("Invalid header: {line:?}")))?;
436            if name.eq_ignore_ascii_case(Self::CONTENT_LENGTH) {
437                let value = value
438                    .parse::<usize>()
439                    .map_err(|_| Error::Protocol(format!("Invalid content-length: {value}")))?;
440                content_len = Some(value);
441            }
442        }
443        let content_len =
444            content_len.ok_or_else(|| Error::Protocol("Missing content-length".into()))?;
445        let mut buf = vec![0u8; content_len];
446        reader.read_exact(&mut buf).await?;
447        #[cfg(feature = "tracing")]
448        ::tracing::trace!(msg = %String::from_utf8_lossy(&buf), "incoming");
449        let msg = serde_json::from_slice::<RawMessage<Self>>(&buf)?;
450        Ok(msg.inner)
451    }
452
453    async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
454        let buf = serde_json::to_string(&RawMessage::new(self))?;
455        #[cfg(feature = "tracing")]
456        ::tracing::trace!(msg = %buf, "outgoing");
457        writer
458            .write_all(format!("{}: {}\r\n\r\n", Self::CONTENT_LENGTH, buf.len()).as_bytes())
459            .await?;
460        writer.write_all(buf.as_bytes()).await?;
461        writer.flush().await?;
462        Ok(())
463    }
464}
465
466/// Service main loop driver for either Language Servers or Language Clients.
467pub struct MainLoop<S: LspService> {
468    service: S,
469    rx: mpsc::UnboundedReceiver<MainLoopEvent>,
470    outgoing_id: i32,
471    outgoing: HashMap<RequestId, oneshot::Sender<AnyResponse>>,
472    tasks: FuturesUnordered<RequestFuture<S::Future>>,
473}
474
475enum MainLoopEvent {
476    Outgoing(Message),
477    OutgoingRequest(AnyRequest, oneshot::Sender<AnyResponse>),
478    Any(AnyEvent),
479}
480
481define_getters!(impl[S: LspService] MainLoop<S>, service: S);
482
483impl<S> MainLoop<S>
484where
485    S: LspService<Response = JsonValue>,
486    ResponseError: From<S::Error>,
487{
488    /// Create a Language Server main loop.
489    #[must_use]
490    pub fn new_server(builder: impl FnOnce(ClientSocket) -> S) -> (Self, ClientSocket) {
491        let (this, socket) = Self::new(|socket| builder(ClientSocket(socket)));
492        (this, ClientSocket(socket))
493    }
494
495    /// Create a Language Client main loop.
496    #[must_use]
497    pub fn new_client(builder: impl FnOnce(ServerSocket) -> S) -> (Self, ServerSocket) {
498        let (this, socket) = Self::new(|socket| builder(ServerSocket(socket)));
499        (this, ServerSocket(socket))
500    }
501
502    fn new(builder: impl FnOnce(PeerSocket) -> S) -> (Self, PeerSocket) {
503        let (tx, rx) = mpsc::unbounded();
504        let socket = PeerSocket { tx };
505        let this = Self {
506            service: builder(socket.clone()),
507            rx,
508            outgoing_id: 0,
509            outgoing: HashMap::new(),
510            tasks: FuturesUnordered::new(),
511        };
512        (this, socket)
513    }
514
515    /// Drive the service main loop to provide the service.
516    ///
517    /// Shortcut to [`MainLoop::run`] that accept an `impl AsyncRead` and implicit wrap it in a
518    /// [`BufReader`].
519    // Documented in `Self::run`.
520    #[allow(clippy::missing_errors_doc)]
521    pub async fn run_buffered(self, input: impl AsyncRead, output: impl AsyncWrite) -> Result<()> {
522        self.run(BufReader::new(input), output).await
523    }
524
525    /// Drive the service main loop to provide the service.
526    ///
527    /// # Errors
528    ///
529    /// - `Error::Io` when the underlying `input` or `output` raises an error.
530    /// - `Error::Deserialize` when the peer sends undecodable or invalid message.
531    /// - `Error::Protocol` when the peer violates Language Server Protocol.
532    /// - Other errors raised from service handlers.
533    pub async fn run(mut self, input: impl AsyncBufRead, output: impl AsyncWrite) -> Result<()> {
534        pin_mut!(input, output);
535        let incoming = futures::stream::unfold(input, |mut input| async move {
536            Some((Message::read(&mut input).await, input))
537        });
538        let outgoing = futures::sink::unfold(output, |mut output, msg| async move {
539            Message::write(&msg, &mut output).await.map(|()| output)
540        });
541        pin_mut!(incoming, outgoing);
542
543        let mut flush_fut = futures::future::Fuse::terminated();
544        let ret = loop {
545            // Outgoing > internal > incoming.
546            // Preference on outgoing data provides back pressure in case of
547            // flooding incoming requests.
548            let ctl = select_biased! {
549                // Concurrently flush out the previous message.
550                ret = flush_fut => { ret?; continue; }
551
552                resp = self.tasks.select_next_some() => ControlFlow::Continue(Some(Message::Response(resp))),
553                event = self.rx.next() => self.dispatch_event(event.expect("Sender is alive")),
554                msg = incoming.next() => {
555                    let dispatch_fut = self.dispatch_message(msg.expect("Never ends")?).fuse();
556                    pin_mut!(dispatch_fut);
557                    // NB. Concurrently wait for `poll_ready`, and write out the last message.
558                    // If the service is waiting for client's response of the last request, while
559                    // the last message is not delivered on the first write, it can deadlock.
560                    loop {
561                        select_biased! {
562                            // Dispatch first. It usually succeeds immediately for non-requests,
563                            // and the service is hardly busy.
564                            ctl = dispatch_fut => break ctl,
565                            ret = flush_fut => { ret?; continue }
566                        }
567                    }
568                }
569            };
570            let msg = match ctl {
571                ControlFlow::Continue(Some(msg)) => msg,
572                ControlFlow::Continue(None) => continue,
573                ControlFlow::Break(ret) => break ret,
574            };
575            // Flush the previous one and load a new message to send.
576            outgoing.feed(msg).await?;
577            flush_fut = outgoing.flush().fuse();
578        };
579
580        // Flush the last message. It is enqueued before the event returning `ControlFlow::Break`.
581        // To preserve the order at best effort, we send it before exiting the main loop.
582        // But the more significant `ControlFlow::Break` error will override the flushing error,
583        // if there is any.
584        let flush_ret = outgoing.close().await;
585        ret.and(flush_ret)
586    }
587
588    async fn dispatch_message(&mut self, msg: Message) -> ControlFlow<Result<()>, Option<Message>> {
589        match msg {
590            Message::Request(req) => {
591                if let Err(err) = poll_fn(|cx| self.service.poll_ready(cx)).await {
592                    let resp = AnyResponse {
593                        id: req.id,
594                        result: None,
595                        error: Some(err.into()),
596                    };
597                    return ControlFlow::Continue(Some(Message::Response(resp)));
598                }
599                let id = req.id.clone();
600                let fut = self.service.call(req);
601                self.tasks.push(RequestFuture { fut, id: Some(id) });
602            }
603            Message::Response(resp) => {
604                if let Some(resp_tx) = self.outgoing.remove(&resp.id) {
605                    // The result may be ignored.
606                    let _: Result<_, _> = resp_tx.send(resp);
607                }
608            }
609            Message::Notification(notif) => {
610                self.service.notify(notif)?;
611            }
612        }
613        ControlFlow::Continue(None)
614    }
615
616    fn dispatch_event(&mut self, event: MainLoopEvent) -> ControlFlow<Result<()>, Option<Message>> {
617        match event {
618            MainLoopEvent::OutgoingRequest(mut req, resp_tx) => {
619                req.id = RequestId::Number(self.outgoing_id);
620                assert!(self.outgoing.insert(req.id.clone(), resp_tx).is_none());
621                self.outgoing_id += 1;
622                ControlFlow::Continue(Some(Message::Request(req)))
623            }
624            MainLoopEvent::Outgoing(msg) => ControlFlow::Continue(Some(msg)),
625            MainLoopEvent::Any(event) => {
626                self.service.emit(event)?;
627                ControlFlow::Continue(None)
628            }
629        }
630    }
631}
632
633pin_project! {
634    struct RequestFuture<Fut> {
635        #[pin]
636        fut: Fut,
637        id: Option<RequestId>,
638    }
639}
640
641impl<Fut, Error> Future for RequestFuture<Fut>
642where
643    Fut: Future<Output = Result<JsonValue, Error>>,
644    ResponseError: From<Error>,
645{
646    type Output = AnyResponse;
647
648    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
649        let this = self.project();
650        let (mut result, mut error) = (None, None);
651        match ready!(this.fut.poll(cx)) {
652            Ok(v) => result = Some(v),
653            Err(err) => error = Some(err.into()),
654        }
655        Poll::Ready(AnyResponse {
656            id: this.id.take().expect("Future is consumed"),
657            result,
658            error,
659        })
660    }
661}
662
663macro_rules! impl_socket_wrapper {
664    ($name:ident) => {
665        impl $name {
666            /// Create a closed socket outside a main loop. Any interaction will immediately return
667            /// an error of [`Error::ServiceStopped`].
668            ///
669            /// This works as a placeholder where a socket is required but actually unused.
670            ///
671            /// # Note
672            ///
673            /// To prevent accidental misusages, this method is NOT implemented as
674            /// [`Default::default`] intentionally.
675            #[must_use]
676            pub fn new_closed() -> Self {
677                Self(PeerSocket::new_closed())
678            }
679
680            /// Send a request to the peer and wait for its response.
681            ///
682            /// # Errors
683            /// - [`Error::ServiceStopped`] when the service main loop stopped.
684            /// - [`Error::Response`] when the peer replies an error.
685            pub async fn request<R: Request>(&self, params: R::Params) -> Result<R::Result> {
686                self.0.request::<R>(params).await
687            }
688
689            /// Send a notification to the peer and wait for its response.
690            ///
691            /// This is done asynchronously. An `Ok` result indicates the message is successfully
692            /// queued, but may not be sent to the peer yet.
693            ///
694            /// # Errors
695            /// - [`Error::ServiceStopped`] when the service main loop stopped.
696            pub fn notify<N: Notification>(&self, params: N::Params) -> Result<()> {
697                self.0.notify::<N>(params)
698            }
699
700            /// Emit an arbitrary loopback event object to the service handler.
701            ///
702            /// This is done asynchronously. An `Ok` result indicates the message is successfully
703            /// queued, but may not be processed yet.
704            ///
705            /// # Errors
706            /// - [`Error::ServiceStopped`] when the service main loop stopped.
707            pub fn emit<E: Send + 'static>(&self, event: E) -> Result<()> {
708                self.0.emit::<E>(event)
709            }
710        }
711    };
712}
713
714/// The socket for Language Server to communicate with the Language Client peer.
715#[derive(Debug, Clone)]
716pub struct ClientSocket(PeerSocket);
717impl_socket_wrapper!(ClientSocket);
718
719/// The socket for Language Client to communicate with the Language Server peer.
720#[derive(Debug, Clone)]
721pub struct ServerSocket(PeerSocket);
722impl_socket_wrapper!(ServerSocket);
723
724#[derive(Debug, Clone)]
725struct PeerSocket {
726    tx: mpsc::UnboundedSender<MainLoopEvent>,
727}
728
729impl PeerSocket {
730    fn new_closed() -> Self {
731        let (tx, _rx) = mpsc::unbounded();
732        Self { tx }
733    }
734
735    fn send(&self, v: MainLoopEvent) -> Result<()> {
736        self.tx.unbounded_send(v).map_err(|_| Error::ServiceStopped)
737    }
738
739    fn request<R: Request>(&self, params: R::Params) -> PeerSocketRequestFuture<R::Result> {
740        let req = AnyRequest {
741            id: RequestId::Number(0),
742            method: R::METHOD.into(),
743            params: serde_json::to_value(params).expect("Failed to serialize"),
744        };
745        let (tx, rx) = oneshot::channel();
746        // If this fails, the oneshot channel will also be closed, and it is handled by
747        // `PeerSocketRequestFuture`.
748        let _: Result<_, _> = self.send(MainLoopEvent::OutgoingRequest(req, tx));
749        PeerSocketRequestFuture {
750            rx,
751            _marker: PhantomData,
752        }
753    }
754
755    fn notify<N: Notification>(&self, params: N::Params) -> Result<()> {
756        let notif = AnyNotification {
757            method: N::METHOD.into(),
758            params: serde_json::to_value(params).expect("Failed to serialize"),
759        };
760        self.send(MainLoopEvent::Outgoing(Message::Notification(notif)))
761    }
762
763    pub fn emit<E: Send + 'static>(&self, event: E) -> Result<()> {
764        self.send(MainLoopEvent::Any(AnyEvent::new(event)))
765    }
766}
767
768struct PeerSocketRequestFuture<T> {
769    rx: oneshot::Receiver<AnyResponse>,
770    _marker: PhantomData<fn() -> T>,
771}
772
773impl<T: DeserializeOwned> Future for PeerSocketRequestFuture<T> {
774    type Output = Result<T>;
775
776    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
777        let resp = ready!(Pin::new(&mut self.rx)
778            .poll(cx)
779            .map_err(|_| Error::ServiceStopped))?;
780        Poll::Ready(match resp.error {
781            None => Ok(serde_json::from_value(resp.result.unwrap_or_default())?),
782            Some(err) => Err(Error::Response(err)),
783        })
784    }
785}
786
787/// A dynamic runtime event.
788///
789/// This is a wrapper of `Box<dyn Any + Send>`, but saves the underlying type name for better
790/// `Debug` impl.
791///
792/// See [`LspService::emit`] for usages of this type.
793pub struct AnyEvent {
794    inner: Box<dyn Any + Send>,
795    type_name: &'static str,
796}
797
798impl fmt::Debug for AnyEvent {
799    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
800        f.debug_struct("AnyEvent")
801            .field("type_name", &self.type_name)
802            .finish_non_exhaustive()
803    }
804}
805
806impl AnyEvent {
807    #[must_use]
808    fn new<T: Send + 'static>(v: T) -> Self {
809        AnyEvent {
810            inner: Box::new(v),
811            type_name: type_name::<T>(),
812        }
813    }
814
815    #[must_use]
816    fn inner_type_id(&self) -> TypeId {
817        // Call `type_id` on the inner `dyn Any`, not `Box<_> as Any` or `&Box<_> as Any`.
818        Any::type_id(&*self.inner)
819    }
820
821    /// Get the underlying type name for debugging purpose.
822    ///
823    /// The result string is only meant for debugging. It is not stable and cannot be trusted.
824    #[must_use]
825    pub fn type_name(&self) -> &'static str {
826        self.type_name
827    }
828
829    /// Returns `true` if the inner type is the same as `T`.
830    #[must_use]
831    pub fn is<T: Send + 'static>(&self) -> bool {
832        self.inner.is::<T>()
833    }
834
835    /// Returns some reference to the inner value if it is of type `T`, or `None` if it isn't.
836    #[must_use]
837    pub fn downcast_ref<T: Send + 'static>(&self) -> Option<&T> {
838        self.inner.downcast_ref::<T>()
839    }
840
841    /// Returns some mutable reference to the inner value if it is of type `T`, or `None` if it
842    /// isn't.
843    #[must_use]
844    pub fn downcast_mut<T: Send + 'static>(&mut self) -> Option<&mut T> {
845        self.inner.downcast_mut::<T>()
846    }
847
848    /// Attempt to downcast it to a concrete type.
849    ///
850    /// # Errors
851    ///
852    /// Returns `self` if the type mismatches.
853    pub fn downcast<T: Send + 'static>(self) -> Result<T, Self> {
854        match self.inner.downcast::<T>() {
855            Ok(v) => Ok(*v),
856            Err(inner) => Err(Self {
857                inner,
858                type_name: self.type_name,
859            }),
860        }
861    }
862}
863
864#[cfg(test)]
865mod tests {
866    use super::*;
867
868    fn _main_loop_future_is_send<S>(
869        f: MainLoop<S>,
870        input: impl AsyncBufRead + Send,
871        output: impl AsyncWrite + Send,
872    ) -> impl Send
873    where
874        S: LspService<Response = JsonValue> + Send,
875        S::Future: Send,
876        S::Error: From<Error> + Send,
877        ResponseError: From<S::Error>,
878    {
879        f.run(input, output)
880    }
881
882    #[tokio::test]
883    async fn closed_client_socket() {
884        let socket = ClientSocket::new_closed();
885        assert!(matches!(
886            socket.notify::<lsp_types::notification::Exit>(()),
887            Err(Error::ServiceStopped)
888        ));
889        assert!(matches!(
890            socket.request::<lsp_types::request::Shutdown>(()).await,
891            Err(Error::ServiceStopped)
892        ));
893        assert!(matches!(socket.emit(42i32), Err(Error::ServiceStopped)));
894    }
895
896    #[tokio::test]
897    async fn closed_server_socket() {
898        let socket = ServerSocket::new_closed();
899        assert!(matches!(
900            socket.notify::<lsp_types::notification::Exit>(()),
901            Err(Error::ServiceStopped)
902        ));
903        assert!(matches!(
904            socket.request::<lsp_types::request::Shutdown>(()).await,
905            Err(Error::ServiceStopped)
906        ));
907        assert!(matches!(socket.emit(42i32), Err(Error::ServiceStopped)));
908    }
909
910    #[test]
911    fn any_event() {
912        #[derive(Debug, Clone, PartialEq, Eq)]
913        struct MyEvent<T>(T);
914
915        let event = MyEvent("hello".to_owned());
916        let mut any_event = AnyEvent::new(event.clone());
917        assert!(any_event.type_name().contains("MyEvent"));
918
919        assert!(!any_event.is::<String>());
920        assert!(!any_event.is::<MyEvent<i32>>());
921        assert!(any_event.is::<MyEvent<String>>());
922
923        assert_eq!(any_event.downcast_ref::<i32>(), None);
924        assert_eq!(any_event.downcast_ref::<MyEvent<String>>(), Some(&event));
925
926        assert_eq!(any_event.downcast_mut::<MyEvent<i32>>(), None);
927        any_event.downcast_mut::<MyEvent<String>>().unwrap().0 += " world";
928
929        let any_event = any_event.downcast::<()>().unwrap_err();
930        let inner = any_event.downcast::<MyEvent<String>>().unwrap();
931        assert_eq!(inner.0, "hello world");
932    }
933}