1use rayon::iter::plumbing::{Consumer, Folder, Producer, ProducerCallback, UnindexedConsumer};
2use rayon::iter::{IndexedParallelIterator, ParallelIterator};
3
4use crate::{iter::SeekMax, ProgressBar, ProgressBarIter};
5
6#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
11pub trait ParallelProgressIterator
12where
13 Self: Sized + ParallelIterator,
14{
15 fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self>;
17
18 fn progress_count(self, len: u64) -> ProgressBarIter<Self> {
20 self.progress_with(ProgressBar::new(len))
21 }
22
23 fn progress(self) -> ProgressBarIter<Self>
24 where
25 Self: IndexedParallelIterator,
26 {
27 let len = u64::try_from(self.len()).unwrap();
28 self.progress_count(len)
29 }
30
31 fn progress_with_style(self, style: crate::ProgressStyle) -> ProgressBarIter<Self>
33 where
34 Self: IndexedParallelIterator,
35 {
36 let len = u64::try_from(self.len()).unwrap();
37 let bar = ProgressBar::new(len).with_style(style);
38 self.progress_with(bar)
39 }
40}
41
42impl<S: Send, T: ParallelIterator<Item = S>> ParallelProgressIterator for T {
43 fn progress_with(self, progress: ProgressBar) -> ProgressBarIter<Self> {
44 ProgressBarIter {
45 it: self,
46 progress,
47 seek_max: SeekMax::default(),
48 }
49 }
50}
51
52impl<S: Send, T: IndexedParallelIterator<Item = S>> IndexedParallelIterator for ProgressBarIter<T> {
53 fn len(&self) -> usize {
54 self.it.len()
55 }
56
57 fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> <C as Consumer<Self::Item>>::Result {
58 let consumer = ProgressConsumer::new(consumer, self.progress);
59 self.it.drive(consumer)
60 }
61
62 fn with_producer<CB: ProducerCallback<Self::Item>>(
63 self,
64 callback: CB,
65 ) -> <CB as ProducerCallback<Self::Item>>::Output {
66 return self.it.with_producer(Callback {
67 callback,
68 progress: self.progress,
69 });
70
71 struct Callback<CB> {
72 callback: CB,
73 progress: ProgressBar,
74 }
75
76 impl<T, CB: ProducerCallback<T>> ProducerCallback<T> for Callback<CB> {
77 type Output = CB::Output;
78
79 fn callback<P>(self, base: P) -> CB::Output
80 where
81 P: Producer<Item = T>,
82 {
83 let producer = ProgressProducer {
84 base,
85 progress: self.progress,
86 };
87 self.callback.callback(producer)
88 }
89 }
90 }
91}
92
93struct ProgressProducer<T> {
94 base: T,
95 progress: ProgressBar,
96}
97
98impl<T, P: Producer<Item = T>> Producer for ProgressProducer<P> {
99 type Item = T;
100 type IntoIter = ProgressBarIter<P::IntoIter>;
101
102 fn into_iter(self) -> Self::IntoIter {
103 ProgressBarIter {
104 it: self.base.into_iter(),
105 progress: self.progress,
106 seek_max: SeekMax::default(),
107 }
108 }
109
110 fn min_len(&self) -> usize {
111 self.base.min_len()
112 }
113
114 fn max_len(&self) -> usize {
115 self.base.max_len()
116 }
117
118 fn split_at(self, index: usize) -> (Self, Self) {
119 let (left, right) = self.base.split_at(index);
120 (
121 ProgressProducer {
122 base: left,
123 progress: self.progress.clone(),
124 },
125 ProgressProducer {
126 base: right,
127 progress: self.progress,
128 },
129 )
130 }
131}
132
133struct ProgressConsumer<C> {
134 base: C,
135 progress: ProgressBar,
136}
137
138impl<C> ProgressConsumer<C> {
139 fn new(base: C, progress: ProgressBar) -> Self {
140 ProgressConsumer { base, progress }
141 }
142}
143
144impl<T, C: Consumer<T>> Consumer<T> for ProgressConsumer<C> {
145 type Folder = ProgressFolder<C::Folder>;
146 type Reducer = C::Reducer;
147 type Result = C::Result;
148
149 fn split_at(self, index: usize) -> (Self, Self, Self::Reducer) {
150 let (left, right, reducer) = self.base.split_at(index);
151 (
152 ProgressConsumer::new(left, self.progress.clone()),
153 ProgressConsumer::new(right, self.progress),
154 reducer,
155 )
156 }
157
158 fn into_folder(self) -> Self::Folder {
159 ProgressFolder {
160 base: self.base.into_folder(),
161 progress: self.progress,
162 }
163 }
164
165 fn full(&self) -> bool {
166 self.base.full()
167 }
168}
169
170impl<T, C: UnindexedConsumer<T>> UnindexedConsumer<T> for ProgressConsumer<C> {
171 fn split_off_left(&self) -> Self {
172 ProgressConsumer::new(self.base.split_off_left(), self.progress.clone())
173 }
174
175 fn to_reducer(&self) -> Self::Reducer {
176 self.base.to_reducer()
177 }
178}
179
180struct ProgressFolder<C> {
181 base: C,
182 progress: ProgressBar,
183}
184
185impl<T, C: Folder<T>> Folder<T> for ProgressFolder<C> {
186 type Result = C::Result;
187
188 fn consume(self, item: T) -> Self {
189 self.progress.inc(1);
190 ProgressFolder {
191 base: self.base.consume(item),
192 progress: self.progress,
193 }
194 }
195
196 fn complete(self) -> C::Result {
197 self.base.complete()
198 }
199
200 fn full(&self) -> bool {
201 self.base.full()
202 }
203}
204
205impl<S: Send, T: ParallelIterator<Item = S>> ParallelIterator for ProgressBarIter<T> {
206 type Item = S;
207
208 fn drive_unindexed<C: UnindexedConsumer<Self::Item>>(self, consumer: C) -> C::Result {
209 let consumer1 = ProgressConsumer::new(consumer, self.progress.clone());
210 self.it.drive_unindexed(consumer1)
211 }
212}
213
214#[cfg(test)]
215mod test {
216 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
217
218 use crate::{ParallelProgressIterator, ProgressBar, ProgressBarIter, ProgressStyle};
219
220 #[test]
221 fn it_can_wrap_a_parallel_iterator() {
222 let v = vec![1, 2, 3];
223 fn wrap<'a, T: ParallelIterator<Item = &'a i32>>(it: ProgressBarIter<T>) {
224 assert_eq!(it.map(|x| x * 2).collect::<Vec<_>>(), vec![2, 4, 6]);
225 }
226
227 wrap(v.par_iter().progress_count(3));
228 wrap({
229 let pb = ProgressBar::new(v.len() as u64);
230 v.par_iter().progress_with(pb)
231 });
232
233 wrap({
234 let style = ProgressStyle::default_bar()
235 .template("{wide_bar:.red} {percent}/100%")
236 .unwrap();
237 v.par_iter().progress_with_style(style)
238 });
239 }
240}