1use std::borrow::Cow;
2use std::io::{self, IoSliceMut};
3use std::iter::FusedIterator;
4#[cfg(feature = "tokio")]
5use std::pin::Pin;
6#[cfg(feature = "tokio")]
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10#[cfg(feature = "tokio")]
11use tokio::io::{ReadBuf, SeekFrom};
12
13use crate::progress_bar::ProgressBar;
14use crate::state::ProgressFinish;
15use crate::style::ProgressStyle;
16
17pub trait ProgressIterator
19where
20 Self: Sized + Iterator,
21{
22 fn try_progress(self) -> Option<ProgressBarIter<Self>> {
27 self.size_hint()
28 .1
29 .map(|len| self.progress_count(u64::try_from(len).unwrap()))
30 }
31
32 fn progress(self) -> ProgressBarIter<Self>
34 where
35 Self: ExactSizeIterator,
36 {
37 let len = u64::try_from(self.len()).unwrap();
38 self.progress_count(len)
39 }
40
41 fn progress_count(self, len: u64) -> ProgressBarIter<Self> {
43 self.progress_with(ProgressBar::new(len))
44 }
45
46 fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self>;
48
49 fn progress_with_style(self, style: crate::ProgressStyle) -> ProgressBarIter<Self>
51 where
52 Self: ExactSizeIterator,
53 {
54 let len = u64::try_from(self.len()).unwrap();
55 let bar = ProgressBar::new(len).with_style(style);
56 self.progress_with(bar)
57 }
58}
59
60#[derive(Debug)]
62pub struct ProgressBarIter<T> {
63 pub(crate) it: T,
64 pub progress: ProgressBar,
65 pub(crate) seek_max: SeekMax,
66}
67
68impl<T> ProgressBarIter<T> {
69 pub fn with_style(mut self, style: ProgressStyle) -> Self {
73 self.progress = self.progress.with_style(style);
74 self
75 }
76
77 pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
81 self.progress = self.progress.with_prefix(prefix);
82 self
83 }
84
85 pub fn with_message(mut self, message: impl Into<Cow<'static, str>>) -> Self {
89 self.progress = self.progress.with_message(message);
90 self
91 }
92
93 pub fn with_position(mut self, position: u64) -> Self {
97 self.progress = self.progress.with_position(position);
98 self
99 }
100
101 pub fn with_elapsed(mut self, elapsed: Duration) -> Self {
105 self.progress = self.progress.with_elapsed(elapsed);
106 self
107 }
108
109 pub fn with_finish(mut self, finish: ProgressFinish) -> Self {
113 self.progress = self.progress.with_finish(finish);
114 self
115 }
116}
117
118impl<S, T: Iterator<Item = S>> Iterator for ProgressBarIter<T> {
119 type Item = S;
120
121 fn next(&mut self) -> Option<Self::Item> {
122 let item = self.it.next();
123
124 if item.is_some() {
125 self.progress.inc(1);
126 } else if !self.progress.is_finished() {
127 self.progress.finish_using_style();
128 }
129
130 item
131 }
132}
133
134impl<T: ExactSizeIterator> ExactSizeIterator for ProgressBarIter<T> {
135 fn len(&self) -> usize {
136 self.it.len()
137 }
138}
139
140impl<T: DoubleEndedIterator> DoubleEndedIterator for ProgressBarIter<T> {
141 fn next_back(&mut self) -> Option<Self::Item> {
142 let item = self.it.next_back();
143
144 if item.is_some() {
145 self.progress.inc(1);
146 } else if !self.progress.is_finished() {
147 self.progress.finish_using_style();
148 }
149
150 item
151 }
152}
153
154impl<T: FusedIterator> FusedIterator for ProgressBarIter<T> {}
155
156impl<R: io::Read> io::Read for ProgressBarIter<R> {
157 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
158 let inc = self.it.read(buf)?;
159 self.progress.set_position(
160 self.seek_max
161 .update_seq(self.progress.position(), inc as u64),
162 );
163 Ok(inc)
164 }
165
166 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
167 let inc = self.it.read_vectored(bufs)?;
168 self.progress.set_position(
169 self.seek_max
170 .update_seq(self.progress.position(), inc as u64),
171 );
172 Ok(inc)
173 }
174
175 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
176 let inc = self.it.read_to_string(buf)?;
177 self.progress.set_position(
178 self.seek_max
179 .update_seq(self.progress.position(), inc as u64),
180 );
181 Ok(inc)
182 }
183
184 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
185 self.it.read_exact(buf)?;
186 self.progress.set_position(
187 self.seek_max
188 .update_seq(self.progress.position(), buf.len() as u64),
189 );
190 Ok(())
191 }
192}
193
194impl<R: io::BufRead> io::BufRead for ProgressBarIter<R> {
195 fn fill_buf(&mut self) -> io::Result<&[u8]> {
196 self.it.fill_buf()
197 }
198
199 fn consume(&mut self, amt: usize) {
200 self.it.consume(amt);
201 self.progress.set_position(
202 self.seek_max
203 .update_seq(self.progress.position(), amt.try_into().unwrap()),
204 );
205 }
206}
207
208impl<S: io::Seek> io::Seek for ProgressBarIter<S> {
209 fn seek(&mut self, f: io::SeekFrom) -> io::Result<u64> {
210 self.it.seek(f).map(|pos| {
211 if f != io::SeekFrom::Current(0) {
212 self.progress.set_position(self.seek_max.update_seek(pos));
215 }
216
217 pos
218 })
219 }
220 fn stream_position(&mut self) -> io::Result<u64> {
223 self.it.stream_position()
224 }
225}
226
227#[derive(Debug, Default)]
232pub(crate) struct SeekMax<const RESET: u8 = 5, const HISTORY: usize = 10> {
233 buf: Option<(Box<MaxRingBuf<HISTORY>>, u8)>,
234}
235
236impl<const RESET: u8, const HISTORY: usize> SeekMax<RESET, HISTORY> {
237 fn update_seq(&mut self, prev_pos: u64, delta: u64) -> u64 {
238 let new_pos = prev_pos + delta;
239 let Some((buf, seq)) = &mut self.buf else {
240 return new_pos;
241 };
242
243 *seq += 1;
244 if *seq >= RESET {
245 self.buf = None;
246 return new_pos;
247 }
248
249 buf.update(new_pos);
250 buf.max()
251 }
252
253 fn update_seek(&mut self, newpos: u64) -> u64 {
254 let (b, seq) = self
255 .buf
256 .get_or_insert_with(|| (Box::new(MaxRingBuf::<HISTORY>::default()), 0));
257 *seq = 0;
258 b.update(newpos);
259 b.max()
260 }
261}
262
263#[derive(Debug)]
265struct MaxRingBuf<const HISTORY: usize = 10> {
266 history: [u64; HISTORY],
267 head: u8, max_pos: u8, }
270
271impl<const HISTORY: usize> MaxRingBuf<HISTORY> {
272 fn update(&mut self, new: u64) {
283 let head = usize::from(self.head) % self.history.len();
284 let max_pos = usize::from(self.max_pos) % self.history.len();
285 let prev_max = self.history[max_pos];
286 self.history[head] = new;
287
288 if new > prev_max {
289 self.max_pos = self.head;
291 } else if self.max_pos == self.head && new < prev_max {
292 let (idx, _val) = self
295 .history
296 .iter()
297 .enumerate()
298 .max_by_key(|(_, v)| *v)
299 .expect("array has fixded size > 0");
300 self.max_pos = idx as u8;
302 }
303
304 self.head = (self.head + 1) % (self.history.len() as u8);
305 }
306
307 fn max(&self) -> u64 {
308 self.history[self.max_pos as usize % self.history.len()]
310 }
311}
312
313impl<const HISTORY: usize> Default for MaxRingBuf<HISTORY> {
314 fn default() -> Self {
315 assert!(HISTORY <= u8::MAX.into());
316 assert!(HISTORY > 0);
317 Self {
318 history: [0; HISTORY],
319 head: 0,
320 max_pos: 0,
321 }
322 }
323}
324
325#[cfg(feature = "tokio")]
326#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
327impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for ProgressBarIter<W> {
328 fn poll_write(
329 mut self: Pin<&mut Self>,
330 cx: &mut Context<'_>,
331 buf: &[u8],
332 ) -> Poll<io::Result<usize>> {
333 Pin::new(&mut self.it).poll_write(cx, buf).map(|poll| {
334 poll.map(|inc| {
335 let pos = self.progress.position();
336 let new = self.seek_max.update_seq(pos, inc as u64);
337 self.progress.set_position(new);
338 inc
339 })
340 })
341 }
342
343 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
344 Pin::new(&mut self.it).poll_flush(cx)
345 }
346
347 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
348 Pin::new(&mut self.it).poll_shutdown(cx)
349 }
350}
351
352#[cfg(feature = "tokio")]
353#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
354impl<W: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for ProgressBarIter<W> {
355 fn poll_read(
356 mut self: Pin<&mut Self>,
357 cx: &mut Context<'_>,
358 buf: &mut ReadBuf<'_>,
359 ) -> Poll<io::Result<()>> {
360 let prev_len = buf.filled().len() as u64;
361 let poll = Pin::new(&mut self.it).poll_read(cx, buf);
362 if let Poll::Ready(_e) = &poll {
363 let inc = buf.filled().len() as u64 - prev_len;
364 let pos = self.progress.position();
365 let new = self.seek_max.update_seq(pos, inc);
366 self.progress.set_position(new);
367 }
368 poll
369 }
370}
371
372#[cfg(feature = "tokio")]
373#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
374impl<W: tokio::io::AsyncSeek + Unpin> tokio::io::AsyncSeek for ProgressBarIter<W> {
375 fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
376 Pin::new(&mut self.it).start_seek(position)
377 }
378
379 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
380 let poll = Pin::new(&mut self.it).poll_complete(cx);
381 if let Poll::Ready(Ok(pos)) = &poll {
382 let new = self.seek_max.update_seek(*pos);
383 self.progress.set_position(new);
384 }
385
386 poll
387 }
388}
389
390#[cfg(feature = "tokio")]
391#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
392impl<W: tokio::io::AsyncBufRead + Unpin + tokio::io::AsyncRead> tokio::io::AsyncBufRead
393 for ProgressBarIter<W>
394{
395 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
396 let this = self.get_mut();
397 Pin::new(&mut this.it).poll_fill_buf(cx)
398 }
399
400 fn consume(mut self: Pin<&mut Self>, amt: usize) {
401 Pin::new(&mut self.it).consume(amt);
402 let pos = self.progress.position();
403 let new = self.seek_max.update_seq(pos, amt as u64);
404 self.progress.set_position(new);
405 }
406}
407
408#[cfg(feature = "futures")]
409#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
410impl<S: futures_core::Stream + Unpin> futures_core::Stream for ProgressBarIter<S> {
411 type Item = S::Item;
412
413 fn poll_next(
414 self: std::pin::Pin<&mut Self>,
415 cx: &mut std::task::Context<'_>,
416 ) -> std::task::Poll<Option<Self::Item>> {
417 let this = self.get_mut();
418 let item = std::pin::Pin::new(&mut this.it).poll_next(cx);
419 match &item {
420 std::task::Poll::Ready(Some(_)) => this.progress.inc(1),
421 std::task::Poll::Ready(None) => this.progress.finish_using_style(),
422 std::task::Poll::Pending => {}
423 }
424 item
425 }
426}
427
428impl<W: io::Write> io::Write for ProgressBarIter<W> {
429 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
430 self.it.write(buf).map(|inc| {
431 self.progress.set_position(
432 self.seek_max
433 .update_seq(self.progress.position(), inc as u64),
434 );
435 inc
436 })
437 }
438
439 fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> io::Result<usize> {
440 self.it.write_vectored(bufs).map(|inc| {
441 self.progress.set_position(
442 self.seek_max
443 .update_seq(self.progress.position(), inc as u64),
444 );
445 inc
446 })
447 }
448
449 fn flush(&mut self) -> io::Result<()> {
450 self.it.flush()
451 }
452
453 }
457
458impl<S, T: Iterator<Item = S>> ProgressIterator for T {
459 fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self> {
460 ProgressBarIter {
461 it: self,
462 progress,
463 seek_max: SeekMax::default(),
464 }
465 }
466}
467
468#[cfg(test)]
469mod test {
470 use crate::iter::{ProgressBarIter, ProgressIterator};
471 use crate::progress_bar::ProgressBar;
472 use crate::ProgressStyle;
473
474 #[test]
475 fn it_can_wrap_an_iterator() {
476 let v = [1, 2, 3];
477 let wrap = |it: ProgressBarIter<_>| {
478 assert_eq!(it.map(|x| x * 2).collect::<Vec<_>>(), vec![2, 4, 6]);
479 };
480
481 wrap(v.iter().progress());
482 wrap(v.iter().progress_count(3));
483 wrap({
484 let pb = ProgressBar::new(v.len() as u64);
485 v.iter().progress_with(pb)
486 });
487 wrap({
488 let style = ProgressStyle::default_bar()
489 .template("{wide_bar:.red} {percent}/100%")
490 .unwrap();
491 v.iter().progress_with_style(style)
492 });
493 }
494
495 #[test]
496 fn test_max_ring_buf() {
497 use crate::iter::MaxRingBuf;
498 let mut max = MaxRingBuf::<10>::default();
499 max.update(100);
500 assert_eq!(max.max(), 100);
501 for i in 0..10 {
502 max.update(99 - i);
503 }
504 assert_eq!(max.max(), 99);
505 }
506}