1use std::collections::{HashMap, HashSet, VecDeque};
2use std::io::{self, BufRead, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::child_registry::LspChildRegistry;
16use crate::lsp::jsonrpc::{
17 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
18};
19use crate::lsp::position::path_to_uri;
20use crate::lsp::registry::ServerKind;
21use crate::lsp::{transport, LspError};
22
23const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
24const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
25const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
26const STDERR_TAIL_LINES: usize = 64;
27const STDERR_LINE_BYTES: usize = 4 * 1024;
28
29type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
30type WatchedFileRegistrations = Arc<Mutex<HashSet<String>>>;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum ServerState {
35 Starting,
36 Initializing,
37 Ready,
38 ShuttingDown,
39 Exited,
40}
41
42#[derive(Debug)]
44pub enum LspEvent {
45 Notification {
47 server_kind: ServerKind,
48 root: PathBuf,
49 method: String,
50 params: Option<Value>,
51 },
52 ServerRequest {
54 server_kind: ServerKind,
55 root: PathBuf,
56 id: RequestId,
57 method: String,
58 params: Option<Value>,
59 },
60 ServerExited {
62 server_kind: ServerKind,
63 root: PathBuf,
64 },
65}
66
67#[derive(Debug, Clone, Default)]
75pub struct ServerDiagnosticCapabilities {
76 pub pull_diagnostics: bool,
78 pub workspace_diagnostics: bool,
80 pub identifier: Option<String>,
83 pub refresh_support: bool,
87}
88
89pub struct LspClient {
91 kind: ServerKind,
92 root: PathBuf,
93 state: ServerState,
94 child: Child,
95 child_pid: u32,
99 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
100
101 pending: Arc<Mutex<PendingMap>>,
103 next_id: AtomicI64,
105 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
109 supports_watched_files: bool,
114 watched_file_registrations: WatchedFileRegistrations,
119 child_registry: LspChildRegistry,
123 stderr_tail: Arc<Mutex<VecDeque<String>>>,
124}
125
126impl LspClient {
127 pub fn spawn(
133 kind: ServerKind,
134 root: PathBuf,
135 binary: &Path,
136 args: &[String],
137 env: &HashMap<String, String>,
138 event_tx: Sender<LspEvent>,
139 child_registry: LspChildRegistry,
140 ) -> io::Result<Self> {
141 let mut command = Command::new(binary);
142 command
143 .args(args)
144 .current_dir(&root)
145 .stdin(Stdio::piped())
146 .stdout(Stdio::piped())
147 .stderr(Stdio::piped());
150 for (key, value) in env {
151 command.env(key, value);
152 }
153
154 #[cfg(unix)]
160 unsafe {
161 use std::os::unix::process::CommandExt;
162 command.pre_exec(|| {
163 #[cfg(target_os = "linux")]
164 {
165 if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL) == -1 {
171 return Err(io::Error::last_os_error());
172 }
173 if libc::getppid() == 1 {
174 return Err(io::Error::other("parent died before LSP spawn completed"));
175 }
176 }
177 if libc::setsid() == -1 {
178 return Err(io::Error::last_os_error());
179 }
180 Ok(())
181 });
182 }
183
184 let mut child = child_registry.spawn_tracked(&mut command)?;
185 let child_pid = child.id();
186
187 let stdout = child
188 .stdout
189 .take()
190 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
191 let stdin = child
192 .stdin
193 .take()
194 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
195 let stderr = child
196 .stderr
197 .take()
198 .ok_or_else(|| io::Error::other("language server missing stderr pipe"))?;
199 let stderr_tail = Arc::new(Mutex::new(VecDeque::with_capacity(STDERR_TAIL_LINES)));
200 spawn_stderr_drain_thread(stderr, Arc::clone(&stderr_tail));
201
202 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
203 let pending = Arc::new(Mutex::new(PendingMap::new()));
204 let watched_file_registrations = Arc::new(Mutex::new(HashSet::new()));
205 let reader_pending = Arc::clone(&pending);
206 let reader_writer = Arc::clone(&writer);
207 let reader_watched_file_registrations = Arc::clone(&watched_file_registrations);
208 let reader_kind = kind.clone();
209 let reader_root = root.clone();
210
211 thread::spawn(move || {
212 let mut reader = BufReader::new(stdout);
213 loop {
214 match transport::read_message(&mut reader) {
215 Ok(Some(ServerMessage::Response(response))) => {
216 if let Ok(mut guard) = reader_pending.lock() {
217 if let Some(tx) = guard.remove(&response.id) {
218 if tx.send(response).is_err() {
219 log::debug!("response channel closed");
220 }
221 }
222 } else {
223 let _ = event_tx.send(LspEvent::ServerExited {
224 server_kind: reader_kind.clone(),
225 root: reader_root.clone(),
226 });
227 break;
228 }
229 }
230 Ok(Some(ServerMessage::Notification { method, params })) => {
231 let _ = event_tx.send(LspEvent::Notification {
232 server_kind: reader_kind.clone(),
233 root: reader_root.clone(),
234 method,
235 params,
236 });
237 }
238 Ok(Some(ServerMessage::Request { id, method, params })) => {
239 record_watched_file_registration(
240 &reader_watched_file_registrations,
241 &method,
242 params.as_ref(),
243 );
244 let response_value = if method == "workspace/configuration" {
254 let item_count = params
257 .as_ref()
258 .and_then(|p| p.get("items"))
259 .and_then(|items| items.as_array())
260 .map_or(1, |arr| arr.len());
261 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
262 } else {
263 serde_json::Value::Null
264 };
265 if let Ok(mut w) = reader_writer.lock() {
266 let response = super::jsonrpc::OutgoingResponse::success(
267 id.clone(),
268 response_value,
269 );
270 let _ = transport::write_response(&mut *w, &response);
271 }
272 let _ = event_tx.send(LspEvent::ServerRequest {
274 server_kind: reader_kind.clone(),
275 root: reader_root.clone(),
276 id,
277 method,
278 params,
279 });
280 }
281 Ok(None) | Err(_) => {
282 if let Ok(mut guard) = reader_pending.lock() {
283 guard.clear();
284 }
285 let _ = event_tx.send(LspEvent::ServerExited {
286 server_kind: reader_kind.clone(),
287 root: reader_root.clone(),
288 });
289 break;
290 }
291 }
292 }
293 });
294
295 Ok(Self {
296 kind,
297 root,
298 state: ServerState::Starting,
299 child,
300 child_pid,
301 writer,
302 pending,
303 next_id: AtomicI64::new(1),
304 diagnostic_caps: None,
305 supports_watched_files: false,
306 watched_file_registrations,
307 child_registry,
308 stderr_tail,
309 })
310 }
311
312 pub fn initialize(
314 &mut self,
315 workspace_root: &Path,
316 initialization_options: Option<serde_json::Value>,
317 ) -> Result<lsp_types::InitializeResult, LspError> {
318 self.ensure_can_send()?;
319 self.state = ServerState::Initializing;
320
321 let root_url = path_to_uri(workspace_root)?;
322 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
323 LspError::NotFound(format!(
324 "failed to convert workspace root '{}' to file URI",
325 workspace_root.display()
326 ))
327 })?;
328
329 let mut params_value = json!({
330 "processId": std::process::id(),
331 "rootUri": root_uri,
332 "capabilities": {
333 "workspace": {
334 "workspaceFolders": true,
335 "configuration": true,
336 "didChangeWatchedFiles": {
337 "dynamicRegistration": true
338 },
339 "diagnostic": {
344 "refreshSupport": false
345 }
346 },
347 "textDocument": {
348 "synchronization": {
349 "dynamicRegistration": false,
350 "didSave": true,
351 "willSave": false,
352 "willSaveWaitUntil": false
353 },
354 "publishDiagnostics": {
355 "relatedInformation": true,
356 "versionSupport": true,
357 "codeDescriptionSupport": true,
358 "dataSupport": true
359 },
360 "diagnostic": {
365 "dynamicRegistration": false,
366 "relatedDocumentSupport": true
367 }
368 }
369 },
370 "clientInfo": {
371 "name": "aft",
372 "version": env!("CARGO_PKG_VERSION")
373 },
374 "workspaceFolders": [
375 {
376 "uri": root_uri,
377 "name": workspace_root
378 .file_name()
379 .and_then(|name| name.to_str())
380 .unwrap_or("workspace")
381 }
382 ]
383 });
384 if let Some(initialization_options) = initialization_options {
385 params_value["initializationOptions"] = initialization_options;
386 }
387
388 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
389
390 let result_value = self.send_request_value(
391 <lsp_types::request::Initialize as lsp_types::request::Request>::METHOD,
392 params,
393 )?;
394 let result: lsp_types::InitializeResult = serde_json::from_value(result_value.clone())?;
395
396 let caps_value = result_value
401 .get("capabilities")
402 .cloned()
403 .unwrap_or_else(|| serde_json::to_value(&result.capabilities).unwrap_or(Value::Null));
404 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
405
406 self.supports_watched_files = caps_value
412 .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
413 .and_then(|v| v.as_bool())
414 .unwrap_or(false)
415 || caps_value
416 .pointer("/workspace/didChangeWatchedFiles")
417 .map(|v| v.is_object() || v.as_bool() == Some(true))
418 .unwrap_or(false);
419
420 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
421 json!({}),
422 )?)?;
423 self.state = ServerState::Ready;
424 Ok(result)
425 }
426
427 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
431 self.diagnostic_caps.as_ref()
432 }
433
434 pub fn supports_watched_files(&self) -> bool {
438 self.supports_watched_files
439 }
440
441 pub fn has_watched_file_registration(&self) -> bool {
445 self.watched_file_registrations
446 .lock()
447 .map(|registrations| !registrations.is_empty())
448 .unwrap_or(false)
449 }
450
451 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
453 where
454 R: lsp_types::request::Request,
455 R::Params: serde::Serialize,
456 R::Result: DeserializeOwned,
457 {
458 self.ensure_can_send()?;
459
460 let value = self.send_request_value(R::METHOD, params)?;
461 serde_json::from_value(value).map_err(Into::into)
462 }
463
464 pub fn send_request_with_timeout<R>(
468 &mut self,
469 params: R::Params,
470 timeout: Duration,
471 ) -> Result<R::Result, LspError>
472 where
473 R: lsp_types::request::Request,
474 R::Params: serde::Serialize,
475 R::Result: DeserializeOwned,
476 {
477 self.ensure_can_send()?;
478
479 let value = self.send_request_value_with_timeout(R::METHOD, params, timeout)?;
480 serde_json::from_value(value).map_err(Into::into)
481 }
482
483 fn send_request_value<P>(&mut self, method: &'static str, params: P) -> Result<Value, LspError>
484 where
485 P: serde::Serialize,
486 {
487 self.send_request_value_with_timeout(method, params, REQUEST_TIMEOUT)
488 }
489
490 fn send_request_value_with_timeout<P>(
491 &mut self,
492 method: &'static str,
493 params: P,
494 timeout: Duration,
495 ) -> Result<Value, LspError>
496 where
497 P: serde::Serialize,
498 {
499 self.ensure_can_send()?;
500
501 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
502 let (tx, rx) = bounded(1);
503 {
504 let mut pending = self.lock_pending()?;
505 pending.insert(id.clone(), tx);
506 }
507
508 let request = Request::new(id.clone(), method, Some(serde_json::to_value(params)?));
509 {
510 let mut writer = self
511 .writer
512 .lock()
513 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
514 if let Err(err) = transport::write_request(&mut *writer, &request) {
515 self.remove_pending(&id);
516 return Err(err.into());
517 }
518 }
519
520 let response = match rx.recv_timeout(timeout) {
521 Ok(response) => response,
522 Err(RecvTimeoutError::Timeout) => {
523 self.remove_pending(&id);
524 self.send_cancel_request(&id)?;
525 return Err(LspError::Timeout(format!(
526 "timed out waiting for '{}' response from {:?}",
527 method, self.kind
528 )));
529 }
530 Err(RecvTimeoutError::Disconnected) => {
531 self.remove_pending(&id);
532 return Err(LspError::ServerNotReady(format!(
533 "language server {:?} disconnected while waiting for '{}'",
534 self.kind, method
535 )));
536 }
537 };
538
539 if let Some(error) = response.error {
540 return Err(LspError::ServerError {
541 code: error.code,
542 message: error.message,
543 });
544 }
545
546 Ok(response.result.unwrap_or(Value::Null))
547 }
548
549 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
551 where
552 N: lsp_types::notification::Notification,
553 N::Params: serde::Serialize,
554 {
555 self.ensure_can_send()?;
556 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
557 let mut writer = self
558 .writer
559 .lock()
560 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
561 transport::write_notification(&mut *writer, ¬ification)?;
562 Ok(())
563 }
564
565 pub fn shutdown(&mut self) -> Result<(), LspError> {
567 if self.state == ServerState::Exited {
568 self.child_registry.untrack(self.child_pid);
569 return Ok(());
570 }
571
572 if self.child.try_wait()?.is_some() {
573 self.state = ServerState::Exited;
574 self.child_registry.untrack(self.child_pid);
575 return Ok(());
576 }
577
578 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
579 self.state = ServerState::ShuttingDown;
580 if self.child.try_wait()?.is_some() {
581 self.state = ServerState::Exited;
582 return Ok(());
583 }
584 return Err(err);
585 }
586
587 self.state = ServerState::ShuttingDown;
588
589 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
590 if self.child.try_wait()?.is_some() {
591 self.state = ServerState::Exited;
592 return Ok(());
593 }
594 return Err(err);
595 }
596
597 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
598 loop {
599 if self.child.try_wait()?.is_some() {
600 self.state = ServerState::Exited;
601 return Ok(());
602 }
603 if Instant::now() >= deadline {
604 kill_lsp_child_group(&mut self.child);
608 self.state = ServerState::Exited;
609 return Err(LspError::Timeout(format!(
610 "timed out waiting for {:?} to exit",
611 self.kind
612 )));
613 }
614 thread::sleep(EXIT_POLL_INTERVAL);
615 }
616 }
617
618 pub fn stderr_tail(&self) -> String {
619 self.stderr_tail
620 .lock()
621 .map(|tail| stderr_tail_to_string(&tail))
622 .unwrap_or_default()
623 }
624
625 pub fn child_exited(&mut self) -> bool {
626 self.child.try_wait().ok().flatten().is_some()
627 }
628
629 pub fn child_exit_status(&mut self) -> Option<std::process::ExitStatus> {
630 self.child.try_wait().ok().flatten()
631 }
632
633 pub fn state(&self) -> ServerState {
634 self.state
635 }
636
637 pub fn kind(&self) -> ServerKind {
638 self.kind.clone()
639 }
640
641 pub fn root(&self) -> &Path {
642 &self.root
643 }
644
645 fn ensure_can_send(&self) -> Result<(), LspError> {
646 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
647 return Err(LspError::ServerNotReady(format!(
648 "language server {:?} is not ready (state: {:?})",
649 self.kind, self.state
650 )));
651 }
652 Ok(())
653 }
654
655 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
656 self.pending
657 .lock()
658 .map_err(|_| io::Error::other("pending response map poisoned").into())
659 }
660
661 fn remove_pending(&self, id: &RequestId) {
662 if let Ok(mut pending) = self.pending.lock() {
663 pending.remove(id);
664 }
665 }
666
667 fn send_cancel_request(&mut self, id: &RequestId) -> Result<(), LspError> {
668 let notification = Notification::new("$/cancelRequest", Some(json!({ "id": id })));
669 let mut writer = self
670 .writer
671 .lock()
672 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
673 transport::write_notification(&mut *writer, ¬ification)?;
674 Ok(())
675 }
676}
677
678impl Drop for LspClient {
679 fn drop(&mut self) {
680 self.child_registry.untrack(self.child_pid);
683 kill_lsp_child_group(&mut self.child);
684 }
685}
686
687fn spawn_stderr_drain_thread(
688 stderr: std::process::ChildStderr,
689 stderr_tail: Arc<Mutex<VecDeque<String>>>,
690) {
691 thread::spawn(move || {
692 let mut reader = BufReader::new(stderr);
693 let mut line = String::new();
694
695 loop {
696 line.clear();
697 match reader.read_line(&mut line) {
698 Ok(0) => break,
699 Ok(_) => {
700 if let Ok(mut tail) = stderr_tail.lock() {
701 append_stderr_tail(&mut tail, &line);
702 } else {
703 break;
704 }
705 }
706 Err(_) => break,
707 }
708 }
709 });
710}
711
712fn append_stderr_tail(tail: &mut VecDeque<String>, line: &str) {
713 if tail.len() == STDERR_TAIL_LINES {
714 tail.pop_front();
715 }
716 tail.push_back(trim_stderr_line(line));
717}
718
719fn trim_stderr_line(line: &str) -> String {
720 let line = line.trim_end_matches(|ch| ch == '\r' || ch == '\n');
721 if line.len() <= STDERR_LINE_BYTES {
722 return line.to_string();
723 }
724
725 let mut start = line.len() - STDERR_LINE_BYTES;
726 while start < line.len() && !line.is_char_boundary(start) {
727 start += 1;
728 }
729 format!("...{}", &line[start..])
730}
731
732fn stderr_tail_to_string(tail: &VecDeque<String>) -> String {
733 tail.iter()
734 .map(String::as_str)
735 .collect::<Vec<_>>()
736 .join("\n")
737}
738
739fn kill_lsp_child_group(child: &mut std::process::Child) {
746 #[cfg(unix)]
747 {
748 let pgid = child.id() as i32;
749 crate::bash_background::process::terminate_pgid(pgid, Some(child));
750 let _ = child.wait();
751 }
752 #[cfg(not(unix))]
753 {
754 crate::bash_background::process::terminate_process(child);
755 let _ = child.wait();
756 }
757}
758
759fn record_watched_file_registration(
760 registrations: &WatchedFileRegistrations,
761 method: &str,
762 params: Option<&Value>,
763) {
764 match method {
765 "client/registerCapability" => {
766 let Some(items) = params
767 .and_then(|params| params.get("registrations"))
768 .and_then(|registrations| registrations.as_array())
769 else {
770 return;
771 };
772 if let Ok(mut guard) = registrations.lock() {
773 for item in items {
774 if item.get("method").and_then(Value::as_str)
775 == Some("workspace/didChangeWatchedFiles")
776 {
777 if let Some(id) = item.get("id").and_then(Value::as_str) {
778 guard.insert(id.to_string());
779 }
780 }
781 }
782 }
783 }
784 "client/unregisterCapability" => {
785 let Some(items) = params
786 .and_then(|params| params.get("unregisterations"))
787 .and_then(|registrations| registrations.as_array())
788 else {
789 return;
790 };
791 if let Ok(mut guard) = registrations.lock() {
792 for item in items {
793 if item.get("method").and_then(Value::as_str)
794 == Some("workspace/didChangeWatchedFiles")
795 {
796 if let Some(id) = item.get("id").and_then(Value::as_str) {
797 guard.remove(id);
798 }
799 }
800 }
801 }
802 }
803 _ => {}
804 }
805}
806
807fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
822 let mut caps = ServerDiagnosticCapabilities::default();
823
824 if let Some(provider) = value.get("diagnosticProvider") {
825 if provider.is_object() || provider.as_bool() == Some(true) {
828 caps.pull_diagnostics = true;
829 }
830
831 if let Some(obj) = provider.as_object() {
832 if obj
833 .get("workspaceDiagnostics")
834 .and_then(|v| v.as_bool())
835 .unwrap_or(false)
836 {
837 caps.workspace_diagnostics = true;
838 }
839 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
840 caps.identifier = Some(identifier.to_string());
841 }
842 }
843 }
844
845 if let Some(refresh) = value
848 .get("workspace")
849 .and_then(|w| w.get("diagnostic"))
850 .and_then(|d| d.get("refreshSupport"))
851 .and_then(|r| r.as_bool())
852 {
853 caps.refresh_support = refresh;
854 }
855
856 caps
857}
858
859#[cfg(test)]
860mod tests {
861 use super::*;
862
863 #[test]
864 fn parse_caps_no_diagnostic_provider() {
865 let value = json!({});
866 let caps = parse_diagnostic_capabilities(&value);
867 assert!(!caps.pull_diagnostics);
868 assert!(!caps.workspace_diagnostics);
869 assert!(caps.identifier.is_none());
870 }
871
872 #[test]
873 fn parse_caps_basic_pull_only() {
874 let value = json!({
875 "diagnosticProvider": {
876 "interFileDependencies": false,
877 "workspaceDiagnostics": false
878 }
879 });
880 let caps = parse_diagnostic_capabilities(&value);
881 assert!(caps.pull_diagnostics);
882 assert!(!caps.workspace_diagnostics);
883 }
884
885 #[test]
886 fn parse_caps_full_pull_with_workspace() {
887 let value = json!({
888 "diagnosticProvider": {
889 "interFileDependencies": true,
890 "workspaceDiagnostics": true,
891 "identifier": "rust-analyzer"
892 }
893 });
894 let caps = parse_diagnostic_capabilities(&value);
895 assert!(caps.pull_diagnostics);
896 assert!(caps.workspace_diagnostics);
897 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
898 }
899
900 #[test]
901 fn parse_caps_provider_as_bare_true() {
902 let value = json!({
904 "diagnosticProvider": true
905 });
906 let caps = parse_diagnostic_capabilities(&value);
907 assert!(caps.pull_diagnostics);
908 assert!(!caps.workspace_diagnostics);
909 }
910
911 #[test]
912 fn parse_caps_workspace_refresh_support() {
913 let value = json!({
914 "workspace": {
915 "diagnostic": {
916 "refreshSupport": true
917 }
918 }
919 });
920 let caps = parse_diagnostic_capabilities(&value);
921 assert!(caps.refresh_support);
922 assert!(!caps.pull_diagnostics);
924 }
925}