rten_base/
iter.rs

1//! Iterators and iterator-related traits.
2
3use rayon::iter::plumbing::{Consumer, Producer, ProducerCallback, UnindexedConsumer, bridge};
4use rayon::prelude::*;
5
6mod range;
7pub use range::{RangeChunks, RangeChunksExact, range_chunks, range_chunks_exact};
8
9/// A trait to simplify adding Rayon support to iterators.
10///
11/// This is used in combination with [`ParIter`] and assumes your iterator has
12/// a known size and can be split at an arbitrary position.
13///
14/// Adding Rayon support to an iterator using this trait requires implementing
15/// the following traits for the iterator:
16///
17/// 1. [`DoubleEndedIterator`] and [`ExactSizeIterator`]. These are requirements
18///    for Rayon's [`IndexedParallelIterator`].
19/// 2. `SplitIterator` to define how Rayon should split the iterator
20/// 3. [`IntoParallelIterator`] using [`ParIter<I>`] as the parallel iterator
21///    type.
22pub trait SplitIterator: DoubleEndedIterator + ExactSizeIterator {
23    /// Split the iterator in two at a given index.
24    ///
25    /// The left result will yield the first `index` items and the right result
26    /// will yield items starting from `index`.
27    ///
28    /// Panics if `index` is greater than the iterator's length, as reported by
29    /// [`ExactSizeIterator::len`].
30    fn split_at(self, index: usize) -> (Self, Self)
31    where
32        Self: Sized;
33}
34
35/// A parallel wrapper around a serial iterator.
36///
37/// This type should be used as the [`IntoParallelIterator::Iter`] associated
38/// type in an implementation of [`IntoParallelIterator`] for `I`.
39pub struct ParIter<I: SplitIterator>(I);
40
41impl<I: SplitIterator> From<I> for ParIter<I> {
42    fn from(val: I) -> Self {
43        ParIter(val)
44    }
45}
46
47impl<I: SplitIterator + Send> ParallelIterator for ParIter<I>
48where
49    <I as Iterator>::Item: Send,
50{
51    type Item = I::Item;
52
53    fn drive_unindexed<C>(self, consumer: C) -> C::Result
54    where
55        C: UnindexedConsumer<Self::Item>,
56    {
57        bridge(self, consumer)
58    }
59
60    fn opt_len(&self) -> Option<usize> {
61        Some(ExactSizeIterator::len(&self.0))
62    }
63}
64
65impl<I: SplitIterator + Send> IndexedParallelIterator for ParIter<I>
66where
67    <I as Iterator>::Item: Send,
68{
69    fn drive<C>(self, consumer: C) -> C::Result
70    where
71        C: Consumer<Self::Item>,
72    {
73        bridge(self, consumer)
74    }
75
76    fn len(&self) -> usize {
77        ExactSizeIterator::len(&self.0)
78    }
79
80    fn with_producer<CB>(self, callback: CB) -> CB::Output
81    where
82        CB: ProducerCallback<Self::Item>,
83    {
84        callback.callback(self)
85    }
86}
87
88impl<I: SplitIterator + Send> Producer for ParIter<I> {
89    type Item = I::Item;
90
91    type IntoIter = I;
92
93    fn into_iter(self) -> Self::IntoIter {
94        self.0
95    }
96
97    fn split_at(self, index: usize) -> (Self, Self) {
98        let (left_inner, right_inner) = SplitIterator::split_at(self.0, index);
99        (Self(left_inner), Self(right_inner))
100    }
101}
102
103/// Wrapper around either a serial or parallel iterator, returned by
104/// [`MaybeParIter::maybe_par_iter`].
105pub enum MaybeParallel<PI: ParallelIterator, SI: Iterator<Item = PI::Item>> {
106    Serial(SI),
107    Parallel(PI),
108}
109
110impl<PI: ParallelIterator, SI: Iterator<Item = PI::Item>> MaybeParallel<PI, SI> {
111    pub fn for_each<F: Fn(PI::Item) + Send + Sync>(self, f: F) {
112        match self {
113            MaybeParallel::Serial(iter) => iter.for_each(f),
114            MaybeParallel::Parallel(iter) => iter.for_each(f),
115        }
116    }
117}
118
119/// Trait which allows use of Rayon parallelism to be conditionally enabled.
120///
121/// See <https://crates.io/crates/rayon-cond> for a more full-featured alternative.
122pub trait MaybeParIter {
123    type Item;
124    type ParIter: ParallelIterator<Item = Self::Item>;
125    type Iter: Iterator<Item = Self::Item>;
126
127    /// Return an iterator which executes either in serial on the current
128    /// thread, or in parallel in a Rayon thread pool if `parallel` is true.
129    fn maybe_par_iter(self, parallel: bool) -> MaybeParallel<Self::ParIter, Self::Iter>;
130}
131
132impl<Item, I: rayon::iter::IntoParallelIterator<Item = Item> + IntoIterator<Item = Item>>
133    MaybeParIter for I
134{
135    type Item = Item;
136    type ParIter = I::Iter;
137    type Iter = I::IntoIter;
138
139    fn maybe_par_iter(self, parallel: bool) -> MaybeParallel<Self::ParIter, Self::Iter> {
140        if parallel {
141            MaybeParallel::Parallel(self.into_par_iter())
142        } else {
143            MaybeParallel::Serial(self.into_iter())
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::sync::atomic::{AtomicU32, Ordering};
151
152    use super::MaybeParIter;
153
154    #[test]
155    fn test_maybe_par_iter() {
156        let count = AtomicU32::new(0);
157        (0..1000).maybe_par_iter(false).for_each(|_| {
158            count.fetch_add(1, Ordering::SeqCst);
159        });
160        assert_eq!(count.load(Ordering::SeqCst), 1000);
161
162        let count = AtomicU32::new(0);
163        (0..1000).maybe_par_iter(true).for_each(|_| {
164            count.fetch_add(1, Ordering::SeqCst);
165        });
166        assert_eq!(count.load(Ordering::SeqCst), 1000);
167    }
168}