1use rayon::iter::{ParallelIterator, plumbing::{UnindexedConsumer, Consumer, Folder, Producer, ProducerCallback}, IndexedParallelIterator, IntoParallelIterator};
55use std::sync::{Arc,atomic::{AtomicUsize, Ordering}};
56
57pub struct ProgressAdaptor<I> {
61 inner: I,
62 items_processed: Arc<AtomicUsize>,
63}
64
65#[derive(Clone)]
66pub struct ItemsProcessed(Arc<AtomicUsize>);
67
68struct ProgressConsumer<C> {
69 inner: C,
70 items_processed: Arc<AtomicUsize>,
71}
72
73struct ProgressFolder<F> {
74 inner: F,
75 items_processed: Arc<AtomicUsize>,
76}
77
78struct ProgressProducer<P> {
79 inner: P,
80 items_processed: Arc<AtomicUsize>,
81}
82
83struct ProgressIterator<I> {
84 inner: I,
85 items_processed: Arc<AtomicUsize>,
86}
87
88impl ProgressAdaptor<()> {
89 pub fn new<T>(iter: T) -> ProgressAdaptor<T::Iter> where T: IntoParallelIterator {
90 ProgressAdaptor {
91 inner: iter.into_par_iter(),
92 items_processed: Arc::new(AtomicUsize::new(0)),
93 }
94 }
95}
96
97impl<T> ProgressAdaptor<T> {
98 pub fn items_processed(&self) -> ItemsProcessed {
102 ItemsProcessed(self.items_processed.clone())
103 }
104}
105
106impl ItemsProcessed {
107 pub fn get(&self) -> usize {
108 self.0.load(Ordering::Relaxed)
109 }
110}
111
112impl<I> ParallelIterator for ProgressAdaptor<I> where I: ParallelIterator {
113 type Item=I::Item;
114
115 fn drive_unindexed<C>(self, consumer: C) -> C::Result
116 where
117 C: UnindexedConsumer<Self::Item> {
118 self.inner.drive_unindexed(ProgressConsumer {inner: consumer, items_processed: self.items_processed})
119 }
120}
121
122impl<I> IndexedParallelIterator for ProgressAdaptor<I> where I: IndexedParallelIterator {
123 fn len(&self) -> usize {
124 self.inner.len()
125 }
126
127 fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
128 self.inner.drive(ProgressConsumer {inner: consumer, items_processed: self.items_processed})
129 }
130
131 fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
132 struct ProgressCB<CB> {
133 inner: CB,
134 items_processed: Arc<AtomicUsize>,
135 }
136 impl<CB,T> ProducerCallback<T> for ProgressCB<CB> where CB: ProducerCallback<T> {
137 type Output=CB::Output;
138
139 fn callback<P>(self, producer: P) -> Self::Output where P: Producer<Item = T> {
140 self.inner.callback(ProgressProducer{inner: producer, items_processed: self.items_processed})
141 }
142 }
143
144 self.inner.with_producer(ProgressCB{inner: callback, items_processed: self.items_processed})
145 }
146}
147
148impl<C,I> UnindexedConsumer<I> for ProgressConsumer<C> where C: UnindexedConsumer<I> {
149 fn split_off_left(&self) -> Self {
150 Self {inner: self.inner.split_off_left(), items_processed: self.items_processed.clone()}
151 }
152
153 fn to_reducer(&self) -> Self::Reducer {
154 self.inner.to_reducer()
155 }
156}
157
158impl<C,I> Consumer<I> for ProgressConsumer<C> where C: Consumer<I> {
159 type Folder = ProgressFolder<C::Folder>;
160
161 type Reducer = C::Reducer;
162
163 type Result = C::Result;
164
165 fn split_at(self, index: usize) -> (Self, Self, Self::Reducer) {
166 let ProgressConsumer {inner, items_processed: entries_processed} = self;
167 let (left, right, reducer) = inner.split_at(index);
168 (ProgressConsumer{inner:left, items_processed: entries_processed.clone()},
169 ProgressConsumer{inner:right, items_processed: entries_processed},
170 reducer)
171 }
172
173 fn into_folder(self) -> Self::Folder {
174 ProgressFolder {inner: self.inner.into_folder(), items_processed: self.items_processed}
175 }
176
177 fn full(&self) -> bool {
178 self.inner.full()
179 }
180}
181
182impl<F,I> Folder<I> for ProgressFolder<F> where F: Folder<I> {
183 type Result=F::Result;
184
185 fn consume(self, item: I) -> Self {
186 let Self{inner, items_processed} = self;
187 let inner = inner.consume(item);
188 items_processed.fetch_add(1, Ordering::Relaxed);
189 Self {inner, items_processed}
190 }
191
192 fn complete(self) -> Self::Result {
193 self.inner.complete()
194 }
195
196 fn full(&self) -> bool {
197 self.inner.full()
198 }
199
200 }
206
207impl<P> Producer for ProgressProducer<P> where P: Producer {
208 type Item=P::Item;
209
210 type IntoIter=ProgressIterator<P::IntoIter>;
211
212 fn into_iter(self) -> Self::IntoIter {
213 ProgressIterator{inner: self.inner.into_iter(), items_processed: self.items_processed}
214 }
215
216 fn split_at(self, index: usize) -> (Self, Self) {
217 let Self{inner, items_processed}=self;
218 let (left,right) = inner.split_at(index);
219 (Self{inner: left, items_processed: items_processed.clone()}, Self{inner: right, items_processed})
220 }
221
222 fn min_len(&self) -> usize {
223 self.inner.min_len()
224 }
225
226 fn max_len(&self) -> usize {
227 self.inner.max_len()
228 }
229
230 fn fold_with<F>(self, folder: F) -> F
231 where
232 F: Folder<Self::Item>,
233 {
234 self.inner.fold_with(ProgressFolder{inner: folder, items_processed: self.items_processed}).inner
235 }
236}
237
238impl<I> Iterator for ProgressIterator<I> where I: Iterator {
239 type Item=I::Item;
240 fn next(&mut self) -> Option<Self::Item> {
241 let res = self.inner.next();
242 if res.is_some() {
243 self.items_processed.fetch_add(1, Ordering::Relaxed);
244 }
245 res
246 }
247 fn size_hint(&self) -> (usize, Option<usize>) {
248 self.inner.size_hint()
249 }
250}
251
252impl<I> DoubleEndedIterator for ProgressIterator<I> where I: DoubleEndedIterator {
253 fn next_back(&mut self) -> Option<Self::Item> {
254 let res = self.inner.next_back();
255 if res.is_some() {
256 self.items_processed.fetch_add(1, Ordering::Relaxed);
257 }
258 res
259 }
260}
261
262impl<I> ExactSizeIterator for ProgressIterator<I> where I: ExactSizeIterator {
263 fn len(&self) -> usize {
264 self.inner.len()
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use std::sync::atomic::AtomicBool;
271
272 use super::*;
273
274 #[test]
275 fn test() {
276 let flag = Arc::new(AtomicBool::new(false));
277 let flag2= flag.clone();
278 let iter = ProgressAdaptor::new(0..1000);
279
280 let items_processed = iter.items_processed();
281
282
283 rayon::spawn(move || {
284 while items_processed.get() < 500 {
285 std::hint::spin_loop();
286 }
287 flag2.store(true, Ordering::Release);
288 });
289
290 let sum: u64 = iter.map(|x| {
291 if x >= 500 {
292 while !flag.load(Ordering::Acquire) {
293 std::hint::spin_loop();
294 }
295 }
296 x
297 }).sum();
298 assert_eq!(sum, 499500);
299
300 }
301
302}