Skip to main content

krait/lsp/
client.rs

1use std::collections::{HashMap, HashSet};
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::{bail, Context};
7use lsp_types::{
8    ClientCapabilities, CodeActionClientCapabilities, DocumentSymbolClientCapabilities,
9    DynamicRegistrationClientCapabilities, GotoCapability, HoverClientCapabilities,
10    InitializeParams, InitializeResult, InitializedParams, PublishDiagnosticsClientCapabilities,
11    RenameClientCapabilities, ServerCapabilities, TextDocumentClientCapabilities,
12    TextDocumentSyncClientCapabilities, Uri, WindowClientCapabilities, WorkspaceClientCapabilities,
13    WorkspaceFolder, WorkspaceSymbolClientCapabilities,
14};
15use serde_json::Value;
16use tracing::debug;
17
18use super::diagnostics::{ingest_publish_diagnostics, DiagnosticStore};
19use super::error::LspError;
20use super::registry::{find_server, get_entry};
21use super::transport::{JsonRpcMessage, LspTransport};
22use crate::detect::Language;
23
24/// Default timeout for the initialize handshake.
25const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(30);
26
27/// Default timeout for shutdown.
28const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
29
30/// High-level LSP client that manages a language server lifecycle.
31pub struct LspClient {
32    transport: LspTransport,
33    capabilities: Option<ServerCapabilities>,
34    language: Language,
35    /// Buffered responses for request IDs that arrived out of order.
36    buffered_responses: HashMap<i64, BufferedResponse>,
37    /// Whether the server supports `workspace/didChangeWorkspaceFolders`.
38    supports_workspace_folders: bool,
39    /// The name of the server binary (e.g., "vtsls", "rust-analyzer").
40    server_name: String,
41    /// Workspace folders currently attached to this server.
42    attached_folders: HashSet<PathBuf>,
43    /// Optional store for collecting `textDocument/publishDiagnostics` notifications.
44    diagnostic_store: Option<Arc<DiagnosticStore>>,
45}
46
47/// A response received for a request ID we weren't waiting for yet.
48enum BufferedResponse {
49    Ok(Value),
50    Err(String),
51}
52
53impl LspClient {
54    /// Start an LSP server for the given language and project root.
55    ///
56    /// Looks for the server binary in PATH first, then in `~/.krait/servers/`.
57    /// Use `start_with_auto_install()` to also download if missing.
58    ///
59    /// # Errors
60    /// Returns `LspError::ServerNotFound` if the binary is missing.
61    /// Returns an error if the process cannot be spawned.
62    pub fn start(language: Language, project_root: &Path) -> Result<Self, LspError> {
63        let entry = get_entry(language).ok_or_else(|| LspError::InitializeFailed {
64            message: format!("no LSP config for {language}"),
65        })?;
66
67        let binary_path = find_server(&entry).ok_or_else(|| LspError::ServerNotFound {
68            language,
69            advice: entry.install_advice.to_string(),
70        })?;
71
72        Self::start_with_binary(&binary_path, entry.args, language, project_root)
73    }
74
75    /// Start an LSP server using a specific binary path.
76    ///
77    /// # Errors
78    /// Returns an error if the process cannot be spawned.
79    pub fn start_with_binary(
80        binary_path: &Path,
81        args: &[&str],
82        language: Language,
83        project_root: &Path,
84    ) -> Result<Self, LspError> {
85        let binary_str = binary_path.to_str().unwrap_or("unknown");
86        let server_name = binary_path
87            .file_name()
88            .and_then(|n| n.to_str())
89            .unwrap_or("unknown")
90            .to_string();
91
92        let transport = LspTransport::spawn(binary_str, args, project_root).map_err(|e| {
93            LspError::InitializeFailed {
94                message: format!("failed to spawn {binary_str}: {e}"),
95            }
96        })?;
97
98        debug!(
99            "started LSP server for {language}: {binary_str} {}",
100            args.join(" ")
101        );
102
103        Ok(Self {
104            transport,
105            capabilities: None,
106            language,
107            buffered_responses: HashMap::new(),
108            supports_workspace_folders: false,
109            server_name,
110            attached_folders: HashSet::new(),
111            diagnostic_store: None,
112        })
113    }
114
115    /// Perform the LSP initialize handshake.
116    ///
117    /// Sends `initialize` request, waits for response, then sends `initialized` notification.
118    ///
119    /// # Errors
120    /// Returns an error if the handshake fails or times out.
121    ///
122    /// # Panics
123    /// Panics if capabilities are not stored after a successful response (should never happen).
124    pub async fn initialize(&mut self, project_root: &Path) -> anyhow::Result<&ServerCapabilities> {
125        let root_uri = path_to_uri(project_root)?;
126        let params = build_initialize_params(&root_uri, project_root, self.language);
127        let params_value = serde_json::to_value(&params)?;
128
129        let request_id = self
130            .transport
131            .send_request("initialize", params_value)
132            .await?;
133
134        let result = self
135            .wait_for_response(request_id, INITIALIZE_TIMEOUT)
136            .await
137            .context("initialize handshake failed")?;
138
139        let init_result: InitializeResult =
140            serde_json::from_value(result).context("failed to parse InitializeResult")?;
141
142        // Detect workspace folder support
143        self.supports_workspace_folders = init_result
144            .capabilities
145            .workspace
146            .as_ref()
147            .and_then(|w| w.workspace_folders.as_ref())
148            .and_then(|wf| wf.supported)
149            .unwrap_or(false);
150
151        debug!(
152            "server capabilities received for {} (workspace_folders={})",
153            self.language, self.supports_workspace_folders
154        );
155
156        self.capabilities = Some(init_result.capabilities);
157
158        // Track the initial workspace folder
159        self.attached_folders.insert(project_root.to_path_buf());
160
161        // Send initialized notification (must be after storing capabilities)
162        self.transport
163            .send_notification("initialized", serde_json::to_value(InitializedParams {})?)
164            .await?;
165
166        debug!("initialized notification sent for {}", self.language);
167
168        self.capabilities
169            .as_ref()
170            .ok_or_else(|| anyhow::anyhow!("internal: capabilities missing after initialize"))
171    }
172
173    /// Shut down the LSP server cleanly.
174    ///
175    /// Sends `shutdown` request, waits for response, sends `exit` notification,
176    /// then waits for the process to exit.
177    ///
178    /// # Errors
179    /// Returns an error if shutdown fails or the process doesn't exit.
180    pub async fn shutdown(&mut self) -> anyhow::Result<()> {
181        let request_id = self.transport.send_request("shutdown", Value::Null).await?;
182
183        // Wait for shutdown response (with timeout)
184        let _ = self.wait_for_response(request_id, SHUTDOWN_TIMEOUT).await;
185
186        // Send exit notification
187        self.transport
188            .send_notification("exit", Value::Null)
189            .await
190            .ok();
191
192        // Give the process a moment to exit, then force kill
193        tokio::time::sleep(Duration::from_millis(100)).await;
194        if self.transport.is_alive() {
195            debug!("LSP server still alive after exit, killing");
196            self.transport.kill().await.ok();
197        }
198
199        debug!("LSP server for {} shut down", self.language);
200        Ok(())
201    }
202
203    /// Wait for a response to a previously sent request.
204    ///
205    /// Uses the default initialize timeout. For commands that need LSP responses
206    /// after the handshake is complete.
207    ///
208    /// # Errors
209    /// Returns an error if the response times out or contains an LSP error.
210    pub async fn wait_for_response_public(&mut self, request_id: i64) -> anyhow::Result<Value> {
211        self.wait_for_response(request_id, INITIALIZE_TIMEOUT).await
212    }
213
214    /// Wait for a response with a caller-specified timeout.
215    ///
216    /// Unlike wrapping `wait_for_response_public` in `tokio::time::timeout`,
217    /// this ensures the response is always consumed (no orphaned responses in the pipe).
218    ///
219    /// # Errors
220    /// Returns an error if the timeout expires before a response is received.
221    pub async fn wait_for_response_with_timeout(
222        &mut self,
223        request_id: i64,
224        timeout: Duration,
225    ) -> anyhow::Result<Value> {
226        self.wait_for_response(request_id, timeout).await
227    }
228
229    /// Wait until the server sends a `$/progress` notification with `"kind": "end"`.
230    ///
231    /// This replaces the fixed-delay polling heuristic: instead of sleeping 200ms × N,
232    /// we listen for the server's own signal that background indexing is complete.
233    /// If no progress end is received within `timeout`, we proceed anyway (graceful degradation).
234    ///
235    /// Any responses received while waiting are buffered for future retrieval.
236    pub async fn wait_for_progress_end(&mut self, timeout: Duration) {
237        let _ = tokio::time::timeout(timeout, async {
238            loop {
239                let message = match self.transport.read_message().await {
240                    Ok(m) => m,
241                    Err(e) => {
242                        debug!("wait_for_progress_end: transport error: {e}");
243                        return;
244                    }
245                };
246                match message {
247                    JsonRpcMessage::Notification { method, params } if method == "$/progress" => {
248                        let kind = params
249                            .as_ref()
250                            .and_then(|p| p.get("value"))
251                            .and_then(|v| v.get("kind"))
252                            .and_then(|k| k.as_str())
253                            .unwrap_or("");
254                        debug!("$/progress kind={kind}");
255                        if kind == "end" {
256                            return;
257                        }
258                    }
259                    JsonRpcMessage::Response { id, result, error } => {
260                        debug!("buffering response id={id} during progress wait");
261                        let buffered = if let Some(err) = error {
262                            BufferedResponse::Err(err.to_string())
263                        } else {
264                            BufferedResponse::Ok(result.unwrap_or(Value::Null))
265                        };
266                        self.buffered_responses.insert(id, buffered);
267                    }
268                    JsonRpcMessage::ServerRequest { id, method, .. } => {
269                        debug!("auto-responding to server request during progress wait: {method}");
270                        let response = serde_json::json!({
271                            "jsonrpc": "2.0",
272                            "id": id,
273                            "result": null,
274                        });
275                        let body = serde_json::to_string(&response).unwrap_or_default();
276                        let header = format!("Content-Length: {}\r\n\r\n", body.len());
277                        let _ = self.transport.write_raw(header.as_bytes()).await;
278                        let _ = self.transport.write_raw(body.as_bytes()).await;
279                        let _ = self.transport.flush().await;
280                    }
281                    JsonRpcMessage::Notification { method, .. } => {
282                        debug!("ignoring notification during progress wait: {method}");
283                    }
284                }
285            }
286        })
287        .await;
288        debug!("wait_for_progress_end: done (ready or timed out)");
289    }
290
291    /// Get the server capabilities (available after initialize).
292    #[must_use]
293    pub fn capabilities(&self) -> Option<&ServerCapabilities> {
294        self.capabilities.as_ref()
295    }
296
297    /// Get the language this client serves.
298    #[must_use]
299    pub fn language(&self) -> Language {
300        self.language
301    }
302
303    /// Whether the server supports `workspace/didChangeWorkspaceFolders`.
304    #[must_use]
305    pub fn supports_workspace_folders(&self) -> bool {
306        self.supports_workspace_folders
307    }
308
309    /// The server binary name (e.g., "vtsls", "rust-analyzer").
310    #[must_use]
311    pub fn server_name(&self) -> &str {
312        &self.server_name
313    }
314
315    /// Check if a workspace folder is currently attached to this server.
316    #[must_use]
317    pub fn is_folder_attached(&self, path: &Path) -> bool {
318        self.attached_folders.contains(path)
319    }
320
321    /// Get all attached workspace folders.
322    #[must_use]
323    pub fn attached_folders(&self) -> &HashSet<PathBuf> {
324        &self.attached_folders
325    }
326
327    /// Dynamically attach a workspace folder to the running server.
328    ///
329    /// Sends `workspace/didChangeWorkspaceFolders` notification.
330    /// No-op if already attached or server doesn't support it.
331    ///
332    /// # Errors
333    /// Returns an error if the notification cannot be sent.
334    pub async fn attach_folder(&mut self, path: &Path) -> anyhow::Result<()> {
335        if self.attached_folders.contains(path) {
336            return Ok(());
337        }
338
339        if !self.supports_workspace_folders {
340            debug!(
341                "server {} does not support workspace folders, skipping attach",
342                self.server_name
343            );
344            // Still track it so we don't re-attempt
345            self.attached_folders.insert(path.to_path_buf());
346            return Ok(());
347        }
348
349        let uri = path_to_uri(path)?;
350        let name = path
351            .file_name()
352            .and_then(|n| n.to_str())
353            .unwrap_or("workspace");
354
355        let params = serde_json::json!({
356            "event": {
357                "added": [{ "uri": uri.as_str(), "name": name }],
358                "removed": []
359            }
360        });
361
362        self.transport
363            .send_notification("workspace/didChangeWorkspaceFolders", params)
364            .await?;
365
366        self.attached_folders.insert(path.to_path_buf());
367        debug!(
368            "attached workspace folder: {} (total: {})",
369            path.display(),
370            self.attached_folders.len()
371        );
372        Ok(())
373    }
374
375    /// Dynamically detach a workspace folder from the running server.
376    ///
377    /// Sends `workspace/didChangeWorkspaceFolders` notification.
378    /// No-op if not attached.
379    ///
380    /// # Errors
381    /// Returns an error if the notification cannot be sent.
382    pub async fn detach_folder(&mut self, path: &Path) -> anyhow::Result<()> {
383        if !self.attached_folders.contains(path) {
384            return Ok(());
385        }
386
387        if self.supports_workspace_folders {
388            let uri = path_to_uri(path)?;
389            let name = path
390                .file_name()
391                .and_then(|n| n.to_str())
392                .unwrap_or("workspace");
393
394            let params = serde_json::json!({
395                "event": {
396                    "added": [],
397                    "removed": [{ "uri": uri.as_str(), "name": name }]
398                }
399            });
400
401            self.transport
402                .send_notification("workspace/didChangeWorkspaceFolders", params)
403                .await?;
404        }
405
406        self.attached_folders.remove(path);
407        debug!(
408            "detached workspace folder: {} (remaining: {})",
409            path.display(),
410            self.attached_folders.len()
411        );
412        Ok(())
413    }
414
415    /// Attach a diagnostic store so `textDocument/publishDiagnostics` notifications
416    /// are captured while waiting for responses.
417    pub fn set_diagnostic_store(&mut self, store: Arc<DiagnosticStore>) {
418        self.diagnostic_store = Some(store);
419    }
420
421    /// Get mutable access to the transport for sending additional requests.
422    pub fn transport_mut(&mut self) -> &mut LspTransport {
423        &mut self.transport
424    }
425
426    /// Wait for a response with a specific ID, buffering out-of-order responses
427    /// and auto-responding to server requests.
428    async fn wait_for_response(
429        &mut self,
430        expected_id: i64,
431        timeout: Duration,
432    ) -> anyhow::Result<Value> {
433        // Check if this response was already buffered from a previous read
434        if let Some(buffered) = self.buffered_responses.remove(&expected_id) {
435            return match buffered {
436                BufferedResponse::Ok(value) => Ok(value),
437                BufferedResponse::Err(msg) => bail!("LSP error: {msg}"),
438            };
439        }
440
441        let result = tokio::time::timeout(timeout, async {
442            loop {
443                let message = self.transport.read_message().await?;
444                match message {
445                    JsonRpcMessage::Response { id, result, error } if id == expected_id => {
446                        if let Some(err) = error {
447                            debug!("LSP error response for id={id}: {err}");
448                            bail!("LSP error: {err}");
449                        }
450                        debug!("received response for id={id}");
451                        return Ok(result.unwrap_or(Value::Null));
452                    }
453                    JsonRpcMessage::Response { id, result, error } => {
454                        // Buffer for later retrieval instead of discarding
455                        debug!("buffering out-of-order response id={id}");
456                        let buffered = if let Some(err) = error {
457                            BufferedResponse::Err(err.to_string())
458                        } else {
459                            BufferedResponse::Ok(result.unwrap_or(Value::Null))
460                        };
461                        self.buffered_responses.insert(id, buffered);
462                    }
463                    JsonRpcMessage::Notification { method, params } => {
464                        if method == "textDocument/publishDiagnostics" {
465                            if let Some(store) = &self.diagnostic_store {
466                                ingest_publish_diagnostics(params, store);
467                            }
468                        } else {
469                            debug!("ignoring notification during wait: {method}");
470                        }
471                    }
472                    JsonRpcMessage::ServerRequest { id, method, .. } => {
473                        debug!("auto-responding to server request: {method}");
474                        let response = serde_json::json!({
475                            "jsonrpc": "2.0",
476                            "id": id,
477                            "result": null,
478                        });
479                        let body = serde_json::to_string(&response)?;
480                        let header = format!("Content-Length: {}\r\n\r\n", body.len());
481                        self.transport.write_raw(header.as_bytes()).await?;
482                        self.transport.write_raw(body.as_bytes()).await?;
483                        self.transport.flush().await?;
484                    }
485                }
486            }
487        })
488        .await;
489
490        match result {
491            Ok(inner) => inner,
492            Err(_) => bail!("timed out waiting for response ({}s)", timeout.as_secs()),
493        }
494    }
495}
496
497/// Convert a filesystem path to an LSP `file://` URI.
498///
499/// # Errors
500/// Returns an error if the path is not absolute or not valid UTF-8.
501pub fn path_to_uri(path: &Path) -> anyhow::Result<Uri> {
502    let abs = if path.is_absolute() {
503        path.to_path_buf()
504    } else {
505        std::env::current_dir()?.join(path)
506    };
507    let path_str = abs.to_str().context("path is not valid UTF-8")?;
508    let uri_string = format!("file://{path_str}");
509    uri_string
510        .parse()
511        .map_err(|e| anyhow::anyhow!("invalid URI: {e}"))
512}
513
514/// Return language-specific `initializationOptions` for servers that need them.
515///
516/// - Java (jdtls): empty settings object; full jdtls setup requires bundles path
517///   but basic symbol queries work without it.
518/// - Lua: sets `Lua.runtime.version` so the server indexes standard Lua/LuaJIT globals.
519/// - Others: `None` (the server uses its own defaults).
520fn language_init_options(_lang: Language) -> Option<Value> {
521    None
522}
523
524/// Build the `InitializeParams` for the LSP handshake.
525#[allow(deprecated)] // root_uri is deprecated but needed for compatibility
526fn build_initialize_params(
527    root_uri: &Uri,
528    project_root: &Path,
529    lang: Language,
530) -> InitializeParams {
531    let project_name = project_root
532        .file_name()
533        .and_then(|n| n.to_str())
534        .unwrap_or("project");
535
536    InitializeParams {
537        process_id: Some(std::process::id()),
538        root_uri: Some(root_uri.clone()),
539        capabilities: ClientCapabilities {
540            text_document: Some(TextDocumentClientCapabilities {
541                synchronization: Some(TextDocumentSyncClientCapabilities {
542                    dynamic_registration: Some(false),
543                    did_save: Some(true),
544                    ..Default::default()
545                }),
546                definition: Some(GotoCapability {
547                    dynamic_registration: Some(false),
548                    link_support: Some(false),
549                }),
550                references: Some(DynamicRegistrationClientCapabilities {
551                    dynamic_registration: Some(false),
552                }),
553                document_symbol: Some(DocumentSymbolClientCapabilities {
554                    dynamic_registration: Some(false),
555                    hierarchical_document_symbol_support: Some(true),
556                    ..Default::default()
557                }),
558                rename: Some(RenameClientCapabilities {
559                    dynamic_registration: Some(false),
560                    prepare_support: Some(true),
561                    ..Default::default()
562                }),
563                hover: Some(HoverClientCapabilities {
564                    dynamic_registration: Some(false),
565                    content_format: None,
566                }),
567                publish_diagnostics: Some(PublishDiagnosticsClientCapabilities {
568                    related_information: Some(true),
569                    ..Default::default()
570                }),
571                code_action: Some(CodeActionClientCapabilities {
572                    dynamic_registration: Some(false),
573                    ..Default::default()
574                }),
575                formatting: Some(DynamicRegistrationClientCapabilities {
576                    dynamic_registration: Some(false),
577                }),
578                ..Default::default()
579            }),
580            workspace: Some(WorkspaceClientCapabilities {
581                symbol: Some(WorkspaceSymbolClientCapabilities {
582                    dynamic_registration: Some(false),
583                    ..Default::default()
584                }),
585                workspace_folders: Some(true),
586                ..Default::default()
587            }),
588            window: Some(WindowClientCapabilities {
589                work_done_progress: Some(true),
590                ..Default::default()
591            }),
592            ..Default::default()
593        },
594        workspace_folders: Some(vec![WorkspaceFolder {
595            uri: root_uri.clone(),
596            name: project_name.to_string(),
597        }]),
598        initialization_options: language_init_options(lang),
599        ..Default::default()
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn path_to_uri_absolute() {
609        let uri = path_to_uri(Path::new("/tmp/test-project")).unwrap();
610        assert_eq!(uri.as_str(), "file:///tmp/test-project");
611    }
612
613    #[test]
614    fn build_params_has_required_fields() {
615        let root = Path::new("/tmp/test-project");
616        let uri = path_to_uri(root).unwrap();
617
618        #[allow(deprecated)]
619        let params = build_initialize_params(&uri, root, Language::Rust);
620
621        assert!(params.process_id.is_some());
622        assert!(params.capabilities.text_document.is_some());
623        assert!(params.capabilities.workspace.is_some());
624
625        let folders = params.workspace_folders.unwrap();
626        assert_eq!(folders.len(), 1);
627        assert_eq!(folders[0].name, "test-project");
628        assert_eq!(folders[0].uri.as_str(), "file:///tmp/test-project");
629    }
630
631    #[test]
632    fn start_missing_server_returns_not_found() {
633        let result = LspClient::start(Language::Go, Path::new("/tmp/nonexistent"));
634        // gopls may or may not be installed
635        if let Err(LspError::ServerNotFound { language, advice }) = result {
636            assert_eq!(language, Language::Go);
637            assert!(!advice.is_empty());
638        }
639        // If gopls is installed or another error occurs, that's also acceptable
640    }
641
642    #[test]
643    fn build_params_declares_workspace_folder_support() {
644        let root = Path::new("/tmp/test-project");
645        let uri = path_to_uri(root).unwrap();
646
647        #[allow(deprecated)]
648        let params = build_initialize_params(&uri, root, Language::TypeScript);
649
650        let ws = params.capabilities.workspace.unwrap();
651        assert_eq!(ws.workspace_folders, Some(true));
652    }
653
654    #[test]
655    fn attached_folders_tracking() {
656        // We can't easily create an LspClient without a real process,
657        // but we can test the HashSet logic conceptually.
658        let mut folders = HashSet::new();
659        let p1 = PathBuf::from("/project/packages/api");
660        let p2 = PathBuf::from("/project/packages/web");
661
662        assert!(!folders.contains(&p1));
663        folders.insert(p1.clone());
664        assert!(folders.contains(&p1));
665        assert!(!folders.contains(&p2));
666
667        // Duplicate insert is a no-op
668        folders.insert(p1.clone());
669        assert_eq!(folders.len(), 1);
670
671        folders.insert(p2.clone());
672        assert_eq!(folders.len(), 2);
673
674        folders.remove(&p1);
675        assert_eq!(folders.len(), 1);
676        assert!(!folders.contains(&p1));
677    }
678
679    // Integration tests requiring real LSP servers
680    #[tokio::test]
681    #[ignore = "requires rust-analyzer installed"]
682    async fn initialize_rust_analyzer() {
683        let fixture = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/rust-hello");
684        let mut client =
685            LspClient::start(Language::Rust, &fixture).expect("rust-analyzer should be available");
686
687        let caps = client
688            .initialize(&fixture)
689            .await
690            .expect("init should succeed");
691        assert!(caps.document_symbol_provider.is_some());
692
693        client.shutdown().await.expect("shutdown should succeed");
694    }
695
696    #[tokio::test]
697    #[ignore = "requires rust-analyzer installed"]
698    async fn shutdown_kills_process() {
699        let fixture = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures/rust-hello");
700        let mut client =
701            LspClient::start(Language::Rust, &fixture).expect("rust-analyzer should be available");
702
703        client
704            .initialize(&fixture)
705            .await
706            .expect("init should succeed");
707        client.shutdown().await.expect("shutdown should succeed");
708
709        assert!(!client.transport.is_alive());
710    }
711}