1use std::collections::{HashMap, HashSet};
2use std::io::{self, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::child_registry::LspChildRegistry;
16use crate::lsp::jsonrpc::{
17 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
18};
19use crate::lsp::position::path_to_uri;
20use crate::lsp::registry::ServerKind;
21use crate::lsp::{transport, LspError};
22
23const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
24const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
25const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
26
27type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
28type WatchedFileRegistrations = Arc<Mutex<HashSet<String>>>;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ServerState {
33 Starting,
34 Initializing,
35 Ready,
36 ShuttingDown,
37 Exited,
38}
39
40#[derive(Debug)]
42pub enum LspEvent {
43 Notification {
45 server_kind: ServerKind,
46 root: PathBuf,
47 method: String,
48 params: Option<Value>,
49 },
50 ServerRequest {
52 server_kind: ServerKind,
53 root: PathBuf,
54 id: RequestId,
55 method: String,
56 params: Option<Value>,
57 },
58 ServerExited {
60 server_kind: ServerKind,
61 root: PathBuf,
62 },
63}
64
65#[derive(Debug, Clone, Default)]
73pub struct ServerDiagnosticCapabilities {
74 pub pull_diagnostics: bool,
76 pub workspace_diagnostics: bool,
78 pub identifier: Option<String>,
81 pub refresh_support: bool,
85}
86
87pub struct LspClient {
89 kind: ServerKind,
90 root: PathBuf,
91 state: ServerState,
92 child: Child,
93 child_pid: u32,
97 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
98
99 pending: Arc<Mutex<PendingMap>>,
101 next_id: AtomicI64,
103 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
107 supports_watched_files: bool,
112 watched_file_registrations: WatchedFileRegistrations,
117 child_registry: LspChildRegistry,
121}
122
123impl LspClient {
124 pub fn spawn(
130 kind: ServerKind,
131 root: PathBuf,
132 binary: &Path,
133 args: &[String],
134 env: &HashMap<String, String>,
135 event_tx: Sender<LspEvent>,
136 child_registry: LspChildRegistry,
137 ) -> io::Result<Self> {
138 let mut command = Command::new(binary);
139 command
140 .args(args)
141 .current_dir(&root)
142 .stdin(Stdio::piped())
143 .stdout(Stdio::piped())
144 .stderr(Stdio::null());
147 for (key, value) in env {
148 command.env(key, value);
149 }
150
151 #[cfg(unix)]
157 unsafe {
158 use std::os::unix::process::CommandExt;
159 command.pre_exec(|| {
160 #[cfg(target_os = "linux")]
161 {
162 if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL) == -1 {
168 return Err(io::Error::last_os_error());
169 }
170 if libc::getppid() == 1 {
171 return Err(io::Error::other("parent died before LSP spawn completed"));
172 }
173 }
174 if libc::setsid() == -1 {
175 return Err(io::Error::last_os_error());
176 }
177 Ok(())
178 });
179 }
180
181 let mut child = child_registry.spawn_tracked(&mut command)?;
182 let child_pid = child.id();
183
184 let stdout = child
185 .stdout
186 .take()
187 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
188 let stdin = child
189 .stdin
190 .take()
191 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
192
193 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
194 let pending = Arc::new(Mutex::new(PendingMap::new()));
195 let watched_file_registrations = Arc::new(Mutex::new(HashSet::new()));
196 let reader_pending = Arc::clone(&pending);
197 let reader_writer = Arc::clone(&writer);
198 let reader_watched_file_registrations = Arc::clone(&watched_file_registrations);
199 let reader_kind = kind.clone();
200 let reader_root = root.clone();
201
202 thread::spawn(move || {
203 let mut reader = BufReader::new(stdout);
204 loop {
205 match transport::read_message(&mut reader) {
206 Ok(Some(ServerMessage::Response(response))) => {
207 if let Ok(mut guard) = reader_pending.lock() {
208 if let Some(tx) = guard.remove(&response.id) {
209 if tx.send(response).is_err() {
210 log::debug!("response channel closed");
211 }
212 }
213 } else {
214 let _ = event_tx.send(LspEvent::ServerExited {
215 server_kind: reader_kind.clone(),
216 root: reader_root.clone(),
217 });
218 break;
219 }
220 }
221 Ok(Some(ServerMessage::Notification { method, params })) => {
222 let _ = event_tx.send(LspEvent::Notification {
223 server_kind: reader_kind.clone(),
224 root: reader_root.clone(),
225 method,
226 params,
227 });
228 }
229 Ok(Some(ServerMessage::Request { id, method, params })) => {
230 record_watched_file_registration(
231 &reader_watched_file_registrations,
232 &method,
233 params.as_ref(),
234 );
235 let response_value = if method == "workspace/configuration" {
245 let item_count = params
248 .as_ref()
249 .and_then(|p| p.get("items"))
250 .and_then(|items| items.as_array())
251 .map_or(1, |arr| arr.len());
252 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
253 } else {
254 serde_json::Value::Null
255 };
256 if let Ok(mut w) = reader_writer.lock() {
257 let response = super::jsonrpc::OutgoingResponse::success(
258 id.clone(),
259 response_value,
260 );
261 let _ = transport::write_response(&mut *w, &response);
262 }
263 let _ = event_tx.send(LspEvent::ServerRequest {
265 server_kind: reader_kind.clone(),
266 root: reader_root.clone(),
267 id,
268 method,
269 params,
270 });
271 }
272 Ok(None) | Err(_) => {
273 if let Ok(mut guard) = reader_pending.lock() {
274 guard.clear();
275 }
276 let _ = event_tx.send(LspEvent::ServerExited {
277 server_kind: reader_kind.clone(),
278 root: reader_root.clone(),
279 });
280 break;
281 }
282 }
283 }
284 });
285
286 Ok(Self {
287 kind,
288 root,
289 state: ServerState::Starting,
290 child,
291 child_pid,
292 writer,
293 pending,
294 next_id: AtomicI64::new(1),
295 diagnostic_caps: None,
296 supports_watched_files: false,
297 watched_file_registrations,
298 child_registry,
299 })
300 }
301
302 pub fn initialize(
304 &mut self,
305 workspace_root: &Path,
306 initialization_options: Option<serde_json::Value>,
307 ) -> Result<lsp_types::InitializeResult, LspError> {
308 self.ensure_can_send()?;
309 self.state = ServerState::Initializing;
310
311 let root_url = path_to_uri(workspace_root)?;
312 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
313 LspError::NotFound(format!(
314 "failed to convert workspace root '{}' to file URI",
315 workspace_root.display()
316 ))
317 })?;
318
319 let mut params_value = json!({
320 "processId": std::process::id(),
321 "rootUri": root_uri,
322 "capabilities": {
323 "workspace": {
324 "workspaceFolders": true,
325 "configuration": true,
326 "didChangeWatchedFiles": {
327 "dynamicRegistration": true
328 },
329 "diagnostic": {
334 "refreshSupport": false
335 }
336 },
337 "textDocument": {
338 "synchronization": {
339 "dynamicRegistration": false,
340 "didSave": true,
341 "willSave": false,
342 "willSaveWaitUntil": false
343 },
344 "publishDiagnostics": {
345 "relatedInformation": true,
346 "versionSupport": true,
347 "codeDescriptionSupport": true,
348 "dataSupport": true
349 },
350 "diagnostic": {
355 "dynamicRegistration": false,
356 "relatedDocumentSupport": true
357 }
358 }
359 },
360 "clientInfo": {
361 "name": "aft",
362 "version": env!("CARGO_PKG_VERSION")
363 },
364 "workspaceFolders": [
365 {
366 "uri": root_uri,
367 "name": workspace_root
368 .file_name()
369 .and_then(|name| name.to_str())
370 .unwrap_or("workspace")
371 }
372 ]
373 });
374 if let Some(initialization_options) = initialization_options {
375 params_value["initializationOptions"] = initialization_options;
376 }
377
378 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
379
380 let result_value = self.send_request_value(
381 <lsp_types::request::Initialize as lsp_types::request::Request>::METHOD,
382 params,
383 )?;
384 let result: lsp_types::InitializeResult = serde_json::from_value(result_value.clone())?;
385
386 let caps_value = result_value
391 .get("capabilities")
392 .cloned()
393 .unwrap_or_else(|| serde_json::to_value(&result.capabilities).unwrap_or(Value::Null));
394 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
395
396 self.supports_watched_files = caps_value
400 .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
401 .and_then(|v| v.as_bool())
402 .unwrap_or(false)
403 || caps_value
404 .pointer("/workspace/didChangeWatchedFiles")
405 .map(|v| v.is_object() || v.as_bool() == Some(true))
406 .unwrap_or(false);
407
408 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
409 json!({}),
410 )?)?;
411 self.state = ServerState::Ready;
412 Ok(result)
413 }
414
415 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
419 self.diagnostic_caps.as_ref()
420 }
421
422 pub fn supports_watched_files(&self) -> bool {
425 self.supports_watched_files
426 }
427
428 pub fn has_watched_file_registration(&self) -> bool {
432 self.watched_file_registrations
433 .lock()
434 .map(|registrations| !registrations.is_empty())
435 .unwrap_or(false)
436 }
437
438 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
440 where
441 R: lsp_types::request::Request,
442 R::Params: serde::Serialize,
443 R::Result: DeserializeOwned,
444 {
445 self.ensure_can_send()?;
446
447 let value = self.send_request_value(R::METHOD, params)?;
448 serde_json::from_value(value).map_err(Into::into)
449 }
450
451 pub fn send_request_with_timeout<R>(
455 &mut self,
456 params: R::Params,
457 timeout: Duration,
458 ) -> Result<R::Result, LspError>
459 where
460 R: lsp_types::request::Request,
461 R::Params: serde::Serialize,
462 R::Result: DeserializeOwned,
463 {
464 self.ensure_can_send()?;
465
466 let value = self.send_request_value_with_timeout(R::METHOD, params, timeout)?;
467 serde_json::from_value(value).map_err(Into::into)
468 }
469
470 fn send_request_value<P>(&mut self, method: &'static str, params: P) -> Result<Value, LspError>
471 where
472 P: serde::Serialize,
473 {
474 self.send_request_value_with_timeout(method, params, REQUEST_TIMEOUT)
475 }
476
477 fn send_request_value_with_timeout<P>(
478 &mut self,
479 method: &'static str,
480 params: P,
481 timeout: Duration,
482 ) -> Result<Value, LspError>
483 where
484 P: serde::Serialize,
485 {
486 self.ensure_can_send()?;
487
488 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
489 let (tx, rx) = bounded(1);
490 {
491 let mut pending = self.lock_pending()?;
492 pending.insert(id.clone(), tx);
493 }
494
495 let request = Request::new(id.clone(), method, Some(serde_json::to_value(params)?));
496 {
497 let mut writer = self
498 .writer
499 .lock()
500 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
501 if let Err(err) = transport::write_request(&mut *writer, &request) {
502 self.remove_pending(&id);
503 return Err(err.into());
504 }
505 }
506
507 let response = match rx.recv_timeout(timeout) {
508 Ok(response) => response,
509 Err(RecvTimeoutError::Timeout) => {
510 self.remove_pending(&id);
511 self.send_cancel_request(&id)?;
512 return Err(LspError::Timeout(format!(
513 "timed out waiting for '{}' response from {:?}",
514 method, self.kind
515 )));
516 }
517 Err(RecvTimeoutError::Disconnected) => {
518 self.remove_pending(&id);
519 return Err(LspError::ServerNotReady(format!(
520 "language server {:?} disconnected while waiting for '{}'",
521 self.kind, method
522 )));
523 }
524 };
525
526 if let Some(error) = response.error {
527 return Err(LspError::ServerError {
528 code: error.code,
529 message: error.message,
530 });
531 }
532
533 Ok(response.result.unwrap_or(Value::Null))
534 }
535
536 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
538 where
539 N: lsp_types::notification::Notification,
540 N::Params: serde::Serialize,
541 {
542 self.ensure_can_send()?;
543 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
544 let mut writer = self
545 .writer
546 .lock()
547 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
548 transport::write_notification(&mut *writer, ¬ification)?;
549 Ok(())
550 }
551
552 pub fn shutdown(&mut self) -> Result<(), LspError> {
554 if self.state == ServerState::Exited {
555 self.child_registry.untrack(self.child_pid);
556 return Ok(());
557 }
558
559 if self.child.try_wait()?.is_some() {
560 self.state = ServerState::Exited;
561 self.child_registry.untrack(self.child_pid);
562 return Ok(());
563 }
564
565 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
566 self.state = ServerState::ShuttingDown;
567 if self.child.try_wait()?.is_some() {
568 self.state = ServerState::Exited;
569 return Ok(());
570 }
571 return Err(err);
572 }
573
574 self.state = ServerState::ShuttingDown;
575
576 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
577 if self.child.try_wait()?.is_some() {
578 self.state = ServerState::Exited;
579 return Ok(());
580 }
581 return Err(err);
582 }
583
584 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
585 loop {
586 if self.child.try_wait()?.is_some() {
587 self.state = ServerState::Exited;
588 return Ok(());
589 }
590 if Instant::now() >= deadline {
591 kill_lsp_child_group(&mut self.child);
595 self.state = ServerState::Exited;
596 return Err(LspError::Timeout(format!(
597 "timed out waiting for {:?} to exit",
598 self.kind
599 )));
600 }
601 thread::sleep(EXIT_POLL_INTERVAL);
602 }
603 }
604
605 pub fn state(&self) -> ServerState {
606 self.state
607 }
608
609 pub fn kind(&self) -> ServerKind {
610 self.kind.clone()
611 }
612
613 pub fn root(&self) -> &Path {
614 &self.root
615 }
616
617 fn ensure_can_send(&self) -> Result<(), LspError> {
618 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
619 return Err(LspError::ServerNotReady(format!(
620 "language server {:?} is not ready (state: {:?})",
621 self.kind, self.state
622 )));
623 }
624 Ok(())
625 }
626
627 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
628 self.pending
629 .lock()
630 .map_err(|_| io::Error::other("pending response map poisoned").into())
631 }
632
633 fn remove_pending(&self, id: &RequestId) {
634 if let Ok(mut pending) = self.pending.lock() {
635 pending.remove(id);
636 }
637 }
638
639 fn send_cancel_request(&mut self, id: &RequestId) -> Result<(), LspError> {
640 let notification = Notification::new("$/cancelRequest", Some(json!({ "id": id })));
641 let mut writer = self
642 .writer
643 .lock()
644 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
645 transport::write_notification(&mut *writer, ¬ification)?;
646 Ok(())
647 }
648}
649
650impl Drop for LspClient {
651 fn drop(&mut self) {
652 self.child_registry.untrack(self.child_pid);
655 kill_lsp_child_group(&mut self.child);
656 }
657}
658
659fn kill_lsp_child_group(child: &mut std::process::Child) {
666 #[cfg(unix)]
667 {
668 let pgid = child.id() as i32;
669 crate::bash_background::process::terminate_pgid(pgid, Some(child));
670 let _ = child.wait();
671 }
672 #[cfg(not(unix))]
673 {
674 crate::bash_background::process::terminate_process(child);
675 let _ = child.wait();
676 }
677}
678
679fn record_watched_file_registration(
680 registrations: &WatchedFileRegistrations,
681 method: &str,
682 params: Option<&Value>,
683) {
684 match method {
685 "client/registerCapability" => {
686 let Some(items) = params
687 .and_then(|params| params.get("registrations"))
688 .and_then(|registrations| registrations.as_array())
689 else {
690 return;
691 };
692 if let Ok(mut guard) = registrations.lock() {
693 for item in items {
694 if item.get("method").and_then(Value::as_str)
695 == Some("workspace/didChangeWatchedFiles")
696 {
697 if let Some(id) = item.get("id").and_then(Value::as_str) {
698 guard.insert(id.to_string());
699 }
700 }
701 }
702 }
703 }
704 "client/unregisterCapability" => {
705 let Some(items) = params
706 .and_then(|params| params.get("unregisterations"))
707 .and_then(|registrations| registrations.as_array())
708 else {
709 return;
710 };
711 if let Ok(mut guard) = registrations.lock() {
712 for item in items {
713 if item.get("method").and_then(Value::as_str)
714 == Some("workspace/didChangeWatchedFiles")
715 {
716 if let Some(id) = item.get("id").and_then(Value::as_str) {
717 guard.remove(id);
718 }
719 }
720 }
721 }
722 }
723 _ => {}
724 }
725}
726
727fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
742 let mut caps = ServerDiagnosticCapabilities::default();
743
744 if let Some(provider) = value.get("diagnosticProvider") {
745 if provider.is_object() || provider.as_bool() == Some(true) {
748 caps.pull_diagnostics = true;
749 }
750
751 if let Some(obj) = provider.as_object() {
752 if obj
753 .get("workspaceDiagnostics")
754 .and_then(|v| v.as_bool())
755 .unwrap_or(false)
756 {
757 caps.workspace_diagnostics = true;
758 }
759 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
760 caps.identifier = Some(identifier.to_string());
761 }
762 }
763 }
764
765 if let Some(refresh) = value
768 .get("workspace")
769 .and_then(|w| w.get("diagnostic"))
770 .and_then(|d| d.get("refreshSupport"))
771 .and_then(|r| r.as_bool())
772 {
773 caps.refresh_support = refresh;
774 }
775
776 caps
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782
783 #[test]
784 fn parse_caps_no_diagnostic_provider() {
785 let value = json!({});
786 let caps = parse_diagnostic_capabilities(&value);
787 assert!(!caps.pull_diagnostics);
788 assert!(!caps.workspace_diagnostics);
789 assert!(caps.identifier.is_none());
790 }
791
792 #[test]
793 fn parse_caps_basic_pull_only() {
794 let value = json!({
795 "diagnosticProvider": {
796 "interFileDependencies": false,
797 "workspaceDiagnostics": false
798 }
799 });
800 let caps = parse_diagnostic_capabilities(&value);
801 assert!(caps.pull_diagnostics);
802 assert!(!caps.workspace_diagnostics);
803 }
804
805 #[test]
806 fn parse_caps_full_pull_with_workspace() {
807 let value = json!({
808 "diagnosticProvider": {
809 "interFileDependencies": true,
810 "workspaceDiagnostics": true,
811 "identifier": "rust-analyzer"
812 }
813 });
814 let caps = parse_diagnostic_capabilities(&value);
815 assert!(caps.pull_diagnostics);
816 assert!(caps.workspace_diagnostics);
817 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
818 }
819
820 #[test]
821 fn parse_caps_provider_as_bare_true() {
822 let value = json!({
824 "diagnosticProvider": true
825 });
826 let caps = parse_diagnostic_capabilities(&value);
827 assert!(caps.pull_diagnostics);
828 assert!(!caps.workspace_diagnostics);
829 }
830
831 #[test]
832 fn parse_caps_workspace_refresh_support() {
833 let value = json!({
834 "workspace": {
835 "diagnostic": {
836 "refreshSupport": true
837 }
838 }
839 });
840 let caps = parse_diagnostic_capabilities(&value);
841 assert!(caps.refresh_support);
842 assert!(!caps.pull_diagnostics);
844 }
845}