1use std::collections::HashMap;
2use std::io::{self, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::jsonrpc::{
16 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
17};
18use crate::lsp::registry::ServerKind;
19use crate::lsp::{transport, LspError};
20
21const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
22const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
23const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
24
25type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ServerState {
30 Starting,
31 Initializing,
32 Ready,
33 ShuttingDown,
34 Exited,
35}
36
37#[derive(Debug)]
39pub enum LspEvent {
40 Notification {
42 server_kind: ServerKind,
43 root: PathBuf,
44 method: String,
45 params: Option<Value>,
46 },
47 ServerRequest {
49 server_kind: ServerKind,
50 root: PathBuf,
51 id: RequestId,
52 method: String,
53 params: Option<Value>,
54 },
55 ServerExited {
57 server_kind: ServerKind,
58 root: PathBuf,
59 },
60}
61
62#[derive(Debug, Clone, Default)]
70pub struct ServerDiagnosticCapabilities {
71 pub pull_diagnostics: bool,
73 pub workspace_diagnostics: bool,
75 pub identifier: Option<String>,
78 pub refresh_support: bool,
82}
83
84pub struct LspClient {
86 kind: ServerKind,
87 root: PathBuf,
88 state: ServerState,
89 child: Child,
90 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
91
92 pending: Arc<Mutex<PendingMap>>,
94 next_id: AtomicI64,
96 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
100 supports_watched_files: bool,
105}
106
107impl LspClient {
108 pub fn spawn(
110 kind: ServerKind,
111 root: PathBuf,
112 binary: &Path,
113 args: &[String],
114 env: &HashMap<String, String>,
115 event_tx: Sender<LspEvent>,
116 ) -> io::Result<Self> {
117 let mut command = Command::new(binary);
118 command
119 .args(args)
120 .current_dir(&root)
121 .stdin(Stdio::piped())
122 .stdout(Stdio::piped())
123 .stderr(Stdio::null());
126 for (key, value) in env {
127 command.env(key, value);
128 }
129
130 let mut child = command.spawn()?;
131
132 let stdout = child
133 .stdout
134 .take()
135 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
136 let stdin = child
137 .stdin
138 .take()
139 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
140
141 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
142 let pending = Arc::new(Mutex::new(PendingMap::new()));
143 let reader_pending = Arc::clone(&pending);
144 let reader_writer = Arc::clone(&writer);
145 let reader_kind = kind.clone();
146 let reader_root = root.clone();
147
148 thread::spawn(move || {
149 let mut reader = BufReader::new(stdout);
150 loop {
151 match transport::read_message(&mut reader) {
152 Ok(Some(ServerMessage::Response(response))) => {
153 if let Ok(mut guard) = reader_pending.lock() {
154 if let Some(tx) = guard.remove(&response.id) {
155 if tx.send(response).is_err() {
156 log::debug!("response channel closed");
157 }
158 }
159 } else {
160 let _ = event_tx.send(LspEvent::ServerExited {
161 server_kind: reader_kind.clone(),
162 root: reader_root.clone(),
163 });
164 break;
165 }
166 }
167 Ok(Some(ServerMessage::Notification { method, params })) => {
168 let _ = event_tx.send(LspEvent::Notification {
169 server_kind: reader_kind.clone(),
170 root: reader_root.clone(),
171 method,
172 params,
173 });
174 }
175 Ok(Some(ServerMessage::Request { id, method, params })) => {
176 let response_value = if method == "workspace/configuration" {
186 let item_count = params
189 .as_ref()
190 .and_then(|p| p.get("items"))
191 .and_then(|items| items.as_array())
192 .map_or(1, |arr| arr.len());
193 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
194 } else {
195 serde_json::Value::Null
196 };
197 if let Ok(mut w) = reader_writer.lock() {
198 let response = super::jsonrpc::OutgoingResponse::success(
199 id.clone(),
200 response_value,
201 );
202 let _ = transport::write_response(&mut *w, &response);
203 }
204 let _ = event_tx.send(LspEvent::ServerRequest {
206 server_kind: reader_kind.clone(),
207 root: reader_root.clone(),
208 id,
209 method,
210 params,
211 });
212 }
213 Ok(None) | Err(_) => {
214 if let Ok(mut guard) = reader_pending.lock() {
215 guard.clear();
216 }
217 let _ = event_tx.send(LspEvent::ServerExited {
218 server_kind: reader_kind.clone(),
219 root: reader_root.clone(),
220 });
221 break;
222 }
223 }
224 }
225 });
226
227 Ok(Self {
228 kind,
229 root,
230 state: ServerState::Starting,
231 child,
232 writer,
233 pending,
234 next_id: AtomicI64::new(1),
235 diagnostic_caps: None,
236 supports_watched_files: false,
237 })
238 }
239
240 pub fn initialize(
242 &mut self,
243 workspace_root: &Path,
244 initialization_options: Option<serde_json::Value>,
245 ) -> Result<lsp_types::InitializeResult, LspError> {
246 self.ensure_can_send()?;
247 self.state = ServerState::Initializing;
248
249 let normalized = normalize_windows_path(workspace_root);
250 let root_url = url::Url::from_file_path(&normalized).map_err(|_| {
251 LspError::NotFound(format!(
252 "failed to convert workspace root '{}' to file URI",
253 workspace_root.display()
254 ))
255 })?;
256 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
257 LspError::NotFound(format!(
258 "failed to convert workspace root '{}' to file URI",
259 workspace_root.display()
260 ))
261 })?;
262
263 let mut params_value = json!({
264 "processId": std::process::id(),
265 "rootUri": root_uri,
266 "capabilities": {
267 "workspace": {
268 "workspaceFolders": true,
269 "configuration": true,
270 "diagnostic": {
275 "refreshSupport": false
276 }
277 },
278 "textDocument": {
279 "synchronization": {
280 "dynamicRegistration": false,
281 "didSave": true,
282 "willSave": false,
283 "willSaveWaitUntil": false
284 },
285 "publishDiagnostics": {
286 "relatedInformation": true,
287 "versionSupport": true,
288 "codeDescriptionSupport": true,
289 "dataSupport": true
290 },
291 "diagnostic": {
296 "dynamicRegistration": false,
297 "relatedDocumentSupport": true
298 }
299 }
300 },
301 "clientInfo": {
302 "name": "aft",
303 "version": env!("CARGO_PKG_VERSION")
304 },
305 "workspaceFolders": [
306 {
307 "uri": root_uri,
308 "name": workspace_root
309 .file_name()
310 .and_then(|name| name.to_str())
311 .unwrap_or("workspace")
312 }
313 ]
314 });
315 if let Some(initialization_options) = initialization_options {
316 params_value["initializationOptions"] = initialization_options;
317 }
318
319 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
320
321 let result = self.send_request::<lsp_types::request::Initialize>(params)?;
322
323 let caps_value = serde_json::to_value(&result.capabilities).unwrap_or(Value::Null);
328 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
329
330 self.supports_watched_files = caps_value
345 .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
346 .and_then(|v| v.as_bool())
347 .unwrap_or(true) || caps_value
349 .pointer("/workspace/didChangeWatchedFiles")
350 .map(|v| v.is_object() || v.as_bool() == Some(true))
351 .unwrap_or(true);
352
353 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
354 json!({}),
355 )?)?;
356 self.state = ServerState::Ready;
357 Ok(result)
358 }
359
360 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
364 self.diagnostic_caps.as_ref()
365 }
366
367 pub fn supports_watched_files(&self) -> bool {
370 self.supports_watched_files
371 }
372
373 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
375 where
376 R: lsp_types::request::Request,
377 R::Params: serde::Serialize,
378 R::Result: DeserializeOwned,
379 {
380 self.ensure_can_send()?;
381
382 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
383 let (tx, rx) = bounded(1);
384 {
385 let mut pending = self.lock_pending()?;
386 pending.insert(id.clone(), tx);
387 }
388
389 let request = Request::new(id.clone(), R::METHOD, Some(serde_json::to_value(params)?));
390 {
391 let mut writer = self
392 .writer
393 .lock()
394 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
395 if let Err(err) = transport::write_request(&mut *writer, &request) {
396 self.remove_pending(&id);
397 return Err(err.into());
398 }
399 }
400
401 let response = match rx.recv_timeout(REQUEST_TIMEOUT) {
402 Ok(response) => response,
403 Err(RecvTimeoutError::Timeout) => {
404 self.remove_pending(&id);
405 return Err(LspError::Timeout(format!(
406 "timed out waiting for '{}' response from {:?}",
407 R::METHOD,
408 self.kind
409 )));
410 }
411 Err(RecvTimeoutError::Disconnected) => {
412 self.remove_pending(&id);
413 return Err(LspError::ServerNotReady(format!(
414 "language server {:?} disconnected while waiting for '{}'",
415 self.kind,
416 R::METHOD
417 )));
418 }
419 };
420
421 if let Some(error) = response.error {
422 return Err(LspError::ServerError {
423 code: error.code,
424 message: error.message,
425 });
426 }
427
428 serde_json::from_value(response.result.unwrap_or(Value::Null)).map_err(Into::into)
429 }
430
431 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
433 where
434 N: lsp_types::notification::Notification,
435 N::Params: serde::Serialize,
436 {
437 self.ensure_can_send()?;
438 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
439 let mut writer = self
440 .writer
441 .lock()
442 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
443 transport::write_notification(&mut *writer, ¬ification)?;
444 Ok(())
445 }
446
447 pub fn shutdown(&mut self) -> Result<(), LspError> {
449 if self.state == ServerState::Exited {
450 return Ok(());
451 }
452
453 if self.child.try_wait()?.is_some() {
454 self.state = ServerState::Exited;
455 return Ok(());
456 }
457
458 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
459 self.state = ServerState::ShuttingDown;
460 if self.child.try_wait()?.is_some() {
461 self.state = ServerState::Exited;
462 return Ok(());
463 }
464 return Err(err);
465 }
466
467 self.state = ServerState::ShuttingDown;
468
469 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
470 if self.child.try_wait()?.is_some() {
471 self.state = ServerState::Exited;
472 return Ok(());
473 }
474 return Err(err);
475 }
476
477 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
478 loop {
479 if self.child.try_wait()?.is_some() {
480 self.state = ServerState::Exited;
481 return Ok(());
482 }
483 if Instant::now() >= deadline {
484 let _ = self.child.kill();
485 let _ = self.child.wait();
486 self.state = ServerState::Exited;
487 return Err(LspError::Timeout(format!(
488 "timed out waiting for {:?} to exit",
489 self.kind
490 )));
491 }
492 thread::sleep(EXIT_POLL_INTERVAL);
493 }
494 }
495
496 pub fn state(&self) -> ServerState {
497 self.state
498 }
499
500 pub fn kind(&self) -> ServerKind {
501 self.kind.clone()
502 }
503
504 pub fn root(&self) -> &Path {
505 &self.root
506 }
507
508 fn ensure_can_send(&self) -> Result<(), LspError> {
509 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
510 return Err(LspError::ServerNotReady(format!(
511 "language server {:?} is not ready (state: {:?})",
512 self.kind, self.state
513 )));
514 }
515 Ok(())
516 }
517
518 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
519 self.pending
520 .lock()
521 .map_err(|_| io::Error::other("pending response map poisoned").into())
522 }
523
524 fn remove_pending(&self, id: &RequestId) {
525 if let Ok(mut pending) = self.pending.lock() {
526 pending.remove(id);
527 }
528 }
529}
530
531impl Drop for LspClient {
532 fn drop(&mut self) {
533 let _ = self.child.kill();
534 let _ = self.child.wait();
535 }
536}
537
538fn normalize_windows_path(path: &Path) -> PathBuf {
542 let s = path.to_string_lossy();
543 if let Some(stripped) = s.strip_prefix(r"\\?\") {
544 PathBuf::from(stripped)
545 } else {
546 path.to_path_buf()
547 }
548}
549
550fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
565 let mut caps = ServerDiagnosticCapabilities::default();
566
567 if let Some(provider) = value.get("diagnosticProvider") {
568 if provider.is_object() || provider.as_bool() == Some(true) {
571 caps.pull_diagnostics = true;
572 }
573
574 if let Some(obj) = provider.as_object() {
575 if obj
576 .get("workspaceDiagnostics")
577 .and_then(|v| v.as_bool())
578 .unwrap_or(false)
579 {
580 caps.workspace_diagnostics = true;
581 }
582 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
583 caps.identifier = Some(identifier.to_string());
584 }
585 }
586 }
587
588 if let Some(refresh) = value
591 .get("workspace")
592 .and_then(|w| w.get("diagnostic"))
593 .and_then(|d| d.get("refreshSupport"))
594 .and_then(|r| r.as_bool())
595 {
596 caps.refresh_support = refresh;
597 }
598
599 caps
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn parse_caps_no_diagnostic_provider() {
608 let value = json!({});
609 let caps = parse_diagnostic_capabilities(&value);
610 assert!(!caps.pull_diagnostics);
611 assert!(!caps.workspace_diagnostics);
612 assert!(caps.identifier.is_none());
613 }
614
615 #[test]
616 fn parse_caps_basic_pull_only() {
617 let value = json!({
618 "diagnosticProvider": {
619 "interFileDependencies": false,
620 "workspaceDiagnostics": false
621 }
622 });
623 let caps = parse_diagnostic_capabilities(&value);
624 assert!(caps.pull_diagnostics);
625 assert!(!caps.workspace_diagnostics);
626 }
627
628 #[test]
629 fn parse_caps_full_pull_with_workspace() {
630 let value = json!({
631 "diagnosticProvider": {
632 "interFileDependencies": true,
633 "workspaceDiagnostics": true,
634 "identifier": "rust-analyzer"
635 }
636 });
637 let caps = parse_diagnostic_capabilities(&value);
638 assert!(caps.pull_diagnostics);
639 assert!(caps.workspace_diagnostics);
640 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
641 }
642
643 #[test]
644 fn parse_caps_provider_as_bare_true() {
645 let value = json!({
647 "diagnosticProvider": true
648 });
649 let caps = parse_diagnostic_capabilities(&value);
650 assert!(caps.pull_diagnostics);
651 assert!(!caps.workspace_diagnostics);
652 }
653
654 #[test]
655 fn parse_caps_workspace_refresh_support() {
656 let value = json!({
657 "workspace": {
658 "diagnostic": {
659 "refreshSupport": true
660 }
661 }
662 });
663 let caps = parse_diagnostic_capabilities(&value);
664 assert!(caps.refresh_support);
665 assert!(!caps.pull_diagnostics);
667 }
668}