Skip to main content

aether_lspd/
client.rs

1use std::collections::HashMap;
2use std::io;
3use std::io::ErrorKind;
4use std::path::{Path, PathBuf};
5use std::process::Stdio;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicI64, Ordering};
8use std::time::Duration;
9
10use lsp_types::{
11    CallHierarchyIncomingCall, CallHierarchyIncomingCallsParams, CallHierarchyItem, CallHierarchyOutgoingCall,
12    CallHierarchyOutgoingCallsParams, CallHierarchyPrepareParams, DocumentSymbolParams, DocumentSymbolResponse,
13    GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverParams, Location, PartialResultParams, Position,
14    PublishDiagnosticsParams, ReferenceContext, ReferenceParams, RenameParams, SymbolInformation,
15    TextDocumentIdentifier, TextDocumentPositionParams, Uri, WorkDoneProgressParams, WorkspaceEdit,
16    WorkspaceSymbolParams,
17};
18use serde::Serialize;
19use serde::de::DeserializeOwned;
20use serde_json::Value;
21use thiserror::Error;
22use tokio::io::{ReadHalf, WriteHalf};
23use tokio::net::UnixStream;
24use tokio::process::Command;
25use tokio::sync::{Mutex, oneshot};
26
27use crate::language_catalog::LanguageId;
28use crate::protocol::{DaemonRequest, DaemonResponse, InitializeRequest, read_frame, write_frame};
29use crate::socket_path::{ensure_socket_dir, log_file_path};
30
31#[doc = include_str!("docs/client_error.md")]
32#[derive(Debug, Error)]
33pub enum ClientError {
34    #[error("Failed to connect to daemon: {0}")]
35    ConnectionFailed(#[source] io::Error),
36
37    #[error("IO error: {0}")]
38    Io(#[from] io::Error),
39
40    #[error("Daemon error: {0}")]
41    DaemonError(String),
42
43    #[error("LSP error (code={code}): {message}")]
44    LspError { code: i32, message: String },
45
46    #[error("Failed to spawn daemon: {0}")]
47    SpawnFailed(#[source] io::Error),
48
49    #[error("Timeout waiting for daemon to start")]
50    SpawnTimeout,
51
52    #[error("Daemon binary not found: {0}")]
53    DaemonBinaryNotFound(String),
54
55    #[error("Protocol error: {0}")]
56    ProtocolError(String),
57
58    #[error("Initialization failed: {0}")]
59    InitializationFailed(String),
60}
61
62pub type ClientResult<T> = std::result::Result<T, ClientError>;
63
64#[doc = include_str!("docs/client.md")]
65pub struct LspClient {
66    writer: Mutex<WriteHalf<UnixStream>>,
67    pending: Arc<Mutex<HashMap<i64, oneshot::Sender<PendingResult>>>>,
68    next_id: AtomicI64,
69    reader_task: tokio::task::JoinHandle<()>,
70}
71
72impl LspClient {
73    pub async fn connect(workspace_root: &Path, language: LanguageId) -> ClientResult<Self> {
74        let socket_path = ensure_socket_dir(workspace_root, language).map_err(ClientError::Io)?;
75
76        match UnixStream::connect(&socket_path).await {
77            Ok(stream) => {
78                return Self::from_stream(stream, workspace_root, language).await;
79            }
80            Err(err) if err.kind() == ErrorKind::ConnectionRefused || err.kind() == ErrorKind::NotFound => {}
81            Err(err) => return Err(ClientError::ConnectionFailed(err)),
82        }
83
84        spawn_daemon(&socket_path).await?;
85        let stream = UnixStream::connect(&socket_path).await.map_err(ClientError::ConnectionFailed)?;
86        Self::from_stream(stream, workspace_root, language).await
87    }
88
89    pub async fn goto_definition(&self, uri: Uri, line: u32, character: u32) -> ClientResult<GotoDefinitionResponse> {
90        let params = GotoDefinitionParams {
91            text_document_position_params: TextDocumentPositionParams {
92                text_document: TextDocumentIdentifier { uri },
93                position: Position { line, character },
94            },
95            work_done_progress_params: WorkDoneProgressParams::default(),
96            partial_result_params: PartialResultParams::default(),
97        };
98        self.call("textDocument/definition", &params, || GotoDefinitionResponse::Array(vec![])).await
99    }
100
101    pub async fn goto_implementation(
102        &self,
103        uri: Uri,
104        line: u32,
105        character: u32,
106    ) -> ClientResult<GotoDefinitionResponse> {
107        let params = GotoDefinitionParams {
108            text_document_position_params: TextDocumentPositionParams {
109                text_document: TextDocumentIdentifier { uri },
110                position: Position { line, character },
111            },
112            work_done_progress_params: WorkDoneProgressParams::default(),
113            partial_result_params: PartialResultParams::default(),
114        };
115        self.call("textDocument/implementation", &params, || GotoDefinitionResponse::Array(vec![])).await
116    }
117
118    pub async fn find_references(
119        &self,
120        uri: Uri,
121        line: u32,
122        character: u32,
123        include_declaration: bool,
124    ) -> ClientResult<Vec<Location>> {
125        let params = ReferenceParams {
126            text_document_position: TextDocumentPositionParams {
127                text_document: TextDocumentIdentifier { uri },
128                position: Position { line, character },
129            },
130            work_done_progress_params: WorkDoneProgressParams::default(),
131            partial_result_params: PartialResultParams::default(),
132            context: ReferenceContext { include_declaration },
133        };
134        self.call("textDocument/references", &params, Vec::new).await
135    }
136
137    pub async fn hover(&self, uri: Uri, line: u32, character: u32) -> ClientResult<Option<Hover>> {
138        let params = HoverParams {
139            text_document_position_params: TextDocumentPositionParams {
140                text_document: TextDocumentIdentifier { uri },
141                position: Position { line, character },
142            },
143            work_done_progress_params: WorkDoneProgressParams::default(),
144        };
145        self.call("textDocument/hover", &params, || None).await
146    }
147
148    pub async fn workspace_symbol(&self, query: String) -> ClientResult<Vec<SymbolInformation>> {
149        let params = WorkspaceSymbolParams {
150            query,
151            partial_result_params: PartialResultParams::default(),
152            work_done_progress_params: WorkDoneProgressParams::default(),
153        };
154        self.call("workspace/symbol", &params, Vec::new).await
155    }
156
157    pub async fn document_symbol(&self, uri: Uri) -> ClientResult<DocumentSymbolResponse> {
158        let params = DocumentSymbolParams {
159            text_document: TextDocumentIdentifier { uri },
160            work_done_progress_params: WorkDoneProgressParams::default(),
161            partial_result_params: PartialResultParams::default(),
162        };
163        self.call("textDocument/documentSymbol", &params, || DocumentSymbolResponse::Flat(vec![])).await
164    }
165
166    pub async fn prepare_call_hierarchy(
167        &self,
168        uri: Uri,
169        line: u32,
170        character: u32,
171    ) -> ClientResult<Vec<CallHierarchyItem>> {
172        let params = CallHierarchyPrepareParams {
173            text_document_position_params: TextDocumentPositionParams {
174                text_document: TextDocumentIdentifier { uri },
175                position: Position { line, character },
176            },
177            work_done_progress_params: WorkDoneProgressParams::default(),
178        };
179        self.call("textDocument/prepareCallHierarchy", &params, Vec::new).await
180    }
181
182    pub async fn incoming_calls(&self, item: CallHierarchyItem) -> ClientResult<Vec<CallHierarchyIncomingCall>> {
183        let params = CallHierarchyIncomingCallsParams {
184            item,
185            work_done_progress_params: WorkDoneProgressParams::default(),
186            partial_result_params: PartialResultParams::default(),
187        };
188        self.call("callHierarchy/incomingCalls", &params, Vec::new).await
189    }
190
191    pub async fn outgoing_calls(&self, item: CallHierarchyItem) -> ClientResult<Vec<CallHierarchyOutgoingCall>> {
192        let params = CallHierarchyOutgoingCallsParams {
193            item,
194            work_done_progress_params: WorkDoneProgressParams::default(),
195            partial_result_params: PartialResultParams::default(),
196        };
197        self.call("callHierarchy/outgoingCalls", &params, Vec::new).await
198    }
199
200    pub async fn rename(
201        &self,
202        uri: Uri,
203        line: u32,
204        character: u32,
205        new_name: String,
206    ) -> ClientResult<Option<WorkspaceEdit>> {
207        let params = RenameParams {
208            text_document_position: TextDocumentPositionParams {
209                text_document: TextDocumentIdentifier { uri },
210                position: Position { line, character },
211            },
212            new_name,
213            work_done_progress_params: WorkDoneProgressParams::default(),
214        };
215        self.call("textDocument/rename", &params, || None).await
216    }
217
218    pub async fn get_diagnostics(&self, uri: Option<Uri>) -> ClientResult<Vec<PublishDiagnosticsParams>> {
219        let client_id = self.next_id.fetch_add(1, Ordering::SeqCst);
220        let request = DaemonRequest::GetDiagnostics { client_id, uri };
221
222        self.send_and_await(request, client_id)
223            .await
224            .and_then(|value| serde_json::from_value(value).map_err(|err| ClientError::ProtocolError(err.to_string())))
225    }
226
227    pub async fn queue_diagnostic_refresh(&self, uri: Uri) -> ClientResult<()> {
228        let client_id = self.next_id.fetch_add(1, Ordering::SeqCst);
229        let request = DaemonRequest::QueueDiagnosticRefresh { client_id, uri };
230        self.send_and_await(request, client_id).await.map(|_| ())
231    }
232
233    pub async fn disconnect(self) -> ClientResult<()> {
234        let request = DaemonRequest::Disconnect;
235        let mut writer = self.writer.lock().await;
236        write_frame(&mut *writer, &request).await.map_err(ClientError::Io)
237    }
238
239    pub async fn call<P: Serialize, R: DeserializeOwned>(
240        &self,
241        method: &str,
242        params: &P,
243        default: impl FnOnce() -> R,
244    ) -> ClientResult<R> {
245        let params_value = serde_json::to_value(params).map_err(|err| ClientError::ProtocolError(err.to_string()))?;
246
247        let client_id = self.next_id.fetch_add(1, Ordering::SeqCst);
248        let request = DaemonRequest::LspCall { client_id, method: method.to_string(), params: params_value };
249
250        let value = self.send_and_await(request, client_id).await?;
251
252        if value.is_null() {
253            Ok(default())
254        } else {
255            serde_json::from_value(value).map_err(|err| ClientError::ProtocolError(format!("Parse error: {err}")))
256        }
257    }
258}
259
260impl LspClient {
261    async fn from_stream(stream: UnixStream, workspace_root: &Path, language: LanguageId) -> ClientResult<Self> {
262        let (mut reader, mut writer) = tokio::io::split(stream);
263
264        let initialize =
265            DaemonRequest::Initialize(InitializeRequest { workspace_root: workspace_root.to_path_buf(), language });
266
267        write_frame(&mut writer, &initialize).await.map_err(ClientError::Io)?;
268
269        let response: Option<DaemonResponse> = read_frame(&mut reader).await.map_err(ClientError::Io)?;
270
271        match response {
272            Some(DaemonResponse::Initialized) => {}
273            Some(DaemonResponse::Error(err)) => {
274                return Err(ClientError::InitializationFailed(err.message));
275            }
276            Some(_) => {
277                return Err(ClientError::ProtocolError("Unexpected response to Initialize".into()));
278            }
279            None => {
280                return Err(ClientError::ProtocolError("Connection closed during initialization".into()));
281            }
282        }
283
284        let pending: Arc<Mutex<HashMap<i64, oneshot::Sender<PendingResult>>>> = Arc::new(Mutex::new(HashMap::new()));
285
286        let pending_clone = Arc::clone(&pending);
287        let reader_task = tokio::spawn(async move {
288            run_reader(reader, pending_clone).await;
289        });
290
291        Ok(Self { writer: Mutex::new(writer), pending, next_id: AtomicI64::new(1), reader_task })
292    }
293
294    async fn send_and_await(&self, request: DaemonRequest, client_id: i64) -> ClientResult<Value> {
295        let (response_tx, response_rx) = oneshot::channel();
296
297        {
298            let mut pending = self.pending.lock().await;
299            pending.insert(client_id, response_tx);
300        }
301
302        let write_result = {
303            let mut writer = self.writer.lock().await;
304            write_frame(&mut *writer, &request).await
305        };
306
307        if let Err(err) = write_result {
308            self.pending.lock().await.remove(&client_id);
309            return Err(ClientError::Io(err));
310        }
311
312        response_rx.await.map_err(|_| ClientError::ProtocolError("Response channel closed".into()))?
313    }
314}
315
316impl Drop for LspClient {
317    fn drop(&mut self) {
318        self.reader_task.abort();
319    }
320}
321
322type PendingResult = Result<Value, ClientError>;
323
324async fn run_reader(
325    mut reader: ReadHalf<UnixStream>,
326    pending: Arc<Mutex<HashMap<i64, oneshot::Sender<PendingResult>>>>,
327) {
328    loop {
329        let response: Option<DaemonResponse> = match read_frame(&mut reader).await {
330            Ok(Some(response)) => Some(response),
331            Ok(None) => break,
332            Err(err) => {
333                tracing::debug!(%err, "Error reading daemon response");
334                break;
335            }
336        };
337
338        match response {
339            Some(DaemonResponse::LspResult { client_id, result }) => {
340                let mut pending = pending.lock().await;
341                if let Some(tx) = pending.remove(&client_id) {
342                    let value_result =
343                        result.map_err(|err| ClientError::LspError { code: err.code, message: err.message });
344                    let _ = tx.send(value_result);
345                }
346            }
347            Some(DaemonResponse::Error(err)) => {
348                if let Some(client_id) = err.client_id {
349                    let mut pending = pending.lock().await;
350                    if let Some(tx) = pending.remove(&client_id) {
351                        let _ = tx.send(Err(ClientError::DaemonError(err.message)));
352                    }
353                }
354            }
355            _ => {}
356        }
357    }
358}
359
360async fn spawn_daemon(socket_path: &Path) -> ClientResult<()> {
361    let (binary, subcommand) = find_daemon_binary()?;
362    let log_file = log_file_path(socket_path);
363
364    let mut cmd = Command::new(&binary);
365    if let Some(sub) = subcommand {
366        cmd.arg(sub);
367    }
368    cmd.arg("--socket")
369        .arg(socket_path)
370        .arg("--log-file")
371        .arg(&log_file)
372        .arg("--log-level")
373        .arg("debug")
374        .stdin(Stdio::null())
375        .stdout(Stdio::null())
376        .stderr(Stdio::null());
377
378    let mut child = cmd.spawn().map_err(ClientError::SpawnFailed)?;
379
380    for _ in 0..50 {
381        match child.try_wait() {
382            Ok(Some(status)) if !status.success() => {
383                return Err(ClientError::SpawnFailed(std::io::Error::other(format!(
384                    "Daemon exited with status: {status}"
385                ))));
386            }
387            Ok(Some(_) | None) => {}
388            Err(err) => return Err(ClientError::SpawnFailed(err)),
389        }
390
391        tokio::time::sleep(Duration::from_millis(100)).await;
392        if UnixStream::connect(socket_path).await.is_ok() {
393            return Ok(());
394        }
395    }
396
397    Err(ClientError::SpawnTimeout)
398}
399
400fn find_daemon_binary() -> ClientResult<(PathBuf, Option<&'static str>)> {
401    let exe = std::env::current_exe().ok();
402    let exe_dir = exe.as_deref().and_then(|p| p.parent());
403
404    let standalone_candidates = [
405        exe_dir.map(|dir| dir.join("aether-lspd")),
406        exe_dir.and_then(|dir| dir.parent()).map(|dir| dir.join("aether-lspd")),
407        which_aether_lspd(),
408        Some(PathBuf::from("target/debug/aether-lspd")),
409        Some(PathBuf::from("target/release/aether-lspd")),
410        Some(PathBuf::from("../../target/debug/aether-lspd")),
411        Some(PathBuf::from("../../target/release/aether-lspd")),
412    ];
413
414    for candidate in standalone_candidates.into_iter().flatten() {
415        if candidate.exists() {
416            return Ok((candidate, None));
417        }
418    }
419
420    if let Some(exe) = exe {
421        return Ok((exe, Some("lspd")));
422    }
423
424    Err(ClientError::DaemonBinaryNotFound("aether-lspd not found".into()))
425}
426
427fn which_aether_lspd() -> Option<PathBuf> {
428    std::env::var_os("PATH")
429        .and_then(|paths| std::env::split_paths(&paths).map(|path| path.join("aether-lspd")).find(|path| path.exists()))
430}