use std::path::Path;
use std::sync::atomic::{AtomicI64, Ordering};
use anyhow::{bail, Context};
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tracing::debug;
#[derive(Debug)]
pub enum JsonRpcMessage {
Response {
id: i64,
result: Option<Value>,
error: Option<Value>,
},
Notification {
method: String,
params: Option<Value>,
},
ServerRequest {
id: Value,
method: String,
params: Option<Value>,
},
}
pub struct LspTransport {
child: Child,
writer: BufWriter<ChildStdin>,
reader: BufReader<ChildStdout>,
next_id: AtomicI64,
}
impl LspTransport {
pub fn spawn(binary: &str, args: &[&str], cwd: &Path) -> anyhow::Result<Self> {
let mut child = Command::new(binary)
.args(args)
.current_dir(cwd)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null())
.kill_on_drop(true)
.spawn()
.with_context(|| format!("failed to spawn LSP server: {binary}"))?;
let stdin = child.stdin.take().context("failed to open LSP stdin")?;
let stdout = child.stdout.take().context("failed to open LSP stdout")?;
Ok(Self {
child,
writer: BufWriter::new(stdin),
reader: BufReader::new(stdout),
next_id: AtomicI64::new(1),
})
}
pub async fn send_request(&mut self, method: &str, params: Value) -> anyhow::Result<i64> {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let message = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
self.write_message(&message).await?;
debug!("sent request id={id} method={method}");
Ok(id)
}
pub async fn send_notification(&mut self, method: &str, params: Value) -> anyhow::Result<()> {
let message = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});
self.write_message(&message).await?;
debug!("sent notification method={method}");
Ok(())
}
pub async fn read_message(&mut self) -> anyhow::Result<JsonRpcMessage> {
let content_length = self.read_headers().await?;
let mut body = vec![0u8; content_length];
self.reader.read_exact(&mut body).await?;
let value: Value = serde_json::from_slice(&body)?;
classify_message(&value)
}
pub async fn kill(&mut self) -> anyhow::Result<()> {
self.child
.kill()
.await
.context("failed to kill LSP process")?;
let _ = self.child.wait().await; Ok(())
}
#[must_use]
pub fn is_alive(&mut self) -> bool {
self.child.try_wait().ok().flatten().is_none()
}
pub async fn write_raw(&mut self, data: &[u8]) -> anyhow::Result<()> {
self.writer.write_all(data).await?;
Ok(())
}
pub async fn flush(&mut self) -> anyhow::Result<()> {
self.writer.flush().await?;
Ok(())
}
async fn write_message(&mut self, message: &Value) -> anyhow::Result<()> {
let body = serde_json::to_string(message)?;
let header = format!("Content-Length: {}\r\n\r\n", body.len());
self.writer.write_all(header.as_bytes()).await?;
self.writer.write_all(body.as_bytes()).await?;
self.writer.flush().await?;
Ok(())
}
pub async fn flush_writer(&mut self) -> anyhow::Result<()> {
self.writer.flush().await?;
Ok(())
}
async fn read_headers(&mut self) -> anyhow::Result<usize> {
let mut content_length: Option<usize> = None;
loop {
let mut line = String::new();
let bytes_read = self.reader.read_line(&mut line).await?;
if bytes_read == 0 {
bail!("LSP server closed its stdout");
}
let trimmed = line.trim();
if trimmed.is_empty() {
break;
}
if let Some(len_str) = trimmed.strip_prefix("Content-Length: ") {
content_length = Some(len_str.parse().context("invalid Content-Length")?);
}
}
content_length.context("missing Content-Length header")
}
}
fn classify_message(value: &Value) -> anyhow::Result<JsonRpcMessage> {
if let Some(id) = value.get("id") {
if value.get("result").is_some() || value.get("error").is_some() {
let id = id.as_i64().context("response id must be an integer")?;
return Ok(JsonRpcMessage::Response {
id,
result: value.get("result").cloned(),
error: value.get("error").cloned(),
});
}
if let Some(method) = value.get("method").and_then(Value::as_str) {
return Ok(JsonRpcMessage::ServerRequest {
id: id.clone(),
method: method.to_string(),
params: value.get("params").cloned(),
});
}
}
if let Some(method) = value.get("method").and_then(Value::as_str) {
return Ok(JsonRpcMessage::Notification {
method: method.to_string(),
params: value.get("params").cloned(),
});
}
bail!("unrecognized JSON-RPC message: {value}")
}
#[must_use]
pub fn frame_message(payload: &Value) -> Vec<u8> {
let body = serde_json::to_string(payload).unwrap_or_default();
let header = format!("Content-Length: {}\r\n\r\n", body.len());
let mut msg = header.into_bytes();
msg.extend_from_slice(body.as_bytes());
msg
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn frame_encode_format() {
let payload = json!({"jsonrpc": "2.0", "id": 1, "method": "test"});
let framed = frame_message(&payload);
let framed_str = String::from_utf8(framed).unwrap();
assert!(framed_str.starts_with("Content-Length: "));
assert!(framed_str.contains("\r\n\r\n"));
let parts: Vec<&str> = framed_str.splitn(2, "\r\n\r\n").collect();
let header = parts[0];
let body = parts[1];
let declared_len: usize = header
.strip_prefix("Content-Length: ")
.unwrap()
.parse()
.unwrap();
assert_eq!(declared_len, body.len());
}
#[test]
fn classify_response() {
let msg = json!({"jsonrpc": "2.0", "id": 1, "result": {"capabilities": {}}});
let classified = classify_message(&msg).unwrap();
assert!(matches!(classified, JsonRpcMessage::Response { id: 1, .. }));
}
#[test]
fn classify_error_response() {
let msg = json!({"jsonrpc": "2.0", "id": 2, "error": {"code": -32600, "message": "bad"}});
let classified = classify_message(&msg).unwrap();
assert!(matches!(
classified,
JsonRpcMessage::Response {
id: 2,
error: Some(_),
..
}
));
}
#[test]
fn classify_notification() {
let msg =
json!({"jsonrpc": "2.0", "method": "textDocument/publishDiagnostics", "params": {}});
let classified = classify_message(&msg).unwrap();
assert!(
matches!(classified, JsonRpcMessage::Notification { ref method, .. } if method == "textDocument/publishDiagnostics")
);
}
#[test]
fn classify_server_request() {
let msg = json!({"jsonrpc": "2.0", "id": 5, "method": "window/workDoneProgress/create", "params": {}});
let classified = classify_message(&msg).unwrap();
assert!(
matches!(classified, JsonRpcMessage::ServerRequest { ref method, .. } if method == "window/workDoneProgress/create")
);
}
#[test]
fn request_ids_increment() {
let next_id = AtomicI64::new(1);
let id1 = next_id.fetch_add(1, Ordering::SeqCst);
let id2 = next_id.fetch_add(1, Ordering::SeqCst);
let id3 = next_id.fetch_add(1, Ordering::SeqCst);
assert_eq!(id1, 1);
assert_eq!(id2, 2);
assert_eq!(id3, 3);
}
#[test]
fn frame_message_content_length_matches_body() {
let payload = json!({"jsonrpc": "2.0", "method": "textDocument/didOpen", "params": {}});
let framed = frame_message(&payload);
let text = String::from_utf8(framed).unwrap();
let (header, body) = text.split_once("\r\n\r\n").unwrap();
let declared: usize = header
.strip_prefix("Content-Length: ")
.unwrap()
.parse()
.unwrap();
assert_eq!(declared, body.len());
assert!(!body.is_empty());
}
#[test]
fn classify_unrecognized_message_returns_error() {
let msg = json!({"jsonrpc": "2.0"});
let result = classify_message(&msg);
assert!(result.is_err(), "message with no method or id should error");
}
}