rayon_cancel/
lib.rs

1//! Rayon does not natively support interrupting long-running computations.
2//! This crate provides an iterator adapter for rayon that can be interrupted during computation.
3//!
4//! The provided [`CancelAdapter`] can be used on any rayon iterator and can be used to interrupt processing new items at any given point.
5//! The adapter provides a handle to cancel the computation and a handle to access the number of processed items.
6//!
7//! By design, the adapter cannot interrupt processing individual items.
8//! Once the computation is cancelled, the adapter will stop producing or consuming new items.
9//! Which items are processed before the computation stops is non-deterministic and depends on the way rayon distributes the work.
10//!
11//! Using this adapter may be less efficient than using the underlying iterator directly as the
12//! number of items produced by the iterator cannot be known in advance.
13//!
14//! If you only need access to the number of processed items, you may want to have a look at the [rayon-progress](https://crates.io/crates/rayon-progress) crate.
15//!
16//! # Example
17//! ```
18//! use rayon::prelude::*;
19//! use rayon_cancel::CancelAdapter;
20//! let adapter = CancelAdapter::new(0..100000);
21//! let canceller = adapter.canceller();
22//! let progress = adapter.counter();
23//! std::thread::spawn(move || {
24//!     while progress.get() < 1000 {
25//!        std::thread::sleep(std::time::Duration::from_millis(2));
26//!     }
27//!     canceller.cancel();
28//! });
29//! let count = adapter.counter();
30//! // some expensive computation
31//! let processed: Vec<_> = adapter.filter(|_| true).map(|i| {
32//!     std::thread::sleep(std::time::Duration::from_millis(20));
33//!     i
34//! }).collect();
35//! assert!(count.get() > 1000);
36//! assert!(count.get() < 100000);
37//! // `processed` contains `count` items, but which ones is non-deterministic
38//! assert_eq!(processed.len(), count.get());
39//! ```
40use std::sync::Arc;
41use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
42use rayon::iter::IntoParallelIterator;
43use rayon::iter::plumbing::{Consumer, Folder, UnindexedConsumer};
44use rayon::prelude::ParallelIterator;
45
46/// Handle that allows cancelling a running iterator computation.
47#[derive(Debug, Clone)]
48pub struct CancelHandle {
49    cancelled: Arc<AtomicBool>,
50}
51
52impl CancelHandle {
53    /// Cancel the running computation.
54    pub fn cancel(&self) {
55        self.cancelled.store(true, Ordering::Relaxed);
56    }
57
58    /// Returns whether the computation is cancelled.
59    pub fn is_cancelled(&self) -> bool {
60        self.cancelled.load(Ordering::Relaxed)
61    }
62}
63
64/// Access the number of items processed by the iterator.
65///
66/// Note that this provides only an upper bound on the number as this only counts the number of items produced, not the number consumed by following operations.
67#[derive(Debug, Clone)]
68pub struct CountHandle {
69    count: Arc<AtomicUsize>,
70}
71
72impl CountHandle {
73    pub fn get(&self) -> usize {
74        self.count.load(Ordering::Relaxed)
75    }
76}
77
78/// Iterator adapter that can be interrupted during computation.
79///
80/// The adapter allows provides the number of items processed by the iterator which can be used to estimate the progress.
81///
82#[derive(Debug, Clone)]
83pub struct CancelAdapter<I> {
84    inner: I,
85    cancel: Arc<AtomicBool>,
86    count: Arc<AtomicUsize>,
87}
88
89impl<I: ParallelIterator> CancelAdapter<I> {
90    /// Wrap the given iterator with a cancel adapter.
91    ///
92    /// `inner` can be any type that implements [`IntoParallelIterator`].
93    pub fn new<It: IntoParallelIterator<Iter=I>>(inner: It) -> Self {
94        Self {
95            inner: inner.into_par_iter(),
96            cancel: Arc::new(AtomicBool::new(false)),
97            count: Arc::new(AtomicUsize::new(0)),
98        }
99    }
100
101    /// Handle to cancel the computation.
102    pub fn canceller(&self) -> CancelHandle {
103        CancelHandle {
104            cancelled: self.cancel.clone(),
105        }
106    }
107
108    /// Handle to the number of processed items.
109    pub fn counter(&self) -> CountHandle {
110        CountHandle {
111            count: self.count.clone(),
112        }
113    }
114}
115
116impl<I> ParallelIterator for CancelAdapter<I>
117where
118    I: ParallelIterator,
119{
120    type Item = I::Item;
121
122    fn drive_unindexed<C>(self, consumer: C) -> C::Result
123    where
124        C: UnindexedConsumer<Self::Item>,
125    {
126        if self.cancel.load(Ordering::Relaxed) {
127            return consumer.split_off_left().into_folder().complete();
128        }
129        self.inner.drive_unindexed(CancelConsumer {
130            inner: consumer,
131            cancel: self.cancel,
132            count: self.count,
133        })
134    }
135}
136
137#[derive(Debug, Clone)]
138struct CancelConsumer<C> {
139    inner: C,
140    cancel: Arc<AtomicBool>,
141    count: Arc<AtomicUsize>,
142}
143
144#[derive(Debug, Clone)]
145struct CancelFolder<F> {
146    inner: F,
147    cancel: Arc<AtomicBool>,
148    count: Arc<AtomicUsize>,
149}
150
151impl<Item, C> Consumer<Item> for CancelConsumer<C>
152where
153    C: Consumer<Item>,
154{
155    type Folder = CancelFolder<C::Folder>;
156    type Reducer = C::Reducer;
157    type Result = C::Result;
158
159    fn split_at(self, index: usize) -> (Self, Self, Self::Reducer) {
160        let Self { inner, cancel, count } = self;
161        let (left, right, reducer) = inner.split_at(index);
162        (
163            CancelConsumer {
164                inner: left,
165                cancel: cancel.clone(),
166                count: count.clone(),
167            },
168            CancelConsumer {
169                inner: right,
170                cancel,
171                count,
172            },
173            reducer,
174        )
175    }
176
177    fn into_folder(self) -> Self::Folder {
178        CancelFolder {
179            inner: self.inner.into_folder(),
180            cancel: self.cancel,
181            count: self.count,
182        }
183    }
184
185    fn full(&self) -> bool {
186        self.cancel.load(Ordering::Relaxed) || self.inner.full()
187    }
188}
189
190impl<F: Folder<I>, I> Folder<I> for CancelFolder<F> {
191    type Result = F::Result;
192
193    fn consume(self, item: I) -> Self {
194        if self.cancel.load(Ordering::Relaxed) {
195            return self;
196        }
197        self.count.fetch_add(1, Ordering::Relaxed);
198        Self {
199            inner: self.inner.consume(item),
200            cancel: self.cancel,
201            count: self.count,
202        }
203    }
204
205    fn complete(self) -> Self::Result {
206        self.inner.complete()
207    }
208
209    fn full(&self) -> bool {
210        self.cancel.load(Ordering::Relaxed) || self.inner.full()
211    }
212}
213
214impl<Item, C> UnindexedConsumer<Item> for CancelConsumer<C>
215where
216    C: UnindexedConsumer<Item>,
217{
218    fn split_off_left(&self) -> Self {
219        CancelConsumer {
220            inner: self.inner.split_off_left(),
221            cancel: self.cancel.clone(),
222            count: self.count.clone(),
223        }
224    }
225
226    fn to_reducer(&self) -> Self::Reducer {
227        self.inner.to_reducer()
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_cancel() {
237        let adapter = CancelAdapter::new(0..10000);
238        let canceller = adapter.canceller();
239        rayon::spawn(move || {
240            std::thread::sleep(std::time::Duration::from_millis(100));
241            canceller.cancel();
242        });
243        let count = adapter.counter();
244        // simulate expensive computation
245        let total: usize = adapter.filter(|_| true).map(|i| {
246            std::thread::sleep(std::time::Duration::from_millis(20));
247            i
248        }).count();
249        assert!(count.get() < 10000);
250        assert_eq!(total, count.get());
251        println!("total: {}, count: {}", total, count.get());
252    }
253
254    #[test]
255    fn test_no_cancel() {
256        let adapter = CancelAdapter::new(0..10000);
257        let count = adapter.counter();
258        // very expensive calculation
259        let total: usize = adapter.filter(|_| true).count();
260        assert_eq!(total, 10000);
261        assert_eq!(total, count.get());
262        println!("total: {}, count: {}", total, count.get());
263    }
264
265    #[test]
266    fn test_blub() {
267        let iter = CancelAdapter::new(0..10000);
268        let canceller = iter.canceller();
269        let count = iter.counter();
270        rayon::spawn(move || {
271            std::thread::sleep(std::time::Duration::from_millis(510));
272            canceller.cancel();
273        });
274        let items: Vec<usize> = iter.filter(|_| true).map(|i| {
275            std::thread::sleep(std::time::Duration::from_millis(20));
276            i
277        }).collect();
278        assert!(count.get() < 10000);
279        assert_eq!(items.len(), count.get());
280    }
281}