zzz/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    fmt::{self, Write as _},
5    io::{self, stderr, Write as _},
6    sync::atomic::{AtomicUsize, Ordering::Relaxed},
7    sync::RwLock,
8    time::{Duration, Instant},
9};
10
11// ============================================================================================== //
12// [Prelude module]                                                                               //
13// ============================================================================================== //
14
15/// Mass-import for the main progress bar type as well as the convenience extension traits.
16pub mod prelude {
17    #[cfg(feature = "streams")]
18    pub use crate::ProgressBarStreamExt;
19    pub use crate::{ProgressBar, ProgressBarIterExt};
20}
21
22// ============================================================================================== //
23// [General configuration]                                                                        //
24// ============================================================================================== //
25
26#[doc(hidden)]
27#[deprecated(note = "renamed to just `Config`")]
28pub type ProgressBarConfig = Config;
29
30/// Configuration for a progress bar.
31///
32/// This is a separate struct from the actual progress bar in order to allow a
33/// configuration to be reused in different progress bar instances.
34#[derive(Clone)]
35pub struct Config {
36    /// Width of the progress bar.
37    pub width: Option<u32>,
38    /// Minimum width to bother with drawing the bar for.
39    pub min_bar_width: u32,
40    /// Theme to use when drawing.
41    pub theme: &'static dyn Theme,
42    /// Maximum redraw rate rate (draws per second).
43    pub max_fps: f32,
44    /// Called to determine whether the progress bar should be drawn or not.
45    ///
46    /// The default value always returns `true`.
47    pub should_draw: &'static (dyn Fn() -> bool + Sync),
48}
49
50static DEFAULT_CFG: Config = Config::const_default();
51
52impl Config {
53    /// `const` variant of [`Config::default`].
54    pub const fn const_default() -> Self {
55        Config {
56            width: None,
57            min_bar_width: 5,
58            theme: &DefaultTheme,
59            max_fps: 60.0,
60            should_draw: &|| true,
61        }
62    }
63}
64
65impl Default for Config {
66    #[inline]
67    fn default() -> Self {
68        Config::const_default()
69    }
70}
71
72/// Selects the currently active global configuration.
73///
74/// This stores a `*const ProgressBarConfig`. We use `AtomicUsize` instead of
75/// the seemingly more idiomatic `AtomicPtr` here because the latter requires a
76/// **mutable** pointer, which would in turn force us to take the config as
77/// mutable reference to not run into UB. There is no const variant of `AtomicPtr`.
78/// Using `AtomicUsize` seemed like the lesser evil here.
79static GLOBAL_CFG: AtomicUsize = AtomicUsize::new(0);
80
81/// Gets the currently active global configuration.
82pub fn global_config() -> &'static Config {
83    match GLOBAL_CFG.load(Relaxed) {
84        0 => &DEFAULT_CFG,
85        ptr => unsafe { &*(ptr as *const Config) }
86    }
87}
88
89/// Set a new global default configuration.
90///
91/// This configuration is used when no explicit per instance configuration
92/// is specified via [`ProgressBar::config`].
93pub fn set_global_config(new_cfg: &'static Config) {
94    GLOBAL_CFG.store(new_cfg as *const _ as _, Relaxed);
95}
96
97// ============================================================================================== //
98// [Utils]                                                                                        //
99// ============================================================================================== //
100
101/// Pads and aligns a value to the length of a cache line.
102///
103/// Adapted from crossbeam:
104/// https://docs.rs/crossbeam/0.7.3/crossbeam/utils/struct.CachePadded.html
105#[cfg_attr(target_arch = "x86_64", repr(align(128)))]
106#[cfg_attr(not(target_arch = "x86_64"), repr(align(64)))]
107struct CachePadded<T>(T);
108
109impl<T> std::ops::Deref for CachePadded<T> {
110    type Target = T;
111
112    fn deref(&self) -> &T {
113        &self.0
114    }
115}
116
117impl<T> std::ops::DerefMut for CachePadded<T> {
118    fn deref_mut(&mut self) -> &mut T {
119        &mut self.0
120    }
121}
122
123// ============================================================================================== //
124// [Error type]                                                                                   //
125// ============================================================================================== //
126
127/// Errors that can ocurr while drawing the progress bar.
128#[derive(Debug)]
129pub enum RenderError {
130    Io(io::Error),
131    Fmt(fmt::Error),
132}
133
134impl fmt::Display for RenderError {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        match self {
137            RenderError::Fmt(e) => e.fmt(f),
138            RenderError::Io(e) => e.fmt(f),
139        }
140    }
141}
142
143// TODO: this should probably forward everything
144impl std::error::Error for RenderError {}
145
146impl From<io::Error> for RenderError {
147    fn from(e: io::Error) -> Self {
148        RenderError::Io(e)
149    }
150}
151
152impl From<fmt::Error> for RenderError {
153    fn from(e: fmt::Error) -> Self {
154        RenderError::Fmt(e)
155    }
156}
157
158// ============================================================================================== //
159// [Customizable printing]                                                                        //
160// ============================================================================================== //
161
162/// Trait defining how the progress bar is rendered.
163pub trait Theme: Sync {
164    fn render(&self, pb: &ProgressBar) -> Result<(), RenderError>;
165}
166
167#[derive(Debug, Default)]
168struct DefaultTheme;
169
170/// Creates a unicode progress bar.
171fn bar(progress: f32, length: u32) -> String {
172    if length == 0 {
173        return String::new();
174    }
175
176    let inner_len = length.saturating_sub(2);
177    let rescaled = (progress * (inner_len - 1) as f32 * 8.0).round() as u32;
178    let (i, r) = (rescaled / 8, rescaled % 8);
179    let main = "█".repeat(i as usize);
180    let tail = '▏' as u32 - r;
181    let tail = unsafe { std::char::from_u32_unchecked(tail) };
182    let pad_len = inner_len - i - 1 /* tail */;
183    let pad = " ".repeat(pad_len as usize);
184
185    let bar = format!("|{}{}{}|", main, tail, pad);
186    debug_assert_eq!(bar.chars().count() as u32, length);
187    bar
188}
189
190fn human_time(duration: Duration) -> String {
191    let total = duration.as_secs();
192    let h = total / 3600;
193    let m = total % 3600 / 60;
194    let s = total % 60;
195    format!("{:02}:{:02}:{:02}", h, m, s)
196}
197
198fn spinner(x: f32, width: u32) -> String {
199    // Subtract two pipes + spinner char
200    let inner_width = width.saturating_sub(3);
201
202    // fn easing_inout_quad(mut x: f32) -> f32 {
203    //     x *= 2.0;
204    //
205    //     if x > 1.0 {
206    //         -0.5 * ((x - 1.0) * (x - 3.0) - 1.0)
207    //     } else {
208    //         0.5 * x * x
209    //     }
210    // }
211
212    fn easing_inout_cubic(mut x: f32) -> f32 {
213        x *= 2.0;
214
215        if x < 1.0 {
216            0.5 * x.powi(3)
217        } else {
218            x -= 2.;
219            0.5 * (x.powi(3) + 2.)
220        }
221    }
222
223    // fn easing_out_quad(x: f32) -> f32 {
224    //     -x * (x - 2.)
225    // }
226
227    // Make the spinner turn around after half the period.
228    let x = ((-x + 0.5).abs() - 0.5) * -2.;
229
230    // Apply easing function.
231    let x = easing_inout_cubic(x).max(0.).min(1.);
232    // Transform 0..1 scale to int width.
233    let x = ((inner_width as f32) * x).round() as u32;
234
235    let lpad = x as usize;
236    let rpad = inner_width.saturating_sub(x) as usize;
237
238    let ball_offs = x / 8 % 8; // slow anim down
239    let ball = unsafe { std::char::from_u32_unchecked('🌑' as u32 + ball_offs) };
240
241    let spinner = format!("[{}{}{}]", " ".repeat(lpad), ball, " ".repeat(rpad));
242    debug_assert_eq!(spinner.chars().count() as u32, width);
243    spinner
244}
245
246/*
247barr1 = UInt32[0x00, 0x40, 0x04, 0x02, 0x01]
248barr2 = UInt32[0x00, 0x80, 0x20, 0x10, 0x08]
249function braille(a::Float64, b::Float64)
250    bchar(a::UInt32) = '⠀' + a
251    idx(x) = min(x * 4 + 1, 5) |> round |> UInt32
252
253    x = barr1[1:idx(a)] |> sum
254    x |= barr2[1:idx(b)] |> sum
255
256    x |> UInt32 |> bchar
257end
258*/
259
260/// Determines the dimensions of stderr.
261#[cfg(feature = "auto-width")]
262fn stderr_dimensions() -> (usize, usize) {
263    // term_size doesn't support stderr on Windows, so just use stdout and 
264    // hope for the best. We should probably replace term_size anyway in the
265    // long run since it's unmaintained, but this works for the moment.
266    #[cfg(target_os = "windows")]
267    return term_size::dimensions_stdout().unwrap_or((80, 30));
268
269    #[cfg(not(target_os = "windows"))]
270    return term_size::dimensions_stderr().unwrap_or((80, 30));
271}
272
273/// Determines the dimensions of stderr.
274#[cfg(not(feature = "auto-width"))]
275fn stderr_dimensions() -> (usize, usize) {
276    (80, 30)
277}
278
279impl Theme for DefaultTheme {
280    fn render(&self, pb: &ProgressBar) -> Result<(), RenderError> {
281        let mut o = stderr();
282        let cfg = pb.active_config();
283
284        // Draw left side.
285        let left = {
286            let mut buf = String::new();
287
288            // If a description is set, print it.
289            if let Some(desc) = pb.message() {
290                write!(buf, "{} ", desc)?;
291            }
292
293            if let Some(progress) = pb.progress() {
294                write!(buf, "{:>6.2}% ", progress * 100.0)?;
295            }
296
297            buf
298        };
299
300        // Draw right side.
301        let right = {
302            let mut buf = String::new();
303
304            // Print "done/total" part
305            buf.write_char(' ')?;
306            pb.unit.write_total(&mut buf, pb.value())?;
307            buf.write_char('/')?;
308            match pb.target {
309                Some(target) => pb.unit.write_total(&mut buf, target)?,
310                None => buf.write_char('?')?,
311            }
312
313            // Print ETA / time elapsed.
314            if let Some(eta) = pb.eta() {
315                write!(buf, " [{}]", human_time(eta))?;
316            } else {
317                write!(buf, " [{}]", human_time(pb.elapsed()))?;
318            }
319
320            // Print iteration rate.
321            buf.write_str(" (")?;
322            pb.unit.write_rate(&mut buf, pb.iters_per_sec())?;
323            buf.write_char(')')?;
324
325            buf
326        };
327
328        let max_width = cfg
329            .width
330            .unwrap_or_else(|| stderr_dimensions().0 as u32);
331
332        let bar_width = max_width
333            .saturating_sub(left.len() as u32)
334            .saturating_sub(right.len() as u32);
335
336        write!(o, "{}", left)?;
337
338        if bar_width > cfg.min_bar_width {
339            // Draw a progress bar for known-length bars.
340            if let Some(progress) = pb.progress() {
341                write!(o, "{}", bar(progress, bar_width))?;
342            }
343            // And a spinner for unknown-length bars.
344            else {
345                let duration = Duration::from_secs(3);
346                let pos = pb.timer_progress(duration);
347                // Sub 1 from width because many terms render emojis with double width.
348                write!(o, "{}", spinner(pos, bar_width - 1))?;
349            }
350        }
351
352        write!(o, "{}\r", right)?;
353
354        o.flush().map_err(Into::into)
355    }
356}
357
358// ============================================================================================== //
359// [Units]                                                                                        //
360// ============================================================================================== //
361
362/// Determines the unit used for printing iteration speed.
363#[non_exhaustive]
364#[derive(Debug, PartialEq, Eq, Copy, Clone)]
365pub enum Unit {
366    Iterations,
367    Bytes,
368}
369
370fn human_iter_unit(x: usize) -> (&'static str, f32) {
371    if x > 10usize.pow(9) {
372        ("B", 1e9)
373    } else if x > 10usize.pow(6) {
374        ("M", 1e6)
375    } else if x > 10usize.pow(3) {
376        ("K", 1e3)
377    } else {
378        ("", 1e0)
379    }
380}
381
382fn bytes_unit(x: usize) -> (&'static str, f32) {
383    if x > 1024usize.pow(4) {
384        ("TiB", 1024_f32.powi(4))
385    } else if x > 1024usize.pow(3) {
386        ("GiB", 1024_f32.powi(3))
387    } else if x > 1024usize.pow(2) {
388        ("MiB", 1024_f32.powi(2))
389    } else if x > 1024usize.pow(1) {
390        ("KiB", 1024_f32.powi(1))
391    } else {
392        ("b", 1024_f32.powi(0))
393    }
394}
395
396impl Unit {
397    /// Formats an absolute amount, e.g. "1200 iterations".
398    fn write_total<W: fmt::Write>(self, mut out: W, amount: usize) -> fmt::Result {
399        match self {
400            Unit::Iterations => {
401                let (unit, div) = human_iter_unit(amount);
402                write!(out, "{:.2}{}", (amount as f32) / div, unit)
403            }
404            Unit::Bytes => {
405                let (unit, div) = bytes_unit(amount);
406                write!(out, "{:.2}{}", (amount as f32) / div, unit)
407            }
408        }
409    }
410
411    /// Formats a rate of change, e.g. "1200 it/sec".
412    fn write_rate<W: fmt::Write>(self, mut out: W, rate: f32) -> fmt::Result {
413        match self {
414            Unit::Iterations => {
415                if rate >= 1.0 {
416                    let (unit, div) = human_iter_unit(rate as usize);
417                    write!(out, "{:.2}{} it/s", rate / div, unit)
418                } else {
419                    write!(out, "{:.0} s/it", 1.0 / rate)
420                }
421            }
422            Unit::Bytes => {
423                let (unit, div) = bytes_unit(rate as usize);
424                write!(out, "{:.2}{}/s", rate / div, unit)
425            }
426        }
427    }
428}
429
430// ============================================================================================== //
431// [Main progress bar struct]                                                                     //
432// ============================================================================================== //
433
434/// Progress bar to be rendered on the terminal.
435///
436/// # Example
437///
438/// ```rust
439/// use zzz::prelude::*;
440///
441/// let mut bar = ProgressBar::with_target(123);
442/// for _ in 0..123 {
443///     bar.add(1);
444/// }
445/// ```
446pub struct ProgressBar {
447    /// Configuration to use.
448    cfg: Option<&'static Config>,
449    /// The expected, possibly approximate target of the progress bar.
450    target: Option<usize>,
451    /// Whether the target was specified explicitly.
452    explicit_target: bool,
453    /// Whether the target was specified explicitly.
454    pub(crate) unit: Unit,
455    /// Creation time of the progress bar.
456    start: Instant,
457    /// Description of the progress bar, e.g. "Downloading image".
458    message: RwLock<Option<String>>,
459    /// Progress value displayed to the user.
460    value: CachePadded<AtomicUsize>,
461    /// Number of progress bar updates so far.
462    update_ctr: CachePadded<AtomicUsize>,
463    /// Next print at `update_ctr == next_print`.
464    next_print: CachePadded<AtomicUsize>,
465}
466
467impl Drop for ProgressBar {
468    fn drop(&mut self) {
469        if (self.active_config().should_draw)() {
470            self.redraw();
471            eprintln!();
472        }
473    }
474}
475
476/// Constructors.
477impl ProgressBar {
478    fn new(target: Option<usize>, explicit_target: bool) -> Self {
479        Self {
480            cfg: None,
481            target,
482            explicit_target,
483            start: Instant::now(),
484            unit: Unit::Iterations,
485            value: CachePadded(0.into()),
486            update_ctr: CachePadded(0.into()),
487            next_print: CachePadded(1.into()),
488            message: RwLock::new(None),
489        }
490    }
491
492    /// Creates a smart progress bar, attempting to infer the target from size hints.
493    pub fn smart() -> Self {
494        Self::new(None, false)
495    }
496
497    /// Creates a spinner, a progress bar with indeterminate target value.
498    pub fn spinner() -> Self {
499        Self::new(None, true)
500    }
501
502    /// Creates a progress bar with an explicit target value.
503    pub fn with_target(target: usize) -> Self {
504        Self::new(Some(target), true)
505    }
506}
507
508/// Builder-style methods.
509impl ProgressBar {
510    /// Replace the config of the progress bar.
511    ///
512    /// Takes precedence over a global config set via [`set_global_config`].
513    pub fn config(mut self, cfg: &'static Config) -> Self {
514        self.cfg = Some(cfg);
515        self
516    }
517
518    /// Force display as a spinner even if size hints are present.
519    pub fn force_spinner(mut self) -> Self {
520        self.explicit_target = true;
521        self.target = None;
522        self
523    }
524
525    /// Set the unit to be used when formatting values.
526    pub fn unit(mut self, unit: Unit) -> Self {
527        self.unit = unit;
528        self
529    }
530}
531
532/// Value manipulation and access.
533impl ProgressBar {
534    /// Returns the currently active configuration.
535    #[inline]
536    pub fn active_config(&self) -> &'static Config {
537        self.cfg.unwrap_or_else(global_config)
538    }
539
540    #[rustfmt::skip]
541    pub fn process_size_hint(&mut self, hint: (usize, Option<usize>)) {
542        // If an explicit target is set, disregard size hints.
543        if self.explicit_target {
544            return;
545        }
546
547        // Prefer hi over lo, treat lo = 0 as unknown.
548        self.target = match hint {
549            (_ , Some(hi)) => Some(hi),
550            (0 , None    ) => None,
551            (lo, None    ) => Some(lo),
552        };
553    }
554
555    /// Set the progress bar value to a new, absolute value.
556    ///
557    /// This doesn't automatically redraw the progress-bar.
558    ///
559    /// See `set_sync` for a thread-safe version.
560    #[inline]
561    pub fn set(&mut self, n: usize) {
562        *self.update_ctr.get_mut() += 1;
563        *self.value.get_mut() = n;
564    }
565
566    /// Synchronized version fo `set`.
567    #[inline]
568    pub fn set_sync(&self, n: usize) {
569        self.update_ctr.fetch_add(1, Relaxed);
570        self.value.store(n, Relaxed);
571    }
572
573    /// Add `n` to the value of the progress bar.
574    ///
575    /// See `add_sync` for a thread-safe version.
576    #[inline]
577    pub fn add(&mut self, n: usize) -> usize {
578        *self.value.get_mut() += n;
579        let prev = *self.update_ctr.get_mut();
580        *self.update_ctr.get_mut() += 1;
581        self.maybe_redraw(prev);
582        prev
583    }
584
585    /// Synchronized version fo `add`.
586    #[inline]
587    pub fn add_sync(&self, n: usize) -> usize {
588        self.value.fetch_add(n, Relaxed);
589        let prev = self.update_ctr.fetch_add(1, Relaxed);
590        self.maybe_redraw(prev);
591        prev
592    }
593
594    /// How often has the value been changed since creation?
595    #[inline]
596    pub fn update_ctr(&self) -> usize {
597        self.update_ctr.load(Relaxed)
598    }
599
600    /// Get the current value of the progress bar.
601    #[inline]
602    pub fn value(&self) -> usize {
603        self.value.load(Relaxed)
604    }
605
606    /// Get the current task description text.
607    pub fn message(&self) -> Option<String> {
608        self.message.read().unwrap().clone()
609    }
610
611    /// Set the current task description text.
612    pub fn set_message(&mut self, text: Option<impl Into<String>>) {
613        *self.message.get_mut().unwrap() = text.map(Into::into);
614    }
615
616    /// Synchronized version for `set_message`.
617    pub fn set_message_sync(&self, text: Option<impl Into<String>>) {
618        let mut message_lock = self.message.write().unwrap();
619        *message_lock = text.map(Into::into);
620    }
621
622    /// Calculate the current progress, `0.0 .. 1.0`.
623    #[inline]
624    pub fn progress(&self) -> Option<f32> {
625        let target = self.target?;
626        Some(self.value() as f32 / target as f32)
627    }
628
629    /// Calculate the elapsed time since creation of the progress bar.
630    pub fn elapsed(&self) -> Duration {
631        self.start.elapsed()
632    }
633
634    /// Estimate the duration until completion.
635    pub fn eta(&self) -> Option<Duration> {
636        // wen eta?!
637        let left = 1. / self.progress()?;
638        let elapsed = self.elapsed();
639        let estimated_total = elapsed.mul_f32(left);
640        Some(estimated_total.saturating_sub(elapsed))
641    }
642
643    /// Calculate the mean iterations per second since creation of the progress bar.
644    pub fn iters_per_sec(&self) -> f32 {
645        let elapsed_sec = self.elapsed().as_secs_f32();
646        self.value() as f32 / elapsed_sec
647    }
648
649    /// Calculate the mean progress bar updates per second since creation of the progress bar.
650    pub fn updates_per_sec(&self) -> f32 {
651        let elapsed_sec = self.elapsed().as_secs_f32();
652        self.update_ctr() as f32 / elapsed_sec
653    }
654
655    /// Calculates the progress of a rolling timer.
656    ///
657    /// Returned values are always between 0 and 1. Timers are calculated
658    /// from the start of the progress bar.
659    pub fn timer_progress(&self, timer: Duration) -> f32 {
660        let elapsed_sec = self.elapsed().as_secs_f32();
661        let timer_sec = timer.as_secs_f32();
662
663        (elapsed_sec % timer_sec) / timer_sec
664    }
665
666    /// Forces a redraw of the progress bar.
667    pub fn redraw(&self) {
668        self.active_config().theme.render(self).unwrap();
669        self.update_next_print();
670    }
671}
672
673/// Internals.
674impl ProgressBar {
675    #[inline]
676    fn next_print(&self) -> usize {
677        self.next_print.load(Relaxed)
678    }
679
680    /// Calculate next print
681    fn update_next_print(&self) {
682        // Give the loop some time to warm up.
683        if self.update_ctr() < 10 {
684            self.next_print.fetch_add(1, Relaxed);
685            return;
686        }
687
688        let freq = (self.updates_per_sec() / self.active_config().max_fps) as usize;
689        let freq = freq.max(1);
690
691        self.next_print.fetch_add(freq as usize, Relaxed);
692    }
693
694    #[inline]
695    fn maybe_redraw(&self, prev: usize) {
696        #[cold]
697        fn cold_redraw(this: &ProgressBar) {
698            if (this.active_config().should_draw)() {
699                this.redraw();
700            }
701        }
702
703        if prev == self.next_print() {
704            cold_redraw(self);
705        }
706    }
707}
708
709// ============================================================================================== //
710// [Iterator integration]                                                                         //
711// ============================================================================================== //
712
713/// Iterator / stream wrapper that automatically updates a progress bar during iteration.
714pub struct ProgressBarIter<Inner> {
715    bar: ProgressBar,
716    inner: Inner,
717}
718
719impl<Inner> ProgressBarIter<Inner> {
720    pub fn into_inner(self) -> Inner {
721        self.inner
722    }
723}
724
725impl<Inner: Iterator> Iterator for ProgressBarIter<Inner> {
726    type Item = Inner::Item;
727
728    fn next(&mut self) -> Option<Self::Item> {
729        let next = self.inner.next()?;
730        self.bar.add(1);
731        Some(next)
732    }
733}
734
735/// Extension trait implemented for all iterators, adding methods for
736/// conveniently adding a progress bar to an existing iterator.
737///
738/// # Example
739///
740/// ```rust
741/// # fn main() {
742/// use zzz::prelude::*;
743///
744/// for _ in (0..123).progress() {
745///     // ...
746/// }
747/// # }
748/// ```
749pub trait ProgressBarIterExt: Iterator + Sized {
750    fn progress(self) -> ProgressBarIter<Self> {
751        let mut bar = ProgressBar::smart();
752        bar.process_size_hint(self.size_hint());
753        ProgressBarIter { bar, inner: self }
754    }
755
756    fn with_progress(self, mut bar: ProgressBar) -> ProgressBarIter<Self> {
757        bar.process_size_hint(self.size_hint());
758        ProgressBarIter { bar, inner: self }
759    }
760}
761
762impl<Inner: Iterator + Sized> ProgressBarIterExt for Inner {}
763
764// ============================================================================================== //
765// [Stream integration]                                                                           //
766// ============================================================================================== //
767
768#[cfg(feature = "streams")]
769pub mod streams {
770    use super::*;
771    use core::pin::Pin;
772    use futures_core::{
773        task::{Context, Poll},
774        Stream,
775    };
776
777    impl<Inner: Stream> Stream for ProgressBarIter<Inner> {
778        type Item = Inner::Item;
779
780        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
781            // SAFETY: This is no different than what pin_project would do, except without
782            //         requiring the dependency on the lib.
783            let (inner, bar) = unsafe {
784                let this = self.get_unchecked_mut();
785                (Pin::new_unchecked(&mut this.inner), &mut this.bar)
786            };
787
788            match inner.poll_next(cx) {
789                x @ Poll::Ready(Some(_)) => {
790                    bar.add(1);
791                    x
792                }
793                x => x,
794            }
795        }
796    }
797
798    /// Extension trait implemented for all streams, adding methods for conveniently adding a
799    /// progress bar to an existing iterator.
800    pub trait ProgressBarStreamExt: Stream + Sized {
801        fn progress(self) -> ProgressBarIter<Self> {
802            let mut bar = ProgressBar::smart();
803            bar.process_size_hint(self.size_hint());
804            ProgressBarIter { bar, inner: self }
805        }
806
807        fn with_progress(self, mut bar: ProgressBar) -> ProgressBarIter<Self> {
808            bar.process_size_hint(self.size_hint());
809            ProgressBarIter { bar, inner: self }
810        }
811    }
812
813    impl<Inner: Stream + Sized> ProgressBarStreamExt for Inner {}
814}
815
816#[cfg(feature = "streams")]
817pub use streams::*;
818
819// ============================================================================================== //
820// [Tests]                                                                                        //
821// ============================================================================================== //
822
823#[cfg(doctest)]
824mod tests {
825    macro_rules! external_doc_test {
826        ($x:expr) => {
827            #[doc = $x]
828            extern "C" {}
829        };
830    }
831
832    // Ensure the examples in README.md work.
833    external_doc_test!(include_str!("../README.md"));
834}
835
836// ============================================================================================== //