Skip to main content

indicatif/
iter.rs

1use std::borrow::Cow;
2use std::io::{self, IoSliceMut};
3use std::iter::FusedIterator;
4#[cfg(feature = "tokio")]
5use std::pin::Pin;
6#[cfg(feature = "tokio")]
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10#[cfg(feature = "tokio")]
11use tokio::io::{ReadBuf, SeekFrom};
12
13use crate::progress_bar::ProgressBar;
14use crate::state::ProgressFinish;
15use crate::style::ProgressStyle;
16
17/// Wraps an iterator to display its progress.
18pub trait ProgressIterator
19where
20    Self: Sized + Iterator,
21{
22    /// Wrap an iterator with default styling. Uses [`Iterator::size_hint()`] to get length.
23    /// Returns `Some(..)` only if `size_hint.1` is [`Some`]. If you want to create a progress bar
24    /// even if `size_hint.1` returns [`None`] use [`progress_count()`](ProgressIterator::progress_count)
25    /// or [`progress_with()`](ProgressIterator::progress_with) instead.
26    fn try_progress(self) -> Option<ProgressBarIter<Self>> {
27        self.size_hint()
28            .1
29            .map(|len| self.progress_count(u64::try_from(len).unwrap()))
30    }
31
32    /// Wrap an iterator with default styling.
33    fn progress(self) -> ProgressBarIter<Self>
34    where
35        Self: ExactSizeIterator,
36    {
37        let len = u64::try_from(self.len()).unwrap();
38        self.progress_count(len)
39    }
40
41    /// Wrap an iterator with an explicit element count.
42    fn progress_count(self, len: u64) -> ProgressBarIter<Self> {
43        self.progress_with(ProgressBar::new(len))
44    }
45
46    /// Wrap an iterator with a custom progress bar.
47    fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self>;
48
49    /// Wrap an iterator with a progress bar and style it.
50    fn progress_with_style(self, style: crate::ProgressStyle) -> ProgressBarIter<Self>
51    where
52        Self: ExactSizeIterator,
53    {
54        let len = u64::try_from(self.len()).unwrap();
55        let bar = ProgressBar::new(len).with_style(style);
56        self.progress_with(bar)
57    }
58}
59
60/// Wraps an iterator to display its progress.
61#[derive(Debug)]
62pub struct ProgressBarIter<T> {
63    pub(crate) it: T,
64    pub progress: ProgressBar,
65    pub(crate) seek_max: SeekMax,
66}
67
68impl<T> ProgressBarIter<T> {
69    /// Builder-like function for setting underlying progress bar's style.
70    ///
71    /// See [`ProgressBar::with_style()`].
72    pub fn with_style(mut self, style: ProgressStyle) -> Self {
73        self.progress = self.progress.with_style(style);
74        self
75    }
76
77    /// Builder-like function for setting underlying progress bar's prefix.
78    ///
79    /// See [`ProgressBar::with_prefix()`].
80    pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
81        self.progress = self.progress.with_prefix(prefix);
82        self
83    }
84
85    /// Builder-like function for setting underlying progress bar's message.
86    ///
87    /// See [`ProgressBar::with_message()`].
88    pub fn with_message(mut self, message: impl Into<Cow<'static, str>>) -> Self {
89        self.progress = self.progress.with_message(message);
90        self
91    }
92
93    /// Builder-like function for setting underlying progress bar's position.
94    ///
95    /// See [`ProgressBar::with_position()`].
96    pub fn with_position(mut self, position: u64) -> Self {
97        self.progress = self.progress.with_position(position);
98        self
99    }
100
101    /// Builder-like function for setting underlying progress bar's elapsed time.
102    ///
103    /// See [`ProgressBar::with_elapsed()`].
104    pub fn with_elapsed(mut self, elapsed: Duration) -> Self {
105        self.progress = self.progress.with_elapsed(elapsed);
106        self
107    }
108
109    /// Builder-like function for setting underlying progress bar's finish behavior.
110    ///
111    /// See [`ProgressBar::with_finish()`].
112    pub fn with_finish(mut self, finish: ProgressFinish) -> Self {
113        self.progress = self.progress.with_finish(finish);
114        self
115    }
116}
117
118impl<S, T: Iterator<Item = S>> Iterator for ProgressBarIter<T> {
119    type Item = S;
120
121    fn next(&mut self) -> Option<Self::Item> {
122        let item = self.it.next();
123
124        if item.is_some() {
125            self.progress.inc(1);
126        } else if !self.progress.is_finished() {
127            self.progress.finish_using_style();
128        }
129
130        item
131    }
132}
133
134impl<T: ExactSizeIterator> ExactSizeIterator for ProgressBarIter<T> {
135    fn len(&self) -> usize {
136        self.it.len()
137    }
138}
139
140impl<T: DoubleEndedIterator> DoubleEndedIterator for ProgressBarIter<T> {
141    fn next_back(&mut self) -> Option<Self::Item> {
142        let item = self.it.next_back();
143
144        if item.is_some() {
145            self.progress.inc(1);
146        } else if !self.progress.is_finished() {
147            self.progress.finish_using_style();
148        }
149
150        item
151    }
152}
153
154impl<T: FusedIterator> FusedIterator for ProgressBarIter<T> {}
155
156impl<R: io::Read> io::Read for ProgressBarIter<R> {
157    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
158        let inc = self.it.read(buf)?;
159        self.progress.set_position(
160            self.seek_max
161                .update_seq(self.progress.position(), inc as u64),
162        );
163        Ok(inc)
164    }
165
166    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
167        let inc = self.it.read_vectored(bufs)?;
168        self.progress.set_position(
169            self.seek_max
170                .update_seq(self.progress.position(), inc as u64),
171        );
172        Ok(inc)
173    }
174
175    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
176        let inc = self.it.read_to_string(buf)?;
177        self.progress.set_position(
178            self.seek_max
179                .update_seq(self.progress.position(), inc as u64),
180        );
181        Ok(inc)
182    }
183
184    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
185        self.it.read_exact(buf)?;
186        self.progress.set_position(
187            self.seek_max
188                .update_seq(self.progress.position(), buf.len() as u64),
189        );
190        Ok(())
191    }
192}
193
194impl<R: io::BufRead> io::BufRead for ProgressBarIter<R> {
195    fn fill_buf(&mut self) -> io::Result<&[u8]> {
196        self.it.fill_buf()
197    }
198
199    fn consume(&mut self, amt: usize) {
200        self.it.consume(amt);
201        self.progress.set_position(
202            self.seek_max
203                .update_seq(self.progress.position(), amt.try_into().unwrap()),
204        );
205    }
206}
207
208impl<S: io::Seek> io::Seek for ProgressBarIter<S> {
209    fn seek(&mut self, f: io::SeekFrom) -> io::Result<u64> {
210        self.it.seek(f).map(|pos| {
211            if f != io::SeekFrom::Current(0) {
212                // this kind of seek is used to find the current position, but does not alter it
213                // generally equivalent to stream_position()
214                self.progress.set_position(self.seek_max.update_seek(pos));
215            }
216
217            pos
218        })
219    }
220    // Pass this through to preserve optimizations that the inner I/O object may use here
221    // Also avoid sending a set_position update when the position hasn't changed
222    fn stream_position(&mut self) -> io::Result<u64> {
223        self.it.stream_position()
224    }
225}
226
227/// Calculates a more stable visual position from jittery seeks to show to the user.
228///
229/// Holds the maximum position encountered out of the last HISTORY read/write positions.
230/// Drops history when only sequential operations are performed RESET times in a row.
231#[derive(Debug, Default)]
232pub(crate) struct SeekMax<const RESET: u8 = 5, const HISTORY: usize = 10> {
233    buf: Option<(Box<MaxRingBuf<HISTORY>>, u8)>,
234}
235
236impl<const RESET: u8, const HISTORY: usize> SeekMax<RESET, HISTORY> {
237    fn update_seq(&mut self, prev_pos: u64, delta: u64) -> u64 {
238        let new_pos = prev_pos + delta;
239        let Some((buf, seq)) = &mut self.buf else {
240            return new_pos;
241        };
242
243        *seq += 1;
244        if *seq >= RESET {
245            self.buf = None;
246            return new_pos;
247        }
248
249        buf.update(new_pos);
250        buf.max()
251    }
252
253    fn update_seek(&mut self, newpos: u64) -> u64 {
254        let (b, seq) = self
255            .buf
256            .get_or_insert_with(|| (Box::new(MaxRingBuf::<HISTORY>::default()), 0));
257        *seq = 0;
258        b.update(newpos);
259        b.max()
260    }
261}
262
263/// Ring buffer that remembers the maximum contained value.
264#[derive(Debug)]
265struct MaxRingBuf<const HISTORY: usize = 10> {
266    history: [u64; HISTORY],
267    head: u8,    // must be < HISTORY
268    max_pos: u8, // must be < HISTORY
269}
270
271impl<const HISTORY: usize> MaxRingBuf<HISTORY> {
272    /// Updates internal bookkeeping to remember the maximum value
273    ///
274    /// Updates that overwrite the position the maximum was stored in with a smaller number do a
275    /// seek of the buffer, searching for the new maximum. This only happens on average each
276    /// 1 / HISTORY and has a cost of HISTORY, therefore amortizing to O(1).
277    ///
278    /// In case there is some linear increase with jitter, as expected in this specific use-case,
279    /// as long as there is one bigger update each HISTORY updates the scan is never triggered at all.
280    ///
281    /// Worst case would be linearly decreasing values, which is still O(1).
282    fn update(&mut self, new: u64) {
283        let head = usize::from(self.head) % self.history.len();
284        let max_pos = usize::from(self.max_pos) % self.history.len();
285        let prev_max = self.history[max_pos];
286        self.history[head] = new;
287
288        if new > prev_max {
289            // This is now the new maximum
290            self.max_pos = self.head;
291        } else if self.max_pos == self.head && new < prev_max {
292            // This was the maximum and may not be anymore
293            // do a linear seek to find the new maximum
294            let (idx, _val) = self
295                .history
296                .iter()
297                .enumerate()
298                .max_by_key(|(_, v)| *v)
299                .expect("array has fixded size > 0");
300            // invariant_m: idx is from an enumeration of history
301            self.max_pos = idx as u8;
302        }
303
304        self.head = (self.head + 1) % (self.history.len() as u8);
305    }
306
307    fn max(&self) -> u64 {
308        // exploit invariant_m to eliminate bounds checks & panic code path
309        self.history[self.max_pos as usize % self.history.len()]
310    }
311}
312
313impl<const HISTORY: usize> Default for MaxRingBuf<HISTORY> {
314    fn default() -> Self {
315        assert!(HISTORY <= u8::MAX.into());
316        assert!(HISTORY > 0);
317        Self {
318            history: [0; HISTORY],
319            head: 0,
320            max_pos: 0,
321        }
322    }
323}
324
325#[cfg(feature = "tokio")]
326#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
327impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for ProgressBarIter<W> {
328    fn poll_write(
329        mut self: Pin<&mut Self>,
330        cx: &mut Context<'_>,
331        buf: &[u8],
332    ) -> Poll<io::Result<usize>> {
333        Pin::new(&mut self.it).poll_write(cx, buf).map(|poll| {
334            poll.map(|inc| {
335                let pos = self.progress.position();
336                let new = self.seek_max.update_seq(pos, inc as u64);
337                self.progress.set_position(new);
338                inc
339            })
340        })
341    }
342
343    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
344        Pin::new(&mut self.it).poll_flush(cx)
345    }
346
347    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
348        Pin::new(&mut self.it).poll_shutdown(cx)
349    }
350}
351
352#[cfg(feature = "tokio")]
353#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
354impl<W: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for ProgressBarIter<W> {
355    fn poll_read(
356        mut self: Pin<&mut Self>,
357        cx: &mut Context<'_>,
358        buf: &mut ReadBuf<'_>,
359    ) -> Poll<io::Result<()>> {
360        let prev_len = buf.filled().len() as u64;
361        let poll = Pin::new(&mut self.it).poll_read(cx, buf);
362        if let Poll::Ready(_e) = &poll {
363            let inc = buf.filled().len() as u64 - prev_len;
364            let pos = self.progress.position();
365            let new = self.seek_max.update_seq(pos, inc);
366            self.progress.set_position(new);
367        }
368        poll
369    }
370}
371
372#[cfg(feature = "tokio")]
373#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
374impl<W: tokio::io::AsyncSeek + Unpin> tokio::io::AsyncSeek for ProgressBarIter<W> {
375    fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
376        Pin::new(&mut self.it).start_seek(position)
377    }
378
379    fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
380        let poll = Pin::new(&mut self.it).poll_complete(cx);
381        if let Poll::Ready(Ok(pos)) = &poll {
382            let new = self.seek_max.update_seek(*pos);
383            self.progress.set_position(new);
384        }
385
386        poll
387    }
388}
389
390#[cfg(feature = "tokio")]
391#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
392impl<W: tokio::io::AsyncBufRead + Unpin + tokio::io::AsyncRead> tokio::io::AsyncBufRead
393    for ProgressBarIter<W>
394{
395    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
396        let this = self.get_mut();
397        Pin::new(&mut this.it).poll_fill_buf(cx)
398    }
399
400    fn consume(mut self: Pin<&mut Self>, amt: usize) {
401        Pin::new(&mut self.it).consume(amt);
402        let pos = self.progress.position();
403        let new = self.seek_max.update_seq(pos, amt as u64);
404        self.progress.set_position(new);
405    }
406}
407
408#[cfg(feature = "futures")]
409#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
410impl<S: futures_core::Stream + Unpin> futures_core::Stream for ProgressBarIter<S> {
411    type Item = S::Item;
412
413    fn poll_next(
414        self: std::pin::Pin<&mut Self>,
415        cx: &mut std::task::Context<'_>,
416    ) -> std::task::Poll<Option<Self::Item>> {
417        let this = self.get_mut();
418        let item = std::pin::Pin::new(&mut this.it).poll_next(cx);
419        match &item {
420            std::task::Poll::Ready(Some(_)) => this.progress.inc(1),
421            std::task::Poll::Ready(None) => this.progress.finish_using_style(),
422            std::task::Poll::Pending => {}
423        }
424        item
425    }
426}
427
428impl<W: io::Write> io::Write for ProgressBarIter<W> {
429    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
430        self.it.write(buf).map(|inc| {
431            self.progress.set_position(
432                self.seek_max
433                    .update_seq(self.progress.position(), inc as u64),
434            );
435            inc
436        })
437    }
438
439    fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> io::Result<usize> {
440        self.it.write_vectored(bufs).map(|inc| {
441            self.progress.set_position(
442                self.seek_max
443                    .update_seq(self.progress.position(), inc as u64),
444            );
445            inc
446        })
447    }
448
449    fn flush(&mut self) -> io::Result<()> {
450        self.it.flush()
451    }
452
453    // write_fmt can not be captured with reasonable effort.
454    // as it uses write_all internally by default that should not be a problem.
455    // fn write_fmt(&mut self, fmt: fmt::Arguments) -> io::Result<()>;
456}
457
458impl<S, T: Iterator<Item = S>> ProgressIterator for T {
459    fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self> {
460        ProgressBarIter {
461            it: self,
462            progress,
463            seek_max: SeekMax::default(),
464        }
465    }
466}
467
468#[cfg(test)]
469mod test {
470    use crate::iter::{ProgressBarIter, ProgressIterator};
471    use crate::progress_bar::ProgressBar;
472    use crate::ProgressStyle;
473
474    #[test]
475    fn it_can_wrap_an_iterator() {
476        let v = [1, 2, 3];
477        let wrap = |it: ProgressBarIter<_>| {
478            assert_eq!(it.map(|x| x * 2).collect::<Vec<_>>(), vec![2, 4, 6]);
479        };
480
481        wrap(v.iter().progress());
482        wrap(v.iter().progress_count(3));
483        wrap({
484            let pb = ProgressBar::new(v.len() as u64);
485            v.iter().progress_with(pb)
486        });
487        wrap({
488            let style = ProgressStyle::default_bar()
489                .template("{wide_bar:.red} {percent}/100%")
490                .unwrap();
491            v.iter().progress_with_style(style)
492        });
493    }
494
495    #[test]
496    fn test_max_ring_buf() {
497        use crate::iter::MaxRingBuf;
498        let mut max = MaxRingBuf::<10>::default();
499        max.update(100);
500        assert_eq!(max.max(), 100);
501        for i in 0..10 {
502            max.update(99 - i);
503        }
504        assert_eq!(max.max(), 99);
505    }
506}