use std::io;
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const CONTENT_LENGTH_PREFIX: &str = "Content-Length:";
const MAX_BODY_BYTES: usize = 16 * 1024 * 1024;
pub async fn encode_frame<W>(writer: &mut W, body: &[u8]) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let header = format!("Content-Length: {}\r\n\r\n", body.len());
writer.write_all(header.as_bytes()).await?;
writer.write_all(body).await?;
writer.flush().await?;
Ok(())
}
pub async fn decode_frame<R>(reader: &mut R) -> io::Result<Vec<u8>>
where
R: AsyncBufRead + Unpin,
{
let mut content_length: Option<usize> = None;
let mut line = String::new();
loop {
line.clear();
let n = reader.read_line(&mut line).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"stream closed before complete frame header",
));
}
let trimmed = line.trim_end_matches(['\r', '\n']);
if trimmed.is_empty() {
break;
}
if let Some(rest) = trimmed.strip_prefix(CONTENT_LENGTH_PREFIX) {
let value = rest.trim();
let parsed: usize = value.parse().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("malformed Content-Length value: {value:?}"),
)
})?;
if parsed > MAX_BODY_BYTES {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("body length {parsed} exceeds cap {MAX_BODY_BYTES}"),
));
}
content_length = Some(parsed);
}
}
let len = content_length.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"frame header missing Content-Length",
)
})?;
let mut body = vec![0u8; len];
reader.read_exact(&mut body).await?;
Ok(body)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::BufReader;
async fn roundtrip(body: &[u8]) -> Vec<u8> {
let mut buf: Vec<u8> = Vec::new();
encode_frame(&mut buf, body).await.unwrap();
let mut reader = BufReader::new(buf.as_slice());
decode_frame(&mut reader).await.unwrap()
}
#[tokio::test]
async fn roundtrip_simple_json_body() {
let body = br#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#;
assert_eq!(roundtrip(body).await, body);
}
#[tokio::test]
async fn roundtrip_unicode_body_preserves_byte_count() {
let body = "{\"text\":\"hello 🦀 rust\"}".as_bytes();
assert_eq!(roundtrip(body).await, body);
}
#[tokio::test]
async fn roundtrip_empty_body() {
assert_eq!(roundtrip(b"").await, b"");
}
#[tokio::test]
async fn encode_format_is_content_length_blank_body() {
let mut buf: Vec<u8> = Vec::new();
encode_frame(&mut buf, b"hello").await.unwrap();
let s = String::from_utf8(buf).unwrap();
assert_eq!(s, "Content-Length: 5\r\n\r\nhello");
}
#[tokio::test]
async fn decode_multiple_frames_from_one_reader() {
let mut buf: Vec<u8> = Vec::new();
encode_frame(&mut buf, b"first").await.unwrap();
encode_frame(&mut buf, b"second").await.unwrap();
encode_frame(&mut buf, b"third").await.unwrap();
let mut reader = BufReader::new(buf.as_slice());
assert_eq!(decode_frame(&mut reader).await.unwrap(), b"first");
assert_eq!(decode_frame(&mut reader).await.unwrap(), b"second");
assert_eq!(decode_frame(&mut reader).await.unwrap(), b"third");
}
#[tokio::test]
async fn decode_skips_unknown_headers_before_blank_line() {
let frame = "Content-Type: application/vscode-jsonrpc; charset=utf-8\r\n\
Content-Length: 5\r\n\
\r\n\
hello";
let mut reader = BufReader::new(frame.as_bytes());
let body = decode_frame(&mut reader).await.unwrap();
assert_eq!(body, b"hello");
}
#[tokio::test]
async fn decode_tolerates_lone_newline_line_endings() {
let frame = "Content-Length: 5\n\nhello";
let mut reader = BufReader::new(frame.as_bytes());
let body = decode_frame(&mut reader).await.unwrap();
assert_eq!(body, b"hello");
}
#[tokio::test]
async fn decode_errors_on_missing_content_length() {
let frame = "Content-Type: foo\r\n\r\nbody";
let mut reader = BufReader::new(frame.as_bytes());
let err = decode_frame(&mut reader).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("Content-Length"));
}
#[tokio::test]
async fn decode_errors_on_malformed_content_length() {
let frame = "Content-Length: not-a-number\r\n\r\n";
let mut reader = BufReader::new(frame.as_bytes());
let err = decode_frame(&mut reader).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("malformed"));
}
#[tokio::test]
async fn decode_errors_when_content_length_exceeds_cap() {
let huge = MAX_BODY_BYTES + 1;
let frame = format!("Content-Length: {huge}\r\n\r\n");
let mut reader = BufReader::new(frame.as_bytes());
let err = decode_frame(&mut reader).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("exceeds cap"));
}
#[tokio::test]
async fn decode_errors_on_eof_mid_header() {
let frame = "Content-Length: 5\r\n"; let mut reader = BufReader::new(frame.as_bytes());
let err = decode_frame(&mut reader).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn decode_errors_on_eof_mid_body() {
let frame = "Content-Length: 10\r\n\r\nabc";
let mut reader = BufReader::new(frame.as_bytes());
let err = decode_frame(&mut reader).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn decode_handles_byte_at_a_time_reader() {
let mut buf: Vec<u8> = Vec::new();
encode_frame(&mut buf, b"hello").await.unwrap();
let (client_read, mut server_write) = tokio::io::duplex(64);
let writer = tokio::spawn(async move {
for byte in buf {
server_write.write_all(&[byte]).await.unwrap();
tokio::task::yield_now().await;
}
drop(server_write);
});
let mut reader = BufReader::new(client_read);
let body = decode_frame(&mut reader).await.unwrap();
writer.await.unwrap();
assert_eq!(body, b"hello");
}
#[tokio::test]
async fn encode_decode_through_duplex_pipe() {
let (mut client_side, mut server_side) = tokio::io::duplex(1024);
let writer = tokio::spawn(async move {
for msg in [b"alpha".as_slice(), b"beta", b"gamma"] {
encode_frame(&mut server_side, msg).await.unwrap();
}
drop(server_side);
});
let mut reader = BufReader::new(&mut client_side);
let mut got = Vec::new();
for _ in 0..3 {
got.push(decode_frame(&mut reader).await.unwrap());
}
writer.await.unwrap();
assert_eq!(
got,
vec![b"alpha".to_vec(), b"beta".to_vec(), b"gamma".to_vec()]
);
}
}