Skip to main content

krait/lsp/
transport.rs

1use std::path::Path;
2use std::sync::atomic::{AtomicI64, Ordering};
3
4use anyhow::{bail, Context};
5use serde_json::Value;
6use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
7use tokio::process::{Child, ChildStdin, ChildStdout, Command};
8use tracing::debug;
9
10/// A JSON-RPC message received from the LSP server.
11#[derive(Debug)]
12pub enum JsonRpcMessage {
13    Response {
14        id: i64,
15        result: Option<Value>,
16        error: Option<Value>,
17    },
18    Notification {
19        method: String,
20        params: Option<Value>,
21    },
22    ServerRequest {
23        id: Value,
24        method: String,
25        params: Option<Value>,
26    },
27}
28
29/// Transport layer for communicating with an LSP server over stdio.
30pub struct LspTransport {
31    child: Child,
32    writer: BufWriter<ChildStdin>,
33    reader: BufReader<ChildStdout>,
34    next_id: AtomicI64,
35}
36
37impl LspTransport {
38    /// Spawn an LSP server process and connect to its stdio.
39    ///
40    /// # Errors
41    /// Returns an error if the binary cannot be spawned.
42    pub fn spawn(binary: &str, args: &[&str], cwd: &Path) -> anyhow::Result<Self> {
43        let mut child = Command::new(binary)
44            .args(args)
45            .current_dir(cwd)
46            .stdin(std::process::Stdio::piped())
47            .stdout(std::process::Stdio::piped())
48            .stderr(std::process::Stdio::null())
49            .kill_on_drop(true)
50            .spawn()
51            .with_context(|| format!("failed to spawn LSP server: {binary}"))?;
52
53        let stdin = child.stdin.take().context("failed to open LSP stdin")?;
54        let stdout = child.stdout.take().context("failed to open LSP stdout")?;
55
56        Ok(Self {
57            child,
58            writer: BufWriter::new(stdin),
59            reader: BufReader::new(stdout),
60            next_id: AtomicI64::new(1),
61        })
62    }
63
64    /// Send a JSON-RPC request. Returns the request ID.
65    ///
66    /// # Errors
67    /// Returns an error on IO or serialization failure.
68    pub async fn send_request(&mut self, method: &str, params: Value) -> anyhow::Result<i64> {
69        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
70        let message = serde_json::json!({
71            "jsonrpc": "2.0",
72            "id": id,
73            "method": method,
74            "params": params,
75        });
76        self.write_message(&message).await?;
77        debug!("sent request id={id} method={method}");
78        Ok(id)
79    }
80
81    /// Send a JSON-RPC notification (no response expected).
82    ///
83    /// # Errors
84    /// Returns an error on IO or serialization failure.
85    pub async fn send_notification(&mut self, method: &str, params: Value) -> anyhow::Result<()> {
86        let message = serde_json::json!({
87            "jsonrpc": "2.0",
88            "method": method,
89            "params": params,
90        });
91        self.write_message(&message).await?;
92        debug!("sent notification method={method}");
93        Ok(())
94    }
95
96    /// Read the next JSON-RPC message from the server.
97    ///
98    /// # Errors
99    /// Returns an error on IO, framing, or JSON parse failure.
100    pub async fn read_message(&mut self) -> anyhow::Result<JsonRpcMessage> {
101        let content_length = self.read_headers().await?;
102
103        let mut body = vec![0u8; content_length];
104        self.reader.read_exact(&mut body).await?;
105
106        let value: Value = serde_json::from_slice(&body)?;
107        classify_message(&value)
108    }
109
110    /// Kill the child process.
111    ///
112    /// # Errors
113    /// Returns an error if the kill signal fails.
114    pub async fn kill(&mut self) -> anyhow::Result<()> {
115        self.child
116            .kill()
117            .await
118            .context("failed to kill LSP process")?;
119        let _ = self.child.wait().await; // reap zombie
120        Ok(())
121    }
122
123    /// Check if the child process is still running.
124    #[must_use]
125    pub fn is_alive(&mut self) -> bool {
126        self.child.try_wait().ok().flatten().is_none()
127    }
128
129    /// Write raw bytes to the server's stdin.
130    ///
131    /// # Errors
132    /// Returns an error on IO failure.
133    pub async fn write_raw(&mut self, data: &[u8]) -> anyhow::Result<()> {
134        self.writer.write_all(data).await?;
135        Ok(())
136    }
137
138    /// Flush the writer.
139    ///
140    /// # Errors
141    /// Returns an error on IO failure.
142    pub async fn flush(&mut self) -> anyhow::Result<()> {
143        self.writer.flush().await?;
144        Ok(())
145    }
146
147    async fn write_message(&mut self, message: &Value) -> anyhow::Result<()> {
148        let body = serde_json::to_string(message)?;
149        let header = format!("Content-Length: {}\r\n\r\n", body.len());
150
151        self.writer.write_all(header.as_bytes()).await?;
152        self.writer.write_all(body.as_bytes()).await?;
153        // Flush is deferred — callers that send batches should call flush_writer()
154        // explicitly after the last message in the batch.
155        self.writer.flush().await?;
156        Ok(())
157    }
158
159    /// Flush the write buffer. Call at the end of a batch of requests
160    /// to avoid per-message syscall overhead.
161    ///
162    /// # Errors
163    /// Returns an error on IO failure.
164    pub async fn flush_writer(&mut self) -> anyhow::Result<()> {
165        self.writer.flush().await?;
166        Ok(())
167    }
168
169    async fn read_headers(&mut self) -> anyhow::Result<usize> {
170        let mut content_length: Option<usize> = None;
171
172        loop {
173            let mut line = String::new();
174            let bytes_read = self.reader.read_line(&mut line).await?;
175            if bytes_read == 0 {
176                bail!("LSP server closed its stdout");
177            }
178
179            let trimmed = line.trim();
180            if trimmed.is_empty() {
181                break;
182            }
183
184            if let Some(len_str) = trimmed.strip_prefix("Content-Length: ") {
185                content_length = Some(len_str.parse().context("invalid Content-Length")?);
186            }
187        }
188
189        content_length.context("missing Content-Length header")
190    }
191}
192
193fn classify_message(value: &Value) -> anyhow::Result<JsonRpcMessage> {
194    // Response: has "id" and ("result" or "error")
195    if let Some(id) = value.get("id") {
196        if value.get("result").is_some() || value.get("error").is_some() {
197            let id = id.as_i64().context("response id must be an integer")?;
198            return Ok(JsonRpcMessage::Response {
199                id,
200                result: value.get("result").cloned(),
201                error: value.get("error").cloned(),
202            });
203        }
204
205        // Server request: has "id" and "method"
206        if let Some(method) = value.get("method").and_then(Value::as_str) {
207            return Ok(JsonRpcMessage::ServerRequest {
208                id: id.clone(),
209                method: method.to_string(),
210                params: value.get("params").cloned(),
211            });
212        }
213    }
214
215    // Notification: has "method" but no "id"
216    if let Some(method) = value.get("method").and_then(Value::as_str) {
217        return Ok(JsonRpcMessage::Notification {
218            method: method.to_string(),
219            params: value.get("params").cloned(),
220        });
221    }
222
223    bail!("unrecognized JSON-RPC message: {value}")
224}
225
226/// Encode a JSON-RPC payload with Content-Length framing (for testing).
227#[must_use]
228pub fn frame_message(payload: &Value) -> Vec<u8> {
229    let body = serde_json::to_string(payload).unwrap_or_default();
230    let header = format!("Content-Length: {}\r\n\r\n", body.len());
231    let mut msg = header.into_bytes();
232    msg.extend_from_slice(body.as_bytes());
233    msg
234}
235
236#[cfg(test)]
237mod tests {
238    use serde_json::json;
239
240    use super::*;
241
242    #[test]
243    fn frame_encode_format() {
244        let payload = json!({"jsonrpc": "2.0", "id": 1, "method": "test"});
245        let framed = frame_message(&payload);
246        let framed_str = String::from_utf8(framed).unwrap();
247
248        assert!(framed_str.starts_with("Content-Length: "));
249        assert!(framed_str.contains("\r\n\r\n"));
250
251        let parts: Vec<&str> = framed_str.splitn(2, "\r\n\r\n").collect();
252        let header = parts[0];
253        let body = parts[1];
254
255        let declared_len: usize = header
256            .strip_prefix("Content-Length: ")
257            .unwrap()
258            .parse()
259            .unwrap();
260        assert_eq!(declared_len, body.len());
261    }
262
263    #[test]
264    fn classify_response() {
265        let msg = json!({"jsonrpc": "2.0", "id": 1, "result": {"capabilities": {}}});
266        let classified = classify_message(&msg).unwrap();
267        assert!(matches!(classified, JsonRpcMessage::Response { id: 1, .. }));
268    }
269
270    #[test]
271    fn classify_error_response() {
272        let msg = json!({"jsonrpc": "2.0", "id": 2, "error": {"code": -32600, "message": "bad"}});
273        let classified = classify_message(&msg).unwrap();
274        assert!(matches!(
275            classified,
276            JsonRpcMessage::Response {
277                id: 2,
278                error: Some(_),
279                ..
280            }
281        ));
282    }
283
284    #[test]
285    fn classify_notification() {
286        let msg =
287            json!({"jsonrpc": "2.0", "method": "textDocument/publishDiagnostics", "params": {}});
288        let classified = classify_message(&msg).unwrap();
289        assert!(
290            matches!(classified, JsonRpcMessage::Notification { ref method, .. } if method == "textDocument/publishDiagnostics")
291        );
292    }
293
294    #[test]
295    fn classify_server_request() {
296        let msg = json!({"jsonrpc": "2.0", "id": 5, "method": "window/workDoneProgress/create", "params": {}});
297        let classified = classify_message(&msg).unwrap();
298        assert!(
299            matches!(classified, JsonRpcMessage::ServerRequest { ref method, .. } if method == "window/workDoneProgress/create")
300        );
301    }
302
303    #[test]
304    fn request_ids_increment() {
305        let next_id = AtomicI64::new(1);
306
307        let id1 = next_id.fetch_add(1, Ordering::SeqCst);
308        let id2 = next_id.fetch_add(1, Ordering::SeqCst);
309        let id3 = next_id.fetch_add(1, Ordering::SeqCst);
310
311        assert_eq!(id1, 1);
312        assert_eq!(id2, 2);
313        assert_eq!(id3, 3);
314    }
315
316    #[test]
317    fn frame_message_content_length_matches_body() {
318        let payload = json!({"jsonrpc": "2.0", "method": "textDocument/didOpen", "params": {}});
319        let framed = frame_message(&payload);
320        let text = String::from_utf8(framed).unwrap();
321        let (header, body) = text.split_once("\r\n\r\n").unwrap();
322        let declared: usize = header
323            .strip_prefix("Content-Length: ")
324            .unwrap()
325            .parse()
326            .unwrap();
327        assert_eq!(declared, body.len());
328        assert!(!body.is_empty());
329    }
330
331    #[test]
332    fn classify_unrecognized_message_returns_error() {
333        let msg = json!({"jsonrpc": "2.0"});
334        let result = classify_message(&msg);
335        assert!(result.is_err(), "message with no method or id should error");
336    }
337}