async_read_progress/
lib.rs

1//! Report the progress of an async read operation
2//!
3//! As promised [on Twitter](https://twitter.com/killercup/status/1254695847796842498).
4//!
5//! # Examples
6//!
7//! ```
8//! # fn main() {
9//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
10//! use futures::{
11//!     io::AsyncReadExt,
12//!     stream::{self, TryStreamExt},
13//! };
14//! use async_read_progress::*;
15//!
16//! let src = vec![1u8, 2, 3, 4, 5];
17//! let total_size = src.len();
18//! let reader = stream::iter(vec![Ok(src)]).into_async_read();
19//!
20//! let mut reader = reader.report_progress(
21//!     /* only call every */ std::time::Duration::from_millis(20),
22//!     |bytes_read| eprintln!("read {}/{}", bytes_read, total_size),
23//! );
24//! #
25//! # let mut buf = Vec::new();
26//! # assert!(reader.read_to_end(&mut buf).await.is_ok());
27//! # });
28//! # }
29//! ```
30use core::pin::Pin;
31use std::{
32    fmt,
33    time::{Duration, Instant},
34};
35
36pub use for_futures::FReportReadProgress as AsyncReadProgressExt;
37
38#[cfg(feature = "with-tokio")]
39pub use for_tokio::TReportReadProgress as TokioAsyncReadProgressExt;
40
41/// Reader for the `report_progress` method.
42#[must_use = "streams do nothing unless polled"]
43pub struct LogStreamProgress<St, F> {
44    inner: St,
45    callback: F,
46    state: State,
47}
48
49struct State {
50    bytes_read: usize,
51    // TODO: Actually use this
52    at_most_ever: Duration,
53    last_call_at: Instant,
54}
55
56// TODO: Remove this comment after someone who knows how this actually works has
57// reviewed/fixed this.
58impl<St, F: FnMut(usize)> LogStreamProgress<St, F> {
59    pin_utils::unsafe_pinned!(inner: St);
60    pin_utils::unsafe_unpinned!(callback: F);
61    pin_utils::unsafe_unpinned!(state: State);
62
63    fn update(mut self: Pin<&mut Self>, bytes_read: usize) {
64        let mut state = self.as_mut().state();
65        state.bytes_read += bytes_read;
66        let read = state.bytes_read;
67
68        if state.last_call_at.elapsed() >= state.at_most_ever {
69            (self.as_mut().callback())(read);
70
71            self.as_mut().state().last_call_at = Instant::now();
72        }
73    }
74}
75
76impl<T, U> Unpin for LogStreamProgress<T, U>
77where
78    T: Unpin,
79    U: Unpin,
80{
81}
82
83impl<St, F> fmt::Debug for LogStreamProgress<St, F>
84where
85    St: fmt::Debug,
86{
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("LogStreamProgress")
89            .field("stream", &self.inner)
90            .field("at_most_ever", &self.state.at_most_ever)
91            .field("last_call_at", &self.state.last_call_at)
92            .finish()
93    }
94}
95
96mod for_futures {
97    use core::{
98        pin::Pin,
99        task::{Context, Poll},
100    };
101    use futures_io::{AsyncRead as FAsyncRead, IoSliceMut};
102    use std::{
103        io,
104        time::{Duration, Instant},
105    };
106
107    /// An extension trait which adds the `report_progress` method to
108    /// `AsyncRead` types.
109    ///
110    /// Note: This is for [`futures_io::AsyncRead`].
111    pub trait FReportReadProgress {
112        fn report_progress<F>(
113            self,
114            at_most_ever: Duration,
115            callback: F,
116        ) -> super::LogStreamProgress<Self, F>
117        where
118            Self: Sized,
119            F: FnMut(usize),
120        {
121            let state = super::State {
122                bytes_read: 0,
123                at_most_ever,
124                last_call_at: Instant::now(),
125            };
126            super::LogStreamProgress {
127                inner: self,
128                callback,
129                state,
130            }
131        }
132    }
133
134    impl<R: FAsyncRead + ?Sized> FReportReadProgress for R {}
135
136    impl<'a, St, F> FAsyncRead for super::LogStreamProgress<St, F>
137    where
138        St: FAsyncRead,
139        F: FnMut(usize),
140    {
141        fn poll_read(
142            mut self: Pin<&mut Self>,
143            cx: &mut Context<'_>,
144            buf: &mut [u8],
145        ) -> Poll<io::Result<usize>> {
146            match self.as_mut().inner().poll_read(cx, buf) {
147                Poll::Pending => Poll::Pending,
148                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
149                Poll::Ready(Ok(bytes_read)) => {
150                    self.update(bytes_read);
151                    Poll::Ready(Ok(bytes_read))
152                }
153            }
154        }
155
156        fn poll_read_vectored(
157            mut self: Pin<&mut Self>,
158            cx: &mut Context<'_>,
159            bufs: &mut [IoSliceMut<'_>],
160        ) -> Poll<io::Result<usize>> {
161            match self.as_mut().inner().poll_read_vectored(cx, bufs) {
162                Poll::Pending => Poll::Pending,
163                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
164                Poll::Ready(Ok(bytes_read)) => {
165                    self.update(bytes_read);
166                    Poll::Ready(Ok(bytes_read))
167                }
168            }
169        }
170    }
171}
172
173#[cfg(feature = "with-tokio")]
174mod for_tokio {
175    use core::{
176        pin::Pin,
177        task::{Context, Poll},
178    };
179    use std::{
180        io,
181        time::{Duration, Instant},
182    };
183    use tokio::io::AsyncRead as TAsyncRead;
184
185    /// An extension trait which adds the `report_progress` method to
186    /// `AsyncRead` types.
187    ///
188    /// Note: This is for [`tokio::io::AsyncRead`].
189    pub trait TReportReadProgress {
190        fn report_progress<F>(
191            self,
192            at_most_ever: Duration,
193            callback: F,
194        ) -> super::LogStreamProgress<Self, F>
195        where
196            Self: Sized,
197            F: FnMut(usize),
198        {
199            let state = super::State {
200                bytes_read: 0,
201                at_most_ever,
202                last_call_at: Instant::now(),
203            };
204            super::LogStreamProgress {
205                inner: self,
206                callback,
207                state,
208            }
209        }
210    }
211
212    impl<R: TAsyncRead + ?Sized> TReportReadProgress for R {}
213
214    impl<'a, St, F> TAsyncRead for super::LogStreamProgress<St, F>
215    where
216        St: TAsyncRead,
217        F: FnMut(usize),
218    {
219        fn poll_read(
220            mut self: Pin<&mut Self>,
221            cx: &mut Context<'_>,
222            buf: &mut tokio::io::ReadBuf<'_>,
223        ) -> Poll<io::Result<()>> {
224            let bytes_in_buffer_before = buf.filled().len();
225            match self.as_mut().inner().poll_read(cx, buf) {
226                Poll::Pending => Poll::Pending,
227                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
228                Poll::Ready(Ok(())) => {
229                    let bytes_read = buf.filled().len() - bytes_in_buffer_before;
230                    self.update(bytes_read);
231                    Poll::Ready(Ok(()))
232                }
233            }
234        }
235    }
236
237    /// Needed an error type for type annotation in constructing annoyting streams below
238    #[cfg(test)]
239    type SomeError = tokio::task::JoinError;
240
241    #[test]
242    fn works_with_tokios_async_read() {
243        use bytes::Bytes;
244        use tokio::io::AsyncReadExt;
245        use tokio_util::io::StreamReader;
246
247        let src = vec![1u8, 2, 3, 4, 5];
248        let total_size = src.len();
249        let data: Vec<Result<Bytes, SomeError>> = vec![Ok(Bytes::from(src))];
250        let reader = StreamReader::new(tokio_stream::iter(data));
251        let mut buf = Vec::new();
252
253        let mut reader = reader.report_progress(
254            /* only call every */ Duration::from_millis(20),
255            |bytes_read| eprintln!("read {}/{}", bytes_read, total_size),
256        );
257
258        tokio::runtime::Runtime::new().unwrap().block_on(async {
259            assert!(reader.read_to_end(&mut buf).await.is_ok());
260        });
261    }
262
263    #[tokio::test]
264    async fn does_delays_and_stuff() {
265        use bytes::Bytes;
266        use std::sync::{Arc, RwLock};
267        use tokio::{io::AsyncReadExt, sync::mpsc, time::sleep};
268        use tokio_util::io::StreamReader;
269
270        let (data_writer, data_reader): (_, mpsc::Receiver<Result<Bytes, SomeError>>) =
271            mpsc::channel(1);
272
273        tokio::spawn(async move {
274            for i in 0u8..10 {
275                dbg!(i);
276                data_writer
277                    .send(Ok(Bytes::from_static(&[1u8, 2, 3, 4])))
278                    .await
279                    .unwrap();
280                sleep(Duration::from_millis(10)).await;
281            }
282            drop(data_writer);
283        });
284
285        let total_size = 4 * 10i32;
286        let reader = StreamReader::new(tokio_stream::wrappers::ReceiverStream::new(data_reader));
287        let mut buf = Vec::new();
288
289        let log = Arc::new(RwLock::new(Vec::new()));
290        let log_writer = log.clone();
291
292        let mut reader = reader.report_progress(
293            /* only call every */ Duration::from_millis(10),
294            |bytes_read| {
295                log_writer.write().unwrap().push(format!(
296                    "read {}/{}",
297                    dbg!(bytes_read),
298                    total_size
299                ));
300            },
301        );
302
303        assert!(reader.read_to_end(&mut buf).await.is_ok());
304        dbg!("read it");
305
306        let log = log.read().unwrap();
307        assert_eq!(
308            *log,
309            &[
310                "read 8/40".to_string(),
311                "read 12/40".to_string(),
312                "read 16/40".to_string(),
313                "read 20/40".to_string(),
314                "read 24/40".to_string(),
315                "read 28/40".to_string(),
316                "read 32/40".to_string(),
317                "read 36/40".to_string(),
318                "read 40/40".to_string(),
319                "read 40/40".to_string(),
320            ]
321        );
322    }
323
324    #[tokio::test]
325    async fn does_delays_and_stuff_real_good() {
326        use bytes::Bytes;
327        use std::sync::{Arc, RwLock};
328        use tokio::{io::AsyncReadExt, sync::mpsc, time::sleep};
329        use tokio_util::io::StreamReader;
330
331        let (data_writer, data_reader): (_, mpsc::Receiver<Result<Bytes, SomeError>>) =
332            mpsc::channel(1);
333
334        tokio::spawn(async move {
335            for i in 0u8..10 {
336                dbg!(i);
337                data_writer
338                    .send(Ok(Bytes::from_static(&[1u8, 2, 3, 4])))
339                    .await
340                    .unwrap();
341                sleep(Duration::from_millis(5)).await;
342            }
343            drop(data_writer);
344        });
345
346        let total_size = 4 * 10i32;
347        let reader = StreamReader::new(tokio_stream::wrappers::ReceiverStream::new(data_reader));
348        let mut buf = Vec::new();
349
350        let log = Arc::new(RwLock::new(Vec::new()));
351        let log_writer = log.clone();
352
353        let mut reader = reader.report_progress(
354            /* only call every */ Duration::from_millis(10),
355            |bytes_read| {
356                log_writer.write().unwrap().push(format!(
357                    "read {}/{}",
358                    dbg!(bytes_read),
359                    total_size
360                ));
361            },
362        );
363
364        assert!(reader.read_to_end(&mut buf).await.is_ok());
365        dbg!("read it");
366
367        let log = log.read().unwrap();
368        assert_eq!(
369            *log,
370            &[
371                "read 12/40".to_string(),
372                "read 20/40".to_string(),
373                "read 28/40".to_string(),
374                "read 36/40".to_string(),
375                "read 40/40".to_string(),
376            ]
377        );
378    }
379}