1use core::pin::Pin;
31use std::{
32 fmt,
33 time::{Duration, Instant},
34};
35
36pub use for_futures::FReportReadProgress as AsyncReadProgressExt;
37
38#[cfg(feature = "with-tokio")]
39pub use for_tokio::TReportReadProgress as TokioAsyncReadProgressExt;
40
41#[must_use = "streams do nothing unless polled"]
43pub struct LogStreamProgress<St, F> {
44 inner: St,
45 callback: F,
46 state: State,
47}
48
49struct State {
50 bytes_read: usize,
51 at_most_ever: Duration,
53 last_call_at: Instant,
54}
55
56impl<St, F: FnMut(usize)> LogStreamProgress<St, F> {
59 pin_utils::unsafe_pinned!(inner: St);
60 pin_utils::unsafe_unpinned!(callback: F);
61 pin_utils::unsafe_unpinned!(state: State);
62
63 fn update(mut self: Pin<&mut Self>, bytes_read: usize) {
64 let mut state = self.as_mut().state();
65 state.bytes_read += bytes_read;
66 let read = state.bytes_read;
67
68 if state.last_call_at.elapsed() >= state.at_most_ever {
69 (self.as_mut().callback())(read);
70
71 self.as_mut().state().last_call_at = Instant::now();
72 }
73 }
74}
75
76impl<T, U> Unpin for LogStreamProgress<T, U>
77where
78 T: Unpin,
79 U: Unpin,
80{
81}
82
83impl<St, F> fmt::Debug for LogStreamProgress<St, F>
84where
85 St: fmt::Debug,
86{
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("LogStreamProgress")
89 .field("stream", &self.inner)
90 .field("at_most_ever", &self.state.at_most_ever)
91 .field("last_call_at", &self.state.last_call_at)
92 .finish()
93 }
94}
95
96mod for_futures {
97 use core::{
98 pin::Pin,
99 task::{Context, Poll},
100 };
101 use futures_io::{AsyncRead as FAsyncRead, IoSliceMut};
102 use std::{
103 io,
104 time::{Duration, Instant},
105 };
106
107 pub trait FReportReadProgress {
112 fn report_progress<F>(
113 self,
114 at_most_ever: Duration,
115 callback: F,
116 ) -> super::LogStreamProgress<Self, F>
117 where
118 Self: Sized,
119 F: FnMut(usize),
120 {
121 let state = super::State {
122 bytes_read: 0,
123 at_most_ever,
124 last_call_at: Instant::now(),
125 };
126 super::LogStreamProgress {
127 inner: self,
128 callback,
129 state,
130 }
131 }
132 }
133
134 impl<R: FAsyncRead + ?Sized> FReportReadProgress for R {}
135
136 impl<'a, St, F> FAsyncRead for super::LogStreamProgress<St, F>
137 where
138 St: FAsyncRead,
139 F: FnMut(usize),
140 {
141 fn poll_read(
142 mut self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 buf: &mut [u8],
145 ) -> Poll<io::Result<usize>> {
146 match self.as_mut().inner().poll_read(cx, buf) {
147 Poll::Pending => Poll::Pending,
148 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
149 Poll::Ready(Ok(bytes_read)) => {
150 self.update(bytes_read);
151 Poll::Ready(Ok(bytes_read))
152 }
153 }
154 }
155
156 fn poll_read_vectored(
157 mut self: Pin<&mut Self>,
158 cx: &mut Context<'_>,
159 bufs: &mut [IoSliceMut<'_>],
160 ) -> Poll<io::Result<usize>> {
161 match self.as_mut().inner().poll_read_vectored(cx, bufs) {
162 Poll::Pending => Poll::Pending,
163 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
164 Poll::Ready(Ok(bytes_read)) => {
165 self.update(bytes_read);
166 Poll::Ready(Ok(bytes_read))
167 }
168 }
169 }
170 }
171}
172
173#[cfg(feature = "with-tokio")]
174mod for_tokio {
175 use core::{
176 pin::Pin,
177 task::{Context, Poll},
178 };
179 use std::{
180 io,
181 time::{Duration, Instant},
182 };
183 use tokio::io::AsyncRead as TAsyncRead;
184
185 pub trait TReportReadProgress {
190 fn report_progress<F>(
191 self,
192 at_most_ever: Duration,
193 callback: F,
194 ) -> super::LogStreamProgress<Self, F>
195 where
196 Self: Sized,
197 F: FnMut(usize),
198 {
199 let state = super::State {
200 bytes_read: 0,
201 at_most_ever,
202 last_call_at: Instant::now(),
203 };
204 super::LogStreamProgress {
205 inner: self,
206 callback,
207 state,
208 }
209 }
210 }
211
212 impl<R: TAsyncRead + ?Sized> TReportReadProgress for R {}
213
214 impl<'a, St, F> TAsyncRead for super::LogStreamProgress<St, F>
215 where
216 St: TAsyncRead,
217 F: FnMut(usize),
218 {
219 fn poll_read(
220 mut self: Pin<&mut Self>,
221 cx: &mut Context<'_>,
222 buf: &mut tokio::io::ReadBuf<'_>,
223 ) -> Poll<io::Result<()>> {
224 let bytes_in_buffer_before = buf.filled().len();
225 match self.as_mut().inner().poll_read(cx, buf) {
226 Poll::Pending => Poll::Pending,
227 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
228 Poll::Ready(Ok(())) => {
229 let bytes_read = buf.filled().len() - bytes_in_buffer_before;
230 self.update(bytes_read);
231 Poll::Ready(Ok(()))
232 }
233 }
234 }
235 }
236
237 #[cfg(test)]
239 type SomeError = tokio::task::JoinError;
240
241 #[test]
242 fn works_with_tokios_async_read() {
243 use bytes::Bytes;
244 use tokio::io::AsyncReadExt;
245 use tokio_util::io::StreamReader;
246
247 let src = vec![1u8, 2, 3, 4, 5];
248 let total_size = src.len();
249 let data: Vec<Result<Bytes, SomeError>> = vec![Ok(Bytes::from(src))];
250 let reader = StreamReader::new(tokio_stream::iter(data));
251 let mut buf = Vec::new();
252
253 let mut reader = reader.report_progress(
254 Duration::from_millis(20),
255 |bytes_read| eprintln!("read {}/{}", bytes_read, total_size),
256 );
257
258 tokio::runtime::Runtime::new().unwrap().block_on(async {
259 assert!(reader.read_to_end(&mut buf).await.is_ok());
260 });
261 }
262
263 #[tokio::test]
264 async fn does_delays_and_stuff() {
265 use bytes::Bytes;
266 use std::sync::{Arc, RwLock};
267 use tokio::{io::AsyncReadExt, sync::mpsc, time::sleep};
268 use tokio_util::io::StreamReader;
269
270 let (data_writer, data_reader): (_, mpsc::Receiver<Result<Bytes, SomeError>>) =
271 mpsc::channel(1);
272
273 tokio::spawn(async move {
274 for i in 0u8..10 {
275 dbg!(i);
276 data_writer
277 .send(Ok(Bytes::from_static(&[1u8, 2, 3, 4])))
278 .await
279 .unwrap();
280 sleep(Duration::from_millis(10)).await;
281 }
282 drop(data_writer);
283 });
284
285 let total_size = 4 * 10i32;
286 let reader = StreamReader::new(tokio_stream::wrappers::ReceiverStream::new(data_reader));
287 let mut buf = Vec::new();
288
289 let log = Arc::new(RwLock::new(Vec::new()));
290 let log_writer = log.clone();
291
292 let mut reader = reader.report_progress(
293 Duration::from_millis(10),
294 |bytes_read| {
295 log_writer.write().unwrap().push(format!(
296 "read {}/{}",
297 dbg!(bytes_read),
298 total_size
299 ));
300 },
301 );
302
303 assert!(reader.read_to_end(&mut buf).await.is_ok());
304 dbg!("read it");
305
306 let log = log.read().unwrap();
307 assert_eq!(
308 *log,
309 &[
310 "read 8/40".to_string(),
311 "read 12/40".to_string(),
312 "read 16/40".to_string(),
313 "read 20/40".to_string(),
314 "read 24/40".to_string(),
315 "read 28/40".to_string(),
316 "read 32/40".to_string(),
317 "read 36/40".to_string(),
318 "read 40/40".to_string(),
319 "read 40/40".to_string(),
320 ]
321 );
322 }
323
324 #[tokio::test]
325 async fn does_delays_and_stuff_real_good() {
326 use bytes::Bytes;
327 use std::sync::{Arc, RwLock};
328 use tokio::{io::AsyncReadExt, sync::mpsc, time::sleep};
329 use tokio_util::io::StreamReader;
330
331 let (data_writer, data_reader): (_, mpsc::Receiver<Result<Bytes, SomeError>>) =
332 mpsc::channel(1);
333
334 tokio::spawn(async move {
335 for i in 0u8..10 {
336 dbg!(i);
337 data_writer
338 .send(Ok(Bytes::from_static(&[1u8, 2, 3, 4])))
339 .await
340 .unwrap();
341 sleep(Duration::from_millis(5)).await;
342 }
343 drop(data_writer);
344 });
345
346 let total_size = 4 * 10i32;
347 let reader = StreamReader::new(tokio_stream::wrappers::ReceiverStream::new(data_reader));
348 let mut buf = Vec::new();
349
350 let log = Arc::new(RwLock::new(Vec::new()));
351 let log_writer = log.clone();
352
353 let mut reader = reader.report_progress(
354 Duration::from_millis(10),
355 |bytes_read| {
356 log_writer.write().unwrap().push(format!(
357 "read {}/{}",
358 dbg!(bytes_read),
359 total_size
360 ));
361 },
362 );
363
364 assert!(reader.read_to_end(&mut buf).await.is_ok());
365 dbg!("read it");
366
367 let log = log.read().unwrap();
368 assert_eq!(
369 *log,
370 &[
371 "read 12/40".to_string(),
372 "read 20/40".to_string(),
373 "read 28/40".to_string(),
374 "read 36/40".to_string(),
375 "read 40/40".to_string(),
376 ]
377 );
378 }
379}