use std::{
pin::Pin,
task::{Context, Poll},
};
use bytes::{BufMut, Bytes, BytesMut};
use futures::Stream;
use log::*;
#[derive(Debug, thiserror::Error)]
pub enum ProtobufStreamError {
#[error("Reqwest error: {0}")]
ReqwestError(#[from] reqwest::Error),
#[error("Prost decode error: {0}")]
DecodeError(#[from] prost::DecodeError),
#[error("Message size {len} exceeds maximum allowed size of {max} bytes")]
MessageSizeExceeded { len: usize, max: usize },
}
pub struct ProtobufStream<TMsg> {
bytes_stream: Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
buf: BytesMut,
max_message_size: usize,
_marker: std::marker::PhantomData<TMsg>,
}
impl<TMsg> ProtobufStream<TMsg> {
pub fn new(bytes_stream: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static) -> Self {
Self::with_max_message_size(bytes_stream, 16 * 1024 * 1024) }
pub fn with_max_message_size(
bytes_stream: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static,
max_message_size: usize,
) -> Self {
Self {
bytes_stream: Box::pin(bytes_stream),
buf: BytesMut::with_capacity(max_message_size),
max_message_size,
_marker: std::marker::PhantomData,
}
}
}
impl<TMsg: prost::Message + Default + Unpin> Stream for ProtobufStream<TMsg> {
type Item = Result<TMsg, ProtobufStreamError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
match this.bytes_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.buf.put(bytes);
let tmp_slice = &this.buf[..];
if tmp_slice.len() < 10 && tmp_slice.iter().take(10).all(|byte| byte & 0x80 != 0) {
continue;
}
let len = prost::decode_length_delimiter(tmp_slice)?;
if len > this.max_message_size {
return Poll::Ready(Some(Err(ProtobufStreamError::MessageSizeExceeded {
len,
max: this.max_message_size,
})));
}
let len_delim_len = prost::length_delimiter_len(len);
if this.buf.len() < len + len_delim_len {
trace!(
"Buffering: have {} bytes, need {} bytes (including {} bytes for length delimiter)",
this.buf.len(),
len + len_delim_len,
len_delim_len
);
continue;
}
let msg = TMsg::decode_length_delimited(&mut this.buf)?;
break Poll::Ready(Some(Ok(msg)));
},
Poll::Ready(Some(Err(e))) => {
break Poll::Ready(Some(Err(e.into())));
},
Poll::Ready(None) => break Poll::Ready(None),
Poll::Pending => break Poll::Pending,
}
}
}
}