1use std::collections::HashMap;
2use std::io::{self, BufReader, BufWriter};
3use std::path::{Path, PathBuf};
4use std::process::{Child, Command, Stdio};
5use std::str::FromStr;
6use std::sync::atomic::{AtomicI64, Ordering};
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crossbeam_channel::{bounded, RecvTimeoutError, Sender};
12use serde::de::DeserializeOwned;
13use serde_json::{json, Value};
14
15use crate::lsp::jsonrpc::{
16 Notification, Request, RequestId, Response as JsonRpcResponse, ServerMessage,
17};
18use crate::lsp::registry::ServerKind;
19use crate::lsp::{transport, LspError};
20
21const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
22const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
23const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(25);
24
25type PendingMap = HashMap<RequestId, Sender<JsonRpcResponse>>;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ServerState {
30 Starting,
31 Initializing,
32 Ready,
33 ShuttingDown,
34 Exited,
35}
36
37#[derive(Debug)]
39pub enum LspEvent {
40 Notification {
42 server_kind: ServerKind,
43 root: PathBuf,
44 method: String,
45 params: Option<Value>,
46 },
47 ServerRequest {
49 server_kind: ServerKind,
50 root: PathBuf,
51 id: RequestId,
52 method: String,
53 params: Option<Value>,
54 },
55 ServerExited {
57 server_kind: ServerKind,
58 root: PathBuf,
59 },
60}
61
62#[derive(Debug, Clone, Default)]
70pub struct ServerDiagnosticCapabilities {
71 pub pull_diagnostics: bool,
73 pub workspace_diagnostics: bool,
75 pub identifier: Option<String>,
78 pub refresh_support: bool,
82}
83
84pub struct LspClient {
86 kind: ServerKind,
87 root: PathBuf,
88 state: ServerState,
89 child: Child,
90 writer: Arc<Mutex<BufWriter<std::process::ChildStdin>>>,
91
92 pending: Arc<Mutex<PendingMap>>,
94 next_id: AtomicI64,
96 diagnostic_caps: Option<ServerDiagnosticCapabilities>,
100}
101
102impl LspClient {
103 pub fn spawn(
105 kind: ServerKind,
106 root: PathBuf,
107 binary: &Path,
108 args: &[String],
109 env: &HashMap<String, String>,
110 event_tx: Sender<LspEvent>,
111 ) -> io::Result<Self> {
112 let mut command = Command::new(binary);
113 command
114 .args(args)
115 .current_dir(&root)
116 .stdin(Stdio::piped())
117 .stdout(Stdio::piped())
118 .stderr(Stdio::null());
121 for (key, value) in env {
122 command.env(key, value);
123 }
124
125 let mut child = command.spawn()?;
126
127 let stdout = child
128 .stdout
129 .take()
130 .ok_or_else(|| io::Error::other("language server missing stdout pipe"))?;
131 let stdin = child
132 .stdin
133 .take()
134 .ok_or_else(|| io::Error::other("language server missing stdin pipe"))?;
135
136 let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
137 let pending = Arc::new(Mutex::new(PendingMap::new()));
138 let reader_pending = Arc::clone(&pending);
139 let reader_writer = Arc::clone(&writer);
140 let reader_kind = kind.clone();
141 let reader_root = root.clone();
142
143 thread::spawn(move || {
144 let mut reader = BufReader::new(stdout);
145 loop {
146 match transport::read_message(&mut reader) {
147 Ok(Some(ServerMessage::Response(response))) => {
148 if let Ok(mut guard) = reader_pending.lock() {
149 if let Some(tx) = guard.remove(&response.id) {
150 if tx.send(response).is_err() {
151 log::debug!("[aft-lsp] response channel closed");
152 }
153 }
154 } else {
155 let _ = event_tx.send(LspEvent::ServerExited {
156 server_kind: reader_kind.clone(),
157 root: reader_root.clone(),
158 });
159 break;
160 }
161 }
162 Ok(Some(ServerMessage::Notification { method, params })) => {
163 let _ = event_tx.send(LspEvent::Notification {
164 server_kind: reader_kind.clone(),
165 root: reader_root.clone(),
166 method,
167 params,
168 });
169 }
170 Ok(Some(ServerMessage::Request { id, method, params })) => {
171 let response_value = if method == "workspace/configuration" {
181 let item_count = params
184 .as_ref()
185 .and_then(|p| p.get("items"))
186 .and_then(|items| items.as_array())
187 .map_or(1, |arr| arr.len());
188 serde_json::Value::Array(vec![serde_json::Value::Null; item_count])
189 } else {
190 serde_json::Value::Null
191 };
192 if let Ok(mut w) = reader_writer.lock() {
193 let response = super::jsonrpc::OutgoingResponse::success(
194 id.clone(),
195 response_value,
196 );
197 let _ = transport::write_response(&mut *w, &response);
198 }
199 let _ = event_tx.send(LspEvent::ServerRequest {
201 server_kind: reader_kind.clone(),
202 root: reader_root.clone(),
203 id,
204 method,
205 params,
206 });
207 }
208 Ok(None) | Err(_) => {
209 if let Ok(mut guard) = reader_pending.lock() {
210 guard.clear();
211 }
212 let _ = event_tx.send(LspEvent::ServerExited {
213 server_kind: reader_kind.clone(),
214 root: reader_root.clone(),
215 });
216 break;
217 }
218 }
219 }
220 });
221
222 Ok(Self {
223 kind,
224 root,
225 state: ServerState::Starting,
226 child,
227 writer,
228 pending,
229 next_id: AtomicI64::new(1),
230 diagnostic_caps: None,
231 })
232 }
233
234 pub fn initialize(
236 &mut self,
237 workspace_root: &Path,
238 initialization_options: Option<serde_json::Value>,
239 ) -> Result<lsp_types::InitializeResult, LspError> {
240 self.ensure_can_send()?;
241 self.state = ServerState::Initializing;
242
243 let normalized = normalize_windows_path(workspace_root);
244 let root_url = url::Url::from_file_path(&normalized).map_err(|_| {
245 LspError::NotFound(format!(
246 "failed to convert workspace root '{}' to file URI",
247 workspace_root.display()
248 ))
249 })?;
250 let root_uri = lsp_types::Uri::from_str(root_url.as_str()).map_err(|_| {
251 LspError::NotFound(format!(
252 "failed to convert workspace root '{}' to file URI",
253 workspace_root.display()
254 ))
255 })?;
256
257 let mut params_value = json!({
258 "processId": std::process::id(),
259 "rootUri": root_uri,
260 "capabilities": {
261 "workspace": {
262 "workspaceFolders": true,
263 "configuration": true,
264 "diagnostic": {
269 "refreshSupport": false
270 }
271 },
272 "textDocument": {
273 "synchronization": {
274 "dynamicRegistration": false,
275 "didSave": true,
276 "willSave": false,
277 "willSaveWaitUntil": false
278 },
279 "publishDiagnostics": {
280 "relatedInformation": true,
281 "versionSupport": true,
282 "codeDescriptionSupport": true,
283 "dataSupport": true
284 },
285 "diagnostic": {
290 "dynamicRegistration": false,
291 "relatedDocumentSupport": true
292 }
293 }
294 },
295 "clientInfo": {
296 "name": "aft",
297 "version": env!("CARGO_PKG_VERSION")
298 },
299 "workspaceFolders": [
300 {
301 "uri": root_uri,
302 "name": workspace_root
303 .file_name()
304 .and_then(|name| name.to_str())
305 .unwrap_or("workspace")
306 }
307 ]
308 });
309 if let Some(initialization_options) = initialization_options {
310 params_value["initializationOptions"] = initialization_options;
311 }
312
313 let params = serde_json::from_value::<lsp_types::InitializeParams>(params_value)?;
314
315 let result = self.send_request::<lsp_types::request::Initialize>(params)?;
316
317 let caps_value = serde_json::to_value(&result.capabilities).unwrap_or(Value::Null);
322 self.diagnostic_caps = Some(parse_diagnostic_capabilities(&caps_value));
323
324 self.send_notification::<lsp_types::notification::Initialized>(serde_json::from_value(
325 json!({}),
326 )?)?;
327 self.state = ServerState::Ready;
328 Ok(result)
329 }
330
331 pub fn diagnostic_capabilities(&self) -> Option<&ServerDiagnosticCapabilities> {
335 self.diagnostic_caps.as_ref()
336 }
337
338 pub fn send_request<R>(&mut self, params: R::Params) -> Result<R::Result, LspError>
340 where
341 R: lsp_types::request::Request,
342 R::Params: serde::Serialize,
343 R::Result: DeserializeOwned,
344 {
345 self.ensure_can_send()?;
346
347 let id = RequestId::Int(self.next_id.fetch_add(1, Ordering::Relaxed));
348 let (tx, rx) = bounded(1);
349 {
350 let mut pending = self.lock_pending()?;
351 pending.insert(id.clone(), tx);
352 }
353
354 let request = Request::new(id.clone(), R::METHOD, Some(serde_json::to_value(params)?));
355 {
356 let mut writer = self
357 .writer
358 .lock()
359 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
360 if let Err(err) = transport::write_request(&mut *writer, &request) {
361 self.remove_pending(&id);
362 return Err(err.into());
363 }
364 }
365
366 let response = match rx.recv_timeout(REQUEST_TIMEOUT) {
367 Ok(response) => response,
368 Err(RecvTimeoutError::Timeout) => {
369 self.remove_pending(&id);
370 return Err(LspError::Timeout(format!(
371 "timed out waiting for '{}' response from {:?}",
372 R::METHOD,
373 self.kind
374 )));
375 }
376 Err(RecvTimeoutError::Disconnected) => {
377 self.remove_pending(&id);
378 return Err(LspError::ServerNotReady(format!(
379 "language server {:?} disconnected while waiting for '{}'",
380 self.kind,
381 R::METHOD
382 )));
383 }
384 };
385
386 if let Some(error) = response.error {
387 return Err(LspError::ServerError {
388 code: error.code,
389 message: error.message,
390 });
391 }
392
393 serde_json::from_value(response.result.unwrap_or(Value::Null)).map_err(Into::into)
394 }
395
396 pub fn send_notification<N>(&mut self, params: N::Params) -> Result<(), LspError>
398 where
399 N: lsp_types::notification::Notification,
400 N::Params: serde::Serialize,
401 {
402 self.ensure_can_send()?;
403 let notification = Notification::new(N::METHOD, Some(serde_json::to_value(params)?));
404 let mut writer = self
405 .writer
406 .lock()
407 .map_err(|_| LspError::ServerNotReady("writer lock poisoned".to_string()))?;
408 transport::write_notification(&mut *writer, ¬ification)?;
409 Ok(())
410 }
411
412 pub fn shutdown(&mut self) -> Result<(), LspError> {
414 if self.state == ServerState::Exited {
415 return Ok(());
416 }
417
418 if self.child.try_wait()?.is_some() {
419 self.state = ServerState::Exited;
420 return Ok(());
421 }
422
423 if let Err(err) = self.send_request::<lsp_types::request::Shutdown>(()) {
424 self.state = ServerState::ShuttingDown;
425 if self.child.try_wait()?.is_some() {
426 self.state = ServerState::Exited;
427 return Ok(());
428 }
429 return Err(err);
430 }
431
432 self.state = ServerState::ShuttingDown;
433
434 if let Err(err) = self.send_notification::<lsp_types::notification::Exit>(()) {
435 if self.child.try_wait()?.is_some() {
436 self.state = ServerState::Exited;
437 return Ok(());
438 }
439 return Err(err);
440 }
441
442 let deadline = Instant::now() + SHUTDOWN_TIMEOUT;
443 loop {
444 if self.child.try_wait()?.is_some() {
445 self.state = ServerState::Exited;
446 return Ok(());
447 }
448 if Instant::now() >= deadline {
449 let _ = self.child.kill();
450 let _ = self.child.wait();
451 self.state = ServerState::Exited;
452 return Err(LspError::Timeout(format!(
453 "timed out waiting for {:?} to exit",
454 self.kind
455 )));
456 }
457 thread::sleep(EXIT_POLL_INTERVAL);
458 }
459 }
460
461 pub fn state(&self) -> ServerState {
462 self.state
463 }
464
465 pub fn kind(&self) -> ServerKind {
466 self.kind.clone()
467 }
468
469 pub fn root(&self) -> &Path {
470 &self.root
471 }
472
473 fn ensure_can_send(&self) -> Result<(), LspError> {
474 if matches!(self.state, ServerState::ShuttingDown | ServerState::Exited) {
475 return Err(LspError::ServerNotReady(format!(
476 "language server {:?} is not ready (state: {:?})",
477 self.kind, self.state
478 )));
479 }
480 Ok(())
481 }
482
483 fn lock_pending(&self) -> Result<std::sync::MutexGuard<'_, PendingMap>, LspError> {
484 self.pending
485 .lock()
486 .map_err(|_| io::Error::other("pending response map poisoned").into())
487 }
488
489 fn remove_pending(&self, id: &RequestId) {
490 if let Ok(mut pending) = self.pending.lock() {
491 pending.remove(id);
492 }
493 }
494}
495
496impl Drop for LspClient {
497 fn drop(&mut self) {
498 let _ = self.child.kill();
499 let _ = self.child.wait();
500 }
501}
502
503fn normalize_windows_path(path: &Path) -> PathBuf {
507 let s = path.to_string_lossy();
508 if let Some(stripped) = s.strip_prefix(r"\\?\") {
509 PathBuf::from(stripped)
510 } else {
511 path.to_path_buf()
512 }
513}
514
515fn parse_diagnostic_capabilities(value: &Value) -> ServerDiagnosticCapabilities {
530 let mut caps = ServerDiagnosticCapabilities::default();
531
532 if let Some(provider) = value.get("diagnosticProvider") {
533 if provider.is_object() || provider.as_bool() == Some(true) {
536 caps.pull_diagnostics = true;
537 }
538
539 if let Some(obj) = provider.as_object() {
540 if obj
541 .get("workspaceDiagnostics")
542 .and_then(|v| v.as_bool())
543 .unwrap_or(false)
544 {
545 caps.workspace_diagnostics = true;
546 }
547 if let Some(identifier) = obj.get("identifier").and_then(|v| v.as_str()) {
548 caps.identifier = Some(identifier.to_string());
549 }
550 }
551 }
552
553 if let Some(refresh) = value
556 .get("workspace")
557 .and_then(|w| w.get("diagnostic"))
558 .and_then(|d| d.get("refreshSupport"))
559 .and_then(|r| r.as_bool())
560 {
561 caps.refresh_support = refresh;
562 }
563
564 caps
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn parse_caps_no_diagnostic_provider() {
573 let value = json!({});
574 let caps = parse_diagnostic_capabilities(&value);
575 assert!(!caps.pull_diagnostics);
576 assert!(!caps.workspace_diagnostics);
577 assert!(caps.identifier.is_none());
578 }
579
580 #[test]
581 fn parse_caps_basic_pull_only() {
582 let value = json!({
583 "diagnosticProvider": {
584 "interFileDependencies": false,
585 "workspaceDiagnostics": false
586 }
587 });
588 let caps = parse_diagnostic_capabilities(&value);
589 assert!(caps.pull_diagnostics);
590 assert!(!caps.workspace_diagnostics);
591 }
592
593 #[test]
594 fn parse_caps_full_pull_with_workspace() {
595 let value = json!({
596 "diagnosticProvider": {
597 "interFileDependencies": true,
598 "workspaceDiagnostics": true,
599 "identifier": "rust-analyzer"
600 }
601 });
602 let caps = parse_diagnostic_capabilities(&value);
603 assert!(caps.pull_diagnostics);
604 assert!(caps.workspace_diagnostics);
605 assert_eq!(caps.identifier.as_deref(), Some("rust-analyzer"));
606 }
607
608 #[test]
609 fn parse_caps_provider_as_bare_true() {
610 let value = json!({
612 "diagnosticProvider": true
613 });
614 let caps = parse_diagnostic_capabilities(&value);
615 assert!(caps.pull_diagnostics);
616 assert!(!caps.workspace_diagnostics);
617 }
618
619 #[test]
620 fn parse_caps_workspace_refresh_support() {
621 let value = json!({
622 "workspace": {
623 "diagnostic": {
624 "refreshSupport": true
625 }
626 }
627 });
628 let caps = parse_diagnostic_capabilities(&value);
629 assert!(caps.refresh_support);
630 assert!(!caps.pull_diagnostics);
632 }
633}