use std::collections::HashMap;
use std::io::{self, BufReader, BufWriter};
use std::path::{Path, PathBuf};
use std::process::{Child, Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
use serde::de::DeserializeOwned;
use serde_json::{json, Value};
use crate::lsp::child_registry::LspChildRegistry;
use crate::lsp::jsonrpc::{
Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
};
use crate::lsp::registry::ServerKind;
use crate::lsp::{transport, LspError};
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServerState {
Starting,
Initializing,
Ready,
ShuttingDown,
Exited,
}
#[derive(Debug)]
pub enum LspEvent {
Notification {
server_kind: ServerKind,
root: PathBuf,
method: String,
params: Option<Value>,
},
ServerRequest {
server_kind: ServerKind,
root: PathBuf,
id: RequestId,
method: String,
params: Option<Value>,
},
ServerExited {
server_kind: ServerKind,
root: PathBuf,
},
}
#[derive(Debug, Clone, Default)]
pub struct ServerDiagnosticCapabilities {
pub pull_diagnostics: bool,
pub workspace_diagnostics: bool,
pub identifier: Option<String>,
pub refresh_support: bool,
}
pub struct LspClient {
kind: ServerKind,
root: PathBuf,
state: ServerState,
child: Child,
child_pid: u32,
writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
pending: Arc<Mutex<PendingMap>>,
next_id: AtomicI64,
diagnostic_caps: Option<ServerDiagnosticCapabilities>,
supports_watched_files: bool,
child_registry: LspChildRegistry,
}
impl LspClient {
pub fn spawn(
kind: ServerKind,
root: PathBuf,
binary: &Path,
args: &[String],
env: &HashMap<String, String>,
event_tx: Sender<LspEvent>,
child_registry: LspChildRegistry,
) -> io::Result<Self> {
let mut command = Command::new(binary);
command
.args(args)
.current_dir(&root)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null());
for (key, value) in env {
command.env(key, value);
}
#[cfg(unix)]
unsafe {
use std::os::unix::process::CommandExt;
command.pre_exec(|| {
if libc::setsid() == -1 {
return Err(io::Error::last_os_error());
}
Ok(())
});
}
let mut child = command.spawn()?;
let child_pid = child.id();
child_registry.track(child_pid);
let stdout = child
.stdout
.take()
.ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
let pending = Arc::new(Mutex::new(PendingMap::new()));
let reader_pending = Arc::clone(&pending);
let reader_writer = Arc::clone(&writer);
let reader_kind = kind.clone();
let reader_root = root.clone();
thread::spawn(move || {
let mut reader = BufReader::new(stdout);
loop {
match transport::read_message(&mut reader) {
Ok(Some(ServerMessage::Response(response))) => {
if let Ok(mut guard) = reader_pending.lock() {
if let Some(tx) = guard.remove(&response.id) {
if tx.send(response).is_err() {
log::debug!("response channel closed");
}
}
} else {
let _ = event_tx.send(LspEvent::ServerExited {
server_kind: reader_kind.clone(),
root: reader_root.clone(),
});
break;
}
}
Ok(Some(ServerMessage::Notification { method, params })) => {
let _ = event_tx.send(LspEvent::Notification {
server_kind: reader_kind.clone(),
root: reader_root.clone(),
method,
params,
});
}
Ok(Some(ServerMessage::Request { id, method, params })) => {
let response_value = if method == "workspace/configuration" {
let item_count = params
.as_ref()
.and_then(|p| p.get("items"))
.and_then(|items| items.as_array())
.map_or(1, |arr| arr.len());
serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
} else {
serde_json::Value::Null
};
if let Ok(mut w) = reader_writer.lock() {
let response = super::jsonrpc::OutgoingResponse::success(
id.clone(),
response_value,
);
let _ = transport::write_response(&mut *w, &response);
}
let _ = event_tx.send(LspEvent::ServerRequest {
server_kind: reader_kind.clone(),
root: reader_root.clone(),
id,
method,
params,
});
}
Ok(None) | Err(_) => {
if let Ok(mut guard) = reader_pending.lock() {
guard.clear();
}
let _ = event_tx.send(LspEvent::ServerExited {
server_kind: reader_kind.clone(),
root: reader_root.clone(),
});
break;
}
}
}
});
Ok(Self {
kind,
root,
state: ServerState::Starting,
child,
child_pid,
writer,
pending,
next_id: AtomicI64::new(1),
diagnostic_caps: None,
supports_watched_files: false,
child_registry,
})
}
pub fn initialize(
&mut self,
workspace_root: &Path,
initialization_options: Option<serde_json::Value>,
) -> Result<lsp_types::InitializeResult, LspError> {
self.ensure_can_send()?;
self.state = ServerState::Initializing;
let normalized = normalize_windows_path(workspace_root);
let root_url = url::Url::from_file_path(&normalized).map_err(|_| {
LspError::NotFound(format!(
"failed to convert workspace root '{}' to file URI",
workspace_root.display()
))
})?;
let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
LspError::NotFound(format!(
"failed to convert workspace root '{}' to file URI",
workspace_root.display()
))
})?;
let mut params_value = json!({
"processId": std::process::id(),
"rootUri": root_uri,
"capabilities": {
"workspace": {
"workspaceFolders": true,
"configuration": true,
"diagnostic": {
"refreshSupport": false
}
},
"textDocument": {
"synchronization": {
"dynamicRegistration": false,
"didSave": true,
"willSave": false,
"willSaveWaitUntil": false
},
"publishDiagnostics": {
"relatedInformation": true,
"versionSupport": true,
"codeDescriptionSupport": true,
"dataSupport": true
},
"diagnostic": {
"dynamicRegistration": false,
"relatedDocumentSupport": true
}
}
},
"clientInfo": {
"name": "aft",
"version": env!("CARGO_PKG_VERSION")
},
"workspaceFolders": [
{
"uri": root_uri,
"name": workspace_root
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("workspace")
}
]
});
if let Some(initialization_options) = initialization_options {
params_value["initializationOptions"] = initialization_options;
}
let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
let result_value = self.send_request_value(
<lsp_types::request::Initialize as lsp_types::request::Request>::METHOD,
params,
)?;
let result: lsp_types::InitializeResult = serde_json::from_value(result_value.clone())?;
let caps_value = result_value
.get("capabilities")
.cloned()
.unwrap_or_else(|| serde_json::to_value(&result.capabilities).unwrap_or(Value::Null));
self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
self.supports_watched_files = caps_value
.pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
.and_then(|v| v.as_bool())
.unwrap_or(false)
|| caps_value
.pointer("/workspace/didChangeWatchedFiles")
.map(|v| v.is_object() || v.as_bool() == Some(true))
.unwrap_or(false);
self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
json!({}),
)?)?;
self.state = ServerState::Ready;
Ok(result)
}
pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
self.diagnostic_caps.as_ref()
}
pub fn supports_watched_files(&self) -> bool {
self.supports_watched_files
}
pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
where
R: lsp_types::request::Request,
R::Params: serde::Serialize,
R::Result: DeserializeOwned,
{
self.ensure_can_send()?;
let value = self.send_request_value(R::METHOD, params)?;
serde_json::from_value(value).map_err(Into::into)
}
fn send_request_value<P>(&mut self, method: &'static str, params: P) -> Result<Value, LspError>
where
P: serde::Serialize,
{
self.ensure_can_send()?;
let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
let (tx, rx) = bounded(1);
{
let mut pending = self.lock_pending()?;
pending.insert(id.clone(), tx);
}
let request = Request::new(id.clone(), method, Some(serde_json::to_value(params)?));
{
let mut writer = self
.writer
.lock()
.map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
if let Err(err) = transport::write_request(&mut *writer, &request) {
self.remove_pending(&id);
return Err(err.into());
}
}
let response = match rx.recv_timeout(REQUEST_TIMEOUT) {
Ok(response) => response,
Err(RecvTimeoutError::Timeout) => {
self.remove_pending(&id);
return Err(LspError::Timeout(format!(
"timed out waiting for '{}' response from {:?}",
method, self.kind
)));
}
Err(RecvTimeoutError::Disconnected) => {
self.remove_pending(&id);
return Err(LspError::ServerNotReady(format!(
"language server {:?} disconnected while waiting for '{}'",
self.kind, method
)));
}
};
if let Some(error) = response.error {
return Err(LspError::ServerError {
code: error.code,
message: error.message,
});
}
Ok(response.result.unwrap_or(Value::Null))
}
pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
where
N: lsp_types::notification::Notification,
N::Params: serde::Serialize,
{
self.ensure_can_send()?;
let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
let mut writer = self
.writer
.lock()
.map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
transport::write_notification(&mut *writer, ¬ification)?;
Ok(())
}
pub fn shutdown(&mut self) -> Result<(), LspError> {
if self.state == ServerState::Exited {
self.child_registry.untrack(self.child_pid);
return Ok(());
}
if self.child.try_wait()?.is_some() {
self.state = ServerState::Exited;
self.child_registry.untrack(self.child_pid);
return Ok(());
}
if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
self.state = ServerState::ShuttingDown;
if self.child.try_wait()?.is_some() {
self.state = ServerState::Exited;
return Ok(());
}
return Err(err);
}
self.state = ServerState::ShuttingDown;
if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
if self.child.try_wait()?.is_some() {
self.state = ServerState::Exited;
return Ok(());
}
return Err(err);
}
let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
loop {
if self.child.try_wait()?.is_some() {
self.state = ServerState::Exited;
return Ok(());
}
if Instant::now() >= deadline {
kill_lsp_child_group(&mut self.child);
self.state = ServerState::Exited;
return Err(LspError::Timeout(format!(
"timed out waiting for {:?} to exit",
self.kind
)));
}
thread::sleep(EXIT_POLL_INTERVAL);
}
}
pub fn state(&self) -> ServerState {
self.state
}
pub fn kind(&self) -> ServerKind {
self.kind.clone()
}
pub fn root(&self) -> &Path {
&self.root
}
fn ensure_can_send(&self) -> Result<(), LspError> {
if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
return Err(LspError::ServerNotReady(format!(
"language server {:?} is not ready (state: {:?})",
self.kind, self.state
)));
}
Ok(())
}
fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
self.pending
.lock()
.map_err(|_| io::Error::other("pending response map poisoned").into())
}
fn remove_pending(&self, id: &RequestId) {
if let Ok(mut pending) = self.pending.lock() {
pending.remove(id);
}
}
}
impl Drop for LspClient {
fn drop(&mut self) {
self.child_registry.untrack(self.child_pid);
kill_lsp_child_group(&mut self.child);
}
}
fn kill_lsp_child_group(child: &mut std::process::Child) {
#[cfg(unix)]
{
let pgid = child.id() as i32;
crate::bash_background::process::terminate_pgid(pgid, Some(child));
let _ = child.wait();
}
#[cfg(not(unix))]
{
crate::bash_background::process::terminate_process(child);
let _ = child.wait();
}
}
fn normalize_windows_path(path: &Path) -> PathBuf {
let s = path.to_string_lossy();
if let Some(stripped) = s.strip_prefix(r"\\?\") {
PathBuf::from(stripped)
} else {
path.to_path_buf()
}
}
fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
let mut caps = ServerDiagnosticCapabilities::default();
if let Some(provider) = value.get("diagnosticProvider") {
if provider.is_object() || provider.as_bool() == Some(true) {
caps.pull_diagnostics = true;
}
if let Some(obj) = provider.as_object() {
if obj
.get("workspaceDiagnostics")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
caps.workspace_diagnostics = true;
}
if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
caps.identifier = Some(identifier.to_string());
}
}
}
if let Some(refresh) = value
.get("workspace")
.and_then(|w| w.get("diagnostic"))
.and_then(|d| d.get("refreshSupport"))
.and_then(|r| r.as_bool())
{
caps.refresh_support = refresh;
}
caps
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_caps_no_diagnostic_provider() {
let value = json!({});
let caps = parse_diagnostic_capabilities(&value);
assert!(!caps.pull_diagnostics);
assert!(!caps.workspace_diagnostics);
assert!(caps.identifier.is_none());
}
#[test]
fn parse_caps_basic_pull_only() {
let value = json!({
"diagnosticProvider": {
"interFileDependencies": false,
"workspaceDiagnostics": false
}
});
let caps = parse_diagnostic_capabilities(&value);
assert!(caps.pull_diagnostics);
assert!(!caps.workspace_diagnostics);
}
#[test]
fn parse_caps_full_pull_with_workspace() {
let value = json!({
"diagnosticProvider": {
"interFileDependencies": true,
"workspaceDiagnostics": true,
"identifier": "rust-analyzer"
}
});
let caps = parse_diagnostic_capabilities(&value);
assert!(caps.pull_diagnostics);
assert!(caps.workspace_diagnostics);
assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
}
#[test]
fn parse_caps_provider_as_bare_true() {
let value = json!({
"diagnosticProvider": true
});
let caps = parse_diagnostic_capabilities(&value);
assert!(caps.pull_diagnostics);
assert!(!caps.workspace_diagnostics);
}
#[test]
fn parse_caps_workspace_refresh_support() {
let value = json!({
"workspace": {
"diagnostic": {
"refreshSupport": true
}
}
});
let caps = parse_diagnostic_capabilities(&value);
assert!(caps.refresh_support);
assert!(!caps.pull_diagnostics);
}
}