use std::io::{self, ErrorKind};
use std::str;
use anyhow::{bail, ensure, Context, Result};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::trace;
use crate::lsp::jsonrpc::Message;
pub struct LspReader<R> {
reader: R,
batch: Vec<Message>,
buffer: Vec<u8>,
tag: &'static str,
}
pub struct Header {
pub content_length: usize,
#[allow(dead_code)]
pub content_type: Option<String>,
}
impl<R> LspReader<R>
where
R: AsyncBufRead + Unpin,
{
pub fn new(reader: R, tag: &'static str) -> Self {
LspReader {
reader,
batch: Vec::new(),
buffer: Vec::with_capacity(1024),
tag,
}
}
pub async fn read_header(&mut self) -> Result<Option<Header>> {
let mut content_type = None;
let mut content_length = None;
loop {
self.buffer.clear();
match self.reader.read_until(b'\n', &mut self.buffer).await {
Ok(0) => return Ok(None), Ok(_) => {}
Err(err) => match err.kind() {
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => return Ok(None),
_ => bail!(err),
},
}
let header_text = self
.buffer
.strip_suffix(b"\r\n")
.context(r"malformed header, missing `\r\n` terminator")?;
let header_text = str::from_utf8(header_text)
.context("malformed header, ascii encoding is a subset of utf-8")?;
if header_text.is_empty() {
break;
}
let (name, value) = match header_text.split_once(": ") {
Some(split) => split,
None => bail!("malformed header, missing value separator: {}", header_text),
};
match name.to_ascii_lowercase().as_str() {
"content-type" => {
ensure!(content_type.is_none(), "repeated header content-type");
content_type = Some(value.to_owned());
}
"content-length" => {
ensure!(content_length.is_none(), "repeated header content-length");
content_length = Some(value.parse::<usize>().context("content-length header")?);
}
_ => bail!("unknown header name: {name:?}"),
}
}
let content_length = content_length.context("missing required header content-length")?;
Ok(Some(Header {
content_length,
content_type,
}))
}
pub async fn read_message(&mut self) -> Result<Option<Message>> {
if let Some(pending) = self.batch.pop() {
trace!(message = ?pending, "<- {}", self.tag);
return Ok(Some(pending));
}
let header = self.read_header().await.context("parsing header")?;
let header = match header {
Some(header) => header,
None => return Ok(None),
};
self.buffer.clear();
self.buffer.resize(header.content_length, 0);
if let Err(err) = self.reader.read_exact(&mut self.buffer).await {
match err.kind() {
ErrorKind::UnexpectedEof
| ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe => return Ok(None),
_ => bail!(err),
}
}
let bytes = self.buffer.as_slice();
let body = str::from_utf8(bytes)
.with_context(|| {
let lossy_utf8 = String::from_utf8_lossy(bytes);
format!("parsing body `{lossy_utf8}`")
})
.context("parsing LSP message")?;
if body.starts_with('[') {
self.batch = serde_json::from_str(body)
.with_context(|| format!("parsing body `{body}`"))
.context("parsing LSP message")?;
self.batch.reverse();
let message = self.batch.pop().context("received an empty batch")?;
trace!(?message, "<- {}", self.tag);
Ok(Some(message))
} else {
let message = serde_json::from_str(body)
.with_context(|| format!("parsing body `{body}`"))
.context("parsing LSP message")?;
trace!(?message, "<- {}", self.tag);
Ok(Some(message))
}
}
}
pub struct LspWriter<W> {
writer: W,
buffer: Vec<u8>,
tag: &'static str,
}
impl<W> LspWriter<W>
where
W: AsyncWrite + Unpin,
{
pub fn new(writer: W, tag: &'static str) -> Self {
LspWriter {
writer,
buffer: Vec::with_capacity(1024),
tag,
}
}
pub async fn write_message(&mut self, message: &Message) -> io::Result<()> {
trace!(?message, "-> {}", self.tag);
self.buffer.clear();
serde_json::to_writer(&mut self.buffer, message).expect("BUG: invalid message");
self.writer
.write_all(format!("Content-Length: {}\r\n\r\n", self.buffer.len()).as_bytes())
.await?;
self.writer.write_all(&self.buffer).await?;
self.writer.flush().await
}
}