Skip to main content

atomcode_core/lsp/
client.rs

1//! LspClient — manages a single language server process.
2//!
3//! Spawns the server, performs the LSP initialize handshake, and runs a
4//! background reader task that dispatches responses and
5//! `textDocument/publishDiagnostics` notifications.
6
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12use anyhow::{Context, Result};
13use serde_json::{json, Value};
14use tokio::io::{AsyncWriteExt, BufReader, BufWriter};
15use tokio::process::{Child, Command};
16use tokio::sync::{oneshot, Mutex, RwLock};
17
18use super::jsonrpc;
19use super::types::{Diagnostic, DiagnosticSeverity};
20
21/// Convert a file:// URI to a PathBuf, handling platform differences and URL encoding.
22fn uri_to_path(uri: &str) -> PathBuf {
23    if uri.starts_with("file://") {
24        // Use the url crate for proper URI parsing (handles Windows paths and % encoding).
25        url::Url::parse(uri)
26            .ok()
27            .and_then(|url| url.to_file_path().ok())
28            .unwrap_or_else(|| PathBuf::from(uri))
29    } else {
30        PathBuf::from(uri)
31    }
32}
33
34/// Convert a local path to a standards-compliant file:// URI.
35fn path_to_uri(path: &Path) -> String {
36    url::Url::from_file_path(path)
37        .map(|url| url.to_string())
38        .unwrap_or_else(|_| format!("file://{}", path.display()))
39}
40
41/// Tracks the state of an open document for didOpen/didChange versioning.
42#[derive(Debug, Clone)]
43pub struct OpenDocumentState {
44    pub uri: String,
45    pub language_id: String,
46    pub version: i32,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50enum DocumentSyncAction {
51    DidOpen,
52    DidChange { version: i32 },
53}
54
55/// A running language server client.
56pub struct LspClient {
57    /// Next JSON-RPC request id.
58    next_id: AtomicU64,
59    /// Pending request id → response sender.
60    pending: Arc<RwLock<HashMap<u64, oneshot::Sender<Result<Value, Value>>>>>,
61    /// Cached diagnostics per file path.
62    diagnostics_cache: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>>,
63    /// Writer half of the server's stdin (behind Mutex for Send safety).
64    writer: Arc<Mutex<BufWriter<tokio::process::ChildStdin>>>,
65    /// Handle to the spawned server process.
66    child: Mutex<Child>,
67    /// Handle to the background reader task.
68    reader_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
69    /// The project root URI used during initialize.
70    #[allow(dead_code)]
71    root_uri: String,
72    /// Tracks open documents for proper didOpen/didChange versioning.
73    opened_documents: Arc<RwLock<HashMap<PathBuf, OpenDocumentState>>>,
74}
75
76impl LspClient {
77    /// Spawn a language server and perform the initialize handshake.
78    pub async fn start(
79        config: &super::registry::LspServerConfig,
80        project_root: &Path,
81        language_id: &str,
82    ) -> Result<Self> {
83        let mut cmd = Command::new(&config.command);
84        cmd.args(&config.args)
85            .stdin(std::process::Stdio::piped())
86            .stdout(std::process::Stdio::piped())
87            .stderr(std::process::Stdio::null())
88            .kill_on_drop(true);
89        crate::process_utils::suppress_console_window(&mut cmd);
90        let mut child = cmd
91            .spawn()
92            .with_context(|| format!("Failed to spawn LSP server: {}", config.command))?;
93
94        let stdin = child
95            .stdin
96            .take()
97            .context("Failed to open LSP server stdin")?;
98        let stdout = child
99            .stdout
100            .take()
101            .context("Failed to open LSP server stdout")?;
102
103        let writer = Arc::new(Mutex::new(BufWriter::new(stdin)));
104        let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<Result<Value, Value>>>>> =
105            Arc::new(RwLock::new(HashMap::new()));
106        let diagnostics_cache: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>> =
107            Arc::new(RwLock::new(HashMap::new()));
108        let opened_documents: Arc<RwLock<HashMap<PathBuf, OpenDocumentState>>> =
109            Arc::new(RwLock::new(HashMap::new()));
110
111        let root_uri = path_to_uri(project_root);
112
113        let client = Self {
114            next_id: AtomicU64::new(1),
115            pending: pending.clone(),
116            diagnostics_cache: diagnostics_cache.clone(),
117            writer: writer.clone(),
118            child: Mutex::new(child),
119            reader_handle: Mutex::new(None),
120            root_uri: root_uri.clone(),
121            opened_documents: opened_documents.clone(),
122        };
123
124        // Spawn background reader BEFORE the initialize handshake so the
125        // response is actually consumed.
126        let reader_pending = pending.clone();
127        let reader_diags = diagnostics_cache.clone();
128        let reader_handle = tokio::spawn(async move {
129            let mut reader = BufReader::new(stdout);
130            loop {
131                match jsonrpc::read_message(&mut reader).await {
132                    Ok(msg) => {
133                        Self::dispatch_message(msg, &reader_pending, &reader_diags).await;
134                    }
135                    Err(_) => {
136                        // Server closed stdout — exit reader loop.
137                        break;
138                    }
139                }
140            }
141        });
142        *client.reader_handle.lock().await = Some(reader_handle);
143
144        // Perform initialize handshake.
145        let init_params = json!({
146            "processId": std::process::id(),
147            "rootUri": root_uri,
148            "capabilities": {
149                "textDocument": {
150                    "publishDiagnostics": {
151                        "relatedInformation": true
152                    },
153                    "synchronization": {
154                        "didOpen": true,
155                        "didChange": true
156                    }
157                }
158            },
159            "clientInfo": {
160                "name": "atomcode",
161                "version": env!("CARGO_PKG_VERSION")
162            }
163        });
164
165        let _init_result = client
166            .send_request("initialize", Some(init_params))
167            .await
168            .with_context(|| {
169                format!(
170                    "LSP initialize handshake failed for {} (language: {})",
171                    config.command, language_id,
172                )
173            })?;
174
175        // Send initialized notification.
176        client
177            .send_notification("initialized", Some(json!({})))
178            .await?;
179
180        Ok(client)
181    }
182
183    /// Return cached diagnostics for a file.
184    pub async fn diagnostics(&self, path: &Path) -> Vec<Diagnostic> {
185        let cache = self.diagnostics_cache.read().await;
186        cache.get(path).cloned().unwrap_or_default()
187    }
188
189    /// Return all cached diagnostics across all files.
190    pub async fn all_diagnostics(&self) -> Vec<Diagnostic> {
191        let cache = self.diagnostics_cache.read().await;
192        cache.values().flatten().cloned().collect()
193    }
194
195    /// Notify the server that a file was opened.
196    pub async fn did_open(&self, path: &Path, content: &str, language_id: &str) -> Result<()> {
197        let uri = path_to_uri(path);
198        self.send_notification(
199            "textDocument/didOpen",
200            Some(json!({
201                "textDocument": {
202                    "uri": uri,
203                    "languageId": language_id,
204                    "version": 1,
205                    "text": content
206                }
207            })),
208        )
209        .await
210    }
211
212    /// Notify the server that a file changed.
213    pub async fn did_change(&self, path: &Path, content: &str, version: i32) -> Result<()> {
214        let uri = path_to_uri(path);
215        self.send_notification(
216            "textDocument/didChange",
217            Some(json!({
218                "textDocument": {
219                    "uri": uri,
220                    "version": version
221                },
222                "contentChanges": [{ "text": content }]
223            })),
224        )
225        .await
226    }
227
228    /// Notify the server that a file was closed.
229    pub async fn did_close(&self, path: &Path) -> Result<()> {
230        let uri = path_to_uri(path);
231        self.send_notification(
232            "textDocument/didClose",
233            Some(json!({
234                "textDocument": { "uri": uri }
235            })),
236        )
237        .await
238    }
239
240    /// Sync a document with the server, using didOpen for first open and didChange for updates.
241    /// This is the preferred method for notifying the server about file changes.
242    pub async fn sync_document(&self, path: &Path, content: &str, language_id: &str) -> Result<()> {
243        match Self::next_sync_action(&self.opened_documents, path, language_id).await {
244            DocumentSyncAction::DidOpen => {
245                self.did_open(path, content, language_id).await?;
246            }
247            DocumentSyncAction::DidChange { version } => {
248                self.did_change(path, content, version).await?;
249            }
250        }
251
252        Ok(())
253    }
254
255    /// Close a document, sending didClose and removing from tracking.
256    pub async fn close_document(&self, path: &Path) -> Result<()> {
257        let mut opened = self.opened_documents.write().await;
258        if opened.remove(path).is_some() {
259            drop(opened); // Release lock before async call.
260            self.did_close(path).await?;
261        }
262        Ok(())
263    }
264
265    /// Graceful shutdown: send shutdown request, then exit notification, then kill.
266    pub async fn shutdown(&self) -> Result<()> {
267        // Try to send shutdown request (ignore errors — server may already be dead).
268        let _ = tokio::time::timeout(
269            std::time::Duration::from_secs(5),
270            self.send_request("shutdown", None),
271        )
272        .await;
273
274        // Send exit notification.
275        let _ = self.send_notification("exit", None).await;
276
277        // Give the process a moment to exit, then kill.
278        let mut child = self.child.lock().await;
279        let _ = tokio::time::timeout(std::time::Duration::from_secs(2), child.wait()).await;
280        let _ = child.kill().await;
281
282        // Abort the reader task.
283        if let Some(handle) = self.reader_handle.lock().await.take() {
284            handle.abort();
285        }
286
287        Ok(())
288    }
289
290    // -----------------------------------------------------------------------
291    // Internal helpers
292    // -----------------------------------------------------------------------
293
294    async fn next_sync_action(
295        opened_documents: &RwLock<HashMap<PathBuf, OpenDocumentState>>,
296        path: &Path,
297        language_id: &str,
298    ) -> DocumentSyncAction {
299        let mut opened = opened_documents.write().await;
300
301        if let Some(state) = opened.get_mut(path) {
302            state.version += 1;
303            return DocumentSyncAction::DidChange {
304                version: state.version,
305            };
306        }
307
308        opened.insert(
309            path.to_path_buf(),
310            OpenDocumentState {
311                uri: path_to_uri(path),
312                language_id: language_id.to_string(),
313                version: 1,
314            },
315        );
316        DocumentSyncAction::DidOpen
317    }
318
319    /// Send a JSON-RPC request and wait for the response (30s timeout).
320    async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value> {
321        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
322
323        let request = jsonrpc::Request {
324            jsonrpc: "2.0".into(),
325            id,
326            method: method.into(),
327            params,
328        };
329
330        let (tx, rx) = oneshot::channel();
331        {
332            let mut pending = self.pending.write().await;
333            pending.insert(id, tx);
334        }
335
336        let body = serde_json::to_vec(&request)?;
337        let msg = jsonrpc::encode(&body);
338
339        {
340            let mut writer = self.writer.lock().await;
341            writer.write_all(&msg).await?;
342            writer.flush().await?;
343        }
344
345        let response = tokio::time::timeout(std::time::Duration::from_secs(30), rx)
346            .await
347            .context("LSP request timed out after 30s")?
348            .context("LSP response channel closed")?
349            .map_err(|error| anyhow::anyhow!("LSP request '{}' failed: {}", method, error))?;
350
351        Ok(response)
352    }
353
354    /// Send a JSON-RPC notification (no response expected).
355    async fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
356        let notification = jsonrpc::Notification {
357            jsonrpc: "2.0".into(),
358            method: method.into(),
359            params,
360        };
361
362        let body = serde_json::to_vec(&notification)?;
363        let msg = jsonrpc::encode(&body);
364
365        let mut writer = self.writer.lock().await;
366        writer.write_all(&msg).await?;
367        writer.flush().await?;
368
369        Ok(())
370    }
371
372    /// Dispatch a received message to the appropriate handler.
373    async fn dispatch_message(
374        msg: Value,
375        pending: &RwLock<HashMap<u64, oneshot::Sender<Result<Value, Value>>>>,
376        diagnostics_cache: &RwLock<HashMap<PathBuf, Vec<Diagnostic>>>,
377    ) {
378        // Check if it's a response (has "id" and "result" or "error").
379        if let Some(id) = msg.get("id").and_then(|v| v.as_u64()) {
380            let mut pending = pending.write().await;
381            if let Some(tx) = pending.remove(&id) {
382                let result = if let Some(result) = msg.get("result") {
383                    Ok(result.clone())
384                } else if let Some(e) = msg.get("error") {
385                    Err(e.clone())
386                } else {
387                    Ok(Value::Null)
388                };
389                let _ = tx.send(result);
390            }
391            return;
392        }
393
394        // Check if it's a notification.
395        if let Some(method) = msg.get("method").and_then(|v| v.as_str()) {
396            if method == "textDocument/publishDiagnostics" {
397                if let Some(params) = msg.get("params") {
398                    Self::handle_diagnostics(params, diagnostics_cache).await;
399                }
400            }
401            // Ignore other notifications (window/logMessage, etc.).
402        }
403    }
404
405    /// Parse and cache `textDocument/publishDiagnostics` notifications.
406    async fn handle_diagnostics(
407        params: &Value,
408        diagnostics_cache: &RwLock<HashMap<PathBuf, Vec<Diagnostic>>>,
409    ) {
410        let uri = match params.get("uri").and_then(|v| v.as_str()) {
411            Some(u) => u,
412            None => return,
413        };
414
415        // Convert file:// URI to path.
416        let file_path = uri_to_path(uri);
417
418        let display_path = file_path.display().to_string();
419
420        let diagnostics: Vec<Diagnostic> = params
421            .get("diagnostics")
422            .and_then(|v| v.as_array())
423            .map(|arr| {
424                arr.iter()
425                    .filter_map(|d| {
426                        let range = d.get("range")?;
427                        let start = range.get("start")?;
428                        let end = range.get("end");
429
430                        // LSP positions are 0-based; display as 1-based.
431                        let line = start.get("line")?.as_u64()? as u32 + 1;
432                        let column = start.get("character")?.as_u64()? as u32 + 1;
433
434                        let end_line = end
435                            .and_then(|e| e.get("line"))
436                            .and_then(|v| v.as_u64())
437                            .map(|v| v as u32 + 1);
438                        let end_column = end
439                            .and_then(|e| e.get("character"))
440                            .and_then(|v| v.as_u64())
441                            .map(|v| v as u32 + 1);
442
443                        let severity = d
444                            .get("severity")
445                            .and_then(|v| v.as_u64())
446                            .map(|v| DiagnosticSeverity::from_lsp(v as u32))
447                            .unwrap_or(DiagnosticSeverity::Error);
448
449                        let message = d
450                            .get("message")
451                            .and_then(|v| v.as_str())
452                            .unwrap_or("")
453                            .to_string();
454
455                        let source = d.get("source").and_then(|v| v.as_str()).map(String::from);
456
457                        let code = d.get("code").and_then(|v| {
458                            v.as_str()
459                                .map(String::from)
460                                .or_else(|| v.as_u64().map(|n| n.to_string()))
461                        });
462
463                        Some(Diagnostic {
464                            file: display_path.clone(),
465                            line,
466                            column,
467                            end_line,
468                            end_column,
469                            severity,
470                            message,
471                            source,
472                            code,
473                        })
474                    })
475                    .collect()
476            })
477            .unwrap_or_default();
478
479        let mut cache = diagnostics_cache.write().await;
480        if diagnostics.is_empty() {
481            cache.remove(&file_path);
482        } else {
483            cache.insert(file_path, diagnostics);
484        }
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[tokio::test]
493    async fn dispatch_response_resolves_pending() {
494        let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<Result<Value, Value>>>>> =
495            Arc::new(RwLock::new(HashMap::new()));
496        let diags: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>> =
497            Arc::new(RwLock::new(HashMap::new()));
498
499        let (tx, rx) = oneshot::channel();
500        pending.write().await.insert(42, tx);
501
502        let msg = json!({
503            "jsonrpc": "2.0",
504            "id": 42,
505            "result": { "capabilities": {} }
506        });
507
508        LspClient::dispatch_message(msg, &pending, &diags).await;
509
510        let result = rx.await.unwrap().unwrap();
511        assert!(result.get("capabilities").is_some());
512        assert!(pending.read().await.is_empty());
513    }
514
515    #[tokio::test]
516    async fn dispatch_error_response_rejects_pending() {
517        let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<Result<Value, Value>>>>> =
518            Arc::new(RwLock::new(HashMap::new()));
519        let diags: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>> =
520            Arc::new(RwLock::new(HashMap::new()));
521
522        let (tx, rx) = oneshot::channel();
523        pending.write().await.insert(7, tx);
524
525        let msg = json!({
526            "jsonrpc": "2.0",
527            "id": 7,
528            "error": {
529                "code": -32602,
530                "message": "invalid initialize params"
531            }
532        });
533
534        LspClient::dispatch_message(msg, &pending, &diags).await;
535
536        let error = rx.await.unwrap().unwrap_err();
537        assert_eq!(error["code"], -32602);
538        assert_eq!(error["message"], "invalid initialize params");
539        assert!(pending.read().await.is_empty());
540    }
541
542    #[tokio::test]
543    async fn dispatch_diagnostics_notification_caches() {
544        let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<Result<Value, Value>>>>> =
545            Arc::new(RwLock::new(HashMap::new()));
546        let diags: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>> =
547            Arc::new(RwLock::new(HashMap::new()));
548
549        let msg = json!({
550            "jsonrpc": "2.0",
551            "method": "textDocument/publishDiagnostics",
552            "params": {
553                "uri": "file:///tmp/test.rs",
554                "diagnostics": [
555                    {
556                        "range": {
557                            "start": { "line": 9, "character": 4 },
558                            "end": { "line": 9, "character": 14 }
559                        },
560                        "severity": 1,
561                        "message": "unused variable",
562                        "source": "rust-analyzer",
563                        "code": "E0001"
564                    }
565                ]
566            }
567        });
568
569        LspClient::dispatch_message(msg, &pending, &diags).await;
570
571        let cache = diags.read().await;
572        let path = PathBuf::from("/tmp/test.rs");
573        let file_diags = cache.get(&path).unwrap();
574        assert_eq!(file_diags.len(), 1);
575        // 0-based LSP → 1-based display
576        assert_eq!(file_diags[0].line, 10);
577        assert_eq!(file_diags[0].column, 5);
578        assert_eq!(file_diags[0].severity, DiagnosticSeverity::Error);
579        assert_eq!(file_diags[0].message, "unused variable");
580    }
581
582    #[tokio::test]
583    async fn empty_diagnostics_clears_cache() {
584        let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<Result<Value, Value>>>>> =
585            Arc::new(RwLock::new(HashMap::new()));
586        let diags: Arc<RwLock<HashMap<PathBuf, Vec<Diagnostic>>>> =
587            Arc::new(RwLock::new(HashMap::new()));
588
589        let path = PathBuf::from("/tmp/test.rs");
590
591        // Pre-populate cache.
592        {
593            let mut cache = diags.write().await;
594            cache.insert(
595                path.clone(),
596                vec![Diagnostic {
597                    file: "/tmp/test.rs".into(),
598                    line: 1,
599                    column: 1,
600                    end_line: None,
601                    end_column: None,
602                    severity: DiagnosticSeverity::Error,
603                    message: "old error".into(),
604                    source: None,
605                    code: None,
606                }],
607            );
608        }
609
610        // Publish empty diagnostics.
611        let msg = json!({
612            "jsonrpc": "2.0",
613            "method": "textDocument/publishDiagnostics",
614            "params": {
615                "uri": "file:///tmp/test.rs",
616                "diagnostics": []
617            }
618        });
619
620        LspClient::dispatch_message(msg, &pending, &diags).await;
621
622        let cache = diags.read().await;
623        assert!(cache.get(&path).is_none());
624    }
625
626    #[test]
627    fn uri_to_path_handles_unix_path() {
628        let path = uri_to_path("file:///tmp/test.rs");
629        assert_eq!(path, PathBuf::from("/tmp/test.rs"));
630    }
631
632    #[test]
633    fn uri_to_path_handles_encoded_spaces() {
634        let path = uri_to_path("file:///tmp/my%20file.rs");
635        assert_eq!(path, PathBuf::from("/tmp/my file.rs"));
636    }
637
638    #[test]
639    fn path_to_uri_encodes_spaces_and_fragments() {
640        let path = PathBuf::from("/tmp/my file#1.rs");
641        let uri = path_to_uri(&path);
642        assert!(
643            uri.contains("my%20file%231.rs"),
644            "path_to_uri must percent-encode reserved characters: {uri}"
645        );
646        assert_eq!(uri_to_path(&uri), path);
647    }
648
649    #[cfg(windows)]
650    #[test]
651    fn uri_to_path_handles_windows_path() {
652        let path = uri_to_path("file:///C:/Users/test.rs");
653        assert_eq!(path, PathBuf::from("C:/Users/test.rs"));
654    }
655
656    #[test]
657    fn open_document_state_tracks_version() {
658        let state = OpenDocumentState {
659            uri: "file:///tmp/test.rs".to_string(),
660            language_id: "rust".to_string(),
661            version: 1,
662        };
663        assert_eq!(state.version, 1);
664    }
665
666    #[tokio::test]
667    async fn sync_action_uses_did_open_then_did_change_versions() {
668        let opened: RwLock<HashMap<PathBuf, OpenDocumentState>> =
669            RwLock::new(HashMap::new());
670        let path = PathBuf::from("/tmp/test.rs");
671
672        let first = LspClient::next_sync_action(&opened, &path, "rust").await;
673        assert_eq!(first, DocumentSyncAction::DidOpen);
674
675        let second = LspClient::next_sync_action(&opened, &path, "rust").await;
676        assert_eq!(second, DocumentSyncAction::DidChange { version: 2 });
677
678        let third = LspClient::next_sync_action(&opened, &path, "rust").await;
679        assert_eq!(third, DocumentSyncAction::DidChange { version: 3 });
680
681        let state = opened.read().await.get(&path).cloned().unwrap();
682        assert_eq!(state.version, 3);
683        assert_eq!(state.language_id, "rust");
684        assert_eq!(state.uri, path_to_uri(&path));
685    }
686}