use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use bytes::Bytes;
use tokio::sync::Semaphore;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, Span, debug, info_span, warn};
use crate::context::Context;
use crate::documents::Documents;
use crate::error::Error;
use crate::raw::{JsonRpcError, RawMessage, RequestId};
use crate::server::LanguageServer;
use crate::transport::{Transport, TransportError, TransportReader, TransportWriter};
use crate::{LspError, Result};
pub(crate) async fn run<S, T>(server: S, transport: T, concurrency_limit: usize) -> Result<Outcome>
where
S: LanguageServer,
T: Transport,
{
let (mut reader, writer) = transport.split();
let server = Arc::new(server);
let (out_tx, out_rx) = mpsc::unbounded_channel::<RawMessage>();
let send_handle = tokio::spawn(send_loop(writer, out_rx));
let state: SharedState = Arc::new(Mutex::new(State::Uninitialized));
let registry: Registry = Arc::new(Mutex::new(HashMap::new()));
let permits = Arc::new(Semaphore::new(concurrency_limit));
let mut tasks: JoinSet<()> = JoinSet::new();
loop {
while tasks.try_join_next().is_some() {}
let msg = match reader.recv().await {
Ok(msg) => msg,
Err(TransportError::Closed) => {
warn!("transport closed by peer before exit notification");
drop(out_tx);
let _ = send_handle.await;
return Ok(Outcome::TransportClosed);
}
Err(e) => return Err(Error::Transport(e)),
};
let flow = dispatch(
&server, &out_tx, &state, ®istry, &permits, &mut tasks, msg,
)
.await?;
if let Flow::Exit(code) = flow {
tasks.shutdown().await;
drop(out_tx);
let _ = send_handle.await;
return Ok(Outcome::Exit(code));
}
}
}
pub(crate) enum Outcome {
TransportClosed,
Exit(i32),
}
async fn send_loop<W: TransportWriter>(mut writer: W, mut out_rx: UnboundedReceiver<RawMessage>) {
while let Some(msg) = out_rx.recv().await {
if let Err(e) = writer.send(msg).await {
warn!(error = %e, "send_loop: transport write failed");
return;
}
}
if let Err(e) = writer.shutdown().await {
warn!(error = %e, "send_loop: transport shutdown failed");
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Uninitialized,
Running,
ShuttingDown,
}
type SharedState = Arc<Mutex<State>>;
enum Flow {
Continue,
Exit(i32),
}
type Registry = Arc<Mutex<HashMap<RequestId, CancellationToken>>>;
#[derive(serde::Deserialize)]
struct CancelParams {
id: RequestId,
}
async fn dispatch<S>(
server: &Arc<S>,
out_tx: &UnboundedSender<RawMessage>,
state: &SharedState,
registry: &Registry,
permits: &Arc<Semaphore>,
tasks: &mut JoinSet<()>,
msg: RawMessage,
) -> Result<Flow>
where
S: LanguageServer,
{
match msg {
RawMessage::Request { id, method, params } => {
let span = info_span!("request", method = %method, id = ?id);
if method != "initialize" && *state.lock().unwrap() == State::Uninitialized {
enqueue_error(out_tx, id, LspError::ServerNotInitialized);
return Ok(Flow::Continue);
}
if *state.lock().unwrap() == State::ShuttingDown {
enqueue_error(out_tx, id, LspError::invalid_request("invalid request"));
return Ok(Flow::Continue);
}
match method.as_ref() {
"initialize" => {
if *state.lock().unwrap() != State::Uninitialized {
enqueue_error(
out_tx,
id,
LspError::ServerError {
code: -32600,
message: "server already initialized".into(),
data: None,
},
);
return Ok(Flow::Continue);
}
let params = parse_params(¶ms)?;
let ctx = Context::for_request(
id.clone(),
span.clone(),
out_tx.clone(),
server.documents().clone(),
);
let result = server
.initialize(&ctx, params, CancellationToken::new())
.instrument(span)
.await;
if result.is_ok() {
*state.lock().unwrap() = State::Running;
}
enqueue_value_response(out_tx, id, result.and_then(to_value));
}
"shutdown" => {
let server = Arc::clone(server);
let state = Arc::clone(state);
let documents = server.documents().clone();
let permit = acquire_permit(permits).await;
spawn_request(
tasks,
registry,
out_tx,
span,
id,
permit,
documents,
move |ctx, ct| async move {
let result = server.shutdown(&ctx, ct).await;
if result.is_ok() {
*state.lock().unwrap() = State::ShuttingDown;
}
result.map(|()| serde_json::Value::Null)
},
);
}
other => {
enqueue_error(out_tx, id, LspError::MethodNotFound(other.to_string()));
}
}
}
RawMessage::Notification { method, params } => {
let span = info_span!("notification", method = %method);
if method != "initialized"
&& method != "exit"
&& *state.lock().unwrap() == State::Uninitialized
{
debug!(method = %method, "dropping notification before initialize");
return Ok(Flow::Continue);
}
match method.as_ref() {
"exit" => {
let ctx = Context::for_notification(
span.clone(),
out_tx.clone(),
server.documents().clone(),
);
server.exit(&ctx).instrument(span).await;
let code = if *state.lock().unwrap() == State::ShuttingDown {
0
} else {
1
};
return Ok(Flow::Exit(code));
}
"$/cancelRequest" => {
handle_cancel(registry, out_tx, ¶ms);
}
"initialized" => {
let params = parse_params(¶ms)?;
let permit = acquire_permit(permits).await;
spawn_notification(
tasks,
server,
out_tx,
span,
permit,
move |server, ctx| async move {
server.initialized(&ctx, params).await;
},
);
}
"textDocument/didOpen" => {
let params: lsp_types::DidOpenTextDocumentParams = parse_params(¶ms)?;
server.documents().open(params.text_document.clone());
let permit = acquire_permit(permits).await;
spawn_notification(
tasks,
server,
out_tx,
span,
permit,
move |server, ctx| async move {
server.text_document_did_open(&ctx, params).await;
},
);
}
"textDocument/didChange" => {
let params: lsp_types::DidChangeTextDocumentParams = parse_params(¶ms)?;
let uri = params.text_document.uri.clone();
let version = params.text_document.version;
for change in ¶ms.content_changes {
if let Err(e) = server.documents().apply_incremental_change(
&uri,
version,
change.clone(),
) {
warn!(error = %e, "textDocument/didChange: failed to apply change");
}
}
let permit = acquire_permit(permits).await;
spawn_notification(
tasks,
server,
out_tx,
span,
permit,
move |server, ctx| async move {
server.text_document_did_change(&ctx, params).await;
},
);
}
"textDocument/didClose" => {
let params: lsp_types::DidCloseTextDocumentParams = parse_params(¶ms)?;
server.documents().close(¶ms.text_document.uri);
let permit = acquire_permit(permits).await;
spawn_notification(
tasks,
server,
out_tx,
span,
permit,
move |server, ctx| async move {
server.text_document_did_close(&ctx, params).await;
},
);
}
"textDocument/didSave" => {
let params: lsp_types::DidSaveTextDocumentParams = parse_params(¶ms)?;
server.documents().save(¶ms.text_document.uri);
let permit = acquire_permit(permits).await;
spawn_notification(
tasks,
server,
out_tx,
span,
permit,
move |server, ctx| async move {
server.text_document_did_save(&ctx, params).await;
},
);
}
other => {
debug!(method = other, "unhandled notification");
}
}
}
RawMessage::Response { .. } => {
warn!("ignoring unexpected response");
}
}
Ok(Flow::Continue)
}
fn spawn_request<F, Fut>(
tasks: &mut JoinSet<()>,
registry: &Registry,
out_tx: &UnboundedSender<RawMessage>,
span: Span,
id: RequestId,
permit: tokio::sync::OwnedSemaphorePermit,
documents: Documents,
body: F,
) where
F: FnOnce(Context, CancellationToken) -> Fut + Send + 'static,
Fut: std::future::Future<Output = std::result::Result<serde_json::Value, LspError>>
+ Send
+ 'static,
{
let ct = CancellationToken::new();
let ct_for_handler = ct.clone();
let ct_for_select = ct.clone();
let registry_for_task = Arc::clone(registry);
let out_tx_for_task = out_tx.clone();
let id_for_task = id.clone();
let id_for_ctx = id.clone();
let span_for_ctx = span.clone();
let out_tx_for_ctx = out_tx.clone();
tasks.spawn(
async move {
let _permit = permit;
let ctx = Context::for_request(id_for_ctx, span_for_ctx, out_tx_for_ctx, documents);
let result = tokio::select! {
biased;
r = body(ctx, ct_for_handler) => r,
_ = ct_for_select.cancelled() => Err(LspError::RequestCancelled),
};
let still_present = registry_for_task
.lock()
.unwrap()
.remove(&id_for_task)
.is_some();
if still_present {
enqueue_value_response(&out_tx_for_task, id_for_task, result);
}
}
.instrument(span),
);
registry.lock().unwrap().insert(id, ct);
}
fn handle_cancel(registry: &Registry, out_tx: &UnboundedSender<RawMessage>, params: &Bytes) {
let bytes: &[u8] = if params.is_empty() { b"{}" } else { params };
let parsed: CancelParams = match serde_json::from_slice(bytes) {
Ok(p) => p,
Err(e) => {
debug!(error = %e, "ignoring malformed $/cancelRequest");
return;
}
};
let token = registry.lock().unwrap().remove(&parsed.id);
if let Some(token) = token {
token.cancel();
enqueue_error(out_tx, parsed.id, LspError::RequestCancelled);
}
}
fn spawn_notification<S, F, Fut>(
tasks: &mut JoinSet<()>,
server: &Arc<S>,
out_tx: &UnboundedSender<RawMessage>,
span: tracing::Span,
permit: tokio::sync::OwnedSemaphorePermit,
body: F,
) where
S: LanguageServer,
F: FnOnce(Arc<S>, Context) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let server = Arc::clone(server);
let out_tx = out_tx.clone();
let span_for_task = span.clone();
tasks.spawn(
async move {
let _permit = permit;
let ctx = Context::for_notification(span_for_task, out_tx, server.documents().clone());
body(server, ctx).await;
}
.instrument(span),
);
}
async fn acquire_permit(permits: &Arc<Semaphore>) -> tokio::sync::OwnedSemaphorePermit {
Arc::clone(permits)
.acquire_owned()
.instrument(info_span!("handler.acquire_permit"))
.await
.expect("dispatcher semaphore is never closed")
}
fn parse_params<P: serde::de::DeserializeOwned>(params: &Bytes) -> Result<P> {
let bytes: &[u8] = if params.is_empty() { b"{}" } else { params };
serde_json::from_slice(bytes).map_err(|e| LspError::invalid_params(e).into())
}
fn to_value<R: serde::Serialize>(value: R) -> std::result::Result<serde_json::Value, LspError> {
serde_json::to_value(value)
.map_err(|e| LspError::internal(format!("serialization failed: {e}")))
}
fn enqueue_value_response(
out_tx: &UnboundedSender<RawMessage>,
id: RequestId,
result: std::result::Result<serde_json::Value, LspError>,
) {
let response = match result {
Ok(value) => match serde_json::to_vec(&value) {
Ok(bytes) => RawMessage::Response {
id,
result: Ok(Bytes::from(bytes)),
},
Err(e) => error_response(
id,
&LspError::internal(format!("serialization failed: {e}")),
),
},
Err(err) => error_response(id, &err),
};
let _ = out_tx.send(response);
}
fn error_response(id: RequestId, err: &LspError) -> RawMessage {
RawMessage::Response {
id,
result: Err(JsonRpcError {
code: err.code(),
message: err.message(),
data: err.data().cloned(),
}),
}
}
fn enqueue_error(out_tx: &UnboundedSender<RawMessage>, id: RequestId, err: LspError) {
let _ = out_tx.send(error_response(id, &err));
}