stream-transfer-limit 0.1.0

Byte-count transfer limits for fallible futures streams
Documentation
use crate::{ChunkLength, TransferLimit, TransferLimitError};
use bytes::Bytes;
use futures::{Stream, StreamExt, stream};
use std::{
    cell::RefCell,
    error::Error,
    fmt,
    rc::Rc,
    sync::{
        Arc,
        atomic::{AtomicUsize, Ordering},
    },
    task::Poll,
};

#[derive(Debug)]
struct TestError;

impl fmt::Display for TestError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str("test stream error")
    }
}

impl Error for TestError {}

#[derive(Debug)]
struct FakeChunk(usize);

impl ChunkLength for FakeChunk {
    fn chunk_len(&self) -> usize {
        self.0
    }
}

#[test]
fn exact_limit_is_allowed() {
    futures::executor::block_on(async {
        let seen = Rc::new(RefCell::new(Vec::new()));
        let chunks = stream::iter([
            Ok::<_, TestError>(Bytes::from_static(b"ab")),
            Ok(Bytes::from_static(b"cd")),
        ]);
        let mut stream = TransferLimit::new(4)
            .on_progress({
                let seen = Rc::clone(&seen);
                move |bytes_seen| seen.borrow_mut().push(bytes_seen)
            })
            .wrap(chunks);

        assert_eq!(
            stream.next().await.unwrap().unwrap(),
            Bytes::from_static(b"ab")
        );
        assert_eq!(
            stream.next().await.unwrap().unwrap(),
            Bytes::from_static(b"cd")
        );
        assert!(stream.next().await.is_none());
        assert_eq!(*seen.borrow(), vec![2, 4]);
    });
}

#[test]
fn limit_exceeded_across_chunks_fails_once_and_terminates() {
    futures::executor::block_on(async {
        let polls = Arc::new(AtomicUsize::new(0));
        let stream = {
            let polls = Arc::clone(&polls);
            futures::stream::poll_fn(move |_| {
                let poll = polls.fetch_add(1, Ordering::SeqCst);
                match poll {
                    0 => Poll::Ready(Some(Ok::<_, TestError>(vec![1, 2]))),
                    1 => Poll::Ready(Some(Ok(vec![3, 4, 5]))),
                    _ => Poll::Ready(Some(Ok(vec![6]))),
                }
            })
        };
        let mut stream = TransferLimit::new(4).wrap(stream);

        assert_eq!(stream.next().await.unwrap().unwrap(), vec![1, 2]);
        assert!(matches!(
            stream.next().await.unwrap(),
            Err(TransferLimitError::LimitExceeded {
                limit: 4,
                actual: 5
            })
        ));
        assert!(stream.next().await.is_none());
        assert_eq!(polls.load(Ordering::SeqCst), 2);
    });
}

#[test]
fn single_oversized_chunk_is_rejected() {
    futures::executor::block_on(async {
        let chunks = stream::iter([Ok::<_, TestError>(Bytes::from_static(b"abc"))]);
        let mut stream = TransferLimit::new(2).wrap(chunks);

        assert!(matches!(
            stream.next().await.unwrap(),
            Err(TransferLimitError::LimitExceeded {
                limit: 2,
                actual: 3
            })
        ));
        assert!(stream.next().await.is_none());
    });
}

#[test]
fn inner_stream_error_is_preserved_as_source() {
    futures::executor::block_on(async {
        let chunks = stream::iter([Err::<Bytes, _>(TestError)]);
        let mut stream = TransferLimit::new(10).wrap(chunks);

        let error = stream.next().await.unwrap().unwrap_err();
        assert!(!error.is_limit_exceeded());
        assert_eq!(
            Error::source(&error).map(ToString::to_string),
            Some("test stream error".to_owned())
        );
        assert!(matches!(
            error.into_inner(),
            Some(error) if error.to_string() == "test stream error"
        ));
    });
}

#[test]
fn progress_callback_receives_cumulative_byte_counts() {
    futures::executor::block_on(async {
        let seen = Rc::new(RefCell::new(Vec::new()));
        let chunks = stream::iter([
            Ok::<_, TestError>(vec![1]),
            Ok(vec![2, 3]),
            Ok(vec![4, 5, 6]),
        ]);
        let mut stream = TransferLimit::new(3)
            .on_progress({
                let seen = Rc::clone(&seen);
                move |bytes_seen| seen.borrow_mut().push(bytes_seen)
            })
            .wrap(chunks);

        assert_eq!(stream.next().await.unwrap().unwrap(), vec![1]);
        assert_eq!(stream.next().await.unwrap().unwrap(), vec![2, 3]);
        assert!(matches!(
            stream.next().await.unwrap(),
            Err(TransferLimitError::LimitExceeded {
                limit: 3,
                actual: 6
            })
        ));
        assert_eq!(*seen.borrow(), vec![1, 3, 6]);
    });
}

