# , "/examples)")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(missing_docs)]
use std::any::{type_name, Any};
use std::collections::HashMap;
use std::future::{poll_fn, Future};
use std::marker::PhantomData;
use std::ops::ControlFlow;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::{fmt, io};
use futures::channel::{mpsc, oneshot};
use futures::io::BufReader;
use futures::stream::FuturesUnordered;
use futures::{
pin_mut, select_biased, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite,
AsyncWriteExt, FutureExt, SinkExt, StreamExt,
};
use lsp_types::notification::Notification;
use lsp_types::request::Request;
use lsp_types::NumberOrString;
use pin_project_lite::pin_project;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use thiserror::Error;
use tower_service::Service;
pub use lsp_types;
macro_rules! define_getters {
(impl[$($generic:tt)*] $ty:ty, $field:ident : $field_ty:ty) => {
impl<$($generic)*> $ty {
#[must_use]
pub fn get_ref(&self) -> &$field_ty {
&self.$field
}
#[must_use]
pub fn get_mut(&mut self) -> &mut $field_ty {
&mut self.$field
}
#[must_use]
pub fn into_inner(self) -> $field_ty {
self.$field
}
}
};
}
pub mod concurrency;
pub mod panic;
pub mod router;
pub mod server;
#[cfg(feature = "forward")]
#[cfg_attr(docsrs, doc(cfg(feature = "forward")))]
mod forward;
#[cfg(feature = "client-monitor")]
#[cfg_attr(docsrs, doc(cfg(feature = "client-monitor")))]
pub mod client_monitor;
#[cfg(all(feature = "stdio", unix))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "stdio", unix))))]
pub mod stdio;
#[cfg(feature = "tracing")]
#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
pub mod tracing;
#[cfg(feature = "omni-trait")]
mod omni_trait;
#[cfg(feature = "omni-trait")]
#[cfg_attr(docsrs, doc(cfg(feature = "omni-trait")))]
pub use omni_trait::{LanguageClient, LanguageServer};
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("service stopped")]
ServiceStopped,
#[error("deserialization failed: {0}")]
Deserialize(#[from] serde_json::Error),
#[error("{0}")]
Response(#[from] ResponseError),
#[error("protocol error: {0}")]
Protocol(String),
#[error("{0}")]
Io(#[from] io::Error),
#[error("the underlying channel reached EOF")]
Eof,
#[error("{0}")]
Routing(String),
}
pub trait LspService: Service<AnyRequest> {
fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>>;
fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>>;
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Error,
)]
#[error("jsonrpc error {0}")]
pub struct ErrorCode(pub i32);
impl From<i32> for ErrorCode {
fn from(i: i32) -> Self {
Self(i)
}
}
impl ErrorCode {
pub const PARSE_ERROR: Self = Self(-32700);
pub const INVALID_REQUEST: Self = Self(-32600);
pub const METHOD_NOT_FOUND: Self = Self(-32601);
pub const INVALID_PARAMS: Self = Self(-32602);
pub const INTERNAL_ERROR: Self = Self(-32603);
pub const JSONRPC_RESERVED_ERROR_RANGE_START: Self = Self(-32099);
pub const SERVER_NOT_INITIALIZED: Self = Self(-32002);
pub const UNKNOWN_ERROR_CODE: Self = Self(-32001);
pub const JSONRPC_RESERVED_ERROR_RANGE_END: Self = Self(-32000);
pub const LSP_RESERVED_ERROR_RANGE_START: Self = Self(-32899);
pub const REQUEST_FAILED: Self = Self(-32803);
pub const SERVER_CANCELLED: Self = Self(-32802);
pub const CONTENT_MODIFIED: Self = Self(-32801);
pub const REQUEST_CANCELLED: Self = Self(-32800);
pub const LSP_RESERVED_ERROR_RANGE_END: Self = Self(-32800);
}
pub type RequestId = NumberOrString;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct RawMessage<T> {
jsonrpc: RpcVersion,
#[serde(flatten)]
inner: T,
}
impl<T> RawMessage<T> {
fn new(inner: T) -> Self {
Self {
jsonrpc: RpcVersion::V2,
inner,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
enum RpcVersion {
#[serde(rename = "2.0")]
V2,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
enum Message {
Request(AnyRequest),
Response(AnyResponse),
Notification(AnyNotification),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct AnyRequest {
pub id: RequestId,
pub method: String,
#[serde(default)]
#[serde(skip_serializing_if = "serde_json::Value::is_null")]
pub params: serde_json::Value,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct AnyNotification {
pub method: String,
#[serde(default)]
#[serde(skip_serializing_if = "serde_json::Value::is_null")]
pub params: JsonValue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
struct AnyResponse {
id: RequestId,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<JsonValue>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<ResponseError>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Error)]
#[non_exhaustive]
#[error("{message} ({code})")]
pub struct ResponseError {
pub code: ErrorCode,
pub message: String,
pub data: Option<JsonValue>,
}
impl ResponseError {
#[must_use]
pub fn new(code: ErrorCode, message: impl fmt::Display) -> Self {
Self {
code,
message: message.to_string(),
data: None,
}
}
#[must_use]
pub fn new_with_data(code: ErrorCode, message: impl fmt::Display, data: JsonValue) -> Self {
Self {
code,
message: message.to_string(),
data: Some(data),
}
}
}
impl Message {
const CONTENT_LENGTH: &'static str = "Content-Length";
async fn read(mut reader: impl AsyncBufRead + Unpin) -> Result<Self> {
let mut line = String::new();
let mut content_len = None;
loop {
line.clear();
reader.read_line(&mut line).await?;
if line.is_empty() {
return Err(Error::Eof);
}
if line == "\r\n" {
break;
}
let (name, value) = line
.strip_suffix("\r\n")
.and_then(|line| line.split_once(": "))
.ok_or_else(|| Error::Protocol(format!("Invalid header: {line:?}")))?;
if name.eq_ignore_ascii_case(Self::CONTENT_LENGTH) {
let value = value
.parse::<usize>()
.map_err(|_| Error::Protocol(format!("Invalid content-length: {value}")))?;
content_len = Some(value);
}
}
let content_len =
content_len.ok_or_else(|| Error::Protocol("Missing content-length".into()))?;
let mut buf = vec![0u8; content_len];
reader.read_exact(&mut buf).await?;
#[cfg(feature = "tracing")]
::tracing::trace!(msg = %String::from_utf8_lossy(&buf), "incoming");
let msg = serde_json::from_slice::<RawMessage<Self>>(&buf)?;
Ok(msg.inner)
}
async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
let buf = serde_json::to_string(&RawMessage::new(self))?;
#[cfg(feature = "tracing")]
::tracing::trace!(msg = %buf, "outgoing");
writer
.write_all(format!("{}: {}\r\n\r\n", Self::CONTENT_LENGTH, buf.len()).as_bytes())
.await?;
writer.write_all(buf.as_bytes()).await?;
writer.flush().await?;
Ok(())
}
}
pub struct MainLoop<S: LspService> {
service: S,
rx: mpsc::UnboundedReceiver<MainLoopEvent>,
outgoing_id: i32,
outgoing: HashMap<RequestId, oneshot::Sender<AnyResponse>>,
tasks: FuturesUnordered<RequestFuture<S::Future>>,
}
enum MainLoopEvent {
Outgoing(Message),
OutgoingRequest(AnyRequest, oneshot::Sender<AnyResponse>),
Any(AnyEvent),
}
define_getters!(impl[S: LspService] MainLoop<S>, service: S);
impl<S> MainLoop<S>
where
S: LspService<Response = JsonValue>,
ResponseError: From<S::Error>,
{
#[must_use]
pub fn new_server(builder: impl FnOnce(ClientSocket) -> S) -> (Self, ClientSocket) {
let (this, socket) = Self::new(|socket| builder(ClientSocket(socket)));
(this, ClientSocket(socket))
}
#[must_use]
pub fn new_client(builder: impl FnOnce(ServerSocket) -> S) -> (Self, ServerSocket) {
let (this, socket) = Self::new(|socket| builder(ServerSocket(socket)));
(this, ServerSocket(socket))
}
fn new(builder: impl FnOnce(PeerSocket) -> S) -> (Self, PeerSocket) {
let (tx, rx) = mpsc::unbounded();
let socket = PeerSocket { tx };
let this = Self {
service: builder(socket.clone()),
rx,
outgoing_id: 0,
outgoing: HashMap::new(),
tasks: FuturesUnordered::new(),
};
(this, socket)
}
#[allow(clippy::missing_errors_doc)]
pub async fn run_buffered(self, input: impl AsyncRead, output: impl AsyncWrite) -> Result<()> {
self.run(BufReader::new(input), output).await
}
pub async fn run(mut self, input: impl AsyncBufRead, output: impl AsyncWrite) -> Result<()> {
pin_mut!(input, output);
let incoming = futures::stream::unfold(input, |mut input| async move {
Some((Message::read(&mut input).await, input))
});
let outgoing = futures::sink::unfold(output, |mut output, msg| async move {
Message::write(&msg, &mut output).await.map(|()| output)
});
pin_mut!(incoming, outgoing);
let mut flush_fut = futures::future::Fuse::terminated();
let ret = loop {
let ctl = select_biased! {
ret = flush_fut => { ret?; continue; }
resp = self.tasks.select_next_some() => ControlFlow::Continue(Some(Message::Response(resp))),
event = self.rx.next() => self.dispatch_event(event.expect("Sender is alive")),
msg = incoming.next() => {
let dispatch_fut = self.dispatch_message(msg.expect("Never ends")?).fuse();
pin_mut!(dispatch_fut);
loop {
select_biased! {
ctl = dispatch_fut => break ctl,
ret = flush_fut => { ret?; continue }
}
}
}
};
let msg = match ctl {
ControlFlow::Continue(Some(msg)) => msg,
ControlFlow::Continue(None) => continue,
ControlFlow::Break(ret) => break ret,
};
outgoing.feed(msg).await?;
flush_fut = outgoing.flush().fuse();
};
let flush_ret = outgoing.close().await;
ret.and(flush_ret)
}
async fn dispatch_message(&mut self, msg: Message) -> ControlFlow<Result<()>, Option<Message>> {
match msg {
Message::Request(req) => {
if let Err(err) = poll_fn(|cx| self.service.poll_ready(cx)).await {
let resp = AnyResponse {
id: req.id,
result: None,
error: Some(err.into()),
};
return ControlFlow::Continue(Some(Message::Response(resp)));
}
let id = req.id.clone();
let fut = self.service.call(req);
self.tasks.push(RequestFuture { fut, id: Some(id) });
}
Message::Response(resp) => {
if let Some(resp_tx) = self.outgoing.remove(&resp.id) {
let _: Result<_, _> = resp_tx.send(resp);
}
}
Message::Notification(notif) => {
self.service.notify(notif)?;
}
}
ControlFlow::Continue(None)
}
fn dispatch_event(&mut self, event: MainLoopEvent) -> ControlFlow<Result<()>, Option<Message>> {
match event {
MainLoopEvent::OutgoingRequest(mut req, resp_tx) => {
req.id = RequestId::Number(self.outgoing_id);
assert!(self.outgoing.insert(req.id.clone(), resp_tx).is_none());
self.outgoing_id += 1;
ControlFlow::Continue(Some(Message::Request(req)))
}
MainLoopEvent::Outgoing(msg) => ControlFlow::Continue(Some(msg)),
MainLoopEvent::Any(event) => {
self.service.emit(event)?;
ControlFlow::Continue(None)
}
}
}
}
pin_project! {
struct RequestFuture<Fut> {
#[pin]
fut: Fut,
id: Option<RequestId>,
}
}
impl<Fut, Error> Future for RequestFuture<Fut>
where
Fut: Future<Output = Result<JsonValue, Error>>,
ResponseError: From<Error>,
{
type Output = AnyResponse;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let (mut result, mut error) = (None, None);
match ready!(this.fut.poll(cx)) {
Ok(v) => result = Some(v),
Err(err) => error = Some(err.into()),
}
Poll::Ready(AnyResponse {
id: this.id.take().expect("Future is consumed"),
result,
error,
})
}
}
macro_rules! impl_socket_wrapper {
($name:ident) => {
impl $name {
#[must_use]
pub fn new_closed() -> Self {
Self(PeerSocket::new_closed())
}
pub async fn request<R: Request>(&self, params: R::Params) -> Result<R::Result> {
self.0.request::<R>(params).await
}
pub fn notify<N: Notification>(&self, params: N::Params) -> Result<()> {
self.0.notify::<N>(params)
}
pub fn emit<E: Send + 'static>(&self, event: E) -> Result<()> {
self.0.emit::<E>(event)
}
}
};
}
#[derive(Debug, Clone)]
pub struct ClientSocket(PeerSocket);
impl_socket_wrapper!(ClientSocket);
#[derive(Debug, Clone)]
pub struct ServerSocket(PeerSocket);
impl_socket_wrapper!(ServerSocket);
#[derive(Debug, Clone)]
struct PeerSocket {
tx: mpsc::UnboundedSender<MainLoopEvent>,
}
impl PeerSocket {
fn new_closed() -> Self {
let (tx, _rx) = mpsc::unbounded();
Self { tx }
}
fn send(&self, v: MainLoopEvent) -> Result<()> {
self.tx.unbounded_send(v).map_err(|_| Error::ServiceStopped)
}
fn request<R: Request>(&self, params: R::Params) -> PeerSocketRequestFuture<R::Result> {
let req = AnyRequest {
id: RequestId::Number(0),
method: R::METHOD.into(),
params: serde_json::to_value(params).expect("Failed to serialize"),
};
let (tx, rx) = oneshot::channel();
let _: Result<_, _> = self.send(MainLoopEvent::OutgoingRequest(req, tx));
PeerSocketRequestFuture {
rx,
_marker: PhantomData,
}
}
fn notify<N: Notification>(&self, params: N::Params) -> Result<()> {
let notif = AnyNotification {
method: N::METHOD.into(),
params: serde_json::to_value(params).expect("Failed to serialize"),
};
self.send(MainLoopEvent::Outgoing(Message::Notification(notif)))
}
pub fn emit<E: Send + 'static>(&self, event: E) -> Result<()> {
self.send(MainLoopEvent::Any(AnyEvent::new(event)))
}
}
struct PeerSocketRequestFuture<T> {
rx: oneshot::Receiver<AnyResponse>,
_marker: PhantomData<fn() -> T>,
}
impl<T: DeserializeOwned> Future for PeerSocketRequestFuture<T> {
type Output = Result<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let resp = ready!(Pin::new(&mut self.rx)
.poll(cx)
.map_err(|_| Error::ServiceStopped))?;
Poll::Ready(match resp.error {
None => Ok(serde_json::from_value(resp.result.unwrap_or_default())?),
Some(err) => Err(Error::Response(err)),
})
}
}
pub struct AnyEvent {
inner: Box<dyn Any + Send>,
type_name: &'static str,
}
impl fmt::Debug for AnyEvent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AnyEvent")
.field("type_name", &self.type_name)
.finish_non_exhaustive()
}
}
impl AnyEvent {
#[must_use]
fn new<T: Send + 'static>(v: T) -> Self {
AnyEvent {
inner: Box::new(v),
type_name: type_name::<T>(),
}
}
#[must_use]
pub fn inner(&self) -> &(dyn Any + Send) {
&*self.inner
}
#[must_use]
pub fn into_inner(self) -> Box<dyn Any + Send> {
self.inner
}
#[must_use]
pub fn type_name(&self) -> &'static str {
self.type_name
}
#[must_use]
pub fn is<T: Send + 'static>(&self) -> bool {
self.inner.is::<T>()
}
#[must_use]
pub fn downcast_ref<T: Send + 'static>(&self) -> Option<&T> {
self.inner.downcast_ref::<T>()
}
#[must_use]
pub fn downcast_mut<T: Send + 'static>(&mut self) -> Option<&mut T> {
self.inner.downcast_mut::<T>()
}
pub fn downcast<T: Send + 'static>(self) -> Result<T, Self> {
match self.inner.downcast::<T>() {
Ok(v) => Ok(*v),
Err(inner) => Err(Self {
inner,
type_name: self.type_name,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn _main_loop_future_is_send<S>(
f: MainLoop<S>,
input: impl AsyncBufRead + Send,
output: impl AsyncWrite + Send,
) -> impl Send
where
S: LspService<Response = JsonValue> + Send,
S::Future: Send,
S::Error: From<Error> + Send,
ResponseError: From<S::Error>,
{
f.run(input, output)
}
#[tokio::test]
async fn closed_client_socket() {
let socket = ClientSocket::new_closed();
assert!(matches!(
socket.notify::<lsp_types::notification::Exit>(()),
Err(Error::ServiceStopped)
));
assert!(matches!(
socket.request::<lsp_types::request::Shutdown>(()).await,
Err(Error::ServiceStopped)
));
assert!(matches!(socket.emit(42i32), Err(Error::ServiceStopped)));
}
#[tokio::test]
async fn closed_server_socket() {
let socket = ServerSocket::new_closed();
assert!(matches!(
socket.notify::<lsp_types::notification::Exit>(()),
Err(Error::ServiceStopped)
));
assert!(matches!(
socket.request::<lsp_types::request::Shutdown>(()).await,
Err(Error::ServiceStopped)
));
assert!(matches!(socket.emit(42i32), Err(Error::ServiceStopped)));
}
#[test]
fn any_event() {
#[derive(Debug, Clone, PartialEq, Eq)]
struct MyEvent<T>(T);
let event = MyEvent("hello".to_owned());
let mut any_event = AnyEvent::new(event.clone());
assert!(any_event.type_name().contains("MyEvent"));
assert!(!any_event.is::<String>());
assert!(!any_event.is::<MyEvent<i32>>());
assert!(any_event.is::<MyEvent<String>>());
assert_eq!(any_event.downcast_ref::<i32>(), None);
assert_eq!(any_event.downcast_ref::<MyEvent<String>>(), Some(&event));
assert_eq!(any_event.downcast_mut::<MyEvent<i32>>(), None);
any_event.downcast_mut::<MyEvent<String>>().unwrap().0 += " world";
let any_event = any_event.downcast::<()>().unwrap_err();
let inner = any_event.downcast::<MyEvent<String>>().unwrap();
assert_eq!(inner.0, "hello world");
}
}