use bytes::Bytes;
use futures_util::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use tonic::Status;
pub type MessageStream = Pin<Box<dyn Stream<Item = Result<Bytes, Status>> + Send>>;
pub struct StreamingRequest {
pub service_name: String,
pub method_name: String,
pub message_stream: MessageStream,
pub metadata: tonic::metadata::MetadataMap,
}
pub fn single_message_stream(message: Bytes) -> MessageStream {
Box::pin(futures_util::stream::once(async move { Ok(message) }))
}
pub fn limit_message_stream(inner: MessageStream, limit: Option<usize>) -> MessageStream {
match limit {
None => inner,
Some(max_bytes) => Box::pin(LimitedMessageStream {
inner,
limit: max_bytes,
consumed: 0,
exhausted: false,
}),
}
}
struct LimitedMessageStream {
inner: MessageStream,
limit: usize,
consumed: usize,
exhausted: bool,
}
impl Stream for LimitedMessageStream {
type Item = Result<Bytes, Status>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.exhausted {
return Poll::Ready(None);
}
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(status))) => Poll::Ready(Some(Err(status))),
Poll::Ready(Some(Ok(bytes))) => {
self.consumed = self.consumed.saturating_add(bytes.len());
if self.consumed > self.limit {
self.exhausted = true;
Poll::Ready(Some(Err(Status::resource_exhausted(format!(
"stream response size exceeded {} bytes",
self.limit
)))))
} else {
Poll::Ready(Some(Ok(bytes)))
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
#[tokio::test]
async fn test_single_message_stream() {
let mut stream = single_message_stream(Bytes::from("single"));
let msg = stream.next().await.unwrap().unwrap();
assert_eq!(msg, Bytes::from("single"));
assert!(stream.next().await.is_none());
}
fn make_four_message_stream() -> MessageStream {
let messages: Vec<Result<Bytes, Status>> = vec![
Ok(Bytes::from(vec![0u8; 100])),
Ok(Bytes::from(vec![1u8; 100])),
Ok(Bytes::from(vec![2u8; 100])),
Ok(Bytes::from(vec![3u8; 100])),
];
Box::pin(futures_util::stream::iter(messages))
}
#[tokio::test]
async fn test_limit_message_stream_none_passes_all_messages() {
let stream = make_four_message_stream();
let mut limited = limit_message_stream(stream, None);
let mut count = 0;
while let Some(item) = limited.next().await {
assert!(item.is_ok(), "expected Ok but got error: {:?}", item);
count += 1;
}
assert_eq!(count, 4, "all four messages should pass through when limit is None");
}
#[tokio::test]
async fn test_limit_message_stream_exact_limit_passes_all_messages() {
let stream = make_four_message_stream();
let mut limited = limit_message_stream(stream, Some(400));
let mut count = 0;
while let Some(item) = limited.next().await {
assert!(item.is_ok(), "expected Ok but got error: {:?}", item);
count += 1;
}
assert_eq!(count, 4, "all four messages should pass when limit == total size");
}
#[tokio::test]
async fn test_limit_message_stream_exceeded_aborts_stream() {
let stream = make_four_message_stream();
let mut limited = limit_message_stream(stream, Some(200));
let item1 = limited.next().await.expect("should have item 1");
assert!(item1.is_ok(), "item 1 should be Ok");
let item2 = limited.next().await.expect("should have item 2");
assert!(item2.is_ok(), "item 2 should be Ok");
let item3 = limited.next().await.expect("should have item 3");
let err = item3.expect_err("item 3 should be a resource_exhausted error");
assert_eq!(err.code(), tonic::Code::ResourceExhausted);
assert!(
err.message().contains("200"),
"error message should mention the limit: {}",
err.message()
);
let item4 = limited.next().await;
assert!(item4.is_none(), "stream should be terminated after resource_exhausted");
}
#[tokio::test]
async fn test_limit_message_stream_propagates_inner_errors() {
let messages: Vec<Result<Bytes, Status>> = vec![
Ok(Bytes::from(vec![0u8; 50])),
Err(Status::internal("upstream failure")),
Ok(Bytes::from(vec![0u8; 50])),
];
let stream = Box::pin(futures_util::stream::iter(messages));
let mut limited = limit_message_stream(stream, Some(1000));
let item1 = limited.next().await.unwrap();
assert!(item1.is_ok());
let item2 = limited.next().await.unwrap();
let err = item2.expect_err("should propagate inner error");
assert_eq!(err.code(), tonic::Code::Internal);
assert_eq!(err.message(), "upstream failure");
}
}