1use std::collections::HashMap;
2use std::io::{self, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::child_registry::LspChildRegistry;
16use crate::lsp::jsonrpc::{
17 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
18};
19use crate::lsp::registry::ServerKind;
20use crate::lsp::{transport, LspError};
21
22const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
23const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
24const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
25
26type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ServerState {
31 Starting,
32 Initializing,
33 Ready,
34 ShuttingDown,
35 Exited,
36}
37
38#[derive(Debug)]
40pub enum LspEvent {
41 Notification {
43 server_kind: ServerKind,
44 root: PathBuf,
45 method: String,
46 params: Option<Value>,
47 },
48 ServerRequest {
50 server_kind: ServerKind,
51 root: PathBuf,
52 id: RequestId,
53 method: String,
54 params: Option<Value>,
55 },
56 ServerExited {
58 server_kind: ServerKind,
59 root: PathBuf,
60 },
61}
62
63#[derive(Debug, Clone, Default)]
71pub struct ServerDiagnosticCapabilities {
72 pub pull_diagnostics: bool,
74 pub workspace_diagnostics: bool,
76 pub identifier: Option<String>,
79 pub refresh_support: bool,
83}
84
85pub struct LspClient {
87 kind: ServerKind,
88 root: PathBuf,
89 state: ServerState,
90 child: Child,
91 child_pid: u32,
95 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
96
97 pending: Arc<Mutex<PendingMap>>,
99 next_id: AtomicI64,
101 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
105 supports_watched_files: bool,
110 child_registry: LspChildRegistry,
114}
115
116impl LspClient {
117 pub fn spawn(
123 kind: ServerKind,
124 root: PathBuf,
125 binary: &Path,
126 args: &[String],
127 env: &HashMap<String, String>,
128 event_tx: Sender<LspEvent>,
129 child_registry: LspChildRegistry,
130 ) -> io::Result<Self> {
131 let mut command = Command::new(binary);
132 command
133 .args(args)
134 .current_dir(&root)
135 .stdin(Stdio::piped())
136 .stdout(Stdio::piped())
137 .stderr(Stdio::null());
140 for (key, value) in env {
141 command.env(key, value);
142 }
143
144 let mut child = command.spawn()?;
145 let child_pid = child.id();
146 child_registry.track(child_pid);
147
148 let stdout = child
149 .stdout
150 .take()
151 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
152 let stdin = child
153 .stdin
154 .take()
155 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
156
157 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
158 let pending = Arc::new(Mutex::new(PendingMap::new()));
159 let reader_pending = Arc::clone(&pending);
160 let reader_writer = Arc::clone(&writer);
161 let reader_kind = kind.clone();
162 let reader_root = root.clone();
163
164 thread::spawn(move || {
165 let mut reader = BufReader::new(stdout);
166 loop {
167 match transport::read_message(&mut reader) {
168 Ok(Some(ServerMessage::Response(response))) => {
169 if let Ok(mut guard) = reader_pending.lock() {
170 if let Some(tx) = guard.remove(&response.id) {
171 if tx.send(response).is_err() {
172 log::debug!("response channel closed");
173 }
174 }
175 } else {
176 let _ = event_tx.send(LspEvent::ServerExited {
177 server_kind: reader_kind.clone(),
178 root: reader_root.clone(),
179 });
180 break;
181 }
182 }
183 Ok(Some(ServerMessage::Notification { method, params })) => {
184 let _ = event_tx.send(LspEvent::Notification {
185 server_kind: reader_kind.clone(),
186 root: reader_root.clone(),
187 method,
188 params,
189 });
190 }
191 Ok(Some(ServerMessage::Request { id, method, params })) => {
192 let response_value = if method == "workspace/configuration" {
202 let item_count = params
205 .as_ref()
206 .and_then(|p| p.get("items"))
207 .and_then(|items| items.as_array())
208 .map_or(1, |arr| arr.len());
209 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
210 } else {
211 serde_json::Value::Null
212 };
213 if let Ok(mut w) = reader_writer.lock() {
214 let response = super::jsonrpc::OutgoingResponse::success(
215 id.clone(),
216 response_value,
217 );
218 let _ = transport::write_response(&mut *w, &response);
219 }
220 let _ = event_tx.send(LspEvent::ServerRequest {
222 server_kind: reader_kind.clone(),
223 root: reader_root.clone(),
224 id,
225 method,
226 params,
227 });
228 }
229 Ok(None) | Err(_) => {
230 if let Ok(mut guard) = reader_pending.lock() {
231 guard.clear();
232 }
233 let _ = event_tx.send(LspEvent::ServerExited {
234 server_kind: reader_kind.clone(),
235 root: reader_root.clone(),
236 });
237 break;
238 }
239 }
240 }
241 });
242
243 Ok(Self {
244 kind,
245 root,
246 state: ServerState::Starting,
247 child,
248 child_pid,
249 writer,
250 pending,
251 next_id: AtomicI64::new(1),
252 diagnostic_caps: None,
253 supports_watched_files: false,
254 child_registry,
255 })
256 }
257
258 pub fn initialize(
260 &mut self,
261 workspace_root: &Path,
262 initialization_options: Option<serde_json::Value>,
263 ) -> Result<lsp_types::InitializeResult, LspError> {
264 self.ensure_can_send()?;
265 self.state = ServerState::Initializing;
266
267 let normalized = normalize_windows_path(workspace_root);
268 let root_url = url::Url::from_file_path(&normalized).map_err(|_| {
269 LspError::NotFound(format!(
270 "failed to convert workspace root '{}' to file URI",
271 workspace_root.display()
272 ))
273 })?;
274 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
275 LspError::NotFound(format!(
276 "failed to convert workspace root '{}' to file URI",
277 workspace_root.display()
278 ))
279 })?;
280
281 let mut params_value = json!({
282 "processId": std::process::id(),
283 "rootUri": root_uri,
284 "capabilities": {
285 "workspace": {
286 "workspaceFolders": true,
287 "configuration": true,
288 "diagnostic": {
293 "refreshSupport": false
294 }
295 },
296 "textDocument": {
297 "synchronization": {
298 "dynamicRegistration": false,
299 "didSave": true,
300 "willSave": false,
301 "willSaveWaitUntil": false
302 },
303 "publishDiagnostics": {
304 "relatedInformation": true,
305 "versionSupport": true,
306 "codeDescriptionSupport": true,
307 "dataSupport": true
308 },
309 "diagnostic": {
314 "dynamicRegistration": false,
315 "relatedDocumentSupport": true
316 }
317 }
318 },
319 "clientInfo": {
320 "name": "aft",
321 "version": env!("CARGO_PKG_VERSION")
322 },
323 "workspaceFolders": [
324 {
325 "uri": root_uri,
326 "name": workspace_root
327 .file_name()
328 .and_then(|name| name.to_str())
329 .unwrap_or("workspace")
330 }
331 ]
332 });
333 if let Some(initialization_options) = initialization_options {
334 params_value["initializationOptions"] = initialization_options;
335 }
336
337 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
338
339 let result_value = self.send_request_value(
340 <lsp_types::request::Initialize as lsp_types::request::Request>::METHOD,
341 params,
342 )?;
343 let result: lsp_types::InitializeResult = serde_json::from_value(result_value.clone())?;
344
345 let caps_value = result_value
350 .get("capabilities")
351 .cloned()
352 .unwrap_or_else(|| serde_json::to_value(&result.capabilities).unwrap_or(Value::Null));
353 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
354
355 self.supports_watched_files = caps_value
359 .pointer("/workspace/didChangeWatchedFiles/dynamicRegistration")
360 .and_then(|v| v.as_bool())
361 .unwrap_or(false)
362 || caps_value
363 .pointer("/workspace/didChangeWatchedFiles")
364 .map(|v| v.is_object() || v.as_bool() == Some(true))
365 .unwrap_or(false);
366
367 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
368 json!({}),
369 )?)?;
370 self.state = ServerState::Ready;
371 Ok(result)
372 }
373
374 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
378 self.diagnostic_caps.as_ref()
379 }
380
381 pub fn supports_watched_files(&self) -> bool {
384 self.supports_watched_files
385 }
386
387 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
389 where
390 R: lsp_types::request::Request,
391 R::Params: serde::Serialize,
392 R::Result: DeserializeOwned,
393 {
394 self.ensure_can_send()?;
395
396 let value = self.send_request_value(R::METHOD, params)?;
397 serde_json::from_value(value).map_err(Into::into)
398 }
399
400 fn send_request_value<P>(&mut self, method: &'static str, params: P) -> Result<Value, LspError>
401 where
402 P: serde::Serialize,
403 {
404 self.ensure_can_send()?;
405
406 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
407 let (tx, rx) = bounded(1);
408 {
409 let mut pending = self.lock_pending()?;
410 pending.insert(id.clone(), tx);
411 }
412
413 let request = Request::new(id.clone(), method, Some(serde_json::to_value(params)?));
414 {
415 let mut writer = self
416 .writer
417 .lock()
418 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
419 if let Err(err) = transport::write_request(&mut *writer, &request) {
420 self.remove_pending(&id);
421 return Err(err.into());
422 }
423 }
424
425 let response = match rx.recv_timeout(REQUEST_TIMEOUT) {
426 Ok(response) => response,
427 Err(RecvTimeoutError::Timeout) => {
428 self.remove_pending(&id);
429 return Err(LspError::Timeout(format!(
430 "timed out waiting for '{}' response from {:?}",
431 method, self.kind
432 )));
433 }
434 Err(RecvTimeoutError::Disconnected) => {
435 self.remove_pending(&id);
436 return Err(LspError::ServerNotReady(format!(
437 "language server {:?} disconnected while waiting for '{}'",
438 self.kind, method
439 )));
440 }
441 };
442
443 if let Some(error) = response.error {
444 return Err(LspError::ServerError {
445 code: error.code,
446 message: error.message,
447 });
448 }
449
450 Ok(response.result.unwrap_or(Value::Null))
451 }
452
453 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
455 where
456 N: lsp_types::notification::Notification,
457 N::Params: serde::Serialize,
458 {
459 self.ensure_can_send()?;
460 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
461 let mut writer = self
462 .writer
463 .lock()
464 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
465 transport::write_notification(&mut *writer, ¬ification)?;
466 Ok(())
467 }
468
469 pub fn shutdown(&mut self) -> Result<(), LspError> {
471 if self.state == ServerState::Exited {
472 self.child_registry.untrack(self.child_pid);
473 return Ok(());
474 }
475
476 if self.child.try_wait()?.is_some() {
477 self.state = ServerState::Exited;
478 self.child_registry.untrack(self.child_pid);
479 return Ok(());
480 }
481
482 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
483 self.state = ServerState::ShuttingDown;
484 if self.child.try_wait()?.is_some() {
485 self.state = ServerState::Exited;
486 return Ok(());
487 }
488 return Err(err);
489 }
490
491 self.state = ServerState::ShuttingDown;
492
493 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
494 if self.child.try_wait()?.is_some() {
495 self.state = ServerState::Exited;
496 return Ok(());
497 }
498 return Err(err);
499 }
500
501 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
502 loop {
503 if self.child.try_wait()?.is_some() {
504 self.state = ServerState::Exited;
505 return Ok(());
506 }
507 if Instant::now() >= deadline {
508 let _ = self.child.kill();
509 let _ = self.child.wait();
510 self.state = ServerState::Exited;
511 return Err(LspError::Timeout(format!(
512 "timed out waiting for {:?} to exit",
513 self.kind
514 )));
515 }
516 thread::sleep(EXIT_POLL_INTERVAL);
517 }
518 }
519
520 pub fn state(&self) -> ServerState {
521 self.state
522 }
523
524 pub fn kind(&self) -> ServerKind {
525 self.kind.clone()
526 }
527
528 pub fn root(&self) -> &Path {
529 &self.root
530 }
531
532 fn ensure_can_send(&self) -> Result<(), LspError> {
533 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
534 return Err(LspError::ServerNotReady(format!(
535 "language server {:?} is not ready (state: {:?})",
536 self.kind, self.state
537 )));
538 }
539 Ok(())
540 }
541
542 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
543 self.pending
544 .lock()
545 .map_err(|_| io::Error::other("pending response map poisoned").into())
546 }
547
548 fn remove_pending(&self, id: &RequestId) {
549 if let Ok(mut pending) = self.pending.lock() {
550 pending.remove(id);
551 }
552 }
553}
554
555impl Drop for LspClient {
556 fn drop(&mut self) {
557 self.child_registry.untrack(self.child_pid);
560 let _ = self.child.kill();
561 let _ = self.child.wait();
562 }
563}
564
565fn normalize_windows_path(path: &Path) -> PathBuf {
569 let s = path.to_string_lossy();
570 if let Some(stripped) = s.strip_prefix(r"\\?\") {
571 PathBuf::from(stripped)
572 } else {
573 path.to_path_buf()
574 }
575}
576
577fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
592 let mut caps = ServerDiagnosticCapabilities::default();
593
594 if let Some(provider) = value.get("diagnosticProvider") {
595 if provider.is_object() || provider.as_bool() == Some(true) {
598 caps.pull_diagnostics = true;
599 }
600
601 if let Some(obj) = provider.as_object() {
602 if obj
603 .get("workspaceDiagnostics")
604 .and_then(|v| v.as_bool())
605 .unwrap_or(false)
606 {
607 caps.workspace_diagnostics = true;
608 }
609 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
610 caps.identifier = Some(identifier.to_string());
611 }
612 }
613 }
614
615 if let Some(refresh) = value
618 .get("workspace")
619 .and_then(|w| w.get("diagnostic"))
620 .and_then(|d| d.get("refreshSupport"))
621 .and_then(|r| r.as_bool())
622 {
623 caps.refresh_support = refresh;
624 }
625
626 caps
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn parse_caps_no_diagnostic_provider() {
635 let value = json!({});
636 let caps = parse_diagnostic_capabilities(&value);
637 assert!(!caps.pull_diagnostics);
638 assert!(!caps.workspace_diagnostics);
639 assert!(caps.identifier.is_none());
640 }
641
642 #[test]
643 fn parse_caps_basic_pull_only() {
644 let value = json!({
645 "diagnosticProvider": {
646 "interFileDependencies": false,
647 "workspaceDiagnostics": false
648 }
649 });
650 let caps = parse_diagnostic_capabilities(&value);
651 assert!(caps.pull_diagnostics);
652 assert!(!caps.workspace_diagnostics);
653 }
654
655 #[test]
656 fn parse_caps_full_pull_with_workspace() {
657 let value = json!({
658 "diagnosticProvider": {
659 "interFileDependencies": true,
660 "workspaceDiagnostics": true,
661 "identifier": "rust-analyzer"
662 }
663 });
664 let caps = parse_diagnostic_capabilities(&value);
665 assert!(caps.pull_diagnostics);
666 assert!(caps.workspace_diagnostics);
667 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
668 }
669
670 #[test]
671 fn parse_caps_provider_as_bare_true() {
672 let value = json!({
674 "diagnosticProvider": true
675 });
676 let caps = parse_diagnostic_capabilities(&value);
677 assert!(caps.pull_diagnostics);
678 assert!(!caps.workspace_diagnostics);
679 }
680
681 #[test]
682 fn parse_caps_workspace_refresh_support() {
683 let value = json!({
684 "workspace": {
685 "diagnostic": {
686 "refreshSupport": true
687 }
688 }
689 });
690 let caps = parse_diagnostic_capabilities(&value);
691 assert!(caps.refresh_support);
692 assert!(!caps.pull_diagnostics);
694 }
695}