Skip to main content

indicatif/
rayon.rs

1use rayon::iter::plumbing::{Consumer, Folder, Producer, ProducerCallback, UnindexedConsumer};
2use rayon::iter::{IndexedParallelIterator, ParallelIterator};
3
4use crate::{iter::SeekMax, ProgressBar, ProgressBarIter};
5
6/// Wraps a Rayon parallel iterator.
7///
8/// See [`ProgressIterator`](trait.ProgressIterator.html) for method
9/// documentation.
10#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
11pub trait ParallelProgressIterator
12where
13    Self: Sized + ParallelIterator,
14{
15    /// Wrap an iterator with a custom progress bar.
16    fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self>;
17
18    /// Wrap an iterator with an explicit element count.
19    fn progress_count(self, len: u64) -> ProgressBarIter<Self> {
20        self.progress_with(ProgressBar::new(len))
21    }
22
23    fn progress(self) -> ProgressBarIter<Self>
24    where
25        Self: IndexedParallelIterator,
26    {
27        let len = u64::try_from(self.len()).unwrap();
28        self.progress_count(len)
29    }
30
31    /// Wrap an iterator with a progress bar and style it.
32    fn progress_with_style(self, style: crate::ProgressStyle) -> ProgressBarIter<Self>
33    where
34        Self: IndexedParallelIterator,
35    {
36        let len = u64::try_from(self.len()).unwrap();
37        let bar = ProgressBar::new(len).with_style(style);
38        self.progress_with(bar)
39    }
40}
41
42impl<S: Send, T: ParallelIterator<Item = S>> ParallelProgressIterator for T {
43    fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self> {
44        ProgressBarIter {
45            it: self,
46            progress,
47            seek_max: SeekMax::default(),
48        }
49    }
50}
51
52impl<S: Send, T: IndexedParallelIterator<Item = S>> IndexedParallelIterator for ProgressBarIter<T> {
53    fn len(&self) -> usize {
54        self.it.len()
55    }
56
57    fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> <C as Consumer<Self::Item>>::Result {
58        let consumer = ProgressConsumer::new(consumer, self.progress);
59        self.it.drive(consumer)
60    }
61
62    fn with_producer<CB: ProducerCallback<Self::Item>>(
63        self,
64        callback: CB,
65    ) -> <CB as ProducerCallback<Self::Item>>::Output {
66        return self.it.with_producer(Callback {
67            callback,
68            progress: self.progress,
69        });
70
71        struct Callback<CB> {
72            callback: CB,
73            progress: ProgressBar,
74        }
75
76        impl<T, CB: ProducerCallback<T>> ProducerCallback<T> for Callback<CB> {
77            type Output = CB::Output;
78
79            fn callback<P>(self, base: P) -> CB::Output
80            where
81                P: Producer<Item = T>,
82            {
83                let producer = ProgressProducer {
84                    base,
85                    progress: self.progress,
86                };
87                self.callback.callback(producer)
88            }
89        }
90    }
91}
92
93struct ProgressProducer<T> {
94    base: T,
95    progress: ProgressBar,
96}
97
98impl<T, P: Producer<Item = T>> Producer for ProgressProducer<P> {
99    type Item = T;
100    type IntoIter = ProgressBarIter<P::IntoIter>;
101
102    fn into_iter(self) -> Self::IntoIter {
103        ProgressBarIter {
104            it: self.base.into_iter(),
105            progress: self.progress,
106            seek_max: SeekMax::default(),
107        }
108    }
109
110    fn min_len(&self) -> usize {
111        self.base.min_len()
112    }
113
114    fn max_len(&self) -> usize {
115        self.base.max_len()
116    }
117
118    fn split_at(self, index: usize) -> (Self, Self) {
119        let (left, right) = self.base.split_at(index);
120        (
121            ProgressProducer {
122                base: left,
123                progress: self.progress.clone(),
124            },
125            ProgressProducer {
126                base: right,
127                progress: self.progress,
128            },
129        )
130    }
131}
132
133struct ProgressConsumer<C> {
134    base: C,
135    progress: ProgressBar,
136}
137
138impl<C> ProgressConsumer<C> {
139    fn new(base: C, progress: ProgressBar) -> Self {
140        ProgressConsumer { base, progress }
141    }
142}
143
144impl<T, C: Consumer<T>> Consumer<T> for ProgressConsumer<C> {
145    type Folder = ProgressFolder<C::Folder>;
146    type Reducer = C::Reducer;
147    type Result = C::Result;
148
149    fn split_at(self, index: usize) -> (Self, Self, Self::Reducer) {
150        let (left, right, reducer) = self.base.split_at(index);
151        (
152            ProgressConsumer::new(left, self.progress.clone()),
153            ProgressConsumer::new(right, self.progress),
154            reducer,
155        )
156    }
157
158    fn into_folder(self) -> Self::Folder {
159        ProgressFolder {
160            base: self.base.into_folder(),
161            progress: self.progress,
162        }
163    }
164
165    fn full(&self) -> bool {
166        self.base.full()
167    }
168}
169
170impl<T, C: UnindexedConsumer<T>> UnindexedConsumer<T> for ProgressConsumer<C> {
171    fn split_off_left(&self) -> Self {
172        ProgressConsumer::new(self.base.split_off_left(), self.progress.clone())
173    }
174
175    fn to_reducer(&self) -> Self::Reducer {
176        self.base.to_reducer()
177    }
178}
179
180struct ProgressFolder<C> {
181    base: C,
182    progress: ProgressBar,
183}
184
185impl<T, C: Folder<T>> Folder<T> for ProgressFolder<C> {
186    type Result = C::Result;
187
188    fn consume(self, item: T) -> Self {
189        self.progress.inc(1);
190        ProgressFolder {
191            base: self.base.consume(item),
192            progress: self.progress,
193        }
194    }
195
196    fn complete(self) -> C::Result {
197        self.base.complete()
198    }
199
200    fn full(&self) -> bool {
201        self.base.full()
202    }
203}
204
205impl<S: Send, T: ParallelIterator<Item = S>> ParallelIterator for ProgressBarIter<T> {
206    type Item = S;
207
208    fn drive_unindexed<C: UnindexedConsumer<Self::Item>>(self, consumer: C) -> C::Result {
209        let consumer1 = ProgressConsumer::new(consumer, self.progress.clone());
210        self.it.drive_unindexed(consumer1)
211    }
212}
213
214#[cfg(test)]
215mod test {
216    use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
217
218    use crate::{ParallelProgressIterator, ProgressBar, ProgressBarIter, ProgressStyle};
219
220    #[test]
221    fn it_can_wrap_a_parallel_iterator() {
222        let v = vec![1, 2, 3];
223        fn wrap<'a, T: ParallelIterator<Item = &'a i32>>(it: ProgressBarIter<T>) {
224            assert_eq!(it.map(|x| x * 2).collect::<Vec<_>>(), vec![2, 4, 6]);
225        }
226
227        wrap(v.par_iter().progress_count(3));
228        wrap({
229            let pb = ProgressBar::new(v.len() as u64);
230            v.par_iter().progress_with(pb)
231        });
232
233        wrap({
234            let style = ProgressStyle::default_bar()
235                .template("{wide_bar:.red} {percent}/100%")
236                .unwrap();
237            v.par_iter().progress_with_style(style)
238        });
239    }
240}