#[test]
fn unlimited_transfer_counts_bytes_without_enforcing_a_limit() {
    futures::executor::block_on(async {
        let seen = Rc::new(RefCell::new(Vec::new()));
        let chunks = stream::iter([Ok::<_, TestError>(vec![1, 2, 3]), Ok(vec![4, 5, 6, 7])]);
        let mut stream = TransferLimit::unlimited()
            .on_progress({
                let seen = Rc::clone(&seen);
                move |bytes_seen| seen.borrow_mut().push(bytes_seen)
            })
            .wrap(chunks);

        assert_eq!(stream.next().await.unwrap().unwrap(), vec![1, 2, 3]);
        assert_eq!(stream.next().await.unwrap().unwrap(), vec![4, 5, 6, 7]);
        assert!(stream.next().await.is_none());
        assert_eq!(*seen.borrow(), vec![3, 7]);
    });
}

#[test]
fn different_instances_can_use_different_dynamic_limits() {
    futures::executor::block_on(async {
        fn chunks() -> impl Stream<Item = Result<Vec<u8>, TestError>> {
            stream::iter([Ok(vec![1, 2, 3])])
        }

        let mut tenant_a = TransferLimit::new(2).wrap(chunks());
        let mut tenant_b = TransferLimit::new(3).wrap(chunks());

        assert!(
            tenant_a
                .next()
                .await
                .unwrap()
                .unwrap_err()
                .is_limit_exceeded()
        );
        assert_eq!(tenant_b.next().await.unwrap().unwrap(), vec![1, 2, 3]);
    });
}

#[test]
fn u64_counter_reports_u64_progress_and_errors() {
    futures::executor::block_on(async {
        let seen = Rc::new(RefCell::new(Vec::new()));
        let chunks = stream::iter([Ok::<_, TestError>(vec![1, 2]), Ok(vec![3, 4])]);
        let mut stream = TransferLimit::<u64>::from_limit(3)
            .on_progress({
                let seen = Rc::clone(&seen);
                move |bytes_seen| seen.borrow_mut().push(bytes_seen)
            })
            .wrap(chunks);

        assert_eq!(stream.next().await.unwrap().unwrap(), vec![1, 2]);
        assert!(matches!(
            stream.next().await.unwrap(),
            Err(TransferLimitError::LimitExceeded {
                limit: 3_u64,
                actual: 4_u64
            })
        ));
        assert_eq!(*seen.borrow(), vec![2_u64, 4_u64]);
    });
}

#[test]
fn usize_counter_overflow_errors_instead_of_saturating() {
    futures::executor::block_on(async {
        let chunks = stream::iter([Ok::<_, TestError>(FakeChunk(usize::MAX)), Ok(FakeChunk(1))]);
        let mut stream = TransferLimit::unlimited().wrap(chunks);

        assert_eq!(stream.next().await.unwrap().unwrap().0, usize::MAX);
        assert!(matches!(
            stream.next().await.unwrap(),
            Err(TransferLimitError::CounterOverflow {
                bytes_seen: usize::MAX,
                chunk_len: 1,
            })
        ));
        assert!(stream.next().await.is_none());
    });
}

#[test]
fn u128_counter_can_count_beyond_usize_max() {
    futures::executor::block_on(async {
        let seen = Rc::new(RefCell::new(Vec::new()));
        let limit = usize::MAX as u128 + 1;
        let chunks = stream::iter([Ok::<_, TestError>(FakeChunk(usize::MAX)), Ok(FakeChunk(1))]);
        let mut stream = TransferLimit::<u128>::from_limit(limit)
            .on_progress({
                let seen = Rc::clone(&seen);
                move |bytes_seen| seen.borrow_mut().push(bytes_seen)
            })
            .wrap(chunks);

        assert_eq!(stream.next().await.unwrap().unwrap().0, usize::MAX);
        assert_eq!(stream.next().await.unwrap().unwrap().0, 1);
        assert!(stream.next().await.is_none());
        assert_eq!(*seen.borrow(), vec![usize::MAX as u128, limit]);
    });
}