s3util-rs 0.2.0

S3 utility commands (cp, mv, rm, etc.)
Documentation
use std::io::Result;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use async_channel::Sender;
use leaky_bucket::RateLimiter;
use pin_project::pin_project;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::runtime::Handle;
use tokio::task;

use crate::storage::checksum::AdditionalChecksum;
use crate::types::{ObjectChecksum, SyncStatistics};

#[pin_project]
pub struct AsyncReadWithCallback<R: AsyncRead + Send + Sync> {
    #[pin]
    inner: R,
    stats_sender: Option<Sender<SyncStatistics>>,
    bandwidth_limiter: Option<Arc<RateLimiter>>,
    additional_checksum: Option<Arc<AdditionalChecksum>>,
    object_checksum: Option<ObjectChecksum>,
}

impl<R: AsyncRead + Send + Sync> AsyncReadWithCallback<R> {
    pub fn new(
        inner: R,
        stats_sender: Option<Sender<SyncStatistics>>,
        bandwidth_limiter: Option<Arc<RateLimiter>>,
        additional_checksum: Option<Arc<AdditionalChecksum>>,
        object_checksum: Option<ObjectChecksum>,
    ) -> Self {
        Self {
            inner,
            stats_sender,
            bandwidth_limiter,
            additional_checksum,
            object_checksum,
        }
    }
}

impl<R: AsyncRead + Send + Sync> AsyncRead for AsyncReadWithCallback<R> {
    #[rustfmt::skip] // For coverage tool incorrectness
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<Result<()>> {
        let this = self.project();

        let before = buf.filled().len();

        let result = this.inner.poll_read(cx, buf);
        if !result.is_ready() {
            return result;
        }

        let after = buf.filled().len();

        let sync_bytes = after - before;

        if let Some(bandwidth_limiter) = this.bandwidth_limiter.clone() {
            if 0 < sync_bytes {
                task::block_in_place(move || {
                    Handle::current().block_on(async move {
                        bandwidth_limiter.acquire(sync_bytes).await;
                    });
                });
            }
        }

        if 0 < sync_bytes
            && let Some(stats_sender) = this.stats_sender.as_ref()
        {
            let _ = stats_sender.send_blocking(SyncStatistics::SyncBytes(sync_bytes as u64));
        }

        result
    }
}

#[cfg(test)]
#[cfg(target_family = "unix")]
mod tests {
    use super::*;
    use tokio::fs::File;
    use tokio::io::AsyncReadExt;
    use tracing_subscriber::EnvFilter;

    const TEST_DATA_SIZE: usize = 5;

    #[tokio::test]
    async fn callback_test() {
        init_dummy_tracing_subscriber();

        let file = File::open("test_data/5byte.dat").await.unwrap();
        let (stats_sender, stats_receiver) = async_channel::unbounded();
        let mut file_with_callback =
            AsyncReadWithCallback::new(file, Some(stats_sender), None, None, None);

        let mut buffer = Vec::new();
        file_with_callback.read_to_end(&mut buffer).await.unwrap();

        assert!(!stats_receiver.is_empty());
        assert_eq!(buffer.len(), TEST_DATA_SIZE);
    }

    #[tokio::test(flavor = "multi_thread")]
    async fn callback_test_with_bandwidth_limiter() {
        init_dummy_tracing_subscriber();

        let file = File::open("test_data/5byte.dat").await.unwrap();
        let (stats_sender, stats_receiver) = async_channel::unbounded();
        let limiter = Arc::new(
            RateLimiter::builder()
                .max(TEST_DATA_SIZE)
                .initial(TEST_DATA_SIZE)
                .refill(TEST_DATA_SIZE)
                .build(),
        );
        let mut file_with_callback =
            AsyncReadWithCallback::new(file, Some(stats_sender), Some(limiter), None, None);

        let mut buffer = Vec::new();
        file_with_callback.read_to_end(&mut buffer).await.unwrap();

        assert!(!stats_receiver.is_empty());
        assert_eq!(buffer.len(), TEST_DATA_SIZE);
    }

    #[tokio::test]
    async fn callback_test_none_sender_is_silent() {
        init_dummy_tracing_subscriber();

        let file = File::open("test_data/5byte.dat").await.unwrap();
        // Hold a stats channel but do NOT wire it to the wrapper. The wrapper
        // is constructed with None and must not emit any SyncBytes.
        let (_stats_sender, stats_receiver) = async_channel::unbounded::<SyncStatistics>();
        let mut file_with_callback = AsyncReadWithCallback::new(file, None, None, None, None);

        let mut buffer = Vec::new();
        file_with_callback.read_to_end(&mut buffer).await.unwrap();

        assert_eq!(buffer.len(), TEST_DATA_SIZE);
        assert!(stats_receiver.is_empty());
    }

    fn init_dummy_tracing_subscriber() {
        let _ = tracing_subscriber::fmt()
            .with_env_filter(
                EnvFilter::try_from_default_env()
                    .or_else(|_| EnvFilter::try_new("dummy=trace"))
                    .unwrap(),
            )
            .try_init();
    }
}