1use std::collections::HashMap;
2use std::io;
3use std::io::ErrorKind;
4use std::path::{Path, PathBuf};
5use std::process::Stdio;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicI64, Ordering};
8use std::time::Duration;
9
10use lsp_types::{
11 CallHierarchyIncomingCall, CallHierarchyIncomingCallsParams, CallHierarchyItem, CallHierarchyOutgoingCall,
12 CallHierarchyOutgoingCallsParams, CallHierarchyPrepareParams, DocumentSymbolParams, DocumentSymbolResponse,
13 GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverParams, Location, PartialResultParams, Position,
14 PublishDiagnosticsParams, ReferenceContext, ReferenceParams, RenameParams, SymbolInformation,
15 TextDocumentIdentifier, TextDocumentPositionParams, Uri, WorkDoneProgressParams, WorkspaceEdit,
16 WorkspaceSymbolParams,
17};
18use serde::Serialize;
19use serde::de::DeserializeOwned;
20use serde_json::Value;
21use thiserror::Error;
22use tokio::io::{ReadHalf, WriteHalf};
23use tokio::net::UnixStream;
24use tokio::process::Command;
25use tokio::sync::{Mutex, oneshot};
26
27use crate::language_catalog::LanguageId;
28use crate::protocol::{DaemonRequest, DaemonResponse, InitializeRequest, read_frame, write_frame};
29use crate::socket_path::{ensure_socket_dir, log_file_path};
30
31#[doc = include_str!("docs/client_error.md")]
32#[derive(Debug, Error)]
33pub enum ClientError {
34 #[error("Failed to connect to daemon: {0}")]
35 ConnectionFailed(#[source] io::Error),
36
37 #[error("IO error: {0}")]
38 Io(#[from] io::Error),
39
40 #[error("Daemon error: {0}")]
41 DaemonError(String),
42
43 #[error("LSP error (code={code}): {message}")]
44 LspError { code: i32, message: String },
45
46 #[error("Failed to spawn daemon: {0}")]
47 SpawnFailed(#[source] io::Error),
48
49 #[error("Timeout waiting for daemon to start")]
50 SpawnTimeout,
51
52 #[error("Daemon binary not found: {0}")]
53 DaemonBinaryNotFound(String),
54
55 #[error("Protocol error: {0}")]
56 ProtocolError(String),
57
58 #[error("Initialization failed: {0}")]
59 InitializationFailed(String),
60}
61
62pub type ClientResult<T> = std::result::Result<T, ClientError>;
63
64#[doc = include_str!("docs/client.md")]
65pub struct LspClient {
66 writer: Mutex<WriteHalf<UnixStream>>,
67 pending: Arc<Mutex<HashMap<i64, oneshot::Sender<PendingResult>>>>,
68 next_id: AtomicI64,
69 reader_task: tokio::task::JoinHandle<()>,
70}
71
72impl LspClient {
73 pub async fn connect(workspace_root: &Path, language: LanguageId) -> ClientResult<Self> {
74 let socket_path = ensure_socket_dir(workspace_root, language).map_err(ClientError::Io)?;
75
76 match UnixStream::connect(&socket_path).await {
77 Ok(stream) => {
78 return Self::from_stream(stream, workspace_root, language).await;
79 }
80 Err(err) if err.kind() == ErrorKind::ConnectionRefused || err.kind() == ErrorKind::NotFound => {}
81 Err(err) => return Err(ClientError::ConnectionFailed(err)),
82 }
83
84 spawn_daemon(&socket_path).await?;
85 let stream = UnixStream::connect(&socket_path).await.map_err(ClientError::ConnectionFailed)?;
86 Self::from_stream(stream, workspace_root, language).await
87 }
88
89 pub async fn goto_definition(&self, uri: Uri, line: u32, character: u32) -> ClientResult<GotoDefinitionResponse> {
90 let params = GotoDefinitionParams {
91 text_document_position_params: TextDocumentPositionParams {
92 text_document: TextDocumentIdentifier { uri },
93 position: Position { line, character },
94 },
95 work_done_progress_params: WorkDoneProgressParams::default(),
96 partial_result_params: PartialResultParams::default(),
97 };
98 self.call("textDocument/definition", ¶ms, || GotoDefinitionResponse::Array(vec![])).await
99 }
100
101 pub async fn goto_implementation(
102 &self,
103 uri: Uri,
104 line: u32,
105 character: u32,
106 ) -> ClientResult<GotoDefinitionResponse> {
107 let params = GotoDefinitionParams {
108 text_document_position_params: TextDocumentPositionParams {
109 text_document: TextDocumentIdentifier { uri },
110 position: Position { line, character },
111 },
112 work_done_progress_params: WorkDoneProgressParams::default(),
113 partial_result_params: PartialResultParams::default(),
114 };
115 self.call("textDocument/implementation", ¶ms, || GotoDefinitionResponse::Array(vec![])).await
116 }
117
118 pub async fn find_references(
119 &self,
120 uri: Uri,
121 line: u32,
122 character: u32,
123 include_declaration: bool,
124 ) -> ClientResult<Vec<Location>> {
125 let params = ReferenceParams {
126 text_document_position: TextDocumentPositionParams {
127 text_document: TextDocumentIdentifier { uri },
128 position: Position { line, character },
129 },
130 work_done_progress_params: WorkDoneProgressParams::default(),
131 partial_result_params: PartialResultParams::default(),
132 context: ReferenceContext { include_declaration },
133 };
134 self.call("textDocument/references", ¶ms, Vec::new).await
135 }
136
137 pub async fn hover(&self, uri: Uri, line: u32, character: u32) -> ClientResult<Option<Hover>> {
138 let params = HoverParams {
139 text_document_position_params: TextDocumentPositionParams {
140 text_document: TextDocumentIdentifier { uri },
141 position: Position { line, character },
142 },
143 work_done_progress_params: WorkDoneProgressParams::default(),
144 };
145 self.call("textDocument/hover", ¶ms, || None).await
146 }
147
148 pub async fn workspace_symbol(&self, query: String) -> ClientResult<Vec<SymbolInformation>> {
149 let params = WorkspaceSymbolParams {
150 query,
151 partial_result_params: PartialResultParams::default(),
152 work_done_progress_params: WorkDoneProgressParams::default(),
153 };
154 self.call("workspace/symbol", ¶ms, Vec::new).await
155 }
156
157 pub async fn document_symbol(&self, uri: Uri) -> ClientResult<DocumentSymbolResponse> {
158 let params = DocumentSymbolParams {
159 text_document: TextDocumentIdentifier { uri },
160 work_done_progress_params: WorkDoneProgressParams::default(),
161 partial_result_params: PartialResultParams::default(),
162 };
163 self.call("textDocument/documentSymbol", ¶ms, || DocumentSymbolResponse::Flat(vec![])).await
164 }
165
166 pub async fn prepare_call_hierarchy(
167 &self,
168 uri: Uri,
169 line: u32,
170 character: u32,
171 ) -> ClientResult<Vec<CallHierarchyItem>> {
172 let params = CallHierarchyPrepareParams {
173 text_document_position_params: TextDocumentPositionParams {
174 text_document: TextDocumentIdentifier { uri },
175 position: Position { line, character },
176 },
177 work_done_progress_params: WorkDoneProgressParams::default(),
178 };
179 self.call("textDocument/prepareCallHierarchy", ¶ms, Vec::new).await
180 }
181
182 pub async fn incoming_calls(&self, item: CallHierarchyItem) -> ClientResult<Vec<CallHierarchyIncomingCall>> {
183 let params = CallHierarchyIncomingCallsParams {
184 item,
185 work_done_progress_params: WorkDoneProgressParams::default(),
186 partial_result_params: PartialResultParams::default(),
187 };
188 self.call("callHierarchy/incomingCalls", ¶ms, Vec::new).await
189 }
190
191 pub async fn outgoing_calls(&self, item: CallHierarchyItem) -> ClientResult<Vec<CallHierarchyOutgoingCall>> {
192 let params = CallHierarchyOutgoingCallsParams {
193 item,
194 work_done_progress_params: WorkDoneProgressParams::default(),
195 partial_result_params: PartialResultParams::default(),
196 };
197 self.call("callHierarchy/outgoingCalls", ¶ms, Vec::new).await
198 }
199
200 pub async fn rename(
201 &self,
202 uri: Uri,
203 line: u32,
204 character: u32,
205 new_name: String,
206 ) -> ClientResult<Option<WorkspaceEdit>> {
207 let params = RenameParams {
208 text_document_position: TextDocumentPositionParams {
209 text_document: TextDocumentIdentifier { uri },
210 position: Position { line, character },
211 },
212 new_name,
213 work_done_progress_params: WorkDoneProgressParams::default(),
214 };
215 self.call("textDocument/rename", ¶ms, || None).await
216 }
217
218 pub async fn get_diagnostics(&self, uri: Option<Uri>) -> ClientResult<Vec<PublishDiagnosticsParams>> {
219 let client_id = self.next_id.fetch_add(1, Ordering::SeqCst);
220 let request = DaemonRequest::GetDiagnostics { client_id, uri };
221
222 self.send_and_await(request, client_id)
223 .await
224 .and_then(|value| serde_json::from_value(value).map_err(|err| ClientError::ProtocolError(err.to_string())))
225 }
226
227 pub async fn queue_diagnostic_refresh(&self, uri: Uri) -> ClientResult<()> {
228 let client_id = self.next_id.fetch_add(1, Ordering::SeqCst);
229 let request = DaemonRequest::QueueDiagnosticRefresh { client_id, uri };
230 self.send_and_await(request, client_id).await.map(|_| ())
231 }
232
233 pub async fn disconnect(self) -> ClientResult<()> {
234 let request = DaemonRequest::Disconnect;
235 let mut writer = self.writer.lock().await;
236 write_frame(&mut *writer, &request).await.map_err(ClientError::Io)
237 }
238
239 pub async fn call<P: Serialize, R: DeserializeOwned>(
240 &self,
241 method: &str,
242 params: &P,
243 default: impl FnOnce() -> R,
244 ) -> ClientResult<R> {
245 let params_value = serde_json::to_value(params).map_err(|err| ClientError::ProtocolError(err.to_string()))?;
246
247 let client_id = self.next_id.fetch_add(1, Ordering::SeqCst);
248 let request = DaemonRequest::LspCall { client_id, method: method.to_string(), params: params_value };
249
250 let value = self.send_and_await(request, client_id).await?;
251
252 if value.is_null() {
253 Ok(default())
254 } else {
255 serde_json::from_value(value).map_err(|err| ClientError::ProtocolError(format!("Parse error: {err}")))
256 }
257 }
258}
259
260impl LspClient {
261 async fn from_stream(stream: UnixStream, workspace_root: &Path, language: LanguageId) -> ClientResult<Self> {
262 let (mut reader, mut writer) = tokio::io::split(stream);
263
264 let initialize =
265 DaemonRequest::Initialize(InitializeRequest { workspace_root: workspace_root.to_path_buf(), language });
266
267 write_frame(&mut writer, &initialize).await.map_err(ClientError::Io)?;
268
269 let response: Option<DaemonResponse> = read_frame(&mut reader).await.map_err(ClientError::Io)?;
270
271 match response {
272 Some(DaemonResponse::Initialized) => {}
273 Some(DaemonResponse::Error(err)) => {
274 return Err(ClientError::InitializationFailed(err.message));
275 }
276 Some(_) => {
277 return Err(ClientError::ProtocolError("Unexpected response to Initialize".into()));
278 }
279 None => {
280 return Err(ClientError::ProtocolError("Connection closed during initialization".into()));
281 }
282 }
283
284 let pending: Arc<Mutex<HashMap<i64, oneshot::Sender<PendingResult>>>> = Arc::new(Mutex::new(HashMap::new()));
285
286 let pending_clone = Arc::clone(&pending);
287 let reader_task = tokio::spawn(async move {
288 run_reader(reader, pending_clone).await;
289 });
290
291 Ok(Self { writer: Mutex::new(writer), pending, next_id: AtomicI64::new(1), reader_task })
292 }
293
294 async fn send_and_await(&self, request: DaemonRequest, client_id: i64) -> ClientResult<Value> {
295 let (response_tx, response_rx) = oneshot::channel();
296
297 {
298 let mut pending = self.pending.lock().await;
299 pending.insert(client_id, response_tx);
300 }
301
302 let write_result = {
303 let mut writer = self.writer.lock().await;
304 write_frame(&mut *writer, &request).await
305 };
306
307 if let Err(err) = write_result {
308 self.pending.lock().await.remove(&client_id);
309 return Err(ClientError::Io(err));
310 }
311
312 response_rx.await.map_err(|_| ClientError::ProtocolError("Response channel closed".into()))?
313 }
314}
315
316impl Drop for LspClient {
317 fn drop(&mut self) {
318 self.reader_task.abort();
319 }
320}
321
322type PendingResult = Result<Value, ClientError>;
323
324async fn run_reader(
325 mut reader: ReadHalf<UnixStream>,
326 pending: Arc<Mutex<HashMap<i64, oneshot::Sender<PendingResult>>>>,
327) {
328 loop {
329 let response: Option<DaemonResponse> = match read_frame(&mut reader).await {
330 Ok(Some(response)) => Some(response),
331 Ok(None) => break,
332 Err(err) => {
333 tracing::debug!(%err, "Error reading daemon response");
334 break;
335 }
336 };
337
338 match response {
339 Some(DaemonResponse::LspResult { client_id, result }) => {
340 let mut pending = pending.lock().await;
341 if let Some(tx) = pending.remove(&client_id) {
342 let value_result =
343 result.map_err(|err| ClientError::LspError { code: err.code, message: err.message });
344 let _ = tx.send(value_result);
345 }
346 }
347 Some(DaemonResponse::Error(err)) => {
348 if let Some(client_id) = err.client_id {
349 let mut pending = pending.lock().await;
350 if let Some(tx) = pending.remove(&client_id) {
351 let _ = tx.send(Err(ClientError::DaemonError(err.message)));
352 }
353 }
354 }
355 _ => {}
356 }
357 }
358}
359
360async fn spawn_daemon(socket_path: &Path) -> ClientResult<()> {
361 let (binary, subcommand) = find_daemon_binary()?;
362 let log_file = log_file_path(socket_path);
363
364 let mut cmd = Command::new(&binary);
365 if let Some(sub) = subcommand {
366 cmd.arg(sub);
367 }
368 cmd.arg("--socket")
369 .arg(socket_path)
370 .arg("--log-file")
371 .arg(&log_file)
372 .arg("--log-level")
373 .arg("debug")
374 .stdin(Stdio::null())
375 .stdout(Stdio::null())
376 .stderr(Stdio::null());
377
378 let mut child = cmd.spawn().map_err(ClientError::SpawnFailed)?;
379
380 for _ in 0..50 {
381 match child.try_wait() {
382 Ok(Some(status)) if !status.success() => {
383 return Err(ClientError::SpawnFailed(std::io::Error::other(format!(
384 "Daemon exited with status: {status}"
385 ))));
386 }
387 Ok(Some(_) | None) => {}
388 Err(err) => return Err(ClientError::SpawnFailed(err)),
389 }
390
391 tokio::time::sleep(Duration::from_millis(100)).await;
392 if UnixStream::connect(socket_path).await.is_ok() {
393 return Ok(());
394 }
395 }
396
397 Err(ClientError::SpawnTimeout)
398}
399
400fn find_daemon_binary() -> ClientResult<(PathBuf, Option<&'static str>)> {
401 let exe = std::env::current_exe().ok();
402 let exe_dir = exe.as_deref().and_then(|p| p.parent());
403
404 let standalone_candidates = [
405 exe_dir.map(|dir| dir.join("aether-lspd")),
406 exe_dir.and_then(|dir| dir.parent()).map(|dir| dir.join("aether-lspd")),
407 which_aether_lspd(),
408 Some(PathBuf::from("target/debug/aether-lspd")),
409 Some(PathBuf::from("target/release/aether-lspd")),
410 Some(PathBuf::from("../../target/debug/aether-lspd")),
411 Some(PathBuf::from("../../target/release/aether-lspd")),
412 ];
413
414 for candidate in standalone_candidates.into_iter().flatten() {
415 if candidate.exists() {
416 return Ok((candidate, None));
417 }
418 }
419
420 if let Some(exe) = exe {
421 return Ok((exe, Some("lspd")));
422 }
423
424 Err(ClientError::DaemonBinaryNotFound("aether-lspd not found".into()))
425}
426
427fn which_aether_lspd() -> Option<PathBuf> {
428 std::env::var_os("PATH")
429 .and_then(|paths| std::env::split_paths(&paths).map(|path| path.join("aether-lspd")).find(|path| path.exists()))
430}