use std::io;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
pub struct UpstreamReader<R> {
reader: BufReader<R>,
buf: String,
}
impl<R: AsyncRead + Unpin> UpstreamReader<R> {
pub fn new(reader: R) -> Self {
Self {
reader: BufReader::new(reader),
buf: String::new(),
}
}
pub async fn next_message(&mut self) -> io::Result<Option<String>> {
loop {
self.buf.clear();
let n = self.reader.read_line(&mut self.buf).await?;
if n == 0 {
return Ok(None); }
let trimmed = self.buf.trim();
if trimmed.is_empty() {
continue;
}
if let Some(rest) = trimmed.strip_prefix("Content-Length:") {
let len: usize = rest
.trim()
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
loop {
self.buf.clear();
let header_n = self.reader.read_line(&mut self.buf).await?;
if header_n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"EOF in Content-Length headers",
));
}
if self.buf.trim().is_empty() {
break;
}
}
let mut body = vec![0u8; len];
self.reader.read_exact(&mut body).await?;
let msg =
String::from_utf8(body).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
return Ok(Some(msg));
}
return Ok(Some(trimmed.to_string()));
}
}
}
pub async fn write_newline_delimited<W: AsyncWrite + Unpin + ?Sized>(
writer: &mut W,
json: &str,
) -> io::Result<()> {
writer.write_all(json.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
Ok(())
}
pub fn encode_content_length(json: &str) -> Vec<u8> {
let header = format!("Content-Length: {}\r\n\r\n", json.len());
let mut buf = Vec::with_capacity(header.len() + json.len());
buf.extend_from_slice(header.as_bytes());
buf.extend_from_slice(json.as_bytes());
buf
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_newline_delimited() {
let input = b"{\"jsonrpc\":\"2.0\",\"id\":1}\n";
let mut reader = UpstreamReader::new(&input[..]);
let msg = reader.next_message().await.unwrap().unwrap();
assert_eq!(msg, "{\"jsonrpc\":\"2.0\",\"id\":1}");
}
#[tokio::test]
async fn test_parse_content_length_frame() {
let body = r#"{"jsonrpc":"2.0","id":2}"#;
let framed = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
let mut reader = UpstreamReader::new(framed.as_bytes());
let msg = reader.next_message().await.unwrap().unwrap();
assert_eq!(msg, body);
}
#[tokio::test]
async fn test_parse_content_length_with_extra_header() {
let body = r#"{"jsonrpc":"2.0","id":3}"#;
let framed = format!(
"Content-Length: {}\r\nContent-Type: application/json\r\n\r\n{}",
body.len(),
body
);
let mut reader = UpstreamReader::new(framed.as_bytes());
let msg = reader.next_message().await.unwrap().unwrap();
assert_eq!(msg, body);
}
#[tokio::test]
async fn test_parse_partial_content_length() {
let body = r#"{"jsonrpc":"2.0","method":"test"}"#;
let framed = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
let mut reader = UpstreamReader::new(framed.as_bytes());
let msg = reader.next_message().await.unwrap().unwrap();
assert_eq!(msg, body);
}
#[tokio::test]
async fn test_parse_multiple_newline_messages() {
let input = b"{\"id\":1}\n{\"id\":2}\n{\"id\":3}\n";
let mut reader = UpstreamReader::new(&input[..]);
assert_eq!(
reader.next_message().await.unwrap().unwrap(),
"{\"id\":1}"
);
assert_eq!(
reader.next_message().await.unwrap().unwrap(),
"{\"id\":2}"
);
assert_eq!(
reader.next_message().await.unwrap().unwrap(),
"{\"id\":3}"
);
assert!(reader.next_message().await.unwrap().is_none());
}
#[tokio::test]
async fn test_write_newline_delimited() {
let mut buf = Vec::new();
write_newline_delimited(&mut buf, r#"{"id":1}"#)
.await
.unwrap();
assert_eq!(buf, b"{\"id\":1}\n");
}
#[tokio::test]
async fn test_content_length_frame_roundtrip() {
let original = r#"{"jsonrpc":"2.0","id":99,"method":"ping"}"#;
let encoded = encode_content_length(original);
let mut reader = UpstreamReader::new(&encoded[..]);
let decoded = reader.next_message().await.unwrap().unwrap();
assert_eq!(decoded, original);
}
#[tokio::test]
async fn test_eof_returns_none() {
let input = b"";
let mut reader = UpstreamReader::new(&input[..]);
assert!(reader.next_message().await.unwrap().is_none());
}
#[tokio::test]
async fn test_blank_lines_skipped() {
let input = b"\n\n{\"id\":1}\n\n";
let mut reader = UpstreamReader::new(&input[..]);
let msg = reader.next_message().await.unwrap().unwrap();
assert_eq!(msg, "{\"id\":1}");
}
}