use anyhow::{Context, Result, anyhow};
use bytes::BytesMut;
use lsp_types::{
CallHierarchyIncomingCall, CallHierarchyIncomingCallsParams, CallHierarchyItem,
CallHierarchyOutgoingCall, CallHierarchyOutgoingCallsParams, CallHierarchyPrepareParams,
ClientCapabilities, CodeActionParams, CodeActionResponse, CompletionParams, CompletionResponse,
Diagnostic, DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams,
DocumentFormattingParams, DocumentRangeFormattingParams, DocumentSymbolParams,
DocumentSymbolResponse, GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverParams,
InitializeParams, InitializeResult, InitializedParams, PublishDiagnosticsParams,
ReferenceParams, RenameParams, SignatureHelp, SignatureHelpParams, TextEdit, TypeHierarchyItem,
TypeHierarchyPrepareParams, TypeHierarchySubtypesParams, TypeHierarchySupertypesParams, Uri,
WorkspaceEdit, WorkspaceFolder, WorkspaceSymbolParams, WorkspaceSymbolResponse,
};
use std::collections::HashMap;
use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::{Mutex, oneshot};
use tracing::{debug, error, trace, warn};
use super::protocol::{self, NotificationMessage, RequestId, RequestMessage, ResponseMessage};
pub type DiagnosticsCache = Arc<Mutex<HashMap<Uri, Vec<Diagnostic>>>>;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
pub struct LspClient {
next_id: AtomicI64,
stdin: Arc<Mutex<ChildStdin>>,
pending: Arc<Mutex<HashMap<RequestId, oneshot::Sender<ResponseMessage>>>>,
diagnostics: DiagnosticsCache,
alive: Arc<AtomicBool>,
_reader_handle: tokio::task::JoinHandle<()>,
_child: Child,
}
impl LspClient {
pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
let mut child = Command::new(program)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()
.with_context(|| format!("Failed to spawn LSP server: {}", program))?;
let stdin = child.stdin.take().expect("stdin not captured");
let stdout = child.stdout.take().expect("stdout not captured");
let stdin = Arc::new(Mutex::new(stdin));
let pending: Arc<Mutex<HashMap<RequestId, oneshot::Sender<ResponseMessage>>>> =
Arc::new(Mutex::new(HashMap::new()));
let diagnostics: DiagnosticsCache = Arc::new(Mutex::new(HashMap::new()));
let alive = Arc::new(AtomicBool::new(true));
let reader_handle = tokio::spawn(Self::reader_task(
stdout,
pending.clone(),
diagnostics.clone(),
alive.clone(),
));
Ok(Self {
next_id: AtomicI64::new(1),
stdin,
pending,
diagnostics,
alive,
_reader_handle: reader_handle,
_child: child,
})
}
async fn reader_task(
stdout: ChildStdout,
pending: Arc<Mutex<HashMap<RequestId, oneshot::Sender<ResponseMessage>>>>,
diagnostics: DiagnosticsCache,
alive: Arc<AtomicBool>,
) {
let mut reader = BufReader::new(stdout);
let mut buffer = BytesMut::with_capacity(8192);
loop {
let mut temp = [0u8; 4096];
match reader.read(&mut temp).await {
Ok(0) => {
debug!("LSP stdout closed");
break;
}
Ok(n) => {
buffer.extend_from_slice(&temp[..n]);
}
Err(e) => {
error!("Error reading from LSP stdout: {}", e);
break;
}
}
while let Ok(Some(message_str)) = protocol::try_parse_message(&mut buffer) {
trace!("Received LSP message: {}", message_str);
if let Ok(response) = serde_json::from_str::<ResponseMessage>(&message_str) {
if let Some(id) = &response.id {
let mut pending = pending.lock().await;
if let Some(sender) = pending.remove(id) {
let _ = sender.send(response);
} else {
warn!("Received response for unknown request id: {:?}", id);
}
}
continue;
}
if let Ok(notification) = serde_json::from_str::<NotificationMessage>(&message_str)
{
Self::handle_notification(¬ification, &diagnostics).await;
continue;
}
warn!("Could not parse LSP message: {}", message_str);
}
}
alive.store(false, Ordering::SeqCst);
warn!("LSP reader task exiting - server connection lost");
}
async fn handle_notification(
notification: &NotificationMessage,
diagnostics: &DiagnosticsCache,
) {
match notification.method.as_str() {
"textDocument/publishDiagnostics" => {
if let Ok(params) =
serde_json::from_value::<PublishDiagnosticsParams>(notification.params.clone())
{
debug!(
"Received {} diagnostics for {:?}",
params.diagnostics.len(),
params.uri.as_str()
);
let mut cache = diagnostics.lock().await;
cache.insert(params.uri, params.diagnostics);
} else {
warn!("Failed to parse publishDiagnostics params");
}
}
"window/logMessage" | "window/showMessage" => {
if let Some(message) = notification.params.get("message").and_then(|m| m.as_str()) {
debug!("LSP server message: {}", message);
}
}
_ => {
trace!(
"Ignoring notification: {} params={}",
notification.method, notification.params
);
}
}
}
async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: P,
) -> Result<R> {
let id = RequestId::Number(self.next_id.fetch_add(1, Ordering::SeqCst));
let request = RequestMessage {
jsonrpc: "2.0".to_string(),
id: id.clone(),
method: method.to_string(),
params: serde_json::to_value(params)?,
};
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(id.clone(), tx);
}
self.send_message(&request).await?;
let response = match tokio::time::timeout(REQUEST_TIMEOUT, rx).await {
Ok(Ok(response)) => response,
Ok(Err(_)) => {
return Err(anyhow!("LSP server closed connection"));
}
Err(_) => {
let mut pending = self.pending.lock().await;
pending.remove(&id);
return Err(anyhow!(
"LSP request '{}' timed out after {:?}",
method,
REQUEST_TIMEOUT
));
}
};
if let Some(error) = response.error {
return Err(anyhow!("LSP error {}: {}", error.code, error.message));
}
let result = response.result.unwrap_or(serde_json::Value::Null);
serde_json::from_value(result).context("Failed to parse LSP response")
}
async fn notify<P: serde::Serialize>(&self, method: &str, params: P) -> Result<()> {
let notification = NotificationMessage {
jsonrpc: "2.0".to_string(),
method: method.to_string(),
params: serde_json::to_value(params)?,
};
self.send_message(¬ification).await
}
async fn send_message<T: serde::Serialize>(&self, message: &T) -> Result<()> {
let body = serde_json::to_string(message)?;
let header = format!("Content-Length: {}\r\n\r\n", body.len());
trace!("Sending LSP message: {}", body);
let mut stdin = self.stdin.lock().await;
stdin.write_all(header.as_bytes()).await?;
stdin.write_all(body.as_bytes()).await?;
stdin.flush().await?;
Ok(())
}
pub async fn initialize(&mut self, root: &Path) -> Result<InitializeResult> {
let root_uri: Uri = format!("file://{}", root.display())
.parse()
.map_err(|e| anyhow!("Invalid root path {:?}: {}", root, e))?;
let params = InitializeParams {
process_id: Some(std::process::id()),
capabilities: ClientCapabilities {
..Default::default()
},
workspace_folders: Some(vec![WorkspaceFolder {
uri: root_uri,
name: root
.file_name()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "workspace".to_string()),
}]),
..Default::default()
};
let result: InitializeResult = self.request("initialize", params).await?;
self.notify("initialized", InitializedParams {}).await?;
Ok(result)
}
pub async fn shutdown(&mut self) -> Result<()> {
let _: serde_json::Value = self.request("shutdown", serde_json::Value::Null).await?;
self.notify("exit", serde_json::Value::Null).await?;
Ok(())
}
pub async fn did_open(&self, params: DidOpenTextDocumentParams) -> Result<()> {
self.notify("textDocument/didOpen", params).await
}
pub async fn did_change(&self, params: DidChangeTextDocumentParams) -> Result<()> {
self.notify("textDocument/didChange", params).await
}
pub async fn did_close(&self, params: DidCloseTextDocumentParams) -> Result<()> {
self.notify("textDocument/didClose", params).await
}
pub async fn hover(&self, params: HoverParams) -> Result<Option<Hover>> {
self.request("textDocument/hover", params).await
}
pub async fn definition(
&self,
params: GotoDefinitionParams,
) -> Result<Option<GotoDefinitionResponse>> {
self.request("textDocument/definition", params).await
}
pub async fn type_definition(
&self,
params: GotoDefinitionParams,
) -> Result<Option<GotoDefinitionResponse>> {
self.request("textDocument/typeDefinition", params).await
}
pub async fn implementation(
&self,
params: GotoDefinitionParams,
) -> Result<Option<GotoDefinitionResponse>> {
self.request("textDocument/implementation", params).await
}
pub async fn references(
&self,
params: ReferenceParams,
) -> Result<Option<Vec<lsp_types::Location>>> {
self.request("textDocument/references", params).await
}
pub async fn document_symbols(
&self,
params: DocumentSymbolParams,
) -> Result<Option<DocumentSymbolResponse>> {
self.request("textDocument/documentSymbol", params).await
}
pub async fn workspace_symbols(
&self,
params: WorkspaceSymbolParams,
) -> Result<Option<WorkspaceSymbolResponse>> {
self.request("workspace/symbol", params).await
}
pub async fn code_actions(
&self,
params: CodeActionParams,
) -> Result<Option<CodeActionResponse>> {
self.request("textDocument/codeAction", params).await
}
pub async fn rename(&self, params: RenameParams) -> Result<Option<WorkspaceEdit>> {
self.request("textDocument/rename", params).await
}
pub async fn completion(&self, params: CompletionParams) -> Result<Option<CompletionResponse>> {
self.request("textDocument/completion", params).await
}
pub async fn signature_help(
&self,
params: SignatureHelpParams,
) -> Result<Option<SignatureHelp>> {
self.request("textDocument/signatureHelp", params).await
}
pub async fn formatting(
&self,
params: DocumentFormattingParams,
) -> Result<Option<Vec<TextEdit>>> {
self.request("textDocument/formatting", params).await
}
pub async fn range_formatting(
&self,
params: DocumentRangeFormattingParams,
) -> Result<Option<Vec<TextEdit>>> {
self.request("textDocument/rangeFormatting", params).await
}
pub async fn prepare_call_hierarchy(
&self,
params: CallHierarchyPrepareParams,
) -> Result<Option<Vec<CallHierarchyItem>>> {
self.request("textDocument/prepareCallHierarchy", params)
.await
}
pub async fn incoming_calls(
&self,
params: CallHierarchyIncomingCallsParams,
) -> Result<Option<Vec<CallHierarchyIncomingCall>>> {
self.request("callHierarchy/incomingCalls", params).await
}
pub async fn outgoing_calls(
&self,
params: CallHierarchyOutgoingCallsParams,
) -> Result<Option<Vec<CallHierarchyOutgoingCall>>> {
self.request("callHierarchy/outgoingCalls", params).await
}
pub async fn prepare_type_hierarchy(
&self,
params: TypeHierarchyPrepareParams,
) -> Result<Option<Vec<TypeHierarchyItem>>> {
self.request("textDocument/prepareTypeHierarchy", params)
.await
}
pub async fn supertypes(
&self,
params: TypeHierarchySupertypesParams,
) -> Result<Option<Vec<TypeHierarchyItem>>> {
self.request("typeHierarchy/supertypes", params).await
}
pub async fn subtypes(
&self,
params: TypeHierarchySubtypesParams,
) -> Result<Option<Vec<TypeHierarchyItem>>> {
self.request("typeHierarchy/subtypes", params).await
}
pub async fn get_diagnostics(&self, uri: &Uri) -> Vec<Diagnostic> {
let cache = self.diagnostics.lock().await;
cache.get(uri).cloned().unwrap_or_default()
}
pub fn is_alive(&self) -> bool {
self.alive.load(Ordering::SeqCst)
}
}