use crate::language::language_id_for_path;
use crate::server::LspServerHandle;
use crate::types::*;
use anyhow::{anyhow, Context, Result};
use lsp_types::request::{
CallHierarchyIncomingCalls, CallHierarchyOutgoingCalls, CallHierarchyPrepare,
DocumentSymbolRequest, GotoDefinition, GotoImplementation, HoverRequest, References,
WorkspaceSymbolRequest,
};
use lsp_types::{
CallHierarchyIncomingCallsParams, CallHierarchyOutgoingCallsParams, ClientCapabilities,
DidChangeTextDocumentParams, DidOpenTextDocumentParams, DocumentSymbolParams,
GotoDefinitionParams, HoverParams, InitializeParams, InitializedParams, PartialResultParams,
ReferenceContext, ReferenceParams, TextDocumentContentChangeEvent, TextDocumentItem,
VersionedTextDocumentIdentifier, WorkDoneProgressParams, WorkspaceSymbolParams,
};
use lsp_types::GotoDefinitionParams as GotoImplementationParams;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::HashMap;
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
fn path_to_uri(path: &Path) -> Result<Uri> {
let url = Url::from_file_path(path).map_err(|_| anyhow!("Invalid file path: {:?}", path))?;
url.as_str()
.parse()
.map_err(|e| anyhow!("Failed to parse URI: {}", e))
}
fn uri_to_path(uri: &Uri) -> Option<PathBuf> {
let url: url::Url = uri.as_str().parse().ok()?;
url.to_file_path().ok()
}
use std::process::{Child, ChildStdin, ChildStdout};
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, info, trace, warn};
const REQUEST_TIMEOUT_MS: u64 = 10_000;
const INIT_TIMEOUT_MS: u64 = 45_000;
#[derive(Debug, Serialize)]
struct JsonRpcRequest<T: Serialize> {
jsonrpc: &'static str,
id: i64,
method: &'static str,
params: T,
}
#[derive(Debug, Serialize)]
struct JsonRpcNotification<T: Serialize> {
jsonrpc: &'static str,
method: &'static str,
params: T,
}
#[derive(Debug, Deserialize)]
struct JsonRpcResponse<T> {
#[allow(dead_code)]
jsonrpc: String,
id: Option<i64>,
result: Option<T>,
error: Option<JsonRpcError>,
}
#[derive(Debug, Deserialize)]
struct JsonRpcError {
code: i64,
message: String,
}
#[derive(Debug, Deserialize)]
struct JsonRpcMessage {
#[allow(dead_code)]
jsonrpc: String,
id: Option<i64>,
method: Option<String>,
#[serde(default)]
params: serde_json::Value,
result: Option<serde_json::Value>,
error: Option<JsonRpcError>,
}
struct PendingRequest {
sender: oneshot::Sender<Result<serde_json::Value>>,
}
pub struct LspClient {
server_id: String,
root: PathBuf,
request_id: AtomicI64,
pending: Arc<RwLock<HashMap<i64, PendingRequest>>>,
writer: Arc<Mutex<ChildStdin>>,
diagnostics: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>>,
file_versions: Arc<RwLock<HashMap<PathBuf, i32>>>,
shutdown_tx: Option<mpsc::Sender<()>>,
#[allow(dead_code)]
process: Child,
}
impl LspClient {
pub async fn new(
server_id: impl Into<String>,
mut handle: LspServerHandle,
root: PathBuf,
) -> Result<Self> {
let server_id = server_id.into();
info!(server_id = %server_id, root = ?root, "Initializing LSP client");
let stdin = handle
.process
.stdin
.take()
.context("Failed to get stdin")?;
let stdout = handle
.process
.stdout
.take()
.context("Failed to get stdout")?;
let pending: Arc<RwLock<HashMap<i64, PendingRequest>>> =
Arc::new(RwLock::new(HashMap::new()));
let diagnostics: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>> =
Arc::new(RwLock::new(HashMap::new()));
let file_versions: Arc<RwLock<HashMap<PathBuf, i32>>> =
Arc::new(RwLock::new(HashMap::new()));
let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
let pending_clone = pending.clone();
let diagnostics_clone = diagnostics.clone();
let server_id_clone = server_id.clone();
std::thread::spawn(move || {
Self::reader_loop(stdout, pending_clone, diagnostics_clone, server_id_clone);
});
let writer = Arc::new(Mutex::new(stdin));
let mut client = Self {
server_id,
root: root.clone(),
request_id: AtomicI64::new(1),
pending,
writer,
diagnostics,
file_versions,
shutdown_tx: Some(shutdown_tx),
process: handle.process,
};
client.initialize(&root, handle.initialization).await?;
Ok(client)
}
pub fn server_id(&self) -> &str {
&self.server_id
}
pub fn root(&self) -> &Path {
&self.root
}
pub fn diagnostics(&self) -> HashMap<PathBuf, Vec<Diagnostic>> {
self.diagnostics.read().unwrap().clone()
}
pub fn diagnostics_for_file(&self, path: &Path) -> Vec<Diagnostic> {
self.diagnostics
.read()
.unwrap()
.get(path)
.cloned()
.unwrap_or_default()
}
async fn initialize(
&mut self,
root: &Path,
initialization_options: Option<serde_json::Value>,
) -> Result<()> {
let root_uri = path_to_uri(root)?;
let params = InitializeParams {
process_id: Some(std::process::id()),
root_uri: Some(root_uri.clone()),
root_path: None,
initialization_options,
capabilities: Self::client_capabilities(),
trace: None,
workspace_folders: Some(vec![lsp_types::WorkspaceFolder {
uri: root_uri,
name: root
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("workspace")
.to_string(),
}]),
client_info: Some(lsp_types::ClientInfo {
name: "codive-lsp".to_string(),
version: Some(env!("CARGO_PKG_VERSION").to_string()),
}),
locale: None,
work_done_progress_params: WorkDoneProgressParams::default(),
};
let _result: lsp_types::InitializeResult =
self.request::<lsp_types::request::Initialize>(params).await?;
self.notify::<lsp_types::notification::Initialized>(InitializedParams {})?;
info!(server_id = %self.server_id, "LSP server initialized");
Ok(())
}
pub async fn open_file(&self, path: &Path) -> Result<()> {
let path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
let existing_version = {
let versions = self.file_versions.read().unwrap();
versions.get(&path).copied()
};
let content = tokio::fs::read_to_string(&path).await?;
if let Some(version) = existing_version {
return self.change_file(&path, &content, version + 1);
}
let uri = path_to_uri(&path)?;
let language_id = language_id_for_path(&path);
debug!(path = ?path, language_id = %language_id, "Opening file in LSP");
let params = DidOpenTextDocumentParams {
text_document: TextDocumentItem {
uri,
language_id: language_id.to_string(),
version: 0,
text: content,
},
};
self.notify::<lsp_types::notification::DidOpenTextDocument>(params)?;
self.file_versions.write().unwrap().insert(path, 0);
Ok(())
}
fn change_file(&self, path: &Path, content: &str, version: i32) -> Result<()> {
let uri = path_to_uri(path)?;
let params = DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier { uri, version },
content_changes: vec![TextDocumentContentChangeEvent {
range: None,
range_length: None,
text: content.to_string(),
}],
};
self.notify::<lsp_types::notification::DidChangeTextDocument>(params)?;
self.file_versions.write().unwrap().insert(path.to_path_buf(), version);
Ok(())
}
pub async fn hover(&self, path: &Path, line: u32, character: u32) -> Result<Option<Hover>> {
let uri = path_to_uri(path)?;
let params = HoverParams {
text_document_position_params: TextDocumentPositionParams {
text_document: TextDocumentIdentifier { uri },
position: Position { line, character },
},
work_done_progress_params: WorkDoneProgressParams::default(),
};
self.request::<HoverRequest>(params).await
}
pub async fn definition(&self, path: &Path, line: u32, character: u32) -> Result<Vec<Location>> {
let uri = path_to_uri(path)?;
let params = GotoDefinitionParams {
text_document_position_params: TextDocumentPositionParams {
text_document: TextDocumentIdentifier { uri },
position: Position { line, character },
},
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
};
let result: Option<GotoDefinitionResponse> =
self.request::<GotoDefinition>(params).await?;
Ok(match result {
Some(GotoDefinitionResponse::Scalar(loc)) => vec![loc],
Some(GotoDefinitionResponse::Array(locs)) => locs,
Some(GotoDefinitionResponse::Link(links)) => links
.into_iter()
.map(|l| Location {
uri: l.target_uri,
range: l.target_selection_range,
})
.collect(),
None => vec![],
})
}
pub async fn references(
&self,
path: &Path,
line: u32,
character: u32,
include_declaration: bool,
) -> Result<Vec<Location>> {
let uri = path_to_uri(path)?;
let params = ReferenceParams {
text_document_position: TextDocumentPositionParams {
text_document: TextDocumentIdentifier { uri },
position: Position { line, character },
},
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
context: ReferenceContext {
include_declaration,
},
};
let result: Option<Vec<Location>> = self.request::<References>(params).await?;
Ok(result.unwrap_or_default())
}
pub async fn implementation(
&self,
path: &Path,
line: u32,
character: u32,
) -> Result<Vec<Location>> {
let uri = path_to_uri(path)?;
let params = GotoImplementationParams {
text_document_position_params: TextDocumentPositionParams {
text_document: TextDocumentIdentifier { uri },
position: Position { line, character },
},
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
};
let result: Option<GotoDefinitionResponse> =
self.request::<GotoImplementation>(params).await?;
Ok(match result {
Some(GotoDefinitionResponse::Scalar(loc)) => vec![loc],
Some(GotoDefinitionResponse::Array(locs)) => locs,
Some(GotoDefinitionResponse::Link(links)) => links
.into_iter()
.map(|l| Location {
uri: l.target_uri,
range: l.target_selection_range,
})
.collect(),
None => vec![],
})
}
pub async fn document_symbols(&self, path: &Path) -> Result<DocumentSymbolResponse> {
let uri = path_to_uri(path)?;
let params = DocumentSymbolParams {
text_document: TextDocumentIdentifier { uri },
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
};
let result: Option<DocumentSymbolResponse> =
self.request::<DocumentSymbolRequest>(params).await?;
Ok(result.unwrap_or(DocumentSymbolResponse::Flat(vec![])))
}
pub async fn workspace_symbols(&self, query: &str) -> Result<Vec<SymbolInformation>> {
let params = WorkspaceSymbolParams {
query: query.to_string(),
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
};
let result: Option<WorkspaceSymbolResponse> =
self.request::<WorkspaceSymbolRequest>(params).await?;
Ok(match result {
Some(WorkspaceSymbolResponse::Flat(symbols)) => symbols,
Some(WorkspaceSymbolResponse::Nested(symbols)) => {
symbols
.into_iter()
.filter_map(|s| {
let location = match s.location {
lsp_types::OneOf::Left(loc) => loc,
lsp_types::OneOf::Right(doc_id) => Location {
uri: doc_id.uri,
range: Range::default(),
},
};
Some(SymbolInformation {
name: s.name,
kind: s.kind,
tags: s.tags,
deprecated: None,
location,
container_name: s.container_name,
})
})
.collect()
}
None => vec![],
})
}
pub async fn prepare_call_hierarchy(
&self,
path: &Path,
line: u32,
character: u32,
) -> Result<Vec<CallHierarchyItem>> {
let uri = path_to_uri(path)?;
let params = lsp_types::CallHierarchyPrepareParams {
text_document_position_params: TextDocumentPositionParams {
text_document: TextDocumentIdentifier { uri },
position: Position { line, character },
},
work_done_progress_params: WorkDoneProgressParams::default(),
};
let result: Option<Vec<CallHierarchyItem>> =
self.request::<CallHierarchyPrepare>(params).await?;
Ok(result.unwrap_or_default())
}
pub async fn incoming_calls(
&self,
item: CallHierarchyItem,
) -> Result<Vec<CallHierarchyIncomingCall>> {
let params = CallHierarchyIncomingCallsParams {
item,
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
};
let result: Option<Vec<CallHierarchyIncomingCall>> =
self.request::<CallHierarchyIncomingCalls>(params).await?;
Ok(result.unwrap_or_default())
}
pub async fn outgoing_calls(
&self,
item: CallHierarchyItem,
) -> Result<Vec<CallHierarchyOutgoingCall>> {
let params = CallHierarchyOutgoingCallsParams {
item,
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
};
let result: Option<Vec<CallHierarchyOutgoingCall>> =
self.request::<CallHierarchyOutgoingCalls>(params).await?;
Ok(result.unwrap_or_default())
}
pub async fn shutdown(mut self) {
info!(server_id = %self.server_id, "Shutting down LSP client");
let _ = self.request::<lsp_types::request::Shutdown>(()).await;
let _ = self.notify::<lsp_types::notification::Exit>(());
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(()).await;
}
}
async fn request<R>(&self, params: R::Params) -> Result<R::Result>
where
R: lsp_types::request::Request,
R::Params: Serialize,
R::Result: DeserializeOwned,
{
let id = self.request_id.fetch_add(1, Ordering::SeqCst);
let request = JsonRpcRequest {
jsonrpc: "2.0",
id,
method: R::METHOD,
params,
};
let message = serde_json::to_string(&request)?;
let header = format!("Content-Length: {}\r\n\r\n", message.len());
trace!(id = id, method = R::METHOD, "Sending LSP request");
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.write().unwrap();
pending.insert(id, PendingRequest { sender: tx });
}
{
let mut writer = self.writer.lock().unwrap();
writer.write_all(header.as_bytes())?;
writer.write_all(message.as_bytes())?;
writer.flush()?;
}
let result = tokio::time::timeout(
std::time::Duration::from_millis(REQUEST_TIMEOUT_MS),
rx,
)
.await
.map_err(|_| anyhow!("LSP request timed out"))??;
let value = result?;
let result: R::Result = serde_json::from_value(value)?;
Ok(result)
}
fn notify<N>(&self, params: N::Params) -> Result<()>
where
N: lsp_types::notification::Notification,
N::Params: Serialize,
{
let notification = JsonRpcNotification {
jsonrpc: "2.0",
method: N::METHOD,
params,
};
let message = serde_json::to_string(¬ification)?;
let header = format!("Content-Length: {}\r\n\r\n", message.len());
trace!(method = N::METHOD, "Sending LSP notification");
let mut writer = self.writer.lock().unwrap();
writer.write_all(header.as_bytes())?;
writer.write_all(message.as_bytes())?;
writer.flush()?;
Ok(())
}
fn reader_loop(
stdout: ChildStdout,
pending: Arc<RwLock<HashMap<i64, PendingRequest>>>,
diagnostics: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>>,
server_id: String,
) {
let mut reader = BufReader::new(stdout);
let mut headers = String::new();
loop {
headers.clear();
let mut content_length: Option<usize> = None;
loop {
let mut line = String::new();
match reader.read_line(&mut line) {
Ok(0) => {
debug!(server_id = %server_id, "LSP server stdout closed");
return;
}
Ok(_) => {
if line == "\r\n" {
break;
}
if line.to_lowercase().starts_with("content-length:") {
if let Some(len_str) = line.split(':').nth(1) {
content_length = len_str.trim().parse().ok();
}
}
}
Err(e) => {
error!(server_id = %server_id, error = ?e, "Error reading from LSP server");
return;
}
}
}
let content_length = match content_length {
Some(len) => len,
None => {
warn!(server_id = %server_id, "No Content-Length header");
continue;
}
};
let mut content = vec![0u8; content_length];
if let Err(e) = std::io::Read::read_exact(&mut reader, &mut content) {
error!(server_id = %server_id, error = ?e, "Error reading LSP content");
continue;
}
let message: JsonRpcMessage = match serde_json::from_slice(&content) {
Ok(msg) => msg,
Err(e) => {
warn!(server_id = %server_id, error = ?e, "Failed to parse LSP message");
continue;
}
};
if let Some(id) = message.id {
if message.method.is_none() {
let mut pending = pending.write().unwrap();
if let Some(req) = pending.remove(&id) {
let result = if let Some(error) = message.error {
Err(anyhow!("LSP error {}: {}", error.code, error.message))
} else {
Ok(message.result.unwrap_or(serde_json::Value::Null))
};
let _ = req.sender.send(result);
}
continue;
}
}
if let Some(method) = &message.method {
match method.as_str() {
"textDocument/publishDiagnostics" => {
if let Ok(params) =
serde_json::from_value::<lsp_types::PublishDiagnosticsParams>(
message.params,
)
{
if let Some(path) = uri_to_path(¶ms.uri) {
debug!(
server_id = %server_id,
path = ?path,
count = params.diagnostics.len(),
"Received diagnostics"
);
diagnostics.write().unwrap().insert(path, params.diagnostics);
}
}
}
"window/logMessage" | "window/showMessage" => {
if let Ok(params) =
serde_json::from_value::<lsp_types::LogMessageParams>(message.params)
{
debug!(server_id = %server_id, message = %params.message, "LSP server message");
}
}
_ => {
trace!(server_id = %server_id, method = %method, "Unhandled LSP notification");
}
}
}
}
}
fn client_capabilities() -> ClientCapabilities {
ClientCapabilities {
text_document: Some(lsp_types::TextDocumentClientCapabilities {
synchronization: Some(lsp_types::TextDocumentSyncClientCapabilities {
dynamic_registration: Some(false),
will_save: Some(false),
will_save_wait_until: Some(false),
did_save: Some(true),
}),
hover: Some(lsp_types::HoverClientCapabilities {
dynamic_registration: Some(false),
content_format: Some(vec![MarkupKind::Markdown, MarkupKind::PlainText]),
}),
definition: Some(lsp_types::GotoCapability {
dynamic_registration: Some(false),
link_support: Some(true),
}),
references: Some(lsp_types::DynamicRegistrationClientCapabilities {
dynamic_registration: Some(false),
}),
implementation: Some(lsp_types::GotoCapability {
dynamic_registration: Some(false),
link_support: Some(true),
}),
document_symbol: Some(lsp_types::DocumentSymbolClientCapabilities {
dynamic_registration: Some(false),
symbol_kind: None,
hierarchical_document_symbol_support: Some(true),
tag_support: None,
}),
publish_diagnostics: Some(lsp_types::PublishDiagnosticsClientCapabilities {
related_information: Some(true),
tag_support: None,
version_support: Some(true),
code_description_support: Some(true),
data_support: Some(true),
}),
call_hierarchy: Some(lsp_types::CallHierarchyClientCapabilities {
dynamic_registration: Some(false),
}),
..Default::default()
}),
workspace: Some(lsp_types::WorkspaceClientCapabilities {
workspace_folders: Some(true),
symbol: Some(lsp_types::WorkspaceSymbolClientCapabilities {
dynamic_registration: Some(false),
symbol_kind: None,
tag_support: None,
resolve_support: None,
}),
..Default::default()
}),
window: Some(lsp_types::WindowClientCapabilities {
work_done_progress: Some(true),
show_message: None,
show_document: None,
}),
..Default::default()
}
}
}