1# , "/examples)")]
35#![cfg_attr(docsrs, feature(doc_cfg))]
55#![warn(missing_docs)]
56use std::any::{type_name, Any, TypeId};
57use std::collections::HashMap;
58use std::future::{poll_fn, Future};
59use std::marker::PhantomData;
60use std::ops::ControlFlow;
61use std::pin::Pin;
62use std::task::{ready, Context, Poll};
63use std::{fmt, io};
64
65use futures::channel::{mpsc, oneshot};
66use futures::io::BufReader;
67use futures::stream::FuturesUnordered;
68use futures::{
69 pin_mut, select_biased, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite,
70 AsyncWriteExt, FutureExt, SinkExt, StreamExt,
71};
72use lsp_types::notification::Notification;
73use lsp_types::request::Request;
74use lsp_types::NumberOrString;
75use pin_project_lite::pin_project;
76use serde::de::DeserializeOwned;
77use serde::{Deserialize, Serialize};
78use serde_json::Value as JsonValue;
79use thiserror::Error;
80use tower_service::Service;
81
82pub use lsp_types;
84
85macro_rules! define_getters {
86 (impl[$($generic:tt)*] $ty:ty, $field:ident : $field_ty:ty) => {
87 impl<$($generic)*> $ty {
88 #[must_use]
90 pub fn get_ref(&self) -> &$field_ty {
91 &self.$field
92 }
93
94 #[must_use]
96 pub fn get_mut(&mut self) -> &mut $field_ty {
97 &mut self.$field
98 }
99
100 #[must_use]
102 pub fn into_inner(self) -> $field_ty {
103 self.$field
104 }
105 }
106 };
107}
108
109pub mod concurrency;
110pub mod panic;
111pub mod router;
112pub mod server;
113
114#[cfg(feature = "forward")]
115#[cfg_attr(docsrs, doc(cfg(feature = "forward")))]
116mod forward;
117
118#[cfg(feature = "client-monitor")]
119#[cfg_attr(docsrs, doc(cfg(feature = "client-monitor")))]
120pub mod client_monitor;
121
122#[cfg(all(feature = "stdio", unix))]
123#[cfg_attr(docsrs, doc(cfg(all(feature = "stdio", unix))))]
124pub mod stdio;
125
126#[cfg(feature = "tracing")]
127#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
128pub mod tracing;
129
130#[cfg(feature = "omni-trait")]
131mod omni_trait;
132#[cfg(feature = "omni-trait")]
133#[cfg_attr(docsrs, doc(cfg(feature = "omni-trait")))]
134pub use omni_trait::{LanguageClient, LanguageServer};
135
136pub type Result<T, E = Error> = std::result::Result<T, E>;
138
139#[derive(Debug, thiserror::Error)]
141#[non_exhaustive]
142pub enum Error {
143 #[error("service stopped")]
145 ServiceStopped,
146 #[error("deserialization failed: {0}")]
148 Deserialize(#[from] serde_json::Error),
149 #[error("{0}")]
151 Response(#[from] ResponseError),
152 #[error("protocol error: {0}")]
154 Protocol(String),
155 #[error("{0}")]
157 Io(#[from] io::Error),
158 #[error("the underlying channel reached EOF")]
160 Eof,
161 #[error("{0}")]
166 Routing(String),
167}
168
169pub trait LspService: Service<AnyRequest> {
171 fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>>;
180
181 fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>>;
191}
192
193#[derive(
201 Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Error,
202)]
203#[error("jsonrpc error {0}")]
204pub struct ErrorCode(pub i32);
205
206impl From<i32> for ErrorCode {
207 fn from(i: i32) -> Self {
208 Self(i)
209 }
210}
211
212impl ErrorCode {
213 pub const PARSE_ERROR: Self = Self(-32700);
218
219 pub const INVALID_REQUEST: Self = Self(-32600);
223
224 pub const METHOD_NOT_FOUND: Self = Self(-32601);
228
229 pub const INVALID_PARAMS: Self = Self(-32602);
233
234 pub const INTERNAL_ERROR: Self = Self(-32603);
238
239 pub const JSONRPC_RESERVED_ERROR_RANGE_START: Self = Self(-32099);
247
248 pub const SERVER_NOT_INITIALIZED: Self = Self(-32002);
251
252 pub const UNKNOWN_ERROR_CODE: Self = Self(-32001);
254
255 pub const JSONRPC_RESERVED_ERROR_RANGE_END: Self = Self(-32000);
260
261 pub const LSP_RESERVED_ERROR_RANGE_START: Self = Self(-32899);
266
267 pub const REQUEST_FAILED: Self = Self(-32803);
274
275 pub const SERVER_CANCELLED: Self = Self(-32802);
281
282 pub const CONTENT_MODIFIED: Self = Self(-32801);
291
292 pub const REQUEST_CANCELLED: Self = Self(-32800);
295
296 pub const LSP_RESERVED_ERROR_RANGE_END: Self = Self(-32800);
301}
302
303pub type RequestId = NumberOrString;
308
309#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
310struct RawMessage<T> {
311 jsonrpc: RpcVersion,
312 #[serde(flatten)]
313 inner: T,
314}
315
316impl<T> RawMessage<T> {
317 fn new(inner: T) -> Self {
318 Self {
319 jsonrpc: RpcVersion::V2,
320 inner,
321 }
322 }
323}
324
325#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
326enum RpcVersion {
327 #[serde(rename = "2.0")]
328 V2,
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
332#[serde(untagged)]
333enum Message {
334 Request(AnyRequest),
335 Response(AnyResponse),
336 Notification(AnyNotification),
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341#[non_exhaustive]
342pub struct AnyRequest {
343 pub id: RequestId,
345 pub method: String,
347 #[serde(default)]
349 #[serde(skip_serializing_if = "serde_json::Value::is_null")]
350 pub params: serde_json::Value,
351}
352
353#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
355#[non_exhaustive]
356pub struct AnyNotification {
357 pub method: String,
359 #[serde(default)]
361 #[serde(skip_serializing_if = "serde_json::Value::is_null")]
362 pub params: JsonValue,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
367#[non_exhaustive]
368struct AnyResponse {
369 id: RequestId,
370 #[serde(skip_serializing_if = "Option::is_none")]
371 result: Option<JsonValue>,
372 #[serde(skip_serializing_if = "Option::is_none")]
373 error: Option<ResponseError>,
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize, Error)]
381#[non_exhaustive]
382#[error("{message} ({code})")]
383pub struct ResponseError {
384 pub code: ErrorCode,
386 pub message: String,
388 pub data: Option<JsonValue>,
391}
392
393impl ResponseError {
394 #[must_use]
396 pub fn new(code: ErrorCode, message: impl fmt::Display) -> Self {
397 Self {
398 code,
399 message: message.to_string(),
400 data: None,
401 }
402 }
403
404 #[must_use]
406 pub fn new_with_data(code: ErrorCode, message: impl fmt::Display, data: JsonValue) -> Self {
407 Self {
408 code,
409 message: message.to_string(),
410 data: Some(data),
411 }
412 }
413}
414
415impl Message {
416 const CONTENT_LENGTH: &'static str = "Content-Length";
417
418 async fn read(mut reader: impl AsyncBufRead + Unpin) -> Result<Self> {
419 let mut line = String::new();
420 let mut content_len = None;
421 loop {
422 line.clear();
423 reader.read_line(&mut line).await?;
424 if line.is_empty() {
425 return Err(Error::Eof);
426 }
427 if line == "\r\n" {
428 break;
429 }
430 let (name, value) = line
433 .strip_suffix("\r\n")
434 .and_then(|line| line.split_once(": "))
435 .ok_or_else(|| Error::Protocol(format!("Invalid header: {line:?}")))?;
436 if name.eq_ignore_ascii_case(Self::CONTENT_LENGTH) {
437 let value = value
438 .parse::<usize>()
439 .map_err(|_| Error::Protocol(format!("Invalid content-length: {value}")))?;
440 content_len = Some(value);
441 }
442 }
443 let content_len =
444 content_len.ok_or_else(|| Error::Protocol("Missing content-length".into()))?;
445 let mut buf = vec![0u8; content_len];
446 reader.read_exact(&mut buf).await?;
447 #[cfg(feature = "tracing")]
448 ::tracing::trace!(msg = %String::from_utf8_lossy(&buf), "incoming");
449 let msg = serde_json::from_slice::<RawMessage<Self>>(&buf)?;
450 Ok(msg.inner)
451 }
452
453 async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
454 let buf = serde_json::to_string(&RawMessage::new(self))?;
455 #[cfg(feature = "tracing")]
456 ::tracing::trace!(msg = %buf, "outgoing");
457 writer
458 .write_all(format!("{}: {}\r\n\r\n", Self::CONTENT_LENGTH, buf.len()).as_bytes())
459 .await?;
460 writer.write_all(buf.as_bytes()).await?;
461 writer.flush().await?;
462 Ok(())
463 }
464}
465
466pub struct MainLoop<S: LspService> {
468 service: S,
469 rx: mpsc::UnboundedReceiver<MainLoopEvent>,
470 outgoing_id: i32,
471 outgoing: HashMap<RequestId, oneshot::Sender<AnyResponse>>,
472 tasks: FuturesUnordered<RequestFuture<S::Future>>,
473}
474
475enum MainLoopEvent {
476 Outgoing(Message),
477 OutgoingRequest(AnyRequest, oneshot::Sender<AnyResponse>),
478 Any(AnyEvent),
479}
480
481define_getters!(impl[S: LspService] MainLoop<S>, service: S);
482
483impl<S> MainLoop<S>
484where
485 S: LspService<Response = JsonValue>,
486 ResponseError: From<S::Error>,
487{
488 #[must_use]
490 pub fn new_server(builder: impl FnOnce(ClientSocket) -> S) -> (Self, ClientSocket) {
491 let (this, socket) = Self::new(|socket| builder(ClientSocket(socket)));
492 (this, ClientSocket(socket))
493 }
494
495 #[must_use]
497 pub fn new_client(builder: impl FnOnce(ServerSocket) -> S) -> (Self, ServerSocket) {
498 let (this, socket) = Self::new(|socket| builder(ServerSocket(socket)));
499 (this, ServerSocket(socket))
500 }
501
502 fn new(builder: impl FnOnce(PeerSocket) -> S) -> (Self, PeerSocket) {
503 let (tx, rx) = mpsc::unbounded();
504 let socket = PeerSocket { tx };
505 let this = Self {
506 service: builder(socket.clone()),
507 rx,
508 outgoing_id: 0,
509 outgoing: HashMap::new(),
510 tasks: FuturesUnordered::new(),
511 };
512 (this, socket)
513 }
514
515 #[allow(clippy::missing_errors_doc)]
521 pub async fn run_buffered(self, input: impl AsyncRead, output: impl AsyncWrite) -> Result<()> {
522 self.run(BufReader::new(input), output).await
523 }
524
525 pub async fn run(mut self, input: impl AsyncBufRead, output: impl AsyncWrite) -> Result<()> {
534 pin_mut!(input, output);
535 let incoming = futures::stream::unfold(input, |mut input| async move {
536 Some((Message::read(&mut input).await, input))
537 });
538 let outgoing = futures::sink::unfold(output, |mut output, msg| async move {
539 Message::write(&msg, &mut output).await.map(|()| output)
540 });
541 pin_mut!(incoming, outgoing);
542
543 let mut flush_fut = futures::future::Fuse::terminated();
544 let ret = loop {
545 let ctl = select_biased! {
549 ret = flush_fut => { ret?; continue; }
551
552 resp = self.tasks.select_next_some() => ControlFlow::Continue(Some(Message::Response(resp))),
553 event = self.rx.next() => self.dispatch_event(event.expect("Sender is alive")),
554 msg = incoming.next() => {
555 let dispatch_fut = self.dispatch_message(msg.expect("Never ends")?).fuse();
556 pin_mut!(dispatch_fut);
557 loop {
561 select_biased! {
562 ctl = dispatch_fut => break ctl,
565 ret = flush_fut => { ret?; continue }
566 }
567 }
568 }
569 };
570 let msg = match ctl {
571 ControlFlow::Continue(Some(msg)) => msg,
572 ControlFlow::Continue(None) => continue,
573 ControlFlow::Break(ret) => break ret,
574 };
575 outgoing.feed(msg).await?;
577 flush_fut = outgoing.flush().fuse();
578 };
579
580 let flush_ret = outgoing.close().await;
585 ret.and(flush_ret)
586 }
587
588 async fn dispatch_message(&mut self, msg: Message) -> ControlFlow<Result<()>, Option<Message>> {
589 match msg {
590 Message::Request(req) => {
591 if let Err(err) = poll_fn(|cx| self.service.poll_ready(cx)).await {
592 let resp = AnyResponse {
593 id: req.id,
594 result: None,
595 error: Some(err.into()),
596 };
597 return ControlFlow::Continue(Some(Message::Response(resp)));
598 }
599 let id = req.id.clone();
600 let fut = self.service.call(req);
601 self.tasks.push(RequestFuture { fut, id: Some(id) });
602 }
603 Message::Response(resp) => {
604 if let Some(resp_tx) = self.outgoing.remove(&resp.id) {
605 let _: Result<_, _> = resp_tx.send(resp);
607 }
608 }
609 Message::Notification(notif) => {
610 self.service.notify(notif)?;
611 }
612 }
613 ControlFlow::Continue(None)
614 }
615
616 fn dispatch_event(&mut self, event: MainLoopEvent) -> ControlFlow<Result<()>, Option<Message>> {
617 match event {
618 MainLoopEvent::OutgoingRequest(mut req, resp_tx) => {
619 req.id = RequestId::Number(self.outgoing_id);
620 assert!(self.outgoing.insert(req.id.clone(), resp_tx).is_none());
621 self.outgoing_id += 1;
622 ControlFlow::Continue(Some(Message::Request(req)))
623 }
624 MainLoopEvent::Outgoing(msg) => ControlFlow::Continue(Some(msg)),
625 MainLoopEvent::Any(event) => {
626 self.service.emit(event)?;
627 ControlFlow::Continue(None)
628 }
629 }
630 }
631}
632
633pin_project! {
634 struct RequestFuture<Fut> {
635 #[pin]
636 fut: Fut,
637 id: Option<RequestId>,
638 }
639}
640
641impl<Fut, Error> Future for RequestFuture<Fut>
642where
643 Fut: Future<Output = Result<JsonValue, Error>>,
644 ResponseError: From<Error>,
645{
646 type Output = AnyResponse;
647
648 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
649 let this = self.project();
650 let (mut result, mut error) = (None, None);
651 match ready!(this.fut.poll(cx)) {
652 Ok(v) => result = Some(v),
653 Err(err) => error = Some(err.into()),
654 }
655 Poll::Ready(AnyResponse {
656 id: this.id.take().expect("Future is consumed"),
657 result,
658 error,
659 })
660 }
661}
662
663macro_rules! impl_socket_wrapper {
664 ($name:ident) => {
665 impl $name {
666 #[must_use]
676 pub fn new_closed() -> Self {
677 Self(PeerSocket::new_closed())
678 }
679
680 pub async fn request<R: Request>(&self, params: R::Params) -> Result<R::Result> {
686 self.0.request::<R>(params).await
687 }
688
689 pub fn notify<N: Notification>(&self, params: N::Params) -> Result<()> {
697 self.0.notify::<N>(params)
698 }
699
700 pub fn emit<E: Send + 'static>(&self, event: E) -> Result<()> {
708 self.0.emit::<E>(event)
709 }
710 }
711 };
712}
713
714#[derive(Debug, Clone)]
716pub struct ClientSocket(PeerSocket);
717impl_socket_wrapper!(ClientSocket);
718
719#[derive(Debug, Clone)]
721pub struct ServerSocket(PeerSocket);
722impl_socket_wrapper!(ServerSocket);
723
724#[derive(Debug, Clone)]
725struct PeerSocket {
726 tx: mpsc::UnboundedSender<MainLoopEvent>,
727}
728
729impl PeerSocket {
730 fn new_closed() -> Self {
731 let (tx, _rx) = mpsc::unbounded();
732 Self { tx }
733 }
734
735 fn send(&self, v: MainLoopEvent) -> Result<()> {
736 self.tx.unbounded_send(v).map_err(|_| Error::ServiceStopped)
737 }
738
739 fn request<R: Request>(&self, params: R::Params) -> PeerSocketRequestFuture<R::Result> {
740 let req = AnyRequest {
741 id: RequestId::Number(0),
742 method: R::METHOD.into(),
743 params: serde_json::to_value(params).expect("Failed to serialize"),
744 };
745 let (tx, rx) = oneshot::channel();
746 let _: Result<_, _> = self.send(MainLoopEvent::OutgoingRequest(req, tx));
749 PeerSocketRequestFuture {
750 rx,
751 _marker: PhantomData,
752 }
753 }
754
755 fn notify<N: Notification>(&self, params: N::Params) -> Result<()> {
756 let notif = AnyNotification {
757 method: N::METHOD.into(),
758 params: serde_json::to_value(params).expect("Failed to serialize"),
759 };
760 self.send(MainLoopEvent::Outgoing(Message::Notification(notif)))
761 }
762
763 pub fn emit<E: Send + 'static>(&self, event: E) -> Result<()> {
764 self.send(MainLoopEvent::Any(AnyEvent::new(event)))
765 }
766}
767
768struct PeerSocketRequestFuture<T> {
769 rx: oneshot::Receiver<AnyResponse>,
770 _marker: PhantomData<fn() -> T>,
771}
772
773impl<T: DeserializeOwned> Future for PeerSocketRequestFuture<T> {
774 type Output = Result<T>;
775
776 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
777 let resp = ready!(Pin::new(&mut self.rx)
778 .poll(cx)
779 .map_err(|_| Error::ServiceStopped))?;
780 Poll::Ready(match resp.error {
781 None => Ok(serde_json::from_value(resp.result.unwrap_or_default())?),
782 Some(err) => Err(Error::Response(err)),
783 })
784 }
785}
786
787pub struct AnyEvent {
794 inner: Box<dyn Any + Send>,
795 type_name: &'static str,
796}
797
798impl fmt::Debug for AnyEvent {
799 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
800 f.debug_struct("AnyEvent")
801 .field("type_name", &self.type_name)
802 .finish_non_exhaustive()
803 }
804}
805
806impl AnyEvent {
807 #[must_use]
808 fn new<T: Send + 'static>(v: T) -> Self {
809 AnyEvent {
810 inner: Box::new(v),
811 type_name: type_name::<T>(),
812 }
813 }
814
815 #[must_use]
816 fn inner_type_id(&self) -> TypeId {
817 Any::type_id(&*self.inner)
819 }
820
821 #[must_use]
825 pub fn type_name(&self) -> &'static str {
826 self.type_name
827 }
828
829 #[must_use]
831 pub fn is<T: Send + 'static>(&self) -> bool {
832 self.inner.is::<T>()
833 }
834
835 #[must_use]
837 pub fn downcast_ref<T: Send + 'static>(&self) -> Option<&T> {
838 self.inner.downcast_ref::<T>()
839 }
840
841 #[must_use]
844 pub fn downcast_mut<T: Send + 'static>(&mut self) -> Option<&mut T> {
845 self.inner.downcast_mut::<T>()
846 }
847
848 pub fn downcast<T: Send + 'static>(self) -> Result<T, Self> {
854 match self.inner.downcast::<T>() {
855 Ok(v) => Ok(*v),
856 Err(inner) => Err(Self {
857 inner,
858 type_name: self.type_name,
859 }),
860 }
861 }
862}
863
864#[cfg(test)]
865mod tests {
866 use super::*;
867
868 fn _main_loop_future_is_send<S>(
869 f: MainLoop<S>,
870 input: impl AsyncBufRead + Send,
871 output: impl AsyncWrite + Send,
872 ) -> impl Send
873 where
874 S: LspService<Response = JsonValue> + Send,
875 S::Future: Send,
876 S::Error: From<Error> + Send,
877 ResponseError: From<S::Error>,
878 {
879 f.run(input, output)
880 }
881
882 #[tokio::test]
883 async fn closed_client_socket() {
884 let socket = ClientSocket::new_closed();
885 assert!(matches!(
886 socket.notify::<lsp_types::notification::Exit>(()),
887 Err(Error::ServiceStopped)
888 ));
889 assert!(matches!(
890 socket.request::<lsp_types::request::Shutdown>(()).await,
891 Err(Error::ServiceStopped)
892 ));
893 assert!(matches!(socket.emit(42i32), Err(Error::ServiceStopped)));
894 }
895
896 #[tokio::test]
897 async fn closed_server_socket() {
898 let socket = ServerSocket::new_closed();
899 assert!(matches!(
900 socket.notify::<lsp_types::notification::Exit>(()),
901 Err(Error::ServiceStopped)
902 ));
903 assert!(matches!(
904 socket.request::<lsp_types::request::Shutdown>(()).await,
905 Err(Error::ServiceStopped)
906 ));
907 assert!(matches!(socket.emit(42i32), Err(Error::ServiceStopped)));
908 }
909
910 #[test]
911 fn any_event() {
912 #[derive(Debug, Clone, PartialEq, Eq)]
913 struct MyEvent<T>(T);
914
915 let event = MyEvent("hello".to_owned());
916 let mut any_event = AnyEvent::new(event.clone());
917 assert!(any_event.type_name().contains("MyEvent"));
918
919 assert!(!any_event.is::<String>());
920 assert!(!any_event.is::<MyEvent<i32>>());
921 assert!(any_event.is::<MyEvent<String>>());
922
923 assert_eq!(any_event.downcast_ref::<i32>(), None);
924 assert_eq!(any_event.downcast_ref::<MyEvent<String>>(), Some(&event));
925
926 assert_eq!(any_event.downcast_mut::<MyEvent<i32>>(), None);
927 any_event.downcast_mut::<MyEvent<String>>().unwrap().0 += " world";
928
929 let any_event = any_event.downcast::<()>().unwrap_err();
930 let inner = any_event.downcast::<MyEvent<String>>().unwrap();
931 assert_eq!(inner.0, "hello world");
932 }
933}