1use std::collections::HashMap;
2use std::io::{self, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::child_registry::LspChildRegistry;
16use crate::lsp::jsonrpc::{
17 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
18};
19use crate::lsp::registry::ServerKind;
20use crate::lsp::{transport, LspError};
21
22const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
23const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
24const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
25
26type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ServerState {
31 Starting,
32 Initializing,
33 Ready,
34 ShuttingDown,
35 Exited,
36}
37
38#[derive(Debug)]
40pub enum LspEvent {
41 Notification {
43 server_kind: ServerKind,
44 root: PathBuf,
45 method: String,
46 params: Option<Value>,
47 },
48 ServerRequest {
50 server_kind: ServerKind,
51 root: PathBuf,
52 id: RequestId,
53 method: String,
54 params: Option<Value>,
55 },
56 ServerExited {
58 server_kind: ServerKind,
59 root: PathBuf,
60 },
61}
62
63#[derive(Debug, Clone, Default)]
71pub struct ServerDiagnosticCapabilities {
72 pub pull_diagnostics: bool,
74 pub workspace_diagnostics: bool,
76 pub identifier: Option<String>,
79 pub refresh_support: bool,
83}
84
85pub struct LspClient {
87 kind: ServerKind,
88 root: PathBuf,
89 state: ServerState,
90 child: Child,
91 child_pid: u32,
95 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
96
97 pending: Arc<Mutex<PendingMap>>,
99 next_id: AtomicI64,
101 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
105 supports_watched_files: bool,
110 child_registry: LspChildRegistry,
114}
115
116impl LspClient {
117 pub fn spawn(
123 kind: ServerKind,
124 root: PathBuf,
125 binary: &Path,
126 args: &[String],
127 env: &HashMap<String, String>,
128 event_tx: Sender<LspEvent>,
129 child_registry: LspChildRegistry,
130 ) -> io::Result<Self> {
131 let mut command = Command::new(binary);
132 command
133 .args(args)
134 .current_dir(&root)
135 .stdin(Stdio::piped())
136 .stdout(Stdio::piped())
137 .stderr(Stdio::null());
140 for (key, value) in env {
141 command.env(key, value);
142 }
143
144 #[cfg(unix)]
150 unsafe {
151 use std::os::unix::process::CommandExt;
152 command.pre_exec(|| {
153 if libc::setsid() == -1 {
154 return Err(io::Error::last_os_error());
155 }
156 Ok(())
157 });
158 }
159
160 let mut child = command.spawn()?;
161 let child_pid = child.id();
162 child_registry.track(child_pid);
163
164 let stdout = child
165 .stdout
166 .take()
167 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
168 let stdin = child
169 .stdin
170 .take()
171 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
172
173 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
174 let pending = Arc::new(Mutex::new(PendingMap::new()));
175 let reader_pending = Arc::clone(&pending);
176 let reader_writer = Arc::clone(&writer);
177 let reader_kind = kind.clone();
178 let reader_root = root.clone();
179
180 thread::spawn(move || {
181 let mut reader = BufReader::new(stdout);
182 loop {
183 match transport::read_message(&mut reader) {
184 Ok(Some(ServerMessage::Response(response))) => {
185 if let Ok(mut guard) = reader_pending.lock() {
186 if let Some(tx) = guard.remove(&response.id) {
187 if tx.send(response).is_err() {
188 log::debug!("response channel closed");
189 }
190 }
191 } else {
192 let _ = event_tx.send(LspEvent::ServerExited {
193 server_kind: reader_kind.clone(),
194 root: reader_root.clone(),
195 });
196 break;
197 }
198 }
199 Ok(Some(ServerMessage::Notification { method, params })) => {
200 let _ = event_tx.send(LspEvent::Notification {
201 server_kind: reader_kind.clone(),
202 root: reader_root.clone(),
203 method,
204 params,
205 });
206 }
207 Ok(Some(ServerMessage::Request { id, method, params })) => {
208 let response_value = if method == "workspace/configuration" {
218 let item_count = params
221 .as_ref()
222 .and_then(|p| p.get("items"))
223 .and_then(|items| items.as_array())
224 .map_or(1, |arr| arr.len());
225 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
226 } else {
227 serde_json::Value::Null
228 };
229 if let Ok(mut w) = reader_writer.lock() {
230 let response = super::jsonrpc::OutgoingResponse::success(
231 id.clone(),
232 response_value,
233 );
234 let _ = transport::write_response(&mut *w, &response);
235 }
236 let _ = event_tx.send(LspEvent::ServerRequest {
238 server_kind: reader_kind.clone(),
239 root: reader_root.clone(),
240 id,
241 method,
242 params,
243 });
244 }
245 Ok(None) | Err(_) => {
246 if let Ok(mut guard) = reader_pending.lock() {
247 guard.clear();
248 }
249 let _ = event_tx.send(LspEvent::ServerExited {
250 server_kind: reader_kind.clone(),
251 root: reader_root.clone(),
252 });
253 break;
254 }
255 }
256 }
257 });
258
259 Ok(Self {
260 kind,
261 root,
262 state: ServerState::Starting,
263 child,
264 child_pid,
265 writer,
266 pending,
267 next_id: AtomicI64::new(1),
268 diagnostic_caps: None,
269 supports_watched_files: false,
270 child_registry,
271 })
272 }
273
274 pub fn initialize(
276 &mut self,
277 workspace_root: &Path,
278 initialization_options: Option<serde_json::Value>,
279 ) -> Result<lsp_types::InitializeResult, LspError> {
280 self.ensure_can_send()?;
281 self.state = ServerState::Initializing;
282
283 let normalized = normalize_windows_path(workspace_root);
284 let root_url = url::Url::from_file_path(&normalized).map_err(|_| {
285 LspError::NotFound(format!(
286 "failed to convert workspace root '{}' to file URI",
287 workspace_root.display()
288 ))
289 })?;
290 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
291 LspError::NotFound(format!(
292 "failed to convert workspace root '{}' to file URI",
293 workspace_root.display()
294 ))
295 })?;
296
297 let mut params_value = json!({
298 "processId": std::process::id(),
299 "rootUri": root_uri,
300 "capabilities": {
301 "workspace": {
302 "workspaceFolders": true,
303 "configuration": true,
304 "diagnostic": {
309 "refreshSupport": false
310 }
311 },
312 "textDocument": {
313 "synchronization": {
314 "dynamicRegistration": false,
315 "didSave": true,
316 "willSave": false,
317 "willSaveWaitUntil": false
318 },
319 "publishDiagnostics": {
320 "relatedInformation": true,
321 "versionSupport": true,
322 "codeDescriptionSupport": true,
323 "dataSupport": true
324 },
325 "diagnostic": {
330 "dynamicRegistration": false,
331 "relatedDocumentSupport": true
332 }
333 }
334 },
335 "clientInfo": {
336 "name": "aft",
337 "version": env!("CARGO_PKG_VERSION")
338 },
339 "workspaceFolders": [
340 {
341 "uri": root_uri,
342 "name": workspace_root
343 .file_name()
344 .and_then(|name| name.to_str())
345 .unwrap_or("workspace")
346 }
347 ]
348 });
349 if let Some(initialization_options) = initialization_options {
350 params_value["initializationOptions"] = initialization_options;
351 }
352
353 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
354
355 let result_value = self.send_request_value(
356 <lsp_types::request::Initialize as lsp_types::request::Request>::METHOD,
357 params,
358 )?;
359 let result: lsp_types::InitializeResult = serde_json::from_value(result_value.clone())?;
360
361 let caps_value = result_value
366 .get("capabilities")
367 .cloned()
368 .unwrap_or_else(|| serde_json::to_value(&result.capabilities).unwrap_or(Value::Null));
369 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
370
371 self.supports_watched_files = caps_value
375 .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
376 .and_then(|v| v.as_bool())
377 .unwrap_or(false)
378 || caps_value
379 .pointer("/workspace/didChangeWatchedFiles")
380 .map(|v| v.is_object() || v.as_bool() == Some(true))
381 .unwrap_or(false);
382
383 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
384 json!({}),
385 )?)?;
386 self.state = ServerState::Ready;
387 Ok(result)
388 }
389
390 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
394 self.diagnostic_caps.as_ref()
395 }
396
397 pub fn supports_watched_files(&self) -> bool {
400 self.supports_watched_files
401 }
402
403 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
405 where
406 R: lsp_types::request::Request,
407 R::Params: serde::Serialize,
408 R::Result: DeserializeOwned,
409 {
410 self.ensure_can_send()?;
411
412 let value = self.send_request_value(R::METHOD, params)?;
413 serde_json::from_value(value).map_err(Into::into)
414 }
415
416 fn send_request_value<P>(&mut self, method: &'static str, params: P) -> Result<Value, LspError>
417 where
418 P: serde::Serialize,
419 {
420 self.ensure_can_send()?;
421
422 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
423 let (tx, rx) = bounded(1);
424 {
425 let mut pending = self.lock_pending()?;
426 pending.insert(id.clone(), tx);
427 }
428
429 let request = Request::new(id.clone(), method, Some(serde_json::to_value(params)?));
430 {
431 let mut writer = self
432 .writer
433 .lock()
434 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
435 if let Err(err) = transport::write_request(&mut *writer, &request) {
436 self.remove_pending(&id);
437 return Err(err.into());
438 }
439 }
440
441 let response = match rx.recv_timeout(REQUEST_TIMEOUT) {
442 Ok(response) => response,
443 Err(RecvTimeoutError::Timeout) => {
444 self.remove_pending(&id);
445 return Err(LspError::Timeout(format!(
446 "timed out waiting for '{}' response from {:?}",
447 method, self.kind
448 )));
449 }
450 Err(RecvTimeoutError::Disconnected) => {
451 self.remove_pending(&id);
452 return Err(LspError::ServerNotReady(format!(
453 "language server {:?} disconnected while waiting for '{}'",
454 self.kind, method
455 )));
456 }
457 };
458
459 if let Some(error) = response.error {
460 return Err(LspError::ServerError {
461 code: error.code,
462 message: error.message,
463 });
464 }
465
466 Ok(response.result.unwrap_or(Value::Null))
467 }
468
469 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
471 where
472 N: lsp_types::notification::Notification,
473 N::Params: serde::Serialize,
474 {
475 self.ensure_can_send()?;
476 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
477 let mut writer = self
478 .writer
479 .lock()
480 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
481 transport::write_notification(&mut *writer, ¬ification)?;
482 Ok(())
483 }
484
485 pub fn shutdown(&mut self) -> Result<(), LspError> {
487 if self.state == ServerState::Exited {
488 self.child_registry.untrack(self.child_pid);
489 return Ok(());
490 }
491
492 if self.child.try_wait()?.is_some() {
493 self.state = ServerState::Exited;
494 self.child_registry.untrack(self.child_pid);
495 return Ok(());
496 }
497
498 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
499 self.state = ServerState::ShuttingDown;
500 if self.child.try_wait()?.is_some() {
501 self.state = ServerState::Exited;
502 return Ok(());
503 }
504 return Err(err);
505 }
506
507 self.state = ServerState::ShuttingDown;
508
509 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
510 if self.child.try_wait()?.is_some() {
511 self.state = ServerState::Exited;
512 return Ok(());
513 }
514 return Err(err);
515 }
516
517 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
518 loop {
519 if self.child.try_wait()?.is_some() {
520 self.state = ServerState::Exited;
521 return Ok(());
522 }
523 if Instant::now() >= deadline {
524 kill_lsp_child_group(&mut self.child);
528 self.state = ServerState::Exited;
529 return Err(LspError::Timeout(format!(
530 "timed out waiting for {:?} to exit",
531 self.kind
532 )));
533 }
534 thread::sleep(EXIT_POLL_INTERVAL);
535 }
536 }
537
538 pub fn state(&self) -> ServerState {
539 self.state
540 }
541
542 pub fn kind(&self) -> ServerKind {
543 self.kind.clone()
544 }
545
546 pub fn root(&self) -> &Path {
547 &self.root
548 }
549
550 fn ensure_can_send(&self) -> Result<(), LspError> {
551 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
552 return Err(LspError::ServerNotReady(format!(
553 "language server {:?} is not ready (state: {:?})",
554 self.kind, self.state
555 )));
556 }
557 Ok(())
558 }
559
560 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
561 self.pending
562 .lock()
563 .map_err(|_| io::Error::other("pending response map poisoned").into())
564 }
565
566 fn remove_pending(&self, id: &RequestId) {
567 if let Ok(mut pending) = self.pending.lock() {
568 pending.remove(id);
569 }
570 }
571}
572
573impl Drop for LspClient {
574 fn drop(&mut self) {
575 self.child_registry.untrack(self.child_pid);
578 kill_lsp_child_group(&mut self.child);
579 }
580}
581
582fn kill_lsp_child_group(child: &mut std::process::Child) {
589 #[cfg(unix)]
590 {
591 let pgid = child.id() as i32;
592 crate::bash_background::process::terminate_pgid(pgid, Some(child));
593 let _ = child.wait();
594 }
595 #[cfg(not(unix))]
596 {
597 crate::bash_background::process::terminate_process(child);
598 let _ = child.wait();
599 }
600}
601
602fn normalize_windows_path(path: &Path) -> PathBuf {
606 let s = path.to_string_lossy();
607 if let Some(stripped) = s.strip_prefix(r"\\?\") {
608 PathBuf::from(stripped)
609 } else {
610 path.to_path_buf()
611 }
612}
613
614fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
629 let mut caps = ServerDiagnosticCapabilities::default();
630
631 if let Some(provider) = value.get("diagnosticProvider") {
632 if provider.is_object() || provider.as_bool() == Some(true) {
635 caps.pull_diagnostics = true;
636 }
637
638 if let Some(obj) = provider.as_object() {
639 if obj
640 .get("workspaceDiagnostics")
641 .and_then(|v| v.as_bool())
642 .unwrap_or(false)
643 {
644 caps.workspace_diagnostics = true;
645 }
646 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
647 caps.identifier = Some(identifier.to_string());
648 }
649 }
650 }
651
652 if let Some(refresh) = value
655 .get("workspace")
656 .and_then(|w| w.get("diagnostic"))
657 .and_then(|d| d.get("refreshSupport"))
658 .and_then(|r| r.as_bool())
659 {
660 caps.refresh_support = refresh;
661 }
662
663 caps
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 #[test]
671 fn parse_caps_no_diagnostic_provider() {
672 let value = json!({});
673 let caps = parse_diagnostic_capabilities(&value);
674 assert!(!caps.pull_diagnostics);
675 assert!(!caps.workspace_diagnostics);
676 assert!(caps.identifier.is_none());
677 }
678
679 #[test]
680 fn parse_caps_basic_pull_only() {
681 let value = json!({
682 "diagnosticProvider": {
683 "interFileDependencies": false,
684 "workspaceDiagnostics": false
685 }
686 });
687 let caps = parse_diagnostic_capabilities(&value);
688 assert!(caps.pull_diagnostics);
689 assert!(!caps.workspace_diagnostics);
690 }
691
692 #[test]
693 fn parse_caps_full_pull_with_workspace() {
694 let value = json!({
695 "diagnosticProvider": {
696 "interFileDependencies": true,
697 "workspaceDiagnostics": true,
698 "identifier": "rust-analyzer"
699 }
700 });
701 let caps = parse_diagnostic_capabilities(&value);
702 assert!(caps.pull_diagnostics);
703 assert!(caps.workspace_diagnostics);
704 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
705 }
706
707 #[test]
708 fn parse_caps_provider_as_bare_true() {
709 let value = json!({
711 "diagnosticProvider": true
712 });
713 let caps = parse_diagnostic_capabilities(&value);
714 assert!(caps.pull_diagnostics);
715 assert!(!caps.workspace_diagnostics);
716 }
717
718 #[test]
719 fn parse_caps_workspace_refresh_support() {
720 let value = json!({
721 "workspace": {
722 "diagnostic": {
723 "refreshSupport": true
724 }
725 }
726 });
727 let caps = parse_diagnostic_capabilities(&value);
728 assert!(caps.refresh_support);
729 assert!(!caps.pull_diagnostics);
731 }
732}