1use std::collections::{HashMap, HashSet};
2use std::io::{self, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::child_registry::LspChildRegistry;
16use crate::lsp::jsonrpc::{
17 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
18};
19use crate::lsp::position::path_to_uri;
20use crate::lsp::registry::ServerKind;
21use crate::lsp::{transport, LspError};
22
23const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
24const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
25const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
26
27type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
28type WatchedFileRegistrations = Arc<Mutex<HashSet<String>>>;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ServerState {
33 Starting,
34 Initializing,
35 Ready,
36 ShuttingDown,
37 Exited,
38}
39
40#[derive(Debug)]
42pub enum LspEvent {
43 Notification {
45 server_kind: ServerKind,
46 root: PathBuf,
47 method: String,
48 params: Option<Value>,
49 },
50 ServerRequest {
52 server_kind: ServerKind,
53 root: PathBuf,
54 id: RequestId,
55 method: String,
56 params: Option<Value>,
57 },
58 ServerExited {
60 server_kind: ServerKind,
61 root: PathBuf,
62 },
63}
64
65#[derive(Debug, Clone, Default)]
73pub struct ServerDiagnosticCapabilities {
74 pub pull_diagnostics: bool,
76 pub workspace_diagnostics: bool,
78 pub identifier: Option<String>,
81 pub refresh_support: bool,
85}
86
87pub struct LspClient {
89 kind: ServerKind,
90 root: PathBuf,
91 state: ServerState,
92 child: Child,
93 child_pid: u32,
97 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
98
99 pending: Arc<Mutex<PendingMap>>,
101 next_id: AtomicI64,
103 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
107 supports_watched_files: bool,
112 watched_file_registrations: WatchedFileRegistrations,
117 child_registry: LspChildRegistry,
121}
122
123impl LspClient {
124 pub fn spawn(
130 kind: ServerKind,
131 root: PathBuf,
132 binary: &Path,
133 args: &[String],
134 env: &HashMap<String, String>,
135 event_tx: Sender<LspEvent>,
136 child_registry: LspChildRegistry,
137 ) -> io::Result<Self> {
138 let mut command = Command::new(binary);
139 command
140 .args(args)
141 .current_dir(&root)
142 .stdin(Stdio::piped())
143 .stdout(Stdio::piped())
144 .stderr(Stdio::null());
147 for (key, value) in env {
148 command.env(key, value);
149 }
150
151 #[cfg(unix)]
157 unsafe {
158 use std::os::unix::process::CommandExt;
159 command.pre_exec(|| {
160 if libc::setsid() == -1 {
161 return Err(io::Error::last_os_error());
162 }
163 Ok(())
164 });
165 }
166
167 let mut child = command.spawn()?;
168 let child_pid = child.id();
169 child_registry.track(child_pid);
170
171 let stdout = child
172 .stdout
173 .take()
174 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
175 let stdin = child
176 .stdin
177 .take()
178 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
179
180 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
181 let pending = Arc::new(Mutex::new(PendingMap::new()));
182 let watched_file_registrations = Arc::new(Mutex::new(HashSet::new()));
183 let reader_pending = Arc::clone(&pending);
184 let reader_writer = Arc::clone(&writer);
185 let reader_watched_file_registrations = Arc::clone(&watched_file_registrations);
186 let reader_kind = kind.clone();
187 let reader_root = root.clone();
188
189 thread::spawn(move || {
190 let mut reader = BufReader::new(stdout);
191 loop {
192 match transport::read_message(&mut reader) {
193 Ok(Some(ServerMessage::Response(response))) => {
194 if let Ok(mut guard) = reader_pending.lock() {
195 if let Some(tx) = guard.remove(&response.id) {
196 if tx.send(response).is_err() {
197 log::debug!("response channel closed");
198 }
199 }
200 } else {
201 let _ = event_tx.send(LspEvent::ServerExited {
202 server_kind: reader_kind.clone(),
203 root: reader_root.clone(),
204 });
205 break;
206 }
207 }
208 Ok(Some(ServerMessage::Notification { method, params })) => {
209 let _ = event_tx.send(LspEvent::Notification {
210 server_kind: reader_kind.clone(),
211 root: reader_root.clone(),
212 method,
213 params,
214 });
215 }
216 Ok(Some(ServerMessage::Request { id, method, params })) => {
217 record_watched_file_registration(
218 &reader_watched_file_registrations,
219 &method,
220 params.as_ref(),
221 );
222 let response_value = if method == "workspace/configuration" {
232 let item_count = params
235 .as_ref()
236 .and_then(|p| p.get("items"))
237 .and_then(|items| items.as_array())
238 .map_or(1, |arr| arr.len());
239 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
240 } else {
241 serde_json::Value::Null
242 };
243 if let Ok(mut w) = reader_writer.lock() {
244 let response = super::jsonrpc::OutgoingResponse::success(
245 id.clone(),
246 response_value,
247 );
248 let _ = transport::write_response(&mut *w, &response);
249 }
250 let _ = event_tx.send(LspEvent::ServerRequest {
252 server_kind: reader_kind.clone(),
253 root: reader_root.clone(),
254 id,
255 method,
256 params,
257 });
258 }
259 Ok(None) | Err(_) => {
260 if let Ok(mut guard) = reader_pending.lock() {
261 guard.clear();
262 }
263 let _ = event_tx.send(LspEvent::ServerExited {
264 server_kind: reader_kind.clone(),
265 root: reader_root.clone(),
266 });
267 break;
268 }
269 }
270 }
271 });
272
273 Ok(Self {
274 kind,
275 root,
276 state: ServerState::Starting,
277 child,
278 child_pid,
279 writer,
280 pending,
281 next_id: AtomicI64::new(1),
282 diagnostic_caps: None,
283 supports_watched_files: false,
284 watched_file_registrations,
285 child_registry,
286 })
287 }
288
289 pub fn initialize(
291 &mut self,
292 workspace_root: &Path,
293 initialization_options: Option<serde_json::Value>,
294 ) -> Result<lsp_types::InitializeResult, LspError> {
295 self.ensure_can_send()?;
296 self.state = ServerState::Initializing;
297
298 let root_url = path_to_uri(workspace_root)?;
299 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
300 LspError::NotFound(format!(
301 "failed to convert workspace root '{}' to file URI",
302 workspace_root.display()
303 ))
304 })?;
305
306 let mut params_value = json!({
307 "processId": std::process::id(),
308 "rootUri": root_uri,
309 "capabilities": {
310 "workspace": {
311 "workspaceFolders": true,
312 "configuration": true,
313 "didChangeWatchedFiles": {
314 "dynamicRegistration": true
315 },
316 "diagnostic": {
321 "refreshSupport": false
322 }
323 },
324 "textDocument": {
325 "synchronization": {
326 "dynamicRegistration": false,
327 "didSave": true,
328 "willSave": false,
329 "willSaveWaitUntil": false
330 },
331 "publishDiagnostics": {
332 "relatedInformation": true,
333 "versionSupport": true,
334 "codeDescriptionSupport": true,
335 "dataSupport": true
336 },
337 "diagnostic": {
342 "dynamicRegistration": false,
343 "relatedDocumentSupport": true
344 }
345 }
346 },
347 "clientInfo": {
348 "name": "aft",
349 "version": env!("CARGO_PKG_VERSION")
350 },
351 "workspaceFolders": [
352 {
353 "uri": root_uri,
354 "name": workspace_root
355 .file_name()
356 .and_then(|name| name.to_str())
357 .unwrap_or("workspace")
358 }
359 ]
360 });
361 if let Some(initialization_options) = initialization_options {
362 params_value["initializationOptions"] = initialization_options;
363 }
364
365 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
366
367 let result_value = self.send_request_value(
368 <lsp_types::request::Initialize as lsp_types::request::Request>::METHOD,
369 params,
370 )?;
371 let result: lsp_types::InitializeResult = serde_json::from_value(result_value.clone())?;
372
373 let caps_value = result_value
378 .get("capabilities")
379 .cloned()
380 .unwrap_or_else(|| serde_json::to_value(&result.capabilities).unwrap_or(Value::Null));
381 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
382
383 self.supports_watched_files = caps_value
387 .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
388 .and_then(|v| v.as_bool())
389 .unwrap_or(false)
390 || caps_value
391 .pointer("/workspace/didChangeWatchedFiles")
392 .map(|v| v.is_object() || v.as_bool() == Some(true))
393 .unwrap_or(false);
394
395 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
396 json!({}),
397 )?)?;
398 self.state = ServerState::Ready;
399 Ok(result)
400 }
401
402 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
406 self.diagnostic_caps.as_ref()
407 }
408
409 pub fn supports_watched_files(&self) -> bool {
412 self.supports_watched_files
413 }
414
415 pub fn has_watched_file_registration(&self) -> bool {
419 self.watched_file_registrations
420 .lock()
421 .map(|registrations| !registrations.is_empty())
422 .unwrap_or(false)
423 }
424
425 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
427 where
428 R: lsp_types::request::Request,
429 R::Params: serde::Serialize,
430 R::Result: DeserializeOwned,
431 {
432 self.ensure_can_send()?;
433
434 let value = self.send_request_value(R::METHOD, params)?;
435 serde_json::from_value(value).map_err(Into::into)
436 }
437
438 pub fn send_request_with_timeout<R>(
442 &mut self,
443 params: R::Params,
444 timeout: Duration,
445 ) -> Result<R::Result, LspError>
446 where
447 R: lsp_types::request::Request,
448 R::Params: serde::Serialize,
449 R::Result: DeserializeOwned,
450 {
451 self.ensure_can_send()?;
452
453 let value = self.send_request_value_with_timeout(R::METHOD, params, timeout)?;
454 serde_json::from_value(value).map_err(Into::into)
455 }
456
457 fn send_request_value<P>(&mut self, method: &'static str, params: P) -> Result<Value, LspError>
458 where
459 P: serde::Serialize,
460 {
461 self.send_request_value_with_timeout(method, params, REQUEST_TIMEOUT)
462 }
463
464 fn send_request_value_with_timeout<P>(
465 &mut self,
466 method: &'static str,
467 params: P,
468 timeout: Duration,
469 ) -> Result<Value, LspError>
470 where
471 P: serde::Serialize,
472 {
473 self.ensure_can_send()?;
474
475 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
476 let (tx, rx) = bounded(1);
477 {
478 let mut pending = self.lock_pending()?;
479 pending.insert(id.clone(), tx);
480 }
481
482 let request = Request::new(id.clone(), method, Some(serde_json::to_value(params)?));
483 {
484 let mut writer = self
485 .writer
486 .lock()
487 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
488 if let Err(err) = transport::write_request(&mut *writer, &request) {
489 self.remove_pending(&id);
490 return Err(err.into());
491 }
492 }
493
494 let response = match rx.recv_timeout(timeout) {
495 Ok(response) => response,
496 Err(RecvTimeoutError::Timeout) => {
497 self.remove_pending(&id);
498 self.send_cancel_request(&id)?;
499 return Err(LspError::Timeout(format!(
500 "timed out waiting for '{}' response from {:?}",
501 method, self.kind
502 )));
503 }
504 Err(RecvTimeoutError::Disconnected) => {
505 self.remove_pending(&id);
506 return Err(LspError::ServerNotReady(format!(
507 "language server {:?} disconnected while waiting for '{}'",
508 self.kind, method
509 )));
510 }
511 };
512
513 if let Some(error) = response.error {
514 return Err(LspError::ServerError {
515 code: error.code,
516 message: error.message,
517 });
518 }
519
520 Ok(response.result.unwrap_or(Value::Null))
521 }
522
523 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
525 where
526 N: lsp_types::notification::Notification,
527 N::Params: serde::Serialize,
528 {
529 self.ensure_can_send()?;
530 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
531 let mut writer = self
532 .writer
533 .lock()
534 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
535 transport::write_notification(&mut *writer, ¬ification)?;
536 Ok(())
537 }
538
539 pub fn shutdown(&mut self) -> Result<(), LspError> {
541 if self.state == ServerState::Exited {
542 self.child_registry.untrack(self.child_pid);
543 return Ok(());
544 }
545
546 if self.child.try_wait()?.is_some() {
547 self.state = ServerState::Exited;
548 self.child_registry.untrack(self.child_pid);
549 return Ok(());
550 }
551
552 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
553 self.state = ServerState::ShuttingDown;
554 if self.child.try_wait()?.is_some() {
555 self.state = ServerState::Exited;
556 return Ok(());
557 }
558 return Err(err);
559 }
560
561 self.state = ServerState::ShuttingDown;
562
563 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
564 if self.child.try_wait()?.is_some() {
565 self.state = ServerState::Exited;
566 return Ok(());
567 }
568 return Err(err);
569 }
570
571 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
572 loop {
573 if self.child.try_wait()?.is_some() {
574 self.state = ServerState::Exited;
575 return Ok(());
576 }
577 if Instant::now() >= deadline {
578 kill_lsp_child_group(&mut self.child);
582 self.state = ServerState::Exited;
583 return Err(LspError::Timeout(format!(
584 "timed out waiting for {:?} to exit",
585 self.kind
586 )));
587 }
588 thread::sleep(EXIT_POLL_INTERVAL);
589 }
590 }
591
592 pub fn state(&self) -> ServerState {
593 self.state
594 }
595
596 pub fn kind(&self) -> ServerKind {
597 self.kind.clone()
598 }
599
600 pub fn root(&self) -> &Path {
601 &self.root
602 }
603
604 fn ensure_can_send(&self) -> Result<(), LspError> {
605 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
606 return Err(LspError::ServerNotReady(format!(
607 "language server {:?} is not ready (state: {:?})",
608 self.kind, self.state
609 )));
610 }
611 Ok(())
612 }
613
614 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
615 self.pending
616 .lock()
617 .map_err(|_| io::Error::other("pending response map poisoned").into())
618 }
619
620 fn remove_pending(&self, id: &RequestId) {
621 if let Ok(mut pending) = self.pending.lock() {
622 pending.remove(id);
623 }
624 }
625
626 fn send_cancel_request(&mut self, id: &RequestId) -> Result<(), LspError> {
627 let notification = Notification::new("$/cancelRequest", Some(json!({ "id": id })));
628 let mut writer = self
629 .writer
630 .lock()
631 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
632 transport::write_notification(&mut *writer, ¬ification)?;
633 Ok(())
634 }
635}
636
637impl Drop for LspClient {
638 fn drop(&mut self) {
639 self.child_registry.untrack(self.child_pid);
642 kill_lsp_child_group(&mut self.child);
643 }
644}
645
646fn kill_lsp_child_group(child: &mut std::process::Child) {
653 #[cfg(unix)]
654 {
655 let pgid = child.id() as i32;
656 crate::bash_background::process::terminate_pgid(pgid, Some(child));
657 let _ = child.wait();
658 }
659 #[cfg(not(unix))]
660 {
661 crate::bash_background::process::terminate_process(child);
662 let _ = child.wait();
663 }
664}
665
666fn record_watched_file_registration(
667 registrations: &WatchedFileRegistrations,
668 method: &str,
669 params: Option<&Value>,
670) {
671 match method {
672 "client/registerCapability" => {
673 let Some(items) = params
674 .and_then(|params| params.get("registrations"))
675 .and_then(|registrations| registrations.as_array())
676 else {
677 return;
678 };
679 if let Ok(mut guard) = registrations.lock() {
680 for item in items {
681 if item.get("method").and_then(Value::as_str)
682 == Some("workspace/didChangeWatchedFiles")
683 {
684 if let Some(id) = item.get("id").and_then(Value::as_str) {
685 guard.insert(id.to_string());
686 }
687 }
688 }
689 }
690 }
691 "client/unregisterCapability" => {
692 let Some(items) = params
693 .and_then(|params| params.get("unregisterations"))
694 .and_then(|registrations| registrations.as_array())
695 else {
696 return;
697 };
698 if let Ok(mut guard) = registrations.lock() {
699 for item in items {
700 if item.get("method").and_then(Value::as_str)
701 == Some("workspace/didChangeWatchedFiles")
702 {
703 if let Some(id) = item.get("id").and_then(Value::as_str) {
704 guard.remove(id);
705 }
706 }
707 }
708 }
709 }
710 _ => {}
711 }
712}
713
714fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
729 let mut caps = ServerDiagnosticCapabilities::default();
730
731 if let Some(provider) = value.get("diagnosticProvider") {
732 if provider.is_object() || provider.as_bool() == Some(true) {
735 caps.pull_diagnostics = true;
736 }
737
738 if let Some(obj) = provider.as_object() {
739 if obj
740 .get("workspaceDiagnostics")
741 .and_then(|v| v.as_bool())
742 .unwrap_or(false)
743 {
744 caps.workspace_diagnostics = true;
745 }
746 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
747 caps.identifier = Some(identifier.to_string());
748 }
749 }
750 }
751
752 if let Some(refresh) = value
755 .get("workspace")
756 .and_then(|w| w.get("diagnostic"))
757 .and_then(|d| d.get("refreshSupport"))
758 .and_then(|r| r.as_bool())
759 {
760 caps.refresh_support = refresh;
761 }
762
763 caps
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769
770 #[test]
771 fn parse_caps_no_diagnostic_provider() {
772 let value = json!({});
773 let caps = parse_diagnostic_capabilities(&value);
774 assert!(!caps.pull_diagnostics);
775 assert!(!caps.workspace_diagnostics);
776 assert!(caps.identifier.is_none());
777 }
778
779 #[test]
780 fn parse_caps_basic_pull_only() {
781 let value = json!({
782 "diagnosticProvider": {
783 "interFileDependencies": false,
784 "workspaceDiagnostics": false
785 }
786 });
787 let caps = parse_diagnostic_capabilities(&value);
788 assert!(caps.pull_diagnostics);
789 assert!(!caps.workspace_diagnostics);
790 }
791
792 #[test]
793 fn parse_caps_full_pull_with_workspace() {
794 let value = json!({
795 "diagnosticProvider": {
796 "interFileDependencies": true,
797 "workspaceDiagnostics": true,
798 "identifier": "rust-analyzer"
799 }
800 });
801 let caps = parse_diagnostic_capabilities(&value);
802 assert!(caps.pull_diagnostics);
803 assert!(caps.workspace_diagnostics);
804 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
805 }
806
807 #[test]
808 fn parse_caps_provider_as_bare_true() {
809 let value = json!({
811 "diagnosticProvider": true
812 });
813 let caps = parse_diagnostic_capabilities(&value);
814 assert!(caps.pull_diagnostics);
815 assert!(!caps.workspace_diagnostics);
816 }
817
818 #[test]
819 fn parse_caps_workspace_refresh_support() {
820 let value = json!({
821 "workspace": {
822 "diagnostic": {
823 "refreshSupport": true
824 }
825 }
826 });
827 let caps = parse_diagnostic_capabilities(&value);
828 assert!(caps.refresh_support);
829 assert!(!caps.pull_diagnostics);
831 }
832}