1use std::collections::{HashMap, HashSet};
2use std::io::{self, BufReader, BufWriter, Read};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::child_registry::LspChildRegistry;
16use crate::lsp::jsonrpc::{
17 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
18};
19use crate::lsp::position::path_to_uri;
20use crate::lsp::registry::ServerKind;
21use crate::lsp::{transport, LspError};
22
23const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
24const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
25const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
26const STDERR_TAIL_BYTES: usize = 16 * 1024;
27const STDERR_TAIL_LINES: usize = 64;
28
29type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
30type WatchedFileRegistrations = Arc<Mutex<HashSet<String>>>;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum ServerState {
35 Starting,
36 Initializing,
37 Ready,
38 ShuttingDown,
39 Exited,
40}
41
42#[derive(Debug)]
44pub enum LspEvent {
45 Notification {
47 server_kind: ServerKind,
48 root: PathBuf,
49 method: String,
50 params: Option<Value>,
51 },
52 ServerRequest {
54 server_kind: ServerKind,
55 root: PathBuf,
56 id: RequestId,
57 method: String,
58 params: Option<Value>,
59 },
60 ServerExited {
62 server_kind: ServerKind,
63 root: PathBuf,
64 },
65}
66
67#[derive(Debug, Clone, Default)]
75pub struct ServerDiagnosticCapabilities {
76 pub pull_diagnostics: bool,
78 pub workspace_diagnostics: bool,
80 pub identifier: Option<String>,
83 pub refresh_support: bool,
87}
88
89pub struct LspClient {
91 kind: ServerKind,
92 root: PathBuf,
93 state: ServerState,
94 child: Child,
95 child_pid: u32,
99 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
100
101 pending: Arc<Mutex<PendingMap>>,
103 next_id: AtomicI64,
105 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
109 supports_watched_files: bool,
114 watched_file_registrations: WatchedFileRegistrations,
119 child_registry: LspChildRegistry,
123 stderr_tail: Arc<Mutex<String>>,
124}
125
126impl LspClient {
127 pub fn spawn(
133 kind: ServerKind,
134 root: PathBuf,
135 binary: &Path,
136 args: &[String],
137 env: &HashMap<String, String>,
138 event_tx: Sender<LspEvent>,
139 child_registry: LspChildRegistry,
140 ) -> io::Result<Self> {
141 let mut command = Command::new(binary);
142 command
143 .args(args)
144 .current_dir(&root)
145 .stdin(Stdio::piped())
146 .stdout(Stdio::piped())
147 .stderr(Stdio::piped());
150 for (key, value) in env {
151 command.env(key, value);
152 }
153
154 #[cfg(unix)]
160 unsafe {
161 use std::os::unix::process::CommandExt;
162 command.pre_exec(|| {
163 #[cfg(target_os = "linux")]
164 {
165 if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL) == -1 {
171 return Err(io::Error::last_os_error());
172 }
173 if libc::getppid() == 1 {
174 return Err(io::Error::other("parent died before LSP spawn completed"));
175 }
176 }
177 if libc::setsid() == -1 {
178 return Err(io::Error::last_os_error());
179 }
180 Ok(())
181 });
182 }
183
184 let mut child = child_registry.spawn_tracked(&mut command)?;
185 let child_pid = child.id();
186
187 let stdout = child
188 .stdout
189 .take()
190 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
191 let stdin = child
192 .stdin
193 .take()
194 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
195 let stderr = child
196 .stderr
197 .take()
198 .ok_or_else(|| io::Error::other("language server missing stderr pipe"))?;
199 let stderr_tail = Arc::new(Mutex::new(String::new()));
200 spawn_stderr_drain_thread(stderr, Arc::clone(&stderr_tail));
201
202 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
203 let pending = Arc::new(Mutex::new(PendingMap::new()));
204 let watched_file_registrations = Arc::new(Mutex::new(HashSet::new()));
205 let reader_pending = Arc::clone(&pending);
206 let reader_writer = Arc::clone(&writer);
207 let reader_watched_file_registrations = Arc::clone(&watched_file_registrations);
208 let reader_kind = kind.clone();
209 let reader_root = root.clone();
210
211 thread::spawn(move || {
212 let mut reader = BufReader::new(stdout);
213 loop {
214 match transport::read_message(&mut reader) {
215 Ok(Some(ServerMessage::Response(response))) => {
216 if let Ok(mut guard) = reader_pending.lock() {
217 if let Some(tx) = guard.remove(&response.id) {
218 if tx.send(response).is_err() {
219 log::debug!("response channel closed");
220 }
221 }
222 } else {
223 let _ = event_tx.send(LspEvent::ServerExited {
224 server_kind: reader_kind.clone(),
225 root: reader_root.clone(),
226 });
227 break;
228 }
229 }
230 Ok(Some(ServerMessage::Notification { method, params })) => {
231 let _ = event_tx.send(LspEvent::Notification {
232 server_kind: reader_kind.clone(),
233 root: reader_root.clone(),
234 method,
235 params,
236 });
237 }
238 Ok(Some(ServerMessage::Request { id, method, params })) => {
239 record_watched_file_registration(
240 &reader_watched_file_registrations,
241 &method,
242 params.as_ref(),
243 );
244 let response_value = if method == "workspace/configuration" {
254 let item_count = params
257 .as_ref()
258 .and_then(|p| p.get("items"))
259 .and_then(|items| items.as_array())
260 .map_or(1, |arr| arr.len());
261 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
262 } else {
263 serde_json::Value::Null
264 };
265 if let Ok(mut w) = reader_writer.lock() {
266 let response = super::jsonrpc::OutgoingResponse::success(
267 id.clone(),
268 response_value,
269 );
270 let _ = transport::write_response(&mut *w, &response);
271 }
272 let _ = event_tx.send(LspEvent::ServerRequest {
274 server_kind: reader_kind.clone(),
275 root: reader_root.clone(),
276 id,
277 method,
278 params,
279 });
280 }
281 Ok(None) | Err(_) => {
282 if let Ok(mut guard) = reader_pending.lock() {
283 guard.clear();
284 }
285 let _ = event_tx.send(LspEvent::ServerExited {
286 server_kind: reader_kind.clone(),
287 root: reader_root.clone(),
288 });
289 break;
290 }
291 }
292 }
293 });
294
295 Ok(Self {
296 kind,
297 root,
298 state: ServerState::Starting,
299 child,
300 child_pid,
301 writer,
302 pending,
303 next_id: AtomicI64::new(1),
304 diagnostic_caps: None,
305 supports_watched_files: false,
306 watched_file_registrations,
307 child_registry,
308 stderr_tail,
309 })
310 }
311
312 pub fn initialize(
314 &mut self,
315 workspace_root: &Path,
316 initialization_options: Option<serde_json::Value>,
317 ) -> Result<lsp_types::InitializeResult, LspError> {
318 self.ensure_can_send()?;
319 self.state = ServerState::Initializing;
320
321 let root_url = path_to_uri(workspace_root)?;
322 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
323 LspError::NotFound(format!(
324 "failed to convert workspace root '{}' to file URI",
325 workspace_root.display()
326 ))
327 })?;
328
329 let mut params_value = json!({
330 "processId": std::process::id(),
331 "rootUri": root_uri,
332 "capabilities": {
333 "workspace": {
334 "workspaceFolders": true,
335 "configuration": true,
336 "didChangeWatchedFiles": {
337 "dynamicRegistration": true
338 },
339 "diagnostic": {
344 "refreshSupport": false
345 }
346 },
347 "textDocument": {
348 "synchronization": {
349 "dynamicRegistration": false,
350 "didSave": true,
351 "willSave": false,
352 "willSaveWaitUntil": false
353 },
354 "publishDiagnostics": {
355 "relatedInformation": true,
356 "versionSupport": true,
357 "codeDescriptionSupport": true,
358 "dataSupport": true
359 },
360 "diagnostic": {
365 "dynamicRegistration": false,
366 "relatedDocumentSupport": true
367 }
368 }
369 },
370 "clientInfo": {
371 "name": "aft",
372 "version": env!("CARGO_PKG_VERSION")
373 },
374 "workspaceFolders": [
375 {
376 "uri": root_uri,
377 "name": workspace_root
378 .file_name()
379 .and_then(|name| name.to_str())
380 .unwrap_or("workspace")
381 }
382 ]
383 });
384 if let Some(initialization_options) = initialization_options {
385 params_value["initializationOptions"] = initialization_options;
386 }
387
388 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
389
390 let result_value = self.send_request_value(
391 <lsp_types::request::Initialize as lsp_types::request::Request>::METHOD,
392 params,
393 )?;
394 let result: lsp_types::InitializeResult = serde_json::from_value(result_value.clone())?;
395
396 let caps_value = result_value
401 .get("capabilities")
402 .cloned()
403 .unwrap_or_else(|| serde_json::to_value(&result.capabilities).unwrap_or(Value::Null));
404 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
405
406 self.supports_watched_files = caps_value
410 .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
411 .and_then(|v| v.as_bool())
412 .unwrap_or(false)
413 || caps_value
414 .pointer("/workspace/didChangeWatchedFiles")
415 .map(|v| v.is_object() || v.as_bool() == Some(true))
416 .unwrap_or(false);
417
418 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
419 json!({}),
420 )?)?;
421 self.state = ServerState::Ready;
422 Ok(result)
423 }
424
425 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
429 self.diagnostic_caps.as_ref()
430 }
431
432 pub fn supports_watched_files(&self) -> bool {
435 self.supports_watched_files
436 }
437
438 pub fn has_watched_file_registration(&self) -> bool {
442 self.watched_file_registrations
443 .lock()
444 .map(|registrations| !registrations.is_empty())
445 .unwrap_or(false)
446 }
447
448 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
450 where
451 R: lsp_types::request::Request,
452 R::Params: serde::Serialize,
453 R::Result: DeserializeOwned,
454 {
455 self.ensure_can_send()?;
456
457 let value = self.send_request_value(R::METHOD, params)?;
458 serde_json::from_value(value).map_err(Into::into)
459 }
460
461 pub fn send_request_with_timeout<R>(
465 &mut self,
466 params: R::Params,
467 timeout: Duration,
468 ) -> Result<R::Result, LspError>
469 where
470 R: lsp_types::request::Request,
471 R::Params: serde::Serialize,
472 R::Result: DeserializeOwned,
473 {
474 self.ensure_can_send()?;
475
476 let value = self.send_request_value_with_timeout(R::METHOD, params, timeout)?;
477 serde_json::from_value(value).map_err(Into::into)
478 }
479
480 fn send_request_value<P>(&mut self, method: &'static str, params: P) -> Result<Value, LspError>
481 where
482 P: serde::Serialize,
483 {
484 self.send_request_value_with_timeout(method, params, REQUEST_TIMEOUT)
485 }
486
487 fn send_request_value_with_timeout<P>(
488 &mut self,
489 method: &'static str,
490 params: P,
491 timeout: Duration,
492 ) -> Result<Value, LspError>
493 where
494 P: serde::Serialize,
495 {
496 self.ensure_can_send()?;
497
498 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
499 let (tx, rx) = bounded(1);
500 {
501 let mut pending = self.lock_pending()?;
502 pending.insert(id.clone(), tx);
503 }
504
505 let request = Request::new(id.clone(), method, Some(serde_json::to_value(params)?));
506 {
507 let mut writer = self
508 .writer
509 .lock()
510 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
511 if let Err(err) = transport::write_request(&mut *writer, &request) {
512 self.remove_pending(&id);
513 return Err(err.into());
514 }
515 }
516
517 let response = match rx.recv_timeout(timeout) {
518 Ok(response) => response,
519 Err(RecvTimeoutError::Timeout) => {
520 self.remove_pending(&id);
521 self.send_cancel_request(&id)?;
522 return Err(LspError::Timeout(format!(
523 "timed out waiting for '{}' response from {:?}",
524 method, self.kind
525 )));
526 }
527 Err(RecvTimeoutError::Disconnected) => {
528 self.remove_pending(&id);
529 return Err(LspError::ServerNotReady(format!(
530 "language server {:?} disconnected while waiting for '{}'",
531 self.kind, method
532 )));
533 }
534 };
535
536 if let Some(error) = response.error {
537 return Err(LspError::ServerError {
538 code: error.code,
539 message: error.message,
540 });
541 }
542
543 Ok(response.result.unwrap_or(Value::Null))
544 }
545
546 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
548 where
549 N: lsp_types::notification::Notification,
550 N::Params: serde::Serialize,
551 {
552 self.ensure_can_send()?;
553 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
554 let mut writer = self
555 .writer
556 .lock()
557 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
558 transport::write_notification(&mut *writer, ¬ification)?;
559 Ok(())
560 }
561
562 pub fn shutdown(&mut self) -> Result<(), LspError> {
564 if self.state == ServerState::Exited {
565 self.child_registry.untrack(self.child_pid);
566 return Ok(());
567 }
568
569 if self.child.try_wait()?.is_some() {
570 self.state = ServerState::Exited;
571 self.child_registry.untrack(self.child_pid);
572 return Ok(());
573 }
574
575 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
576 self.state = ServerState::ShuttingDown;
577 if self.child.try_wait()?.is_some() {
578 self.state = ServerState::Exited;
579 return Ok(());
580 }
581 return Err(err);
582 }
583
584 self.state = ServerState::ShuttingDown;
585
586 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
587 if self.child.try_wait()?.is_some() {
588 self.state = ServerState::Exited;
589 return Ok(());
590 }
591 return Err(err);
592 }
593
594 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
595 loop {
596 if self.child.try_wait()?.is_some() {
597 self.state = ServerState::Exited;
598 return Ok(());
599 }
600 if Instant::now() >= deadline {
601 kill_lsp_child_group(&mut self.child);
605 self.state = ServerState::Exited;
606 return Err(LspError::Timeout(format!(
607 "timed out waiting for {:?} to exit",
608 self.kind
609 )));
610 }
611 thread::sleep(EXIT_POLL_INTERVAL);
612 }
613 }
614
615 pub fn stderr_tail(&self) -> String {
616 self.stderr_tail
617 .lock()
618 .map(|s| s.clone())
619 .unwrap_or_default()
620 }
621
622 pub fn child_exited(&mut self) -> bool {
623 self.child.try_wait().ok().flatten().is_some()
624 }
625
626 pub fn child_exit_status(&mut self) -> Option<std::process::ExitStatus> {
627 self.child.try_wait().ok().flatten()
628 }
629
630 pub fn state(&self) -> ServerState {
631 self.state
632 }
633
634 pub fn kind(&self) -> ServerKind {
635 self.kind.clone()
636 }
637
638 pub fn root(&self) -> &Path {
639 &self.root
640 }
641
642 fn ensure_can_send(&self) -> Result<(), LspError> {
643 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
644 return Err(LspError::ServerNotReady(format!(
645 "language server {:?} is not ready (state: {:?})",
646 self.kind, self.state
647 )));
648 }
649 Ok(())
650 }
651
652 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
653 self.pending
654 .lock()
655 .map_err(|_| io::Error::other("pending response map poisoned").into())
656 }
657
658 fn remove_pending(&self, id: &RequestId) {
659 if let Ok(mut pending) = self.pending.lock() {
660 pending.remove(id);
661 }
662 }
663
664 fn send_cancel_request(&mut self, id: &RequestId) -> Result<(), LspError> {
665 let notification = Notification::new("$/cancelRequest", Some(json!({ "id": id })));
666 let mut writer = self
667 .writer
668 .lock()
669 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
670 transport::write_notification(&mut *writer, ¬ification)?;
671 Ok(())
672 }
673}
674
675impl Drop for LspClient {
676 fn drop(&mut self) {
677 self.child_registry.untrack(self.child_pid);
680 kill_lsp_child_group(&mut self.child);
681 }
682}
683
684fn spawn_stderr_drain_thread(
685 mut stderr: std::process::ChildStderr,
686 stderr_tail: Arc<Mutex<String>>,
687) {
688 thread::spawn(move || {
689 let mut buf = [0_u8; 4096];
690 loop {
691 match stderr.read(&mut buf) {
692 Ok(0) => break,
693 Ok(n) => {
694 let chunk = String::from_utf8_lossy(&buf[..n]);
695 if let Ok(mut tail) = stderr_tail.lock() {
696 append_stderr_tail(&mut tail, &chunk);
697 } else {
698 break;
699 }
700 }
701 Err(_) => break,
702 }
703 }
704 });
705}
706
707fn append_stderr_tail(tail: &mut String, chunk: &str) {
708 tail.push_str(chunk);
709 trim_stderr_tail_bytes(tail);
710 trim_stderr_tail_lines(tail);
711}
712
713fn trim_stderr_tail_bytes(tail: &mut String) {
714 if tail.len() <= STDERR_TAIL_BYTES {
715 return;
716 }
717 let mut start = tail.len() - STDERR_TAIL_BYTES;
718 while start < tail.len() && !tail.is_char_boundary(start) {
719 start += 1;
720 }
721 tail.drain(..start);
722}
723
724fn trim_stderr_tail_lines(tail: &mut String) {
725 let line_count = tail.lines().count();
726 if line_count <= STDERR_TAIL_LINES {
727 return;
728 }
729 let excess = line_count - STDERR_TAIL_LINES;
730 let split_at = tail.match_indices('\n').nth(excess - 1).map(|(i, _)| i + 1);
731 if let Some(at) = split_at {
732 tail.drain(..at);
733 }
734}
735
736fn kill_lsp_child_group(child: &mut std::process::Child) {
743 #[cfg(unix)]
744 {
745 let pgid = child.id() as i32;
746 crate::bash_background::process::terminate_pgid(pgid, Some(child));
747 let _ = child.wait();
748 }
749 #[cfg(not(unix))]
750 {
751 crate::bash_background::process::terminate_process(child);
752 let _ = child.wait();
753 }
754}
755
756fn record_watched_file_registration(
757 registrations: &WatchedFileRegistrations,
758 method: &str,
759 params: Option<&Value>,
760) {
761 match method {
762 "client/registerCapability" => {
763 let Some(items) = params
764 .and_then(|params| params.get("registrations"))
765 .and_then(|registrations| registrations.as_array())
766 else {
767 return;
768 };
769 if let Ok(mut guard) = registrations.lock() {
770 for item in items {
771 if item.get("method").and_then(Value::as_str)
772 == Some("workspace/didChangeWatchedFiles")
773 {
774 if let Some(id) = item.get("id").and_then(Value::as_str) {
775 guard.insert(id.to_string());
776 }
777 }
778 }
779 }
780 }
781 "client/unregisterCapability" => {
782 let Some(items) = params
783 .and_then(|params| params.get("unregisterations"))
784 .and_then(|registrations| registrations.as_array())
785 else {
786 return;
787 };
788 if let Ok(mut guard) = registrations.lock() {
789 for item in items {
790 if item.get("method").and_then(Value::as_str)
791 == Some("workspace/didChangeWatchedFiles")
792 {
793 if let Some(id) = item.get("id").and_then(Value::as_str) {
794 guard.remove(id);
795 }
796 }
797 }
798 }
799 }
800 _ => {}
801 }
802}
803
804fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
819 let mut caps = ServerDiagnosticCapabilities::default();
820
821 if let Some(provider) = value.get("diagnosticProvider") {
822 if provider.is_object() || provider.as_bool() == Some(true) {
825 caps.pull_diagnostics = true;
826 }
827
828 if let Some(obj) = provider.as_object() {
829 if obj
830 .get("workspaceDiagnostics")
831 .and_then(|v| v.as_bool())
832 .unwrap_or(false)
833 {
834 caps.workspace_diagnostics = true;
835 }
836 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
837 caps.identifier = Some(identifier.to_string());
838 }
839 }
840 }
841
842 if let Some(refresh) = value
845 .get("workspace")
846 .and_then(|w| w.get("diagnostic"))
847 .and_then(|d| d.get("refreshSupport"))
848 .and_then(|r| r.as_bool())
849 {
850 caps.refresh_support = refresh;
851 }
852
853 caps
854}
855
856#[cfg(test)]
857mod tests {
858 use super::*;
859
860 #[test]
861 fn parse_caps_no_diagnostic_provider() {
862 let value = json!({});
863 let caps = parse_diagnostic_capabilities(&value);
864 assert!(!caps.pull_diagnostics);
865 assert!(!caps.workspace_diagnostics);
866 assert!(caps.identifier.is_none());
867 }
868
869 #[test]
870 fn parse_caps_basic_pull_only() {
871 let value = json!({
872 "diagnosticProvider": {
873 "interFileDependencies": false,
874 "workspaceDiagnostics": false
875 }
876 });
877 let caps = parse_diagnostic_capabilities(&value);
878 assert!(caps.pull_diagnostics);
879 assert!(!caps.workspace_diagnostics);
880 }
881
882 #[test]
883 fn parse_caps_full_pull_with_workspace() {
884 let value = json!({
885 "diagnosticProvider": {
886 "interFileDependencies": true,
887 "workspaceDiagnostics": true,
888 "identifier": "rust-analyzer"
889 }
890 });
891 let caps = parse_diagnostic_capabilities(&value);
892 assert!(caps.pull_diagnostics);
893 assert!(caps.workspace_diagnostics);
894 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
895 }
896
897 #[test]
898 fn parse_caps_provider_as_bare_true() {
899 let value = json!({
901 "diagnosticProvider": true
902 });
903 let caps = parse_diagnostic_capabilities(&value);
904 assert!(caps.pull_diagnostics);
905 assert!(!caps.workspace_diagnostics);
906 }
907
908 #[test]
909 fn parse_caps_workspace_refresh_support() {
910 let value = json!({
911 "workspace": {
912 "diagnostic": {
913 "refreshSupport": true
914 }
915 }
916 });
917 let caps = parse_diagnostic_capabilities(&value);
918 assert!(caps.refresh_support);
919 assert!(!caps.pull_diagnostics);
921 }
922}