use std::collections::BTreeMap;
use std::io::ErrorKind;
use std::sync::Arc;
use anyhow::{bail, ensure, Context, Result};
use percent_encoding::percent_decode_str;
use serde_json::Value;
use tokio::io::BufReader;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Mutex};
use tokio::task;
use tracing::{debug, error, info, warn, Instrument};
use uriparse::URI;
use crate::instance::{self, Instance, InstanceKey, InstanceMap};
use crate::lsp::ext::{self, LspMuxOptions, Tag};
use crate::lsp::jsonrpc::{
self, Message, Request, RequestId, ResponseError, ResponseSuccess, Version,
};
use crate::lsp::transport::{LspReader, LspWriter};
use crate::lsp::InitializeParams;
use crate::socketwrapper::{OwnedReadHalf, OwnedWriteHalf, Stream};
pub async fn process(
socket: Stream,
client_id: usize,
instance_map: Arc<Mutex<InstanceMap>>,
) -> Result<()> {
let (socket_read, socket_write) = socket.into_split();
let mut reader = LspReader::new(BufReader::new(socket_read), "client");
let writer = LspWriter::new(socket_write, "client");
let req = match reader
.read_message()
.await
.context("receive `initialize` request")?
.context("channel closed")?
{
Message::Request(req) if req.method == "initialize" => req,
_ => bail!("first client message was not `initialize` request"),
};
let mut init_params = serde_json::from_value::<InitializeParams>(req.params.clone())
.context("parse `initialize` request params")?;
let options = init_params
.initialization_options
.as_mut()
.context("missing `initializationOptions` in `initialize` request")?
.lsp_mux
.take()
.context("missing `lspMux` in `initializationOptions` in `initialize` request")?;
ensure!(
options.version == LspMuxOptions::PROTOCOL_VERSION,
"unsupported protocol version {:?}, expected {:?}",
&options.version,
LspMuxOptions::PROTOCOL_VERSION,
);
debug!(?options, "lspmux initialization");
match options.method {
ext::Request::Connect {
server,
args,
env,
cwd,
} => {
connect(
client_id,
instance_map,
(server, args, env, cwd),
req,
init_params,
reader,
writer,
)
.await
}
ext::Request::Status {} => status(instance_map, writer).await,
ext::Request::Reload { cwd } => reload(cwd, instance_map, writer).await,
}
}
#[derive(Clone)]
pub struct Client {
id: usize,
sender: mpsc::Sender<Message>,
}
impl Client {
fn new(id: usize) -> (Client, mpsc::Receiver<Message>) {
let (sender, receiver) = mpsc::channel(16);
(Client { id, sender }, receiver)
}
pub fn id(&self) -> usize {
self.id
}
pub async fn send_message(&self, message: Message) -> Result<(), SendError<Message>> {
self.sender.send(message).await
}
}
async fn status(
instance_map: Arc<Mutex<InstanceMap>>,
mut writer: LspWriter<OwnedWriteHalf>,
) -> Result<()> {
let status = instance_map.lock().await.get_status().await;
writer
.write_message(&Message::ResponseSuccess(ResponseSuccess {
jsonrpc: Version,
result: serde_json::to_value(status).unwrap(),
id: RequestId::Number(0),
}))
.await
.context("writing response")
}
async fn reload(
cwd: String,
instance_map: Arc<Mutex<InstanceMap>>,
mut writer: LspWriter<OwnedWriteHalf>,
) -> Result<()> {
if let Some(instance) = instance_map.lock().await.get_by_cwd(&cwd) {
instance
.send_message(Message::Request(Request {
jsonrpc: Version,
method: "rust-analyzer/reloadWorkspace".into(),
params: Value::Null,
id: RequestId::Number(0).tag(Tag::Drop),
}))
.await
.ok()
.context("instance closed")?;
writer
.write_message(&Message::ResponseSuccess(ResponseSuccess::null(
RequestId::Number(0),
)))
.await
.context("writing response")?;
} else {
writer
.write_message(&Message::ResponseError(ResponseError {
jsonrpc: Version,
error: jsonrpc::Error {
code: 0,
message: "no instance found".into(),
data: None,
},
id: RequestId::Number(0),
}))
.await
.context("writing response")?;
debug!(?cwd, "no instance found for path");
}
Ok(())
}
async fn connect(
client_id: usize,
instance_map: Arc<Mutex<InstanceMap>>,
(server, args, env, cwd): (
String,
Vec<String>,
BTreeMap<String, String>,
Option<String>,
),
req: Request,
init_params: InitializeParams,
mut reader: LspReader<BufReader<OwnedReadHalf>>,
mut writer: LspWriter<OwnedWriteHalf>,
) -> Result<()> {
let workspace_root = select_workspace_root(&init_params, cwd.as_deref())
.context("could not get any workspace_root")?;
let key = InstanceKey {
server,
args,
env,
workspace_root,
};
let instance = instance::get_or_spawn(instance_map, key, init_params).await?;
let res = ResponseSuccess {
jsonrpc: Version,
result: serde_json::to_value(instance.initialize_result()).unwrap(),
id: req.id,
};
writer
.write_message(&res.into())
.await
.context("send `initialize` request response")?;
match reader
.read_message()
.await
.context("receive `initialized` notification")?
.context("channel closed")?
{
Message::Notification(notif) if notif.method == "initialized" => {
}
_ => bail!("second client message was not `initialized` notification"),
}
info!("initialized client");
let (client, client_rx) = Client::new(client_id);
task::spawn(input_task(client_rx, writer).in_current_span());
instance.add_client(client.clone()).await;
task::spawn(output_task(reader, client, instance).in_current_span());
Ok(())
}
fn parse_root_uri(root_uri: &str) -> Result<String> {
let (scheme, _, mut path, _, _) = URI::try_from(root_uri)
.context("failed to parse URI")?
.into_parts();
if scheme != uriparse::Scheme::File {
bail!("only `file://` URIs are supported");
}
path.normalize(false);
let root = percent_decode_str(&path.to_string())
.decode_utf8()
.context("decoded URI was not valid utf-8")?
.to_string();
let root = match root.as_bytes() {
#[cfg(any(windows, test))]
[b'/', drive, b':', b'/', ..] if drive.is_ascii_alphabetic() => {
root.strip_prefix('/').unwrap().to_owned()
}
_ => root,
};
Ok(root)
}
#[cfg(test)]
#[test]
fn parsing_root_uris() {
use parse_root_uri as p;
assert_eq!(p("file:///home/user/proj").unwrap(), "/home/user/proj");
assert_eq!(p("file:///c:/dev/proj").unwrap(), "c:/dev/proj");
assert_eq!(p("file:///proj").unwrap(), "/proj");
assert_eq!(p("file:///d:/proj").unwrap(), "d:/proj");
assert_eq!(p("file:///").unwrap(), "/");
assert_eq!(p("file:///e:/").unwrap(), "e:/");
}
fn select_workspace_root<'a>(
init_params: &'a InitializeParams,
proxy_cwd: Option<&'a str>,
) -> Result<String> {
if init_params.workspace_folders.len() > 1 {
warn!("initialize request with multiple workspace folders isn't supported");
debug!(workspace_folders = ?init_params.workspace_folders);
}
if init_params.workspace_folders.len() == 1 {
return parse_root_uri(&init_params.workspace_folders[0].uri)
.context("parse initParams.workspaceFolders[0].uri");
}
assert!(init_params.workspace_folders.is_empty());
if let Some(root_uri) = &init_params.root_uri {
return parse_root_uri(root_uri).context("parse initParams.rootUri");
}
if let Some(root_path) = &init_params.root_path {
return Ok(root_path.to_owned());
}
if let Some(proxy_cwd) = proxy_cwd {
return Ok(proxy_cwd.to_owned());
}
bail!("could not determine a suitable workspace_root");
}
async fn input_task(mut rx: mpsc::Receiver<Message>, mut writer: LspWriter<OwnedWriteHalf>) {
while let Some(message) = rx.recv().await {
if let Err(err) = writer.write_message(&message).await {
match err.kind() {
ErrorKind::BrokenPipe => {}
_ => error!(?err, "error writing client input: {err}"),
}
break; }
}
debug!("client input closed");
info!("client disconnected");
}
async fn output_task(
mut reader: LspReader<BufReader<OwnedReadHalf>>,
client: Client,
instance: Arc<Instance>,
) {
loop {
let message = match reader.read_message().await {
Ok(Some(message)) => message,
Ok(None) => {
debug!("client output closed");
break;
}
Err(err) => {
error!(?err, "error reading client output");
continue;
}
};
instance.keep_alive();
match message {
Message::Request(req) if req.method == "shutdown" => {
info!("client sent shutdown request, sending a response and closing connection");
let res = ResponseSuccess::null(req.id);
let _ = client.send_message(res.into()).await;
break;
}
Message::Request(mut req) => {
req.id = req.id.tag(Tag::ClientId(client.id));
if instance.send_message(req.into()).await.is_err() {
break;
}
}
Message::ResponseSuccess(mut res) => match res.id.untag() {
(Some(Tag::Forward), id) => {
res.id = id;
if instance.send_message(res.into()).await.is_err() {
break;
}
}
(Some(Tag::Drop), _) => {
}
_ => {
debug!(?res, "unexpected client response");
}
},
Message::ResponseError(res) => {
warn!(?res, "client responded with error");
}
Message::Notification(notif) if notif.method == "textDocument/didOpen" => {
if let Err(err) = instance.open_file(client.id, notif.params).await {
warn!(?err, "error opening file");
}
}
Message::Notification(notif) if notif.method == "textDocument/didClose" => {
if let Err(err) = instance.close_file(client.id, notif.params).await {
warn!(?err, "error closing file");
}
}
Message::Notification(notif) => {
if instance.send_message(notif.into()).await.is_err() {
break;
}
}
}
}
if let Err(err) = instance.cleanup_client(client).await {
warn!(?err, "error cleaning up after a client");
}
}