1use std::collections::HashMap;
2use std::io::{self, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::child_registry::LspChildRegistry;
16use crate::lsp::jsonrpc::{
17 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
18};
19use crate::lsp::registry::ServerKind;
20use crate::lsp::{transport, LspError};
21
22const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
23const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
24const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
25
26type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ServerState {
31 Starting,
32 Initializing,
33 Ready,
34 ShuttingDown,
35 Exited,
36}
37
38#[derive(Debug)]
40pub enum LspEvent {
41 Notification {
43 server_kind: ServerKind,
44 root: PathBuf,
45 method: String,
46 params: Option<Value>,
47 },
48 ServerRequest {
50 server_kind: ServerKind,
51 root: PathBuf,
52 id: RequestId,
53 method: String,
54 params: Option<Value>,
55 },
56 ServerExited {
58 server_kind: ServerKind,
59 root: PathBuf,
60 },
61}
62
63#[derive(Debug, Clone, Default)]
71pub struct ServerDiagnosticCapabilities {
72 pub pull_diagnostics: bool,
74 pub workspace_diagnostics: bool,
76 pub identifier: Option<String>,
79 pub refresh_support: bool,
83}
84
85pub struct LspClient {
87 kind: ServerKind,
88 root: PathBuf,
89 state: ServerState,
90 child: Child,
91 child_pid: u32,
95 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
96
97 pending: Arc<Mutex<PendingMap>>,
99 next_id: AtomicI64,
101 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
105 supports_watched_files: bool,
110 child_registry: LspChildRegistry,
114}
115
116impl LspClient {
117 pub fn spawn(
123 kind: ServerKind,
124 root: PathBuf,
125 binary: &Path,
126 args: &[String],
127 env: &HashMap<String, String>,
128 event_tx: Sender<LspEvent>,
129 child_registry: LspChildRegistry,
130 ) -> io::Result<Self> {
131 let mut command = Command::new(binary);
132 command
133 .args(args)
134 .current_dir(&root)
135 .stdin(Stdio::piped())
136 .stdout(Stdio::piped())
137 .stderr(Stdio::null());
140 for (key, value) in env {
141 command.env(key, value);
142 }
143
144 let mut child = command.spawn()?;
145 let child_pid = child.id();
146 child_registry.track(child_pid);
147
148 let stdout = child
149 .stdout
150 .take()
151 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
152 let stdin = child
153 .stdin
154 .take()
155 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
156
157 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
158 let pending = Arc::new(Mutex::new(PendingMap::new()));
159 let reader_pending = Arc::clone(&pending);
160 let reader_writer = Arc::clone(&writer);
161 let reader_kind = kind.clone();
162 let reader_root = root.clone();
163
164 thread::spawn(move || {
165 let mut reader = BufReader::new(stdout);
166 loop {
167 match transport::read_message(&mut reader) {
168 Ok(Some(ServerMessage::Response(response))) => {
169 if let Ok(mut guard) = reader_pending.lock() {
170 if let Some(tx) = guard.remove(&response.id) {
171 if tx.send(response).is_err() {
172 log::debug!("response channel closed");
173 }
174 }
175 } else {
176 let _ = event_tx.send(LspEvent::ServerExited {
177 server_kind: reader_kind.clone(),
178 root: reader_root.clone(),
179 });
180 break;
181 }
182 }
183 Ok(Some(ServerMessage::Notification { method, params })) => {
184 let _ = event_tx.send(LspEvent::Notification {
185 server_kind: reader_kind.clone(),
186 root: reader_root.clone(),
187 method,
188 params,
189 });
190 }
191 Ok(Some(ServerMessage::Request { id, method, params })) => {
192 let response_value = if method == "workspace/configuration" {
202 let item_count = params
205 .as_ref()
206 .and_then(|p| p.get("items"))
207 .and_then(|items| items.as_array())
208 .map_or(1, |arr| arr.len());
209 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
210 } else {
211 serde_json::Value::Null
212 };
213 if let Ok(mut w) = reader_writer.lock() {
214 let response = super::jsonrpc::OutgoingResponse::success(
215 id.clone(),
216 response_value,
217 );
218 let _ = transport::write_response(&mut *w, &response);
219 }
220 let _ = event_tx.send(LspEvent::ServerRequest {
222 server_kind: reader_kind.clone(),
223 root: reader_root.clone(),
224 id,
225 method,
226 params,
227 });
228 }
229 Ok(None) | Err(_) => {
230 if let Ok(mut guard) = reader_pending.lock() {
231 guard.clear();
232 }
233 let _ = event_tx.send(LspEvent::ServerExited {
234 server_kind: reader_kind.clone(),
235 root: reader_root.clone(),
236 });
237 break;
238 }
239 }
240 }
241 });
242
243 Ok(Self {
244 kind,
245 root,
246 state: ServerState::Starting,
247 child,
248 child_pid,
249 writer,
250 pending,
251 next_id: AtomicI64::new(1),
252 diagnostic_caps: None,
253 supports_watched_files: false,
254 child_registry,
255 })
256 }
257
258 pub fn initialize(
260 &mut self,
261 workspace_root: &Path,
262 initialization_options: Option<serde_json::Value>,
263 ) -> Result<lsp_types::InitializeResult, LspError> {
264 self.ensure_can_send()?;
265 self.state = ServerState::Initializing;
266
267 let normalized = normalize_windows_path(workspace_root);
268 let root_url = url::Url::from_file_path(&normalized).map_err(|_| {
269 LspError::NotFound(format!(
270 "failed to convert workspace root '{}' to file URI",
271 workspace_root.display()
272 ))
273 })?;
274 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
275 LspError::NotFound(format!(
276 "failed to convert workspace root '{}' to file URI",
277 workspace_root.display()
278 ))
279 })?;
280
281 let mut params_value = json!({
282 "processId": std::process::id(),
283 "rootUri": root_uri,
284 "capabilities": {
285 "workspace": {
286 "workspaceFolders": true,
287 "configuration": true,
288 "diagnostic": {
293 "refreshSupport": false
294 }
295 },
296 "textDocument": {
297 "synchronization": {
298 "dynamicRegistration": false,
299 "didSave": true,
300 "willSave": false,
301 "willSaveWaitUntil": false
302 },
303 "publishDiagnostics": {
304 "relatedInformation": true,
305 "versionSupport": true,
306 "codeDescriptionSupport": true,
307 "dataSupport": true
308 },
309 "diagnostic": {
314 "dynamicRegistration": false,
315 "relatedDocumentSupport": true
316 }
317 }
318 },
319 "clientInfo": {
320 "name": "aft",
321 "version": env!("CARGO_PKG_VERSION")
322 },
323 "workspaceFolders": [
324 {
325 "uri": root_uri,
326 "name": workspace_root
327 .file_name()
328 .and_then(|name| name.to_str())
329 .unwrap_or("workspace")
330 }
331 ]
332 });
333 if let Some(initialization_options) = initialization_options {
334 params_value["initializationOptions"] = initialization_options;
335 }
336
337 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
338
339 let result = self.send_request::<lsp_types::request::Initialize>(params)?;
340
341 let caps_value = serde_json::to_value(&result.capabilities).unwrap_or(Value::Null);
346 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
347
348 self.supports_watched_files = caps_value
363 .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
364 .and_then(|v| v.as_bool())
365 .unwrap_or(true) || caps_value
367 .pointer("/workspace/didChangeWatchedFiles")
368 .map(|v| v.is_object() || v.as_bool() == Some(true))
369 .unwrap_or(true);
370
371 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
372 json!({}),
373 )?)?;
374 self.state = ServerState::Ready;
375 Ok(result)
376 }
377
378 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
382 self.diagnostic_caps.as_ref()
383 }
384
385 pub fn supports_watched_files(&self) -> bool {
388 self.supports_watched_files
389 }
390
391 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
393 where
394 R: lsp_types::request::Request,
395 R::Params: serde::Serialize,
396 R::Result: DeserializeOwned,
397 {
398 self.ensure_can_send()?;
399
400 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
401 let (tx, rx) = bounded(1);
402 {
403 let mut pending = self.lock_pending()?;
404 pending.insert(id.clone(), tx);
405 }
406
407 let request = Request::new(id.clone(), R::METHOD, Some(serde_json::to_value(params)?));
408 {
409 let mut writer = self
410 .writer
411 .lock()
412 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
413 if let Err(err) = transport::write_request(&mut *writer, &request) {
414 self.remove_pending(&id);
415 return Err(err.into());
416 }
417 }
418
419 let response = match rx.recv_timeout(REQUEST_TIMEOUT) {
420 Ok(response) => response,
421 Err(RecvTimeoutError::Timeout) => {
422 self.remove_pending(&id);
423 return Err(LspError::Timeout(format!(
424 "timed out waiting for '{}' response from {:?}",
425 R::METHOD,
426 self.kind
427 )));
428 }
429 Err(RecvTimeoutError::Disconnected) => {
430 self.remove_pending(&id);
431 return Err(LspError::ServerNotReady(format!(
432 "language server {:?} disconnected while waiting for '{}'",
433 self.kind,
434 R::METHOD
435 )));
436 }
437 };
438
439 if let Some(error) = response.error {
440 return Err(LspError::ServerError {
441 code: error.code,
442 message: error.message,
443 });
444 }
445
446 serde_json::from_value(response.result.unwrap_or(Value::Null)).map_err(Into::into)
447 }
448
449 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
451 where
452 N: lsp_types::notification::Notification,
453 N::Params: serde::Serialize,
454 {
455 self.ensure_can_send()?;
456 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
457 let mut writer = self
458 .writer
459 .lock()
460 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
461 transport::write_notification(&mut *writer, ¬ification)?;
462 Ok(())
463 }
464
465 pub fn shutdown(&mut self) -> Result<(), LspError> {
467 if self.state == ServerState::Exited {
468 self.child_registry.untrack(self.child_pid);
469 return Ok(());
470 }
471
472 if self.child.try_wait()?.is_some() {
473 self.state = ServerState::Exited;
474 self.child_registry.untrack(self.child_pid);
475 return Ok(());
476 }
477
478 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
479 self.state = ServerState::ShuttingDown;
480 if self.child.try_wait()?.is_some() {
481 self.state = ServerState::Exited;
482 return Ok(());
483 }
484 return Err(err);
485 }
486
487 self.state = ServerState::ShuttingDown;
488
489 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
490 if self.child.try_wait()?.is_some() {
491 self.state = ServerState::Exited;
492 return Ok(());
493 }
494 return Err(err);
495 }
496
497 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
498 loop {
499 if self.child.try_wait()?.is_some() {
500 self.state = ServerState::Exited;
501 return Ok(());
502 }
503 if Instant::now() >= deadline {
504 let _ = self.child.kill();
505 let _ = self.child.wait();
506 self.state = ServerState::Exited;
507 return Err(LspError::Timeout(format!(
508 "timed out waiting for {:?} to exit",
509 self.kind
510 )));
511 }
512 thread::sleep(EXIT_POLL_INTERVAL);
513 }
514 }
515
516 pub fn state(&self) -> ServerState {
517 self.state
518 }
519
520 pub fn kind(&self) -> ServerKind {
521 self.kind.clone()
522 }
523
524 pub fn root(&self) -> &Path {
525 &self.root
526 }
527
528 fn ensure_can_send(&self) -> Result<(), LspError> {
529 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
530 return Err(LspError::ServerNotReady(format!(
531 "language server {:?} is not ready (state: {:?})",
532 self.kind, self.state
533 )));
534 }
535 Ok(())
536 }
537
538 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
539 self.pending
540 .lock()
541 .map_err(|_| io::Error::other("pending response map poisoned").into())
542 }
543
544 fn remove_pending(&self, id: &RequestId) {
545 if let Ok(mut pending) = self.pending.lock() {
546 pending.remove(id);
547 }
548 }
549}
550
551impl Drop for LspClient {
552 fn drop(&mut self) {
553 self.child_registry.untrack(self.child_pid);
556 let _ = self.child.kill();
557 let _ = self.child.wait();
558 }
559}
560
561fn normalize_windows_path(path: &Path) -> PathBuf {
565 let s = path.to_string_lossy();
566 if let Some(stripped) = s.strip_prefix(r"\\?\") {
567 PathBuf::from(stripped)
568 } else {
569 path.to_path_buf()
570 }
571}
572
573fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
588 let mut caps = ServerDiagnosticCapabilities::default();
589
590 if let Some(provider) = value.get("diagnosticProvider") {
591 if provider.is_object() || provider.as_bool() == Some(true) {
594 caps.pull_diagnostics = true;
595 }
596
597 if let Some(obj) = provider.as_object() {
598 if obj
599 .get("workspaceDiagnostics")
600 .and_then(|v| v.as_bool())
601 .unwrap_or(false)
602 {
603 caps.workspace_diagnostics = true;
604 }
605 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
606 caps.identifier = Some(identifier.to_string());
607 }
608 }
609 }
610
611 if let Some(refresh) = value
614 .get("workspace")
615 .and_then(|w| w.get("diagnostic"))
616 .and_then(|d| d.get("refreshSupport"))
617 .and_then(|r| r.as_bool())
618 {
619 caps.refresh_support = refresh;
620 }
621
622 caps
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628
629 #[test]
630 fn parse_caps_no_diagnostic_provider() {
631 let value = json!({});
632 let caps = parse_diagnostic_capabilities(&value);
633 assert!(!caps.pull_diagnostics);
634 assert!(!caps.workspace_diagnostics);
635 assert!(caps.identifier.is_none());
636 }
637
638 #[test]
639 fn parse_caps_basic_pull_only() {
640 let value = json!({
641 "diagnosticProvider": {
642 "interFileDependencies": false,
643 "workspaceDiagnostics": false
644 }
645 });
646 let caps = parse_diagnostic_capabilities(&value);
647 assert!(caps.pull_diagnostics);
648 assert!(!caps.workspace_diagnostics);
649 }
650
651 #[test]
652 fn parse_caps_full_pull_with_workspace() {
653 let value = json!({
654 "diagnosticProvider": {
655 "interFileDependencies": true,
656 "workspaceDiagnostics": true,
657 "identifier": "rust-analyzer"
658 }
659 });
660 let caps = parse_diagnostic_capabilities(&value);
661 assert!(caps.pull_diagnostics);
662 assert!(caps.workspace_diagnostics);
663 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
664 }
665
666 #[test]
667 fn parse_caps_provider_as_bare_true() {
668 let value = json!({
670 "diagnosticProvider": true
671 });
672 let caps = parse_diagnostic_capabilities(&value);
673 assert!(caps.pull_diagnostics);
674 assert!(!caps.workspace_diagnostics);
675 }
676
677 #[test]
678 fn parse_caps_workspace_refresh_support() {
679 let value = json!({
680 "workspace": {
681 "diagnostic": {
682 "refreshSupport": true
683 }
684 }
685 });
686 let caps = parse_diagnostic_capabilities(&value);
687 assert!(caps.refresh_support);
688 assert!(!caps.pull_diagnostics);
690 }
691}