Skip to main content

aft/lsp/
client.rs

1use std::collections::HashMap;
2use std::io::{self, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::child_registry::LspChildRegistry;
16use crate::lsp::jsonrpc::{
17    Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
18};
19use crate::lsp::registry::ServerKind;
20use crate::lsp::{transport, LspError};
21
22const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
23const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
24const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
25
26type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
27
28/// Lifecycle state of a language server.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ServerState {
31    Starting,
32    Initializing,
33    Ready,
34    ShuttingDown,
35    Exited,
36}
37
38/// Events sent from background reader threads into the main loop.
39#[derive(Debug)]
40pub enum LspEvent {
41    /// Server sent a notification (e.g. publishDiagnostics).
42    Notification {
43        server_kind: ServerKind,
44        root: PathBuf,
45        method: String,
46        params: Option<Value>,
47    },
48    /// Server sent a request (e.g. workspace/configuration).
49    ServerRequest {
50        server_kind: ServerKind,
51        root: PathBuf,
52        id: RequestId,
53        method: String,
54        params: Option<Value>,
55    },
56    /// Server process exited or the transport stream closed.
57    ServerExited {
58        server_kind: ServerKind,
59        root: PathBuf,
60    },
61}
62
63/// What this server told us it can do during the LSP `initialize` handshake.
64///
65/// We capture this once and use it to route diagnostic requests:
66/// - `pull_diagnostics` → use `textDocument/diagnostic` instead of waiting for push
67/// - `workspace_diagnostics` → use `workspace/diagnostic` for directory mode
68///
69/// Defaults are conservative: `false` means "fall back to push semantics".
70#[derive(Debug, Clone, Default)]
71pub struct ServerDiagnosticCapabilities {
72    /// Server supports `textDocument/diagnostic` (LSP 3.17 per-file pull).
73    pub pull_diagnostics: bool,
74    /// Server supports `workspace/diagnostic` (LSP 3.17 workspace-wide pull).
75    pub workspace_diagnostics: bool,
76    /// `identifier` field from server's diagnosticProvider, if any.
77    /// Used to scope previousResultId tracking when multiple servers share a file.
78    pub identifier: Option<String>,
79    /// Whether the server requested workspace diagnostic refresh notifications.
80    /// We declare `refreshSupport: false` in our client capabilities so this
81    /// should always be false in practice — kept for completeness.
82    pub refresh_support: bool,
83}
84
85/// A client connected to one language server process.
86pub struct LspClient {
87    kind: ServerKind,
88    root: PathBuf,
89    state: ServerState,
90    child: Child,
91    /// Child PID captured at spawn time. Used by Drop to untrack the
92    /// PID from the shared registry; we capture once rather than reading
93    /// `child.id()` later because Drop ordering with the Child can race.
94    child_pid: u32,
95    writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
96
97    /// Pending request responses, keyed by request ID.
98    pending: Arc<Mutex<PendingMap>>,
99    /// Next request ID counter.
100    next_id: AtomicI64,
101    /// Diagnostic capabilities reported by the server in its initialize response.
102    /// `None` until `initialize()` succeeds; conservative defaults thereafter
103    /// when the server doesn't advertise diagnosticProvider.
104    diagnostic_caps: Option<ServerDiagnosticCapabilities>,
105    /// Whether the server advertised `workspace.didChangeWatchedFiles` support
106    /// during `initialize`. When `false` (or `None` pre-init), we skip sending
107    /// `workspace/didChangeWatchedFiles` notifications to avoid spec violations.
108    /// Intentional default: `false` (conservative — requires server opt-in).
109    supports_watched_files: bool,
110    /// Shared registry that tracks live LSP child PIDs across the process
111    /// so the signal handler can SIGKILL them on SIGTERM/SIGINT before
112    /// aft exits. Cloned via `Arc` — multiple clients share the same set.
113    child_registry: LspChildRegistry,
114}
115
116impl LspClient {
117    /// Spawn a new language server process and start the background reader thread.
118    ///
119    /// `child_registry` is a shared handle that records this child's PID so
120    /// the signal handler can SIGKILL it on SIGTERM/SIGINT. Tests that don't
121    /// care about signal cleanup can pass `LspChildRegistry::new()`.
122    pub fn spawn(
123        kind: ServerKind,
124        root: PathBuf,
125        binary: &Path,
126        args: &[String],
127        env: &HashMap<String, String>,
128        event_tx: Sender<LspEvent>,
129        child_registry: LspChildRegistry,
130    ) -> io::Result<Self> {
131        let mut command = Command::new(binary);
132        command
133            .args(args)
134            .current_dir(&root)
135            .stdin(Stdio::piped())
136            .stdout(Stdio::piped())
137            // Use null() instead of piped() to prevent deadlock when the server
138            // writes more than ~64KB to stderr (piped buffer fills, server blocks)
139            .stderr(Stdio::null());
140        for (key, value) in env {
141            command.env(key, value);
142        }
143
144        let mut child = command.spawn()?;
145        let child_pid = child.id();
146        child_registry.track(child_pid);
147
148        let stdout = child
149            .stdout
150            .take()
151            .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
152        let stdin = child
153            .stdin
154            .take()
155            .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
156
157        let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
158        let pending = Arc::new(Mutex::new(PendingMap::new()));
159        let reader_pending = Arc::clone(&pending);
160        let reader_writer = Arc::clone(&writer);
161        let reader_kind = kind.clone();
162        let reader_root = root.clone();
163
164        thread::spawn(move || {
165            let mut reader = BufReader::new(stdout);
166            loop {
167                match transport::read_message(&mut reader) {
168                    Ok(Some(ServerMessage::Response(response))) => {
169                        if let Ok(mut guard) = reader_pending.lock() {
170                            if let Some(tx) = guard.remove(&response.id) {
171                                if tx.send(response).is_err() {
172                                    log::debug!("response channel closed");
173                                }
174                            }
175                        } else {
176                            let _ = event_tx.send(LspEvent::ServerExited {
177                                server_kind: reader_kind.clone(),
178                                root: reader_root.clone(),
179                            });
180                            break;
181                        }
182                    }
183                    Ok(Some(ServerMessage::Notification { method, params })) => {
184                        let _ = event_tx.send(LspEvent::Notification {
185                            server_kind: reader_kind.clone(),
186                            root: reader_root.clone(),
187                            method,
188                            params,
189                        });
190                    }
191                    Ok(Some(ServerMessage::Request { id, method, params })) => {
192                        // Auto-respond to server requests to prevent deadlocks.
193                        // Server requests (like client/registerCapability,
194                        // window/workDoneProgress/create) block the server until
195                        // we respond. If we don't respond, the server won't send
196                        // responses to OUR pending requests → deadlock.
197                        //
198                        // Dispatch by method to return correct types:
199                        // - workspace/configuration expects Vec<Value> (one per item)
200                        // - Everything else gets null (safe default for registration/progress)
201                        let response_value = if method == "workspace/configuration" {
202                            // Return an array of null configs — one per requested item.
203                            // Servers fall back to filesystem config (tsconfig, pyrightconfig, etc.)
204                            let item_count = params
205                                .as_ref()
206                                .and_then(|p| p.get("items"))
207                                .and_then(|items| items.as_array())
208                                .map_or(1, |arr| arr.len());
209                            serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
210                        } else {
211                            serde_json::Value::Null
212                        };
213                        if let Ok(mut w) = reader_writer.lock() {
214                            let response = super::jsonrpc::OutgoingResponse::success(
215                                id.clone(),
216                                response_value,
217                            );
218                            let _ = transport::write_response(&mut *w, &response);
219                        }
220                        // Also forward as event for any interested handlers
221                        let _ = event_tx.send(LspEvent::ServerRequest {
222                            server_kind: reader_kind.clone(),
223                            root: reader_root.clone(),
224                            id,
225                            method,
226                            params,
227                        });
228                    }
229                    Ok(None) | Err(_) => {
230                        if let Ok(mut guard) = reader_pending.lock() {
231                            guard.clear();
232                        }
233                        let _ = event_tx.send(LspEvent::ServerExited {
234                            server_kind: reader_kind.clone(),
235                            root: reader_root.clone(),
236                        });
237                        break;
238                    }
239                }
240            }
241        });
242
243        Ok(Self {
244            kind,
245            root,
246            state: ServerState::Starting,
247            child,
248            child_pid,
249            writer,
250            pending,
251            next_id: AtomicI64::new(1),
252            diagnostic_caps: None,
253            supports_watched_files: false,
254            child_registry,
255        })
256    }
257
258    /// Send the initialize request and wait for response. Transition to Ready.
259    pub fn initialize(
260        &mut self,
261        workspace_root: &Path,
262        initialization_options: Option<serde_json::Value>,
263    ) -> Result<lsp_types::InitializeResult, LspError> {
264        self.ensure_can_send()?;
265        self.state = ServerState::Initializing;
266
267        let normalized = normalize_windows_path(workspace_root);
268        let root_url = url::Url::from_file_path(&normalized).map_err(|_| {
269            LspError::NotFound(format!(
270                "failed to convert workspace root '{}' to file URI",
271                workspace_root.display()
272            ))
273        })?;
274        let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
275            LspError::NotFound(format!(
276                "failed to convert workspace root '{}' to file URI",
277                workspace_root.display()
278            ))
279        })?;
280
281        let mut params_value = json!({
282            "processId": std::process::id(),
283            "rootUri": root_uri,
284            "capabilities": {
285                "workspace": {
286                    "workspaceFolders": true,
287                    "configuration": true,
288                    // LSP 3.17 workspace diagnostic pull. We declare refreshSupport=false
289                    // because we drive diagnostics on-demand via pull/push and re-query
290                    // when the agent calls lsp_diagnostics again — we don't need the
291                    // server to proactively push refresh notifications.
292                    "diagnostic": {
293                        "refreshSupport": false
294                    }
295                },
296                "textDocument": {
297                    "synchronization": {
298                        "dynamicRegistration": false,
299                        "didSave": true,
300                        "willSave": false,
301                        "willSaveWaitUntil": false
302                    },
303                    "publishDiagnostics": {
304                        "relatedInformation": true,
305                        "versionSupport": true,
306                        "codeDescriptionSupport": true,
307                        "dataSupport": true
308                    },
309                    // LSP 3.17 textDocument diagnostic pull. dynamicRegistration=false
310                    // because we use static capability discovery from the InitializeResult.
311                    // relatedDocumentSupport=true to receive cascading diagnostics for
312                    // files that became known while analyzing the requested one.
313                    "diagnostic": {
314                        "dynamicRegistration": false,
315                        "relatedDocumentSupport": true
316                    }
317                }
318            },
319            "clientInfo": {
320                "name": "aft",
321                "version": env!("CARGO_PKG_VERSION")
322            },
323            "workspaceFolders": [
324                {
325                    "uri": root_uri,
326                    "name": workspace_root
327                        .file_name()
328                        .and_then(|name| name.to_str())
329                        .unwrap_or("workspace")
330                }
331            ]
332        });
333        if let Some(initialization_options) = initialization_options {
334            params_value["initializationOptions"] = initialization_options;
335        }
336
337        let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
338
339        let result_value = self.send_request_value(
340            <lsp_types::request::Initialize as lsp_types::request::Request>::METHOD,
341            params,
342        )?;
343        let result: lsp_types::InitializeResult = serde_json::from_value(result_value.clone())?;
344
345        // Capture diagnostic capabilities from the initialize response. We parse
346        // from a re-serialized JSON Value because the lsp-types crate's
347        // diagnostic_provider strict variants reject some shapes real servers
348        // emit (e.g. bare `true`), and we want defensive Default fallback.
349        let caps_value = result_value
350            .get("capabilities")
351            .cloned()
352            .unwrap_or_else(|| serde_json::to_value(&result.capabilities).unwrap_or(Value::Null));
353        self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
354
355        // Capture whether the server supports workspace/didChangeWatchedFiles.
356        // Missing capability is unsupported by default; callers must not send
357        // notifications unless the server explicitly opted in.
358        self.supports_watched_files = caps_value
359            .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
360            .and_then(|v| v.as_bool())
361            .unwrap_or(false)
362            || caps_value
363                .pointer("/workspace/didChangeWatchedFiles")
364                .map(|v| v.is_object() || v.as_bool() == Some(true))
365                .unwrap_or(false);
366
367        self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
368            json!({}),
369        )?)?;
370        self.state = ServerState::Ready;
371        Ok(result)
372    }
373
374    /// Diagnostic capabilities advertised by the server. Returns `None` until
375    /// `initialize()` has succeeded; returns `Some` with conservative defaults
376    /// (all `false`) when the server didn't advertise diagnosticProvider.
377    pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
378        self.diagnostic_caps.as_ref()
379    }
380
381    /// Whether the server supports `workspace/didChangeWatchedFiles`.
382    /// Captured from the `initialize` response. Default `false` (conservative).
383    pub fn supports_watched_files(&self) -> bool {
384        self.supports_watched_files
385    }
386
387    /// Send a request and wait for the response.
388    pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
389    where
390        R: lsp_types::request::Request,
391        R::Params: serde::Serialize,
392        R::Result: DeserializeOwned,
393    {
394        self.ensure_can_send()?;
395
396        let value = self.send_request_value(R::METHOD, params)?;
397        serde_json::from_value(value).map_err(Into::into)
398    }
399
400    fn send_request_value<P>(&mut self, method: &'static str, params: P) -> Result<Value, LspError>
401    where
402        P: serde::Serialize,
403    {
404        self.ensure_can_send()?;
405
406        let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
407        let (tx, rx) = bounded(1);
408        {
409            let mut pending = self.lock_pending()?;
410            pending.insert(id.clone(), tx);
411        }
412
413        let request = Request::new(id.clone(), method, Some(serde_json::to_value(params)?));
414        {
415            let mut writer = self
416                .writer
417                .lock()
418                .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
419            if let Err(err) = transport::write_request(&mut *writer, &request) {
420                self.remove_pending(&id);
421                return Err(err.into());
422            }
423        }
424
425        let response = match rx.recv_timeout(REQUEST_TIMEOUT) {
426            Ok(response) => response,
427            Err(RecvTimeoutError::Timeout) => {
428                self.remove_pending(&id);
429                return Err(LspError::Timeout(format!(
430                    "timed out waiting for '{}' response from {:?}",
431                    method, self.kind
432                )));
433            }
434            Err(RecvTimeoutError::Disconnected) => {
435                self.remove_pending(&id);
436                return Err(LspError::ServerNotReady(format!(
437                    "language server {:?} disconnected while waiting for '{}'",
438                    self.kind, method
439                )));
440            }
441        };
442
443        if let Some(error) = response.error {
444            return Err(LspError::ServerError {
445                code: error.code,
446                message: error.message,
447            });
448        }
449
450        Ok(response.result.unwrap_or(Value::Null))
451    }
452
453    /// Send a notification (fire-and-forget).
454    pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
455    where
456        N: lsp_types::notification::Notification,
457        N::Params: serde::Serialize,
458    {
459        self.ensure_can_send()?;
460        let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
461        let mut writer = self
462            .writer
463            .lock()
464            .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
465        transport::write_notification(&mut *writer, &notification)?;
466        Ok(())
467    }
468
469    /// Graceful shutdown: send shutdown request, then exit notification.
470    pub fn shutdown(&mut self) -> Result<(), LspError> {
471        if self.state == ServerState::Exited {
472            self.child_registry.untrack(self.child_pid);
473            return Ok(());
474        }
475
476        if self.child.try_wait()?.is_some() {
477            self.state = ServerState::Exited;
478            self.child_registry.untrack(self.child_pid);
479            return Ok(());
480        }
481
482        if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
483            self.state = ServerState::ShuttingDown;
484            if self.child.try_wait()?.is_some() {
485                self.state = ServerState::Exited;
486                return Ok(());
487            }
488            return Err(err);
489        }
490
491        self.state = ServerState::ShuttingDown;
492
493        if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
494            if self.child.try_wait()?.is_some() {
495                self.state = ServerState::Exited;
496                return Ok(());
497            }
498            return Err(err);
499        }
500
501        let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
502        loop {
503            if self.child.try_wait()?.is_some() {
504                self.state = ServerState::Exited;
505                return Ok(());
506            }
507            if Instant::now() >= deadline {
508                let _ = self.child.kill();
509                let _ = self.child.wait();
510                self.state = ServerState::Exited;
511                return Err(LspError::Timeout(format!(
512                    "timed out waiting for {:?} to exit",
513                    self.kind
514                )));
515            }
516            thread::sleep(EXIT_POLL_INTERVAL);
517        }
518    }
519
520    pub fn state(&self) -> ServerState {
521        self.state
522    }
523
524    pub fn kind(&self) -> ServerKind {
525        self.kind.clone()
526    }
527
528    pub fn root(&self) -> &Path {
529        &self.root
530    }
531
532    fn ensure_can_send(&self) -> Result<(), LspError> {
533        if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
534            return Err(LspError::ServerNotReady(format!(
535                "language server {:?} is not ready (state: {:?})",
536                self.kind, self.state
537            )));
538        }
539        Ok(())
540    }
541
542    fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
543        self.pending
544            .lock()
545            .map_err(|_| io::Error::other("pending response map poisoned").into())
546    }
547
548    fn remove_pending(&self, id: &RequestId) {
549        if let Ok(mut pending) = self.pending.lock() {
550            pending.remove(id);
551        }
552    }
553}
554
555impl Drop for LspClient {
556    fn drop(&mut self) {
557        // Untrack first so the signal handler can't race with this kill and
558        // try to SIGKILL a PID that's already been reaped.
559        self.child_registry.untrack(self.child_pid);
560        let _ = self.child.kill();
561        let _ = self.child.wait();
562    }
563}
564
565/// Normalize a path for file URI conversion.
566/// On Windows, strips the extended-length `\\?\` prefix that `Url::from_file_path` cannot handle.
567/// On other platforms, returns the path unchanged.
568fn normalize_windows_path(path: &Path) -> PathBuf {
569    let s = path.to_string_lossy();
570    if let Some(stripped) = s.strip_prefix(r"\\?\") {
571        PathBuf::from(stripped)
572    } else {
573        path.to_path_buf()
574    }
575}
576
577/// Parse `ServerDiagnosticCapabilities` from a re-serialized
578/// `ServerCapabilities` JSON value.
579///
580/// LSP 3.17 spec for `diagnosticProvider`:
581/// - `capabilities.diagnosticProvider` may be absent (no pull support),
582///   `DiagnosticOptions`, or `DiagnosticRegistrationOptions`.
583/// - If present:
584///   - `interFileDependencies: bool` (we don't currently use this)
585///   - `workspaceDiagnostics: bool` → workspace pull support
586///   - `identifier?: string` → optional identifier scoping result IDs
587///
588/// We parse the raw JSON Value defensively: presence of any
589/// `diagnosticProvider` value (object or `true`) means the server supports
590/// at least `textDocument/diagnostic` pull.
591fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
592    let mut caps = ServerDiagnosticCapabilities::default();
593
594    if let Some(provider) = value.get("diagnosticProvider") {
595        // diagnosticProvider can be `true` (rare) or an object. Treat both as
596        // pull_diagnostics support.
597        if provider.is_object() || provider.as_bool() == Some(true) {
598            caps.pull_diagnostics = true;
599        }
600
601        if let Some(obj) = provider.as_object() {
602            if obj
603                .get("workspaceDiagnostics")
604                .and_then(|v| v.as_bool())
605                .unwrap_or(false)
606            {
607                caps.workspace_diagnostics = true;
608            }
609            if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
610                caps.identifier = Some(identifier.to_string());
611            }
612        }
613    }
614
615    // Workspace diagnostic refresh (rare — most servers don't request this,
616    // and we declared refreshSupport=false in our client capabilities anyway).
617    if let Some(refresh) = value
618        .get("workspace")
619        .and_then(|w| w.get("diagnostic"))
620        .and_then(|d| d.get("refreshSupport"))
621        .and_then(|r| r.as_bool())
622    {
623        caps.refresh_support = refresh;
624    }
625
626    caps
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn parse_caps_no_diagnostic_provider() {
635        let value = json!({});
636        let caps = parse_diagnostic_capabilities(&value);
637        assert!(!caps.pull_diagnostics);
638        assert!(!caps.workspace_diagnostics);
639        assert!(caps.identifier.is_none());
640    }
641
642    #[test]
643    fn parse_caps_basic_pull_only() {
644        let value = json!({
645            "diagnosticProvider": {
646                "interFileDependencies": false,
647                "workspaceDiagnostics": false
648            }
649        });
650        let caps = parse_diagnostic_capabilities(&value);
651        assert!(caps.pull_diagnostics);
652        assert!(!caps.workspace_diagnostics);
653    }
654
655    #[test]
656    fn parse_caps_full_pull_with_workspace() {
657        let value = json!({
658            "diagnosticProvider": {
659                "interFileDependencies": true,
660                "workspaceDiagnostics": true,
661                "identifier": "rust-analyzer"
662            }
663        });
664        let caps = parse_diagnostic_capabilities(&value);
665        assert!(caps.pull_diagnostics);
666        assert!(caps.workspace_diagnostics);
667        assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
668    }
669
670    #[test]
671    fn parse_caps_provider_as_bare_true() {
672        // LSP 3.17 allows DiagnosticOptions OR boolean — treat true as pull_diagnostics
673        let value = json!({
674            "diagnosticProvider": true
675        });
676        let caps = parse_diagnostic_capabilities(&value);
677        assert!(caps.pull_diagnostics);
678        assert!(!caps.workspace_diagnostics);
679    }
680
681    #[test]
682    fn parse_caps_workspace_refresh_support() {
683        let value = json!({
684            "workspace": {
685                "diagnostic": {
686                    "refreshSupport": true
687                }
688            }
689        });
690        let caps = parse_diagnostic_capabilities(&value);
691        assert!(caps.refresh_support);
692        // No diagnosticProvider → pull still false
693        assert!(!caps.pull_diagnostics);
694    }
695}