Skip to main content

codetether_agent/lsp/
transport.rs

1//! LSP transport layer - stdio implementation with Content-Length framing
2//!
3//! LSP uses a special framing format with Content-Length headers:
4//! ```text
5//! Content-Length: 123\r\n
6//! \r\n
7//! <JSON payload>
8//! ```
9
10use super::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
11use anyhow::{Context, Result};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
15use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
16use tokio::process::{Child, Command};
17use tokio::sync::{RwLock, mpsc, oneshot};
18use tracing::{debug, error, trace, warn};
19
20/// LSP Transport for communicating with language servers
21pub struct LspTransport {
22    /// The child process (kept alive for the transport lifetime)
23    _child: Child,
24    /// Channel for sending messages
25    tx: mpsc::Sender<String>,
26    /// Pending requests waiting for responses
27    pending: Arc<RwLock<HashMap<i64, oneshot::Sender<JsonRpcResponse>>>>,
28    /// Request ID counter
29    request_id: AtomicI64,
30    /// Whether the server is initialized
31    initialized: std::sync::atomic::AtomicBool,
32    /// Per-request timeout in milliseconds.
33    timeout_ms: u64,
34    /// Recent stderr lines from the language server for diagnostics.
35    recent_stderr: Arc<RwLock<Vec<String>>>,
36    /// Server command for diagnostics.
37    command: String,
38    /// Last diagnostics published by the language server, keyed by URI.
39    diagnostics: Arc<RwLock<HashMap<String, Vec<lsp_types::Diagnostic>>>>,
40    /// Monotonic counter bumped every time the server publishes diagnostics
41    /// for any URI. Used by callers to detect fresh publications after a
42    /// `textDocument/didChange` so stale cached diagnostics don't leak into
43    /// post-edit validation.
44    diag_publish_seq: Arc<AtomicU64>,
45}
46
47impl LspTransport {
48    /// Spawn a language server and create a transport
49    pub async fn spawn(command: &str, args: &[String], timeout_ms: u64) -> Result<Self> {
50        let mut child = Command::new(command)
51            .args(args)
52            .stdin(std::process::Stdio::piped())
53            .stdout(std::process::Stdio::piped())
54            .stderr(std::process::Stdio::piped())
55            .spawn()
56            .with_context(|| format!("Failed to spawn language server '{command}'"))?;
57
58        let stdout = child
59            .stdout
60            .take()
61            .ok_or_else(|| anyhow::anyhow!("No stdout"))?;
62        let stderr = child
63            .stderr
64            .take()
65            .ok_or_else(|| anyhow::anyhow!("No stderr"))?;
66        let mut stdin = child
67            .stdin
68            .take()
69            .ok_or_else(|| anyhow::anyhow!("No stdin"))?;
70
71        let (write_tx, mut write_rx) = mpsc::channel::<String>(100);
72        let pending: Arc<RwLock<HashMap<i64, oneshot::Sender<JsonRpcResponse>>>> =
73            Arc::new(RwLock::new(HashMap::new()));
74        let recent_stderr = Arc::new(RwLock::new(Vec::new()));
75        let diagnostics = Arc::new(RwLock::new(HashMap::new()));
76        let diag_publish_seq = Arc::new(AtomicU64::new(0));
77
78        // Writer task - sends messages with Content-Length framing
79        let pending_clone = Arc::clone(&pending);
80        tokio::spawn(async move {
81            while let Some(msg) = write_rx.recv().await {
82                let content_length = msg.len();
83                let header = format!("Content-Length: {}\r\n\r\n", content_length);
84                trace!("LSP TX header: {}", header.trim());
85                trace!("LSP TX body: {}", msg);
86
87                if let Err(e) = stdin.write_all(header.as_bytes()).await {
88                    error!("Failed to write header to LSP server: {}", e);
89                    break;
90                }
91                if let Err(e) = stdin.write_all(msg.as_bytes()).await {
92                    error!("Failed to write body to LSP server: {}", e);
93                    break;
94                }
95                if let Err(e) = stdin.flush().await {
96                    error!("Failed to flush LSP server stdin: {}", e);
97                    break;
98                }
99            }
100            pending_clone.write().await.clear();
101        });
102
103        // Stderr task - capture recent diagnostics from the language server.
104        let recent_stderr_clone = Arc::clone(&recent_stderr);
105        let stderr_command = command.to_string();
106        tokio::spawn(async move {
107            let mut reader = BufReader::new(stderr);
108            let mut line = String::new();
109            loop {
110                line.clear();
111                match reader.read_line(&mut line).await {
112                    Ok(0) => return,
113                    Ok(_) => {
114                        let trimmed = line.trim().to_string();
115                        if trimmed.is_empty() {
116                            continue;
117                        }
118                        warn!(command = %stderr_command, stderr = %trimmed, "Language server stderr");
119                        let mut guard = recent_stderr_clone.write().await;
120                        guard.push(trimmed);
121                        if guard.len() > 20 {
122                            let excess = guard.len() - 20;
123                            guard.drain(0..excess);
124                        }
125                    }
126                    Err(e) => {
127                        warn!(command = %stderr_command, error = %e, "Failed reading language server stderr");
128                        return;
129                    }
130                }
131            }
132        });
133
134        // Reader task - parses Content-Length framed responses and notifications.
135        let pending_clone = Arc::clone(&pending);
136        let diagnostics_clone = Arc::clone(&diagnostics);
137        let diag_publish_seq_clone = Arc::clone(&diag_publish_seq);
138        tokio::spawn(async move {
139            let mut reader = BufReader::new(stdout);
140            let mut header_buf = String::new();
141
142            loop {
143                header_buf.clear();
144                let mut content_length: Option<usize> = None;
145
146                loop {
147                    header_buf.clear();
148                    match reader.read_line(&mut header_buf).await {
149                        Ok(0) => {
150                            debug!("LSP server closed connection");
151                            return;
152                        }
153                        Ok(_) => {
154                            let line = header_buf.trim();
155                            if line.is_empty() {
156                                break;
157                            }
158                            if let Some(stripped) = line.strip_prefix("Content-Length:")
159                                && let Ok(len) = stripped.trim().parse::<usize>()
160                            {
161                                content_length = Some(len);
162                            }
163                        }
164                        Err(e) => {
165                            error!("Failed to read header from LSP server: {}", e);
166                            return;
167                        }
168                    }
169                }
170
171                let Some(len) = content_length else {
172                    warn!("LSP message missing Content-Length header");
173                    continue;
174                };
175
176                let mut body_buf = vec![0u8; len];
177                match reader.read_exact(&mut body_buf).await {
178                    Ok(_) => {
179                        let body = String::from_utf8_lossy(&body_buf);
180                        trace!("LSP RX: {}", body);
181
182                        if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&body) {
183                            let mut pending_guard = pending_clone.write().await;
184                            if let Some(tx) = pending_guard.remove(&response.id) {
185                                let id = response.id;
186                                if tx.send(response).is_err() {
187                                    warn!("Request {} receiver dropped", id);
188                                }
189                            } else {
190                                debug!("Received response for unknown request {}", response.id);
191                            }
192                            continue;
193                        }
194
195                        match serde_json::from_str::<serde_json::Value>(&body) {
196                            Ok(value) => {
197                                if value.get("method").and_then(serde_json::Value::as_str)
198                                    == Some("textDocument/publishDiagnostics")
199                                {
200                                    if let Some(params) = value.get("params") {
201                                        let uri = params
202                                            .get("uri")
203                                            .and_then(serde_json::Value::as_str)
204                                            .unwrap_or_default()
205                                            .to_string();
206                                        let diagnostics = params
207                                            .get("diagnostics")
208                                            .cloned()
209                                            .and_then(|v| serde_json::from_value(v).ok())
210                                            .unwrap_or_default();
211                                        if !uri.is_empty() {
212                                            diagnostics_clone
213                                                .write()
214                                                .await
215                                                .insert(uri, diagnostics);
216                                            diag_publish_seq_clone.fetch_add(1, Ordering::SeqCst);
217                                        }
218                                    }
219                                } else {
220                                    debug!(
221                                        "Ignoring LSP notification/message without tracked handler: {}",
222                                        body
223                                    );
224                                }
225                            }
226                            Err(e) => {
227                                debug!("Failed to parse LSP message: {} - body: {}", e, body);
228                            }
229                        }
230                    }
231                    Err(e) => {
232                        error!("Failed to read LSP message body: {}", e);
233                        return;
234                    }
235                }
236            }
237        });
238
239        Ok(Self {
240            _child: child,
241            tx: write_tx,
242            pending,
243            request_id: AtomicI64::new(1),
244            initialized: std::sync::atomic::AtomicBool::new(false),
245            timeout_ms,
246            recent_stderr,
247            command: command.to_string(),
248            diagnostics,
249            diag_publish_seq,
250        })
251    }
252
253    /// Send a request and wait for response
254    pub async fn request(
255        &self,
256        method: &str,
257        params: Option<serde_json::Value>,
258    ) -> Result<JsonRpcResponse> {
259        let id = self.request_id.fetch_add(1, Ordering::SeqCst);
260        let request = JsonRpcRequest::new(id, method, params);
261
262        let (tx, rx) = oneshot::channel();
263        self.pending.write().await.insert(id, tx);
264
265        let json = serde_json::to_string(&request)?;
266        self.tx.send(json).await?;
267
268        let response = tokio::time::timeout(std::time::Duration::from_millis(self.timeout_ms), rx)
269            .await
270            .map_err(|_| {
271                let stderr_summary = self.stderr_summary();
272                anyhow::anyhow!(
273                    "LSP request timeout for method: {} (server: {}, timeout: {}ms{})",
274                    method,
275                    self.command,
276                    self.timeout_ms,
277                    stderr_summary
278                        .as_deref()
279                        .map(|summary| format!(", recent stderr: {summary}"))
280                        .unwrap_or_default()
281                )
282            })?
283            .map_err(|_| anyhow::anyhow!("LSP response channel closed"))?;
284
285        Ok(response)
286    }
287
288    fn stderr_summary(&self) -> Option<String> {
289        self.recent_stderr.try_read().ok().and_then(|lines| {
290            if lines.is_empty() {
291                None
292            } else {
293                Some(lines.join(" | "))
294            }
295        })
296    }
297
298    /// Send a notification (no response expected)
299    pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
300        let notification = JsonRpcNotification::new(method, params);
301        let json = serde_json::to_string(&notification)?;
302        self.tx.send(json).await?;
303        Ok(())
304    }
305
306    /// Return the last diagnostics published by the language server.
307    pub async fn diagnostics_snapshot(&self) -> HashMap<String, Vec<lsp_types::Diagnostic>> {
308        self.diagnostics.read().await.clone()
309    }
310
311    /// Current publish sequence counter. Increments every time the language
312    /// server publishes diagnostics for any URI. Callers can read this before
313    /// a `textDocument/didChange` and then wait via
314    /// [`Self::wait_for_publish_after`] for the server to republish.
315    pub fn diagnostics_publish_seq(&self) -> u64 {
316        self.diag_publish_seq.load(Ordering::SeqCst)
317    }
318
319    /// Remove any cached diagnostics for `uri`. Useful before a didChange so
320    /// stale entries can't be returned while the server recomputes.
321    pub async fn invalidate_diagnostics(&self, uri: &str) {
322        self.diagnostics.write().await.remove(uri);
323    }
324
325    /// Wait until the server publishes diagnostics with a sequence greater
326    /// than `baseline`, or the timeout elapses. Returns `true` if a new
327    /// publication arrived in time.
328    pub async fn wait_for_publish_after(
329        &self,
330        baseline: u64,
331        timeout: std::time::Duration,
332    ) -> bool {
333        let deadline = std::time::Instant::now() + timeout;
334        loop {
335            if self.diag_publish_seq.load(Ordering::SeqCst) > baseline {
336                return true;
337            }
338            if std::time::Instant::now() >= deadline {
339                return false;
340            }
341            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
342        }
343    }
344
345    /// Check if the server is initialized
346    pub fn is_initialized(&self) -> bool {
347        self.initialized.load(std::sync::atomic::Ordering::SeqCst)
348    }
349
350    /// Mark the server as initialized
351    pub fn set_initialized(&self, value: bool) {
352        self.initialized
353            .store(value, std::sync::atomic::Ordering::SeqCst);
354    }
355}
356
357impl Drop for LspTransport {
358    fn drop(&mut self) {
359        if self.is_initialized() {
360            tracing::debug!("LspTransport dropped while still initialized");
361        }
362    }
363}