use {
crate::{
connect::{
ipc::{Connection, IpcHandle},
lsp::{
request::Request,
response_pending::ResponsePending,
},
},
protocol::{
jsonrpc::{
self,
Id,
Message,
},
lsp::InitializeParams,
task::{
ServerState,
State,
},
},
},
dashmap::DashMap,
opentelemetry::trace::FutureExt,
smol::channel::Sender,
std::{
collections::HashMap,
sync::{
Arc,
Mutex,
atomic::{
AtomicU32,
Ordering,
},
},
},
};
mod connected;
pub mod errors;
mod id;
pub mod notification;
mod registry;
pub mod request;
mod response_pending;
mod subscription;
pub use {
connected::ConnectedClient,
id::{
ClientId,
ClientKind,
wants_lsp_notifications,
},
registry::{
ClientRegistry,
ConnectedClientInfo,
SendError,
SharedRegistry,
},
subscription::{
Subscriptions,
Topic,
},
};
pub struct DaemonConnection {
pub client: LspClient,
pub handle: IpcHandle,
pub config: crate::daemon::DaemonConfig,
}
impl DaemonConnection {
pub async fn connect(
config: crate::daemon::DaemonConfig,
version: &str,
) -> std::io::Result<Self> {
Self::connect_as(config, version, ClientKind::Cli, HashMap::new()).await
}
pub async fn connect_as(
config: crate::daemon::DaemonConfig,
version: &str,
client_kind: ClientKind,
metadata: HashMap<String, String>,
) -> std::io::Result<Self> {
let (connection, handle) = Connection::ipc_as(
&config.endpoint(),
version,
client_kind,
metadata,
)
.await?;
let client = LspClient::new(connection);
Ok(Self {
client,
handle,
config,
})
}
}
otel::tracer!(lsp_client);
#[derive(Debug, Clone)]
struct RecordedRequestResponse {
method: String,
request_id: Id,
request_params: serde_json::Value,
response: serde_json::Value,
}
#[derive(Debug, Clone)]
struct RecordedNotification {
method: String,
params: Option<serde_json::Value>,
}
struct MessageRecording {
requests_responses: Mutex<Vec<RecordedRequestResponse>>,
notifications_from_server: Mutex<Vec<RecordedNotification>>,
notifications_to_server: Mutex<Vec<RecordedNotification>>,
pending_requests: Mutex<HashMap<Id, (String, serde_json::Value)>>,
}
pub struct LspClientInner {
conn: Connection,
request_id: AtomicU32,
response_pending: Arc<ResponsePending>,
state: Arc<ServerState>,
recording: Option<Arc<MessageRecording>>,
default_retry_count: usize,
default_request_timeout: std::time::Duration,
wait_after_notification: Option<std::time::Duration>,
notification_waiters: Arc<DashMap<String, Vec<Sender<serde_json::Value>>>>,
client_id: Option<ClientId>,
}
pub struct LspClient {
inner: Arc<LspClientInner>,
}
impl LspClient {
pub fn new(conn: Connection) -> Self {
Self::new_with_options(conn, false, None, None)
}
pub fn new_test(conn: Connection) -> Self {
Self::new_with_options(
conn,
true,
Some(std::time::Duration::from_millis(1000)),
None,
)
}
pub(crate) fn new_with_options(
conn: Connection,
enable_recording: bool,
wait_after_notification: Option<std::time::Duration>,
client_id: Option<ClientId>,
) -> Self {
otel::span!(@LSP_CLIENT_TRACER, "laburnum.lsp_client.new", in |cx| {
let response_pending = Arc::new(ResponsePending::new());
let notification_waiters = Arc::new(DashMap::new());
let recording = if enable_recording {
Some(Arc::new(MessageRecording {
requests_responses: Mutex::new(Vec::new()),
notifications_from_server: Mutex::new(Vec::new()),
notifications_to_server: Mutex::new(Vec::new()),
pending_requests: Mutex::new(HashMap::new()),
}))
} else {
None
};
let receiver = conn.receiver.clone();
let response_pending_clone = response_pending.clone();
let recording_clone = recording.clone();
let notification_waiters_clone = notification_waiters.clone();
let spawn_cx = cx.clone();
smol::spawn(
async move {
use opentelemetry::trace::{SpanKind, FutureExt};
if let Some(recording) = recording_clone {
loop {
match receiver.recv().with_context(cx.clone()).await {
| Ok(Message::Response(response)) => {
let cx = otel::span!(^
@LSP_CLIENT_TRACER,
"lsp_client.receive_response",
kind = SpanKind::Consumer
);
async {
if let Some((method, params)) = recording
.pending_requests
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(response.id())
{
recording.requests_responses.lock().unwrap_or_else(|e| e.into_inner()).push(
RecordedRequestResponse {
method,
request_id: response.id().clone(),
request_params: params,
response: serde_json::to_value(&response)
.unwrap_or(serde_json::Value::Null),
},
);
}
response_pending_clone.insert(response);
}
.with_context(cx)
.await;
},
| Ok(Message::Notification(notification)) => {
let method = notification.method().to_string();
let cx = otel::span!(^
@LSP_CLIENT_TRACER,
"lsp_client.receive_notification",
kind = SpanKind::Consumer,
"rpc.method" = method.clone()
);
async {
recording.notifications_from_server.lock().unwrap_or_else(|e| e.into_inner()).push(
RecordedNotification {
method: notification.method().to_string(),
params: notification.params().cloned(),
},
);
if let Some(params) = notification.params() {
let method = notification.method();
if let Some(mut entry) =
notification_waiters_clone.get_mut(method)
{
let waiters: Vec<Sender<serde_json::Value>> =
std::mem::take(&mut *entry);
drop(entry);
for sender in waiters {
let _ = sender.try_send(params.clone());
}
notification_waiters_clone.remove(method);
}
}
}
.with_context(cx)
.await;
},
| Ok(_) => {},
| Err(_e) => {
response_pending_clone.close_all();
break;
},
}
}
} else {
loop {
match receiver.recv().with_current_context().await {
| Ok(Message::Response(response)) => {
let cx = otel::span!(^
@LSP_CLIENT_TRACER,
"lsp_client.receive_response",
kind = SpanKind::Consumer
);
async {
response_pending_clone.insert(response);
}
.with_context(cx)
.await;
},
| Ok(_) => {},
| Err(_e) => {
response_pending_clone.close_all();
break;
},
}
}
}
}
.with_context(spawn_cx.clone()),
)
.detach();
let default_timeout = if cfg!(test) || cfg!(feature = "test") {
std::time::Duration::from_secs(5)
} else {
std::time::Duration::from_secs(30)
};
let inner = Arc::new(LspClientInner {
conn,
request_id: AtomicU32::new(0),
response_pending,
state: Arc::new(ServerState::new()),
recording,
default_retry_count: 3,
default_request_timeout: default_timeout,
wait_after_notification,
notification_waiters,
client_id,
});
Self { inner }
})
}
#[allow(unused_variables)]
pub fn set_client_id(&self, id: ClientId) {
}
pub fn new_with_client_id(conn: Connection, client_id: ClientId) -> Self {
Self::new_with_options(conn, false, None, Some(client_id))
}
pub fn new_test_with_client_id(conn: Connection, client_id: ClientId) -> Self {
Self::new_with_options(
conn,
true,
Some(std::time::Duration::from_millis(1000)),
Some(client_id),
)
}
fn get_next_request_id(&self) -> Id {
let num = self.inner.request_id.fetch_add(1, Ordering::Relaxed);
Id::Number(num as i64)
}
async fn retry_with_backoff<T, F, Fut>(
&self,
mut operation: F,
retry_count: usize,
) -> Result<T, errors::LspClientError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, errors::LspClientError>>,
{
let mut attempts = 0;
let mut last_error: Option<errors::LspClientError> = None;
while attempts <= retry_count {
match operation().await {
| Ok(result) => return Ok(result),
| Err(e) => {
last_error = Some(e);
attempts += 1;
if attempts <= retry_count {
let base_delay_ms = 100;
let exponential_delay_ms = base_delay_ms * (1 << attempts);
let capped_delay_ms = exponential_delay_ms.min(3000);
smol::Timer::after(std::time::Duration::from_millis(
capped_delay_ms,
))
.await;
}
},
}
}
Err(last_error.unwrap_or(errors::LspClientError::ConnectionClosed))
}
pub fn is_initialized(&self) -> bool {
matches!(self.inner.state.get(), State::Initialized | State::ShutDown)
}
pub fn wait_after_notification(&self) -> Option<std::time::Duration> {
self.inner.wait_after_notification
}
pub async fn start(
&self,
params: InitializeParams,
) -> Result<crate::protocol::lsp::InitializeResult, errors::LspClientError>
{
otel::span!(@LSP_CLIENT_TRACER, "laburnum.lsp_client.start");
let result = self.initialize(params).await?;
self.initialized().await?;
Ok(result)
}
pub async fn stop(&self) -> Result<(), errors::LspClientError> {
self.shutdown().await?;
self.exit().await?;
Ok(())
}
pub async fn stop_test(
&self,
snapshot: &mut ferrotype::Ferrotype,
) -> Result<(), errors::LspClientError> {
otel::span!(@LSP_CLIENT_TRACER, "laburnum.lsp_client.stop_test");
let result = self.shutdown().await;
self.write_to_snapshot(snapshot);
result
}
pub async fn send_request<R: Request>(
&self,
params: R::Params,
) -> Result<R::Result, errors::LspClientError> {
if !self.is_initialized()
&& R::METHOD != "initialize"
&& R::METHOD != "workspace/executeCommand"
{
return Err(errors::LspClientError::NotInitialized);
}
let id = self.get_next_request_id();
let rx = self.inner.response_pending.wait(id.clone());
let params_value = serde_json::to_value(¶ms)?;
if let Some(recording) = &self.inner.recording {
recording
.pending_requests
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(id.clone(), (R::METHOD.to_string(), params_value.clone()));
}
let trace_ctx = crate::protocol::otel::TraceContext::from_current_span();
let mut request_builder = jsonrpc::Request::build(R::METHOD, id.clone())
.params(params_value)
.with_trace_context(trace_ctx);
if let Some(client_id) = self.inner.client_id {
request_builder = request_builder.client_id(client_id);
}
let request = request_builder.finish();
let sender = &self.inner.conn.sender;
let retry_count = self.inner.default_retry_count;
use opentelemetry::trace::{
FutureExt,
SpanKind,
};
let cx = otel::span!(^
@LSP_CLIENT_TRACER,
"lsp_client.send_request",
kind = SpanKind::Producer,
"rpc.method" = R::METHOD
);
self
.retry_with_backoff(
|| {
async {
sender
.send(Message::Request(request.clone()))
.with_current_context()
.await
.map_err(errors::LspClientError::SendFailed)
}
},
retry_count,
)
.with_context(cx)
.await?;
let timeout = self.inner.default_request_timeout;
let response = smol::future::or(
async {
rx.recv()
.with_current_context()
.await
.map_err(|_| errors::LspClientError::ConnectionClosed)
},
async {
smol::Timer::after(timeout).with_current_context().await;
Err(errors::LspClientError::Timeout(timeout))
},
)
.with_current_context()
.await?;
if let Some(result) = response.result() {
match R::METHOD {
| "shutdown" => self.inner.state.set(State::ShutDown),
| _ => {},
}
serde_json::from_value(result.clone())
.map_err(errors::LspClientError::DeserializationFailed)
} else if let Some(error) = response.error() {
Err(errors::LspClientError::JsonRpcError {
code: error.code.into(),
message: error.message.to_string(),
})
} else {
Err(errors::LspClientError::InvalidResponse)
}
}
pub async fn send_notification_with_retry<N: notification::Notification>(
&self,
params: N::Params,
retry_count: usize,
) -> Result<(), errors::LspClientError> {
let params_value = serde_json::to_value(¶ms).unwrap_or_default();
if let Some(recording) = &self.inner.recording {
recording.notifications_to_server.lock().unwrap_or_else(|e| e.into_inner()).push(
RecordedNotification {
method: N::METHOD.to_string(),
params: Some(params_value.clone()),
},
);
}
let trace_ctx = crate::protocol::otel::TraceContext::from_current_span();
let mut notification_builder = jsonrpc::Notification::build(N::METHOD)
.params(params_value)
.with_trace_context(trace_ctx);
if let Some(client_id) = self.inner.client_id {
notification_builder = notification_builder.client_id(client_id);
}
let notification = notification_builder.finish();
let sender = &self.inner.conn.sender;
use opentelemetry::trace::{
FutureExt,
SpanKind,
};
let cx = otel::span!(^
@LSP_CLIENT_TRACER,
"lsp_client.send_notification",
kind = SpanKind::Producer,
"rpc.method" = N::METHOD
);
self
.retry_with_backoff(
|| {
async {
sender
.send(Message::Notification(notification.clone()))
.with_current_context()
.await
.map_err(errors::LspClientError::SendFailed)
}
},
retry_count,
)
.with_context(cx)
.await?;
match N::METHOD {
| "initialized" => self.inner.state.set(State::Initialized),
| _ => {},
}
Ok(())
}
pub async fn send_notification<N: notification::Notification>(
&self,
params: N::Params,
) -> Result<(), errors::LspClientError> {
let retry_count = if N::METHOD == "exit" {
0
} else {
self.inner.default_retry_count
};
self
.send_notification_with_retry::<N>(params, retry_count)
.await
}
pub async fn wait_for_notification(
&self,
method: &str,
) -> Result<serde_json::Value, String> {
let _initial_count = if let Some(recording) = &self.inner.recording {
recording.notifications_from_server.lock().unwrap_or_else(|e| e.into_inner()).len()
} else {
0
};
let (sender, receiver) = smol::channel::bounded(1);
self
.inner
.notification_waiters
.entry(method.to_string())
.or_default()
.push(sender);
let timeout = std::time::Duration::from_secs(3);
match smol::future::or(
async {
receiver
.recv()
.await
.map_err(|_| "Channel closed".to_string())
},
async {
smol::Timer::after(timeout).await;
Err(format!("Timeout waiting for notification: {}", method))
},
)
.await
{
| Ok(value) => Ok(value),
| Err(_) => Err(format!("Timeout waiting for notification: {}", method)),
}
}
pub async fn wait_for_progress_end(
&self,
token_pattern: &str,
timeout_secs: u64,
) -> Result<(), String> {
use crate::protocol::lsp::{
ProgressParams,
ProgressParamsValue,
WorkDoneProgress,
};
let Some(recording) = &self.inner.recording else {
return Err("Recording not enabled".to_string());
};
let timeout = std::time::Duration::from_secs(timeout_secs);
let start = std::time::Instant::now();
let poll_interval = std::time::Duration::from_millis(50);
loop {
{
let notifications = recording.notifications_from_server.lock().unwrap_or_else(|e| e.into_inner());
for notif in notifications.iter().rev() {
if notif.method == "$/progress"
&& let Some(params) = ¬if.params
&& let Ok(progress) =
serde_json::from_value::<ProgressParams>(params.clone())
{
let token_str = match &progress.token {
| crate::protocol::lsp::NumberOrString::String(s) => s.as_str(),
| crate::protocol::lsp::NumberOrString::Number(_) => {
continue;
},
};
if token_str.contains(token_pattern)
&& let ProgressParamsValue::WorkDone(WorkDoneProgress::End(_)) =
progress.value
{
return Ok(());
}
}
}
}
if start.elapsed() >= timeout {
return Err(format!(
"Timeout waiting for progress end with token containing '{}'",
token_pattern
));
}
smol::Timer::after(poll_interval).await;
}
}
pub fn get_received_diagnostics(
&self,
) -> Vec<crate::protocol::lsp::PublishDiagnosticsParams> {
if let Some(recording) = &self.inner.recording {
let notifications = recording.notifications_from_server.lock().unwrap_or_else(|e| e.into_inner());
notifications
.iter()
.filter(|n| n.method == "textDocument/publishDiagnostics")
.filter_map(|n| {
n.params
.as_ref()
.and_then(|p| serde_json::from_value(p.clone()).ok())
})
.collect()
} else {
Vec::new()
}
}
pub fn get_received_progress_notifications(
&self,
) -> Vec<crate::protocol::lsp::ProgressParams> {
if let Some(recording) = &self.inner.recording {
let notifications = recording.notifications_from_server.lock().unwrap_or_else(|e| e.into_inner());
notifications
.iter()
.filter(|n| n.method == "$/progress")
.filter_map(|n| {
n.params
.as_ref()
.and_then(|p| serde_json::from_value(p.clone()).ok())
})
.collect()
} else {
Vec::new()
}
}
pub fn write_to_snapshot(&self, snapshot: &mut ferrotype::Ferrotype) {
fn redact_volatile_fields(
method: &str,
response: &serde_json::Value,
) -> serde_json::Value {
let mut value = response.clone();
if method == "initialize"
&& let Some(server_info) = value
.get_mut("result")
.and_then(|r| r.get_mut("serverInfo"))
.and_then(|si| si.as_object_mut())
&& server_info.contains_key("version")
{
server_info.insert(
"version".to_string(),
serde_json::Value::String("<redacted>".to_string()),
);
}
value
}
if let Some(recording) = &self.inner.recording {
let requests_responses = recording.requests_responses.lock().unwrap_or_else(|e| e.into_inner());
if !requests_responses.is_empty() {
let mut rr_output = String::new();
for rr in requests_responses.iter() {
if rr.method == "workspace/executeCommand"
&& let Some(cmd) =
rr.request_params.get("command").and_then(|v| v.as_str())
&& matches!(cmd, "laburnum/queryRecords" | "laburnum/dbStats")
{
continue;
}
rr_output.push_str(&format!(
"\n[{}] ({}) ->> {}\n\n",
rr.request_id,
rr.method,
serde_json::to_string_pretty(&rr.request_params)
.unwrap_or_default()
));
let response_for_snapshot = redact_volatile_fields(&rr.method, &rr.response);
rr_output.push_str(&format!(
"[{}] <<- {}\n\n",
rr.request_id,
serde_json::to_string_pretty(&response_for_snapshot).unwrap_or_default()
));
rr_output.push_str("---\n\n");
}
snapshot.add("Request/Response", rr_output);
}
}
}
}