1use std::sync::Arc;
41use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
42use rayon::iter::IntoParallelIterator;
43use rayon::iter::plumbing::{Consumer, Folder, UnindexedConsumer};
44use rayon::prelude::ParallelIterator;
45
46#[derive(Debug, Clone)]
48pub struct CancelHandle {
49 cancelled: Arc<AtomicBool>,
50}
51
52impl CancelHandle {
53 pub fn cancel(&self) {
55 self.cancelled.store(true, Ordering::Relaxed);
56 }
57
58 pub fn is_cancelled(&self) -> bool {
60 self.cancelled.load(Ordering::Relaxed)
61 }
62}
63
64#[derive(Debug, Clone)]
68pub struct CountHandle {
69 count: Arc<AtomicUsize>,
70}
71
72impl CountHandle {
73 pub fn get(&self) -> usize {
74 self.count.load(Ordering::Relaxed)
75 }
76}
77
78#[derive(Debug, Clone)]
83pub struct CancelAdapter<I> {
84 inner: I,
85 cancel: Arc<AtomicBool>,
86 count: Arc<AtomicUsize>,
87}
88
89impl<I: ParallelIterator> CancelAdapter<I> {
90 pub fn new<It: IntoParallelIterator<Iter=I>>(inner: It) -> Self {
94 Self {
95 inner: inner.into_par_iter(),
96 cancel: Arc::new(AtomicBool::new(false)),
97 count: Arc::new(AtomicUsize::new(0)),
98 }
99 }
100
101 pub fn canceller(&self) -> CancelHandle {
103 CancelHandle {
104 cancelled: self.cancel.clone(),
105 }
106 }
107
108 pub fn counter(&self) -> CountHandle {
110 CountHandle {
111 count: self.count.clone(),
112 }
113 }
114}
115
116impl<I> ParallelIterator for CancelAdapter<I>
117where
118 I: ParallelIterator,
119{
120 type Item = I::Item;
121
122 fn drive_unindexed<C>(self, consumer: C) -> C::Result
123 where
124 C: UnindexedConsumer<Self::Item>,
125 {
126 if self.cancel.load(Ordering::Relaxed) {
127 return consumer.split_off_left().into_folder().complete();
128 }
129 self.inner.drive_unindexed(CancelConsumer {
130 inner: consumer,
131 cancel: self.cancel,
132 count: self.count,
133 })
134 }
135}
136
137#[derive(Debug, Clone)]
138struct CancelConsumer<C> {
139 inner: C,
140 cancel: Arc<AtomicBool>,
141 count: Arc<AtomicUsize>,
142}
143
144#[derive(Debug, Clone)]
145struct CancelFolder<F> {
146 inner: F,
147 cancel: Arc<AtomicBool>,
148 count: Arc<AtomicUsize>,
149}
150
151impl<Item, C> Consumer<Item> for CancelConsumer<C>
152where
153 C: Consumer<Item>,
154{
155 type Folder = CancelFolder<C::Folder>;
156 type Reducer = C::Reducer;
157 type Result = C::Result;
158
159 fn split_at(self, index: usize) -> (Self, Self, Self::Reducer) {
160 let Self { inner, cancel, count } = self;
161 let (left, right, reducer) = inner.split_at(index);
162 (
163 CancelConsumer {
164 inner: left,
165 cancel: cancel.clone(),
166 count: count.clone(),
167 },
168 CancelConsumer {
169 inner: right,
170 cancel,
171 count,
172 },
173 reducer,
174 )
175 }
176
177 fn into_folder(self) -> Self::Folder {
178 CancelFolder {
179 inner: self.inner.into_folder(),
180 cancel: self.cancel,
181 count: self.count,
182 }
183 }
184
185 fn full(&self) -> bool {
186 self.cancel.load(Ordering::Relaxed) || self.inner.full()
187 }
188}
189
190impl<F: Folder<I>, I> Folder<I> for CancelFolder<F> {
191 type Result = F::Result;
192
193 fn consume(self, item: I) -> Self {
194 if self.cancel.load(Ordering::Relaxed) {
195 return self;
196 }
197 self.count.fetch_add(1, Ordering::Relaxed);
198 Self {
199 inner: self.inner.consume(item),
200 cancel: self.cancel,
201 count: self.count,
202 }
203 }
204
205 fn complete(self) -> Self::Result {
206 self.inner.complete()
207 }
208
209 fn full(&self) -> bool {
210 self.cancel.load(Ordering::Relaxed) || self.inner.full()
211 }
212}
213
214impl<Item, C> UnindexedConsumer<Item> for CancelConsumer<C>
215where
216 C: UnindexedConsumer<Item>,
217{
218 fn split_off_left(&self) -> Self {
219 CancelConsumer {
220 inner: self.inner.split_off_left(),
221 cancel: self.cancel.clone(),
222 count: self.count.clone(),
223 }
224 }
225
226 fn to_reducer(&self) -> Self::Reducer {
227 self.inner.to_reducer()
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_cancel() {
237 let adapter = CancelAdapter::new(0..10000);
238 let canceller = adapter.canceller();
239 rayon::spawn(move || {
240 std::thread::sleep(std::time::Duration::from_millis(100));
241 canceller.cancel();
242 });
243 let count = adapter.counter();
244 let total: usize = adapter.filter(|_| true).map(|i| {
246 std::thread::sleep(std::time::Duration::from_millis(20));
247 i
248 }).count();
249 assert!(count.get() < 10000);
250 assert_eq!(total, count.get());
251 println!("total: {}, count: {}", total, count.get());
252 }
253
254 #[test]
255 fn test_no_cancel() {
256 let adapter = CancelAdapter::new(0..10000);
257 let count = adapter.counter();
258 let total: usize = adapter.filter(|_| true).count();
260 assert_eq!(total, 10000);
261 assert_eq!(total, count.get());
262 println!("total: {}, count: {}", total, count.get());
263 }
264
265 #[test]
266 fn test_blub() {
267 let iter = CancelAdapter::new(0..10000);
268 let canceller = iter.canceller();
269 let count = iter.counter();
270 rayon::spawn(move || {
271 std::thread::sleep(std::time::Duration::from_millis(510));
272 canceller.cancel();
273 });
274 let items: Vec<usize> = iter.filter(|_| true).map(|i| {
275 std::thread::sleep(std::time::Duration::from_millis(20));
276 i
277 }).collect();
278 assert!(count.get() < 10000);
279 assert_eq!(items.len(), count.get());
280 }
281}