use crate::{Encoding, Status};
use futures_lite::{AsyncRead, Stream};
use std::{
pin::Pin,
task::{Context, Poll},
};
pub use crate::encoding::DEFAULT_MAX_MESSAGE_SIZE;
const PREFIX_LEN: usize = 5;
pub struct MessageStream<T, R> {
reader: R,
state: ReadState,
max_message_size: usize,
encoding: Encoding,
decode: fn(&[u8]) -> Result<T, Status>,
}
pub(crate) enum ReadState {
ReadingPrefix {
buf: [u8; PREFIX_LEN],
filled: usize,
},
ReadingPayload {
compressed: bool,
payload: Vec<u8>,
filled: usize,
},
Done,
}
impl ReadState {
pub(crate) fn new() -> Self {
Self::ReadingPrefix {
buf: [0u8; PREFIX_LEN],
filled: 0,
}
}
}
impl<T, R> MessageStream<T, R> {
pub fn new(reader: R, decode: fn(&[u8]) -> Result<T, Status>) -> Self {
Self {
reader,
state: ReadState::new(),
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
encoding: Encoding::Identity,
decode,
}
}
pub fn with_max_message_size(mut self, max: usize) -> Self {
self.max_message_size = max;
self
}
pub fn with_encoding(mut self, encoding: Encoding) -> Self {
self.encoding = encoding;
self
}
}
impl<T, R> Stream for MessageStream<T, R>
where
R: AsyncRead + Unpin,
T: 'static,
{
type Item = Result<T, Status>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
poll_read_message(
Pin::new(&mut this.reader),
&mut this.state,
cx,
this.decode,
this.encoding,
this.max_message_size,
)
}
}
pub(crate) fn poll_read_message<T, R>(
mut reader: Pin<&mut R>,
state: &mut ReadState,
cx: &mut Context<'_>,
decode: fn(&[u8]) -> Result<T, Status>,
encoding: Encoding,
max_message_size: usize,
) -> Poll<Option<Result<T, Status>>>
where
R: AsyncRead + ?Sized,
{
loop {
match state {
ReadState::Done => return Poll::Ready(None),
ReadState::ReadingPrefix { buf, filled } => {
while *filled < PREFIX_LEN {
let dst = &mut buf[*filled..];
match reader.as_mut().poll_read(cx, dst) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
*state = ReadState::Done;
return Poll::Ready(Some(Err(Status::unavailable(format!(
"read error: {e}"
)))));
}
Poll::Ready(Ok(0)) => {
if *filled == 0 {
*state = ReadState::Done;
return Poll::Ready(None);
} else {
*state = ReadState::Done;
return Poll::Ready(Some(Err(Status::internal(
"unexpected EOF in frame prefix",
))));
}
}
Poll::Ready(Ok(n)) => *filled += n,
}
}
let compressed = buf[0] != 0;
let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
if len > max_message_size {
*state = ReadState::Done;
return Poll::Ready(Some(Err(Status::resource_exhausted(format!(
"received message of {len} bytes exceeds limit of {max_message_size}"
)))));
}
*state = ReadState::ReadingPayload {
compressed,
payload: vec![0u8; len],
filled: 0,
};
}
ReadState::ReadingPayload {
compressed,
payload,
filled,
} => {
while *filled < payload.len() {
let dst = &mut payload[*filled..];
match reader.as_mut().poll_read(cx, dst) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
*state = ReadState::Done;
return Poll::Ready(Some(Err(Status::unavailable(format!(
"read error: {e}"
)))));
}
Poll::Ready(Ok(0)) => {
*state = ReadState::Done;
return Poll::Ready(Some(Err(Status::internal(
"unexpected EOF in frame payload",
))));
}
Poll::Ready(Ok(n)) => *filled += n,
}
}
let compressed = *compressed;
let payload = std::mem::take(payload);
*state = ReadState::new();
let bytes = if compressed {
if matches!(encoding, Encoding::Identity) {
return Poll::Ready(Some(Err(Status::internal(
"received compressed message but no encoding negotiated",
))));
}
match encoding.decompress(&payload, max_message_size) {
Ok(b) => b,
Err(status) => return Poll::Ready(Some(Err(status))),
}
} else {
payload
};
return Poll::Ready(Some(decode(&bytes)));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Code, Codec, codec::Prost};
use futures_lite::{StreamExt, future::block_on};
fn frame(payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(PREFIX_LEN + payload.len());
out.push(0); out.extend_from_slice(&(payload.len() as u32).to_be_bytes());
out.extend_from_slice(payload);
out
}
fn vec_decode(bytes: &[u8]) -> Result<Vec<u8>, Status> {
<Prost as Codec<Vec<u8>>>::decode(bytes)
}
type Stream<'a> = MessageStream<Vec<u8>, &'a [u8]>;
fn new_stream(bytes: &[u8]) -> Stream<'_> {
MessageStream::new(bytes, vec_decode)
}
#[test]
fn empty_input_yields_none() {
let bytes: &[u8] = &[];
let mut s = new_stream(bytes);
assert!(block_on(s.next()).is_none());
}
#[test]
fn single_empty_message() {
let body = frame(&[]);
let mut s = new_stream(&body[..]);
let msg = block_on(s.next()).unwrap().unwrap();
assert!(msg.is_empty());
assert!(block_on(s.next()).is_none());
}
#[test]
fn multiple_messages() {
let mut body = Vec::new();
body.extend_from_slice(&frame(&[0x0A, 0x02, b'h', b'i']));
body.extend_from_slice(&frame(&[0x0A, 0x03, b'b', b'y', b'e']));
let mut s = new_stream(&body[..]);
let m1 = block_on(s.next()).unwrap().unwrap();
let m2 = block_on(s.next()).unwrap().unwrap();
assert_eq!(m1, b"hi");
assert_eq!(m2, b"bye");
assert!(block_on(s.next()).is_none());
}
#[test]
fn partial_prefix_at_eof_is_error() {
let body = [0u8, 0u8, 0u8]; let mut s = new_stream(&body[..]);
let err = block_on(s.next()).unwrap().unwrap_err();
assert_eq!(err.code, Code::Internal);
assert!(block_on(s.next()).is_none());
}
#[test]
fn partial_payload_at_eof_is_error() {
let mut body = Vec::new();
body.push(0); body.extend_from_slice(&10u32.to_be_bytes()); body.extend_from_slice(&[1, 2, 3]); let mut s = new_stream(&body[..]);
let err = block_on(s.next()).unwrap().unwrap_err();
assert_eq!(err.code, Code::Internal);
}
#[test]
fn oversized_message_is_resource_exhausted() {
let mut body = Vec::new();
body.push(0);
body.extend_from_slice(&100u32.to_be_bytes());
let mut s = new_stream(&body[..]).with_max_message_size(50);
let err = block_on(s.next()).unwrap().unwrap_err();
assert_eq!(err.code, Code::ResourceExhausted);
}
#[test]
fn compressed_flag_with_identity_encoding_is_internal() {
let mut body = Vec::new();
body.push(1); body.extend_from_slice(&0u32.to_be_bytes());
let mut s = new_stream(&body[..]);
let err = block_on(s.next()).unwrap().unwrap_err();
assert_eq!(err.code, Code::Internal);
}
#[cfg(feature = "gzip")]
#[test]
fn compressed_frame_decompressed_with_gzip() {
let inner = [0x0Au8, 0x02, b'h', b'i'];
let compressed = Encoding::Gzip.compress(&inner).unwrap();
let mut body = Vec::new();
body.push(1); body.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
body.extend_from_slice(&compressed);
let mut s = new_stream(&body[..]).with_encoding(Encoding::Gzip);
let msg = block_on(s.next()).unwrap().unwrap();
assert_eq!(msg, b"hi");
}
#[test]
fn codec_decode_failure_propagates_invalid_argument() {
let body = frame(&[0xFF, 0xFF, 0xFF, 0xFF]);
let mut s = new_stream(&body[..]);
let err = block_on(s.next()).unwrap().unwrap_err();
assert_eq!(err.code, Code::InvalidArgument);
}
}