read_progress_stream/
lib.rs

1//! Wrapper for a stream that stores the number
2//! of bytes read so that uploading to S3 using the `rusoto_s3`
3//! crate can indicate upload progress for larger files.
4//!
5//! See the test for example usage and run with `--nocapture` to 
6//! see the mock progress bar:
7//!
8//! ```ignore
9//! cargo test -- --nocapture
10//! ```
11use bytes::Bytes;
12use futures::stream::Stream;
13use futures::task::{Context, Poll};
14use pin_project_lite::pin_project;
15use std::io::Result;
16use std::pin::Pin;
17
18/// Progress handler is called with information about the stream read progress.
19///
20/// The first argument is the amount of bytes that were just read from the 
21/// current chunk and the second argument is the total number of bytes read.
22pub type ProgressHandler = Box<dyn FnMut(u64, u64) + Send + Sync + 'static>;
23
24pin_project! {
25    /// Wrap a stream and store the number of bytes read.
26    pub struct ReadProgressStream<T> {
27        #[pin]
28        inner: Pin<Box<T>>,
29        bytes_read: u64,
30        progress: ProgressHandler,
31        marker: std::marker::PhantomData<T>,
32    }
33}
34
35impl<T> ReadProgressStream<T>
36where
37    T: Stream<Item = Result<Bytes>> + Send + Sync + 'static,
38{
39    /// Create a wrapped stream.
40    ///
41    /// The progress function will be called as bytes are read from the underlying stream.
42    pub fn new(inner: T, progress: ProgressHandler) -> Self {
43        ReadProgressStream {
44            inner: Box::pin(inner),
45            progress,
46            bytes_read: 0,
47            marker: std::marker::PhantomData {},
48        }
49    }
50}
51
52impl<T> Stream for ReadProgressStream<T>
53where
54    T: Stream<Item = Result<Bytes>> + Send + Sync + 'static,
55{
56    type Item = Result<Bytes>;
57
58    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59        let this = self.project();
60        match this.inner.poll_next(cx) {
61            Poll::Ready(reader) => match reader {
62                Some(result) => match result {
63                    Ok(bytes) => {
64                        *this.bytes_read += bytes.len() as u64;
65                        (this.progress)(bytes.len() as u64, this.bytes_read.clone());
66                        Poll::Ready(Some(Ok(bytes)))
67                    }
68                    Err(e) => Poll::Ready(Some(Err(e))),
69                },
70                None => Poll::Ready(None),
71            },
72            Poll::Pending => Poll::Pending,
73        }
74    }
75}
76
77#[test]
78fn bytes_progress() -> Result<()> {
79    use std::{thread, path::PathBuf, time::Duration};
80    use futures::{StreamExt, TryStreamExt};
81    use rusoto_core::ByteStream;
82    use tokio::fs::File;
83    use tokio::runtime::Runtime;
84    use tokio_util::codec::{BytesCodec, FramedRead};
85    use pbr::{ProgressBar, Units};
86
87    let rt = Runtime::new().unwrap();
88
89    rt.block_on(async {
90        let path = PathBuf::from("tests/big-enough-to-buffer.mp4");
91        let file = File::open(&path).await?;
92        let size = file.metadata().await?.len();
93        let reader = FramedRead::new(file, BytesCodec::new())
94            .map_ok(|r| r.freeze());
95
96        // Mock progress bar
97        let mut pb = ProgressBar::new(size);
98        pb.set_units(Units::Bytes);
99        pb.show_speed = false;
100        if let Some(name) = path.file_name() {
101            let msg = format!("{} ", name.to_string_lossy());
102            pb.message(&msg);
103        }
104
105        // Progress handler to be called as bytes are read
106        let progress = Box::new(move |amount: u64, _| {
107            pb.add(amount);
108            // So we can view the progress
109            thread::sleep(Duration::from_millis(5));
110        });
111
112        // Wrap the read stream
113        let stream = ReadProgressStream::new(reader, progress);
114
115        // Normally this would be passed to a `rusoto` request object
116        let body = ByteStream::new_with_size(stream, size as usize);
117
118        // Consume the stream
119        let mut content = FramedRead::new(
120            body.into_async_read(), BytesCodec::new());
121
122        let mut total_bytes = 0u64;
123        while let Some(bytes) = content.next().await {
124            total_bytes += bytes?.len() as u64;
125        }
126        assert_eq!(size, total_bytes);
127
128        Ok::<(), std::io::Error>(())
129    })?;
130
131    Ok(())
132}