1use std::path::Path;
2use std::sync::atomic::{AtomicI64, Ordering};
3
4use anyhow::{bail, Context};
5use serde_json::Value;
6use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
7use tokio::process::{Child, ChildStdin, ChildStdout, Command};
8use tracing::debug;
9
10#[derive(Debug)]
12pub enum JsonRpcMessage {
13 Response {
14 id: i64,
15 result: Option<Value>,
16 error: Option<Value>,
17 },
18 Notification {
19 method: String,
20 params: Option<Value>,
21 },
22 ServerRequest {
23 id: Value,
24 method: String,
25 params: Option<Value>,
26 },
27}
28
29pub struct LspTransport {
31 child: Child,
32 writer: BufWriter<ChildStdin>,
33 reader: BufReader<ChildStdout>,
34 next_id: AtomicI64,
35}
36
37impl LspTransport {
38 pub fn spawn(binary: &str, args: &[&str], cwd: &Path) -> anyhow::Result<Self> {
43 let mut child = Command::new(binary)
44 .args(args)
45 .current_dir(cwd)
46 .stdin(std::process::Stdio::piped())
47 .stdout(std::process::Stdio::piped())
48 .stderr(std::process::Stdio::null())
49 .kill_on_drop(true)
50 .spawn()
51 .with_context(|| format!("failed to spawn LSP server: {binary}"))?;
52
53 let stdin = child.stdin.take().context("failed to open LSP stdin")?;
54 let stdout = child.stdout.take().context("failed to open LSP stdout")?;
55
56 Ok(Self {
57 child,
58 writer: BufWriter::new(stdin),
59 reader: BufReader::new(stdout),
60 next_id: AtomicI64::new(1),
61 })
62 }
63
64 pub async fn send_request(&mut self, method: &str, params: Value) -> anyhow::Result<i64> {
69 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
70 let message = serde_json::json!({
71 "jsonrpc": "2.0",
72 "id": id,
73 "method": method,
74 "params": params,
75 });
76 self.write_message(&message).await?;
77 debug!("sent request id={id} method={method}");
78 Ok(id)
79 }
80
81 pub async fn send_notification(&mut self, method: &str, params: Value) -> anyhow::Result<()> {
86 let message = serde_json::json!({
87 "jsonrpc": "2.0",
88 "method": method,
89 "params": params,
90 });
91 self.write_message(&message).await?;
92 debug!("sent notification method={method}");
93 Ok(())
94 }
95
96 pub async fn read_message(&mut self) -> anyhow::Result<JsonRpcMessage> {
101 let content_length = self.read_headers().await?;
102
103 let mut body = vec![0u8; content_length];
104 self.reader.read_exact(&mut body).await?;
105
106 let value: Value = serde_json::from_slice(&body)?;
107 classify_message(&value)
108 }
109
110 pub async fn kill(&mut self) -> anyhow::Result<()> {
115 self.child
116 .kill()
117 .await
118 .context("failed to kill LSP process")?;
119 let _ = self.child.wait().await; Ok(())
121 }
122
123 #[must_use]
125 pub fn is_alive(&mut self) -> bool {
126 self.child.try_wait().ok().flatten().is_none()
127 }
128
129 pub async fn write_raw(&mut self, data: &[u8]) -> anyhow::Result<()> {
134 self.writer.write_all(data).await?;
135 Ok(())
136 }
137
138 pub async fn flush(&mut self) -> anyhow::Result<()> {
143 self.writer.flush().await?;
144 Ok(())
145 }
146
147 async fn write_message(&mut self, message: &Value) -> anyhow::Result<()> {
148 let body = serde_json::to_string(message)?;
149 let header = format!("Content-Length: {}\r\n\r\n", body.len());
150
151 self.writer.write_all(header.as_bytes()).await?;
152 self.writer.write_all(body.as_bytes()).await?;
153 self.writer.flush().await?;
156 Ok(())
157 }
158
159 pub async fn flush_writer(&mut self) -> anyhow::Result<()> {
165 self.writer.flush().await?;
166 Ok(())
167 }
168
169 async fn read_headers(&mut self) -> anyhow::Result<usize> {
170 let mut content_length: Option<usize> = None;
171
172 loop {
173 let mut line = String::new();
174 let bytes_read = self.reader.read_line(&mut line).await?;
175 if bytes_read == 0 {
176 bail!("LSP server closed its stdout");
177 }
178
179 let trimmed = line.trim();
180 if trimmed.is_empty() {
181 break;
182 }
183
184 if let Some(len_str) = trimmed.strip_prefix("Content-Length: ") {
185 content_length = Some(len_str.parse().context("invalid Content-Length")?);
186 }
187 }
188
189 content_length.context("missing Content-Length header")
190 }
191}
192
193fn classify_message(value: &Value) -> anyhow::Result<JsonRpcMessage> {
194 if let Some(id) = value.get("id") {
196 if value.get("result").is_some() || value.get("error").is_some() {
197 let id = id.as_i64().context("response id must be an integer")?;
198 return Ok(JsonRpcMessage::Response {
199 id,
200 result: value.get("result").cloned(),
201 error: value.get("error").cloned(),
202 });
203 }
204
205 if let Some(method) = value.get("method").and_then(Value::as_str) {
207 return Ok(JsonRpcMessage::ServerRequest {
208 id: id.clone(),
209 method: method.to_string(),
210 params: value.get("params").cloned(),
211 });
212 }
213 }
214
215 if let Some(method) = value.get("method").and_then(Value::as_str) {
217 return Ok(JsonRpcMessage::Notification {
218 method: method.to_string(),
219 params: value.get("params").cloned(),
220 });
221 }
222
223 bail!("unrecognized JSON-RPC message: {value}")
224}
225
226#[must_use]
228pub fn frame_message(payload: &Value) -> Vec<u8> {
229 let body = serde_json::to_string(payload).unwrap_or_default();
230 let header = format!("Content-Length: {}\r\n\r\n", body.len());
231 let mut msg = header.into_bytes();
232 msg.extend_from_slice(body.as_bytes());
233 msg
234}
235
236#[cfg(test)]
237mod tests {
238 use serde_json::json;
239
240 use super::*;
241
242 #[test]
243 fn frame_encode_format() {
244 let payload = json!({"jsonrpc": "2.0", "id": 1, "method": "test"});
245 let framed = frame_message(&payload);
246 let framed_str = String::from_utf8(framed).unwrap();
247
248 assert!(framed_str.starts_with("Content-Length: "));
249 assert!(framed_str.contains("\r\n\r\n"));
250
251 let parts: Vec<&str> = framed_str.splitn(2, "\r\n\r\n").collect();
252 let header = parts[0];
253 let body = parts[1];
254
255 let declared_len: usize = header
256 .strip_prefix("Content-Length: ")
257 .unwrap()
258 .parse()
259 .unwrap();
260 assert_eq!(declared_len, body.len());
261 }
262
263 #[test]
264 fn classify_response() {
265 let msg = json!({"jsonrpc": "2.0", "id": 1, "result": {"capabilities": {}}});
266 let classified = classify_message(&msg).unwrap();
267 assert!(matches!(classified, JsonRpcMessage::Response { id: 1, .. }));
268 }
269
270 #[test]
271 fn classify_error_response() {
272 let msg = json!({"jsonrpc": "2.0", "id": 2, "error": {"code": -32600, "message": "bad"}});
273 let classified = classify_message(&msg).unwrap();
274 assert!(matches!(
275 classified,
276 JsonRpcMessage::Response {
277 id: 2,
278 error: Some(_),
279 ..
280 }
281 ));
282 }
283
284 #[test]
285 fn classify_notification() {
286 let msg =
287 json!({"jsonrpc": "2.0", "method": "textDocument/publishDiagnostics", "params": {}});
288 let classified = classify_message(&msg).unwrap();
289 assert!(
290 matches!(classified, JsonRpcMessage::Notification { ref method, .. } if method == "textDocument/publishDiagnostics")
291 );
292 }
293
294 #[test]
295 fn classify_server_request() {
296 let msg = json!({"jsonrpc": "2.0", "id": 5, "method": "window/workDoneProgress/create", "params": {}});
297 let classified = classify_message(&msg).unwrap();
298 assert!(
299 matches!(classified, JsonRpcMessage::ServerRequest { ref method, .. } if method == "window/workDoneProgress/create")
300 );
301 }
302
303 #[test]
304 fn request_ids_increment() {
305 let next_id = AtomicI64::new(1);
306
307 let id1 = next_id.fetch_add(1, Ordering::SeqCst);
308 let id2 = next_id.fetch_add(1, Ordering::SeqCst);
309 let id3 = next_id.fetch_add(1, Ordering::SeqCst);
310
311 assert_eq!(id1, 1);
312 assert_eq!(id2, 2);
313 assert_eq!(id3, 3);
314 }
315
316 #[test]
317 fn frame_message_content_length_matches_body() {
318 let payload = json!({"jsonrpc": "2.0", "method": "textDocument/didOpen", "params": {}});
319 let framed = frame_message(&payload);
320 let text = String::from_utf8(framed).unwrap();
321 let (header, body) = text.split_once("\r\n\r\n").unwrap();
322 let declared: usize = header
323 .strip_prefix("Content-Length: ")
324 .unwrap()
325 .parse()
326 .unwrap();
327 assert_eq!(declared, body.len());
328 assert!(!body.is_empty());
329 }
330
331 #[test]
332 fn classify_unrecognized_message_returns_error() {
333 let msg = json!({"jsonrpc": "2.0"});
334 let result = classify_message(&msg);
335 assert!(result.is_err(), "message with no method or id should error");
336 }
337}