rayon_progress/
lib.rs

1//! Rayon is a powerful library for making operations faster using multiple threads, but it can rarely if ever 
2//! be assumed that such operations will be instantaneous.  It is often useful, therefore, if one
3//! needs to process e.g. an iterator with millions of items, to display a progress bar so the user
4//! knows approximately how long they will have to wait.
5//!
6//! This crate provides a thin wrapper over any type implementing `ParallelIterator` that keeps
7//! track of the number of items that have been processed and allows accessing that value from
8//! another thread.
9//!
10//! ```
11//! use rayon::prelude::*;
12//! use std::sync::{Arc, Mutex};
13//! use std::time::Duration;
14//! use std::thread::sleep;
15//! use rayon_progress::ProgressAdaptor;
16//!
17//! let it = ProgressAdaptor::new(0..1000); // the constructor calls into_par_iter()
18//! // get a handle to the number of items processed
19//! // calling `progress.get()` repeatedly will return increasing values as processing continues
20//! let progress = it.items_processed();
21//! // it.len() is available when the underlying iterator implements IndexedParallelIterator (e.g.
22//! // Range, Vec etc.)
23//! // in other cases your code will have to either display indefinite progress or make an educated guess
24//! let total = it.len();
25//! // example method of transferring the result back to the thread that asked for it.
26//! // you could also use `tokio::sync::oneshot`, `std::sync::mpsc` etc., or simply have the progress bar
27//! // happen in a separate thread that dies when it gets notified that processing is complete.
28//! let result = Arc::new(Mutex::new(None::<u32>));
29//! 
30//! // note that we wrap the iterator in a `ProgressAdaptor` *before* chaining any processing steps
31//! // this is important for the count to be accurate, especially if later processing steps return
32//! // fewer items (e.g. filter())
33//! rayon::spawn({
34//!     let result = result.clone();
35//!     move || {
36//!         let sum = it.map(|i| {
37//!                sleep(Duration::from_millis(10)); // simulate some nontrivial computation
38//!                i
39//!             })
40//!             .sum();
41//!         *result.lock().unwrap() = Some(sum);
42//!     }
43//! });
44//!
45//! while result.lock().unwrap().is_none() {
46//!     let percent = (progress.get() * 100) / total;
47//!     println!("Processing... {}% complete", percent);
48//! }
49//! if let Some(result) = result.lock().unwrap().take() {
50//!     println!("Done! Result was: {}", result);
51//! };
52//!
53//! ```
54use rayon::iter::{ParallelIterator, plumbing::{UnindexedConsumer, Consumer, Folder, Producer, ProducerCallback}, IndexedParallelIterator, IntoParallelIterator};
55use std::sync::{Arc,atomic::{AtomicUsize, Ordering}};
56
57/// A wrapper around a ParallelIterator that allows you to check, from another thread, how many
58/// items have been processed at any given time, allowing, for example, displaying a progress bar
59/// during a long-running Rayon operation.
60pub struct ProgressAdaptor<I> {
61    inner: I,
62    items_processed: Arc<AtomicUsize>,
63}
64
65#[derive(Clone)]
66pub struct ItemsProcessed(Arc<AtomicUsize>);
67
68struct ProgressConsumer<C> {
69    inner: C,
70    items_processed: Arc<AtomicUsize>,
71}
72
73struct ProgressFolder<F> {
74    inner: F,
75    items_processed: Arc<AtomicUsize>,
76}
77
78struct ProgressProducer<P> {
79    inner: P,
80    items_processed: Arc<AtomicUsize>,
81}
82
83struct ProgressIterator<I> {
84    inner: I,
85    items_processed: Arc<AtomicUsize>,
86}
87
88impl ProgressAdaptor<()> {
89    pub fn new<T>(iter: T) -> ProgressAdaptor<T::Iter> where T: IntoParallelIterator {
90        ProgressAdaptor {
91            inner: iter.into_par_iter(),
92            items_processed: Arc::new(AtomicUsize::new(0)),
93        }
94    }
95}
96
97impl<T> ProgressAdaptor<T> {
98    /// Returns a cheap-to-clone handle that can be used to get the number of items processed so
99    /// far.  This method does not take a snapshot -- the value returned by the returned handle's
100    /// `get()` method will update as processing continues.
101    pub fn items_processed(&self) -> ItemsProcessed {
102        ItemsProcessed(self.items_processed.clone())
103    }
104}
105
106impl ItemsProcessed {
107    pub fn get(&self) -> usize {
108        self.0.load(Ordering::Relaxed)
109    }
110}
111
112impl<I> ParallelIterator for ProgressAdaptor<I> where I: ParallelIterator {
113    type Item=I::Item;
114
115    fn drive_unindexed<C>(self, consumer: C) -> C::Result
116    where
117        C: UnindexedConsumer<Self::Item> {
118        self.inner.drive_unindexed(ProgressConsumer {inner: consumer, items_processed: self.items_processed})
119    }
120}
121
122impl<I> IndexedParallelIterator for ProgressAdaptor<I> where I: IndexedParallelIterator {
123    fn len(&self) -> usize {
124        self.inner.len()
125    }
126
127    fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
128        self.inner.drive(ProgressConsumer {inner: consumer, items_processed: self.items_processed})
129    }
130
131    fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
132        struct ProgressCB<CB> {
133            inner: CB,
134            items_processed: Arc<AtomicUsize>,
135        }
136        impl<CB,T> ProducerCallback<T> for ProgressCB<CB> where CB: ProducerCallback<T> {
137            type Output=CB::Output;
138
139            fn callback<P>(self, producer: P) -> Self::Output where P: Producer<Item = T> {
140                self.inner.callback(ProgressProducer{inner: producer, items_processed: self.items_processed})
141            }
142        }
143        
144        self.inner.with_producer(ProgressCB{inner: callback, items_processed: self.items_processed})
145    }
146}
147
148impl<C,I> UnindexedConsumer<I> for ProgressConsumer<C> where C: UnindexedConsumer<I> {
149    fn split_off_left(&self) -> Self {
150        Self {inner: self.inner.split_off_left(), items_processed: self.items_processed.clone()}
151    }
152
153    fn to_reducer(&self) -> Self::Reducer {
154        self.inner.to_reducer()
155    }
156}
157
158impl<C,I> Consumer<I> for ProgressConsumer<C> where C: Consumer<I> {
159    type Folder = ProgressFolder<C::Folder>;
160
161    type Reducer = C::Reducer;
162
163    type Result = C::Result;
164
165    fn split_at(self, index: usize) -> (Self, Self, Self::Reducer) {
166        let ProgressConsumer {inner, items_processed: entries_processed} = self;
167        let (left, right, reducer) = inner.split_at(index);
168        (ProgressConsumer{inner:left, items_processed: entries_processed.clone()},
169         ProgressConsumer{inner:right, items_processed: entries_processed},
170         reducer)
171    }
172
173    fn into_folder(self) -> Self::Folder {
174        ProgressFolder {inner: self.inner.into_folder(), items_processed: self.items_processed}
175    }
176
177    fn full(&self) -> bool {
178        self.inner.full()
179    }
180}
181
182impl<F,I> Folder<I> for ProgressFolder<F> where F: Folder<I> {
183    type Result=F::Result;
184
185    fn consume(self, item: I) -> Self {
186        let Self{inner, items_processed} = self;
187        let inner = inner.consume(item);
188        items_processed.fetch_add(1, Ordering::Relaxed);
189        Self {inner, items_processed}
190    }
191
192    fn complete(self) -> Self::Result {
193        self.inner.complete()
194    }
195
196    fn full(&self) -> bool {
197        self.inner.full()
198    }
199
200    // the more I think about it, the more I realize the optimization with consume_iter() is
201    // unsound and will produce wrong results.  best case scenario the count gets updated far too
202    // early.  worst case the consumer realizes it is full midway through the iterator and we 
203    // report more items were processed than actually were.
204    // so i have left it out.
205}
206
207impl<P> Producer for ProgressProducer<P> where P: Producer {
208    type Item=P::Item;
209
210    type IntoIter=ProgressIterator<P::IntoIter>;
211
212    fn into_iter(self) -> Self::IntoIter {
213        ProgressIterator{inner: self.inner.into_iter(), items_processed: self.items_processed}
214    }
215
216    fn split_at(self, index: usize) -> (Self, Self) {
217        let Self{inner, items_processed}=self;
218        let (left,right) = inner.split_at(index);
219        (Self{inner: left, items_processed: items_processed.clone()}, Self{inner: right, items_processed})
220    }
221
222    fn min_len(&self) -> usize {
223        self.inner.min_len()
224    }
225
226    fn max_len(&self) -> usize {
227        self.inner.max_len()
228    }
229
230    fn fold_with<F>(self, folder: F) -> F
231    where
232        F: Folder<Self::Item>,
233    {
234        self.inner.fold_with(ProgressFolder{inner: folder, items_processed: self.items_processed}).inner
235    }
236}
237
238impl<I> Iterator for ProgressIterator<I> where I: Iterator {
239    type Item=I::Item;
240    fn next(&mut self) -> Option<Self::Item> {
241        let res = self.inner.next();
242        if res.is_some() {
243            self.items_processed.fetch_add(1, Ordering::Relaxed);
244        }
245        res
246    }
247    fn size_hint(&self) -> (usize, Option<usize>) {
248        self.inner.size_hint()
249    }
250}
251
252impl<I> DoubleEndedIterator for ProgressIterator<I> where I: DoubleEndedIterator {
253    fn next_back(&mut self) -> Option<Self::Item> {
254        let res = self.inner.next_back();
255        if res.is_some() {
256            self.items_processed.fetch_add(1, Ordering::Relaxed);
257        }
258        res
259    }
260}
261
262impl<I> ExactSizeIterator for ProgressIterator<I> where I: ExactSizeIterator {
263    fn len(&self) -> usize {
264        self.inner.len()
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use std::sync::atomic::AtomicBool;
271
272    use super::*;
273    
274    #[test]
275    fn test() {
276        let flag = Arc::new(AtomicBool::new(false));
277        let flag2= flag.clone();
278        let iter = ProgressAdaptor::new(0..1000);
279
280        let items_processed = iter.items_processed();
281
282
283        rayon::spawn(move || {
284            while items_processed.get() < 500 {
285                std::hint::spin_loop();
286            }
287            flag2.store(true, Ordering::Release);
288        });
289
290        let sum: u64 = iter.map(|x| {
291            if x >= 500 {
292                while !flag.load(Ordering::Acquire) {
293                    std::hint::spin_loop();
294                }
295            }
296            x
297        }).sum();
298        assert_eq!(sum, 499500);
299        
300    }
301    
302}