stream-transfer-limit 0.1.0

Byte-count transfer limits for fallible futures streams
Documentation
use crate::{ChunkLength, TransferCounter, TransferLimitError};
use futures::{Stream, StreamExt, TryStream, TryStreamExt, stream};

/// Default progress callback type used when no callback is configured.
pub type NoopProgress<C = usize> = fn(C);

fn noop_progress<C>(_: C) {}

/// Builder for applying byte-count transfer limits to fallible streams.
#[derive(Debug, Clone)]
pub struct TransferLimit<C = usize, P = NoopProgress<C>> {
    limit: Option<C>,
    bytes_seen: C,
    failed: bool,
    on_progress: P,
}

impl<C> Default for TransferLimit<C, NoopProgress<C>>
where
    C: TransferCounter,
{
    fn default() -> Self {
        Self::from_optional_limit(None)
    }
}

impl TransferLimit<usize, NoopProgress<usize>> {
    /// Create a transfer limit that allows at most `limit` bytes.
    pub fn new(limit: usize) -> Self {
        Self::from_limit(limit)
    }

    /// Create a transfer limit from an optional byte limit.
    pub fn optional(limit: Option<usize>) -> Self {
        Self::from_optional_limit(limit)
    }

    /// Create a transfer tracker without a byte limit.
    pub fn unlimited() -> Self {
        Self::from_optional_limit(None)
    }
}

impl<C> TransferLimit<C, NoopProgress<C>>
where
    C: TransferCounter,
{
    /// Create a transfer limit using an explicit counter type.
    pub fn from_limit(limit: C) -> Self {
        Self::from_optional_limit(Some(limit))
    }

    /// Create a transfer limit from an optional byte limit using an explicit
    /// counter type.
    pub fn from_optional_limit(limit: Option<C>) -> Self {
        Self {
            limit,
            bytes_seen: C::ZERO,
            failed: false,
            on_progress: noop_progress,
        }
    }
}

impl<C, P> TransferLimit<C, P>
where
    C: TransferCounter,
{
    /// Set the maximum allowed number of bytes.
    ///
    /// A stream is allowed to produce exactly `limit` bytes. It fails on the
    /// first chunk that makes the cumulative total greater than `limit`.
    pub fn with_limit(mut self, limit: C) -> Self {
        self.limit = Some(limit);
        self
    }

    /// Remove the maximum byte limit while keeping progress tracking.
    pub fn without_limit(mut self) -> Self {
        self.limit = None;
        self
    }

    /// Replace the progress callback.
    ///
    /// The callback receives cumulative bytes after every successful chunk read
    /// from the inner stream, including the chunk that crosses the limit.
    pub fn on_progress<F>(self, on_progress: F) -> TransferLimit<C, F>
    where
        F: FnMut(C),
    {
        TransferLimit {
            limit: self.limit,
            bytes_seen: self.bytes_seen,
            failed: self.failed,
            on_progress,
        }
    }

    /// Return the configured maximum byte count, if any.
    pub fn limit(&self) -> Option<C> {
        self.limit
    }

    /// Wrap a fallible stream and apply this transfer limit.
    pub fn wrap<S>(
        mut self,
        stream: S,
    ) -> impl Stream<Item = Result<S::Ok, TransferLimitError<S::Error, C>>>
    where
        S: TryStream,
        S::Ok: ChunkLength,
        P: FnMut(C) + Unpin,
    {
        self.bytes_seen = C::ZERO;
        self.failed = false;

        let stream = Box::pin(stream.into_stream());
        Box::pin(stream::unfold(
            (stream, self),
            |(mut stream, mut limit)| async move {
                if limit.failed {
                    return None;
                }

                let item = stream
                    .next()
                    .await?
                    .map_err(TransferLimitError::inner)
                    .and_then(|chunk| {
                        limit
                            .record_chunk(chunk.chunk_len())
                            .inspect_err(|_| limit.failed = true)
                            .map(|_| chunk)
                    });

                Some((item, (stream, limit)))
            },
        ))
    }
}

impl<C, P> TransferLimit<C, P>
where
    C: TransferCounter,
    P: FnMut(C),
{
    fn record_chunk<E>(&mut self, chunk_len: usize) -> Result<(), TransferLimitError<E, C>> {
        self.bytes_seen = self
            .bytes_seen
            .checked_add_chunk(chunk_len)
            .ok_or_else(|| TransferLimitError::CounterOverflow {
                bytes_seen: self.bytes_seen,
                chunk_len,
            })?;
        (self.on_progress)(self.bytes_seen);

        self.limit
            .filter(|&limit| self.bytes_seen > limit)
            .map_or(Ok(()), |limit| {
                Err(TransferLimitError::LimitExceeded {
                    limit,
                    actual: self.bytes_seen,
                })
            })
    }
}