par_iter/iter/
fold_chunks.rs

1use std::fmt::{self, Debug};
2
3use super::{chunks::ChunkProducer, plumbing::*, *};
4use crate::math::div_round_up;
5
6/// `FoldChunks` is an iterator that groups elements of an underlying iterator
7/// and applies a function over them, producing a single value for each group.
8///
9/// This struct is created by the [`fold_chunks()`] method on
10/// [`IndexedParallelIterator`]
11///
12/// [`fold_chunks()`]: trait.IndexedParallelIterator.html#method.fold_chunks
13/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
14#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
15#[derive(Clone)]
16pub struct FoldChunks<I, ID, F>
17where
18    I: IndexedParallelIterator,
19{
20    base: I,
21    chunk_size: usize,
22    fold_op: F,
23    identity: ID,
24}
25
26impl<I: IndexedParallelIterator + Debug, ID, F> Debug for FoldChunks<I, ID, F> {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_struct("Fold")
29            .field("base", &self.base)
30            .field("chunk_size", &self.chunk_size)
31            .finish()
32    }
33}
34
35impl<I, ID, U, F> FoldChunks<I, ID, F>
36where
37    I: IndexedParallelIterator,
38    ID: Fn() -> U + Send + Sync,
39    F: Fn(U, I::Item) -> U + Send + Sync,
40    U: Send,
41{
42    /// Creates a new `FoldChunks` iterator
43    pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self {
44        FoldChunks {
45            base,
46            chunk_size,
47            identity,
48            fold_op,
49        }
50    }
51}
52
53impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F>
54where
55    I: IndexedParallelIterator,
56    ID: Fn() -> U + Send + Sync,
57    F: Fn(U, I::Item) -> U + Send + Sync,
58    U: Send,
59{
60    type Item = U;
61
62    fn drive_unindexed<C>(self, consumer: C) -> C::Result
63    where
64        C: Consumer<U>,
65    {
66        bridge(self, consumer)
67    }
68
69    fn opt_len(&self) -> Option<usize> {
70        Some(self.len())
71    }
72}
73
74impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F>
75where
76    I: IndexedParallelIterator,
77    ID: Fn() -> U + Send + Sync,
78    F: Fn(U, I::Item) -> U + Send + Sync,
79    U: Send,
80{
81    fn len(&self) -> usize {
82        div_round_up(self.base.len(), self.chunk_size)
83    }
84
85    fn drive<C>(self, consumer: C) -> C::Result
86    where
87        C: Consumer<Self::Item>,
88    {
89        bridge(self, consumer)
90    }
91
92    fn with_producer<CB>(self, callback: CB) -> CB::Output
93    where
94        CB: ProducerCallback<Self::Item>,
95    {
96        let len = self.base.len();
97        return self.base.with_producer(Callback {
98            chunk_size: self.chunk_size,
99            len,
100            identity: self.identity,
101            fold_op: self.fold_op,
102            callback,
103        });
104
105        struct Callback<CB, ID, F> {
106            chunk_size: usize,
107            len: usize,
108            identity: ID,
109            fold_op: F,
110            callback: CB,
111        }
112
113        impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F>
114        where
115            CB: ProducerCallback<U>,
116            ID: Fn() -> U + Send + Sync,
117            F: Fn(U, T) -> U + Send + Sync,
118        {
119            type Output = CB::Output;
120
121            fn callback<P>(self, base: P) -> CB::Output
122            where
123                P: Producer<Item = T>,
124            {
125                let identity = &self.identity;
126                let fold_op = &self.fold_op;
127                let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op);
128                let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
129                self.callback.callback(producer)
130            }
131        }
132    }
133}
134
135#[cfg(test)]
136mod test {
137    use std::ops::Add;
138
139    use super::*;
140
141    #[test]
142    fn check_fold_chunks() {
143        let words = "bishbashbosh!"
144            .chars()
145            .collect::<Vec<_>>()
146            .into_par_iter()
147            .fold_chunks(4, String::new, |mut s, c| {
148                s.push(c);
149                s
150            })
151            .collect::<Vec<_>>();
152
153        assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
154    }
155
156    // 'closure' values for tests below
157    fn id() -> i32 {
158        0
159    }
160    fn sum<T, U>(x: T, y: U) -> T
161    where
162        T: Add<U, Output = T>,
163    {
164        x + y
165    }
166
167    #[test]
168    #[should_panic(expected = "chunk_size must not be zero")]
169    fn check_fold_chunks_zero_size() {
170        let _: Vec<i32> = vec![1, 2, 3]
171            .into_par_iter()
172            .fold_chunks(0, id, sum)
173            .collect();
174    }
175
176    #[test]
177    fn check_fold_chunks_even_size() {
178        assert_eq!(
179            vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
180            (1..10)
181                .into_par_iter()
182                .fold_chunks(3, id, sum)
183                .collect::<Vec<i32>>()
184        );
185    }
186
187    #[test]
188    fn check_fold_chunks_empty() {
189        let v: Vec<i32> = vec![];
190        let expected: Vec<i32> = vec![];
191        assert_eq!(
192            expected,
193            v.into_par_iter()
194                .fold_chunks(2, id, sum)
195                .collect::<Vec<i32>>()
196        );
197    }
198
199    #[test]
200    fn check_fold_chunks_len() {
201        assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len());
202        assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len());
203        assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len());
204        assert_eq!(1, [1].par_iter().fold_chunks(3, id, sum).len());
205        assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len());
206    }
207
208    #[test]
209    fn check_fold_chunks_uneven() {
210        let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
211            ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
212            (vec![1], 5, vec![1]),
213            ((0..4).collect(), 3, vec![1 + 2, 3]),
214        ];
215
216        for (i, (v, n, expected)) in cases.into_iter().enumerate() {
217            let mut res: Vec<u32> = vec![];
218            v.par_iter()
219                .fold_chunks(n, || 0, sum)
220                .collect_into_vec(&mut res);
221            assert_eq!(expected, res, "Case {} failed", i);
222
223            res.truncate(0);
224            v.into_par_iter()
225                .fold_chunks(n, || 0, sum)
226                .rev()
227                .collect_into_vec(&mut res);
228            assert_eq!(
229                expected.into_iter().rev().collect::<Vec<u32>>(),
230                res,
231                "Case {} reversed failed",
232                i
233            );
234        }
235    }
236}