par_iter/iter/
fold_chunks_with.rs

1use std::fmt::{self, Debug};
2
3use super::{chunks::ChunkProducer, plumbing::*, *};
4use crate::math::div_round_up;
5
6/// `FoldChunksWith` is an iterator that groups elements of an underlying
7/// iterator and applies a function over them, producing a single value for each
8/// group.
9///
10/// This struct is created by the [`fold_chunks_with()`] method on
11/// [`IndexedParallelIterator`]
12///
13/// [`fold_chunks_with()`]: trait.IndexedParallelIterator.html#method.fold_chunks
14/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
15#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
16#[derive(Clone)]
17pub struct FoldChunksWith<I, U, F>
18where
19    I: IndexedParallelIterator,
20{
21    base: I,
22    chunk_size: usize,
23    item: U,
24    fold_op: F,
25}
26
27impl<I: IndexedParallelIterator + Debug, U: Debug, F> Debug for FoldChunksWith<I, U, F> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("Fold")
30            .field("base", &self.base)
31            .field("chunk_size", &self.chunk_size)
32            .field("item", &self.item)
33            .finish()
34    }
35}
36
37impl<I, U, F> FoldChunksWith<I, U, F>
38where
39    I: IndexedParallelIterator,
40    U: Send + Clone,
41    F: Fn(U, I::Item) -> U + Send + Sync,
42{
43    /// Creates a new `FoldChunksWith` iterator
44    pub(super) fn new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self {
45        FoldChunksWith {
46            base,
47            chunk_size,
48            item,
49            fold_op,
50        }
51    }
52}
53
54impl<I, U, F> ParallelIterator for FoldChunksWith<I, U, F>
55where
56    I: IndexedParallelIterator,
57    U: Send + Clone,
58    F: Fn(U, I::Item) -> U + Send + Sync,
59{
60    type Item = U;
61
62    fn drive_unindexed<C>(self, consumer: C) -> C::Result
63    where
64        C: Consumer<U>,
65    {
66        bridge(self, consumer)
67    }
68
69    fn opt_len(&self) -> Option<usize> {
70        Some(self.len())
71    }
72}
73
74impl<I, U, F> IndexedParallelIterator for FoldChunksWith<I, U, F>
75where
76    I: IndexedParallelIterator,
77    U: Send + Clone,
78    F: Fn(U, I::Item) -> U + Send + Sync,
79{
80    fn len(&self) -> usize {
81        div_round_up(self.base.len(), self.chunk_size)
82    }
83
84    fn drive<C>(self, consumer: C) -> C::Result
85    where
86        C: Consumer<Self::Item>,
87    {
88        bridge(self, consumer)
89    }
90
91    fn with_producer<CB>(self, callback: CB) -> CB::Output
92    where
93        CB: ProducerCallback<Self::Item>,
94    {
95        let len = self.base.len();
96        return self.base.with_producer(Callback {
97            chunk_size: self.chunk_size,
98            len,
99            item: self.item,
100            fold_op: self.fold_op,
101            callback,
102        });
103
104        struct Callback<CB, T, F> {
105            chunk_size: usize,
106            len: usize,
107            item: T,
108            fold_op: F,
109            callback: CB,
110        }
111
112        impl<T, U, F, CB> ProducerCallback<T> for Callback<CB, U, F>
113        where
114            CB: ProducerCallback<U>,
115            U: Send + Clone,
116            F: Fn(U, T) -> U + Send + Sync,
117        {
118            type Output = CB::Output;
119
120            fn callback<P>(self, base: P) -> CB::Output
121            where
122                P: Producer<Item = T>,
123            {
124                let item = self.item;
125                let fold_op = &self.fold_op;
126                let fold_iter = move |iter: P::IntoIter| iter.fold(item.clone(), fold_op);
127                let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
128                self.callback.callback(producer)
129            }
130        }
131    }
132}
133
134#[cfg(test)]
135mod test {
136    use std::ops::Add;
137
138    use super::*;
139
140    #[test]
141    fn check_fold_chunks_with() {
142        let words = "bishbashbosh!"
143            .chars()
144            .collect::<Vec<_>>()
145            .into_par_iter()
146            .fold_chunks_with(4, String::new(), |mut s, c| {
147                s.push(c);
148                s
149            })
150            .collect::<Vec<_>>();
151
152        assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
153    }
154
155    // 'closure' value for tests below
156    fn sum<T, U>(x: T, y: U) -> T
157    where
158        T: Add<U, Output = T>,
159    {
160        x + y
161    }
162
163    #[test]
164    #[should_panic(expected = "chunk_size must not be zero")]
165    fn check_fold_chunks_zero_size() {
166        let _: Vec<i32> = vec![1, 2, 3]
167            .into_par_iter()
168            .fold_chunks_with(0, 0, sum)
169            .collect();
170    }
171
172    #[test]
173    fn check_fold_chunks_even_size() {
174        assert_eq!(
175            vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
176            (1..10)
177                .into_par_iter()
178                .fold_chunks_with(3, 0, sum)
179                .collect::<Vec<i32>>()
180        );
181    }
182
183    #[test]
184    fn check_fold_chunks_with_empty() {
185        let v: Vec<i32> = vec![];
186        let expected: Vec<i32> = vec![];
187        assert_eq!(
188            expected,
189            v.into_par_iter()
190                .fold_chunks_with(2, 0, sum)
191                .collect::<Vec<i32>>()
192        );
193    }
194
195    #[test]
196    fn check_fold_chunks_len() {
197        assert_eq!(4, (0..8).into_par_iter().fold_chunks_with(2, 0, sum).len());
198        assert_eq!(3, (0..9).into_par_iter().fold_chunks_with(3, 0, sum).len());
199        assert_eq!(3, (0..8).into_par_iter().fold_chunks_with(3, 0, sum).len());
200        assert_eq!(1, [1].par_iter().fold_chunks_with(3, 0, sum).len());
201        assert_eq!(0, (0..0).into_par_iter().fold_chunks_with(3, 0, sum).len());
202    }
203
204    #[test]
205    fn check_fold_chunks_uneven() {
206        let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
207            ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
208            (vec![1], 5, vec![1]),
209            ((0..4).collect(), 3, vec![1 + 2, 3]),
210        ];
211
212        for (i, (v, n, expected)) in cases.into_iter().enumerate() {
213            let mut res: Vec<u32> = vec![];
214            v.par_iter()
215                .fold_chunks_with(n, 0, sum)
216                .collect_into_vec(&mut res);
217            assert_eq!(expected, res, "Case {} failed", i);
218
219            res.truncate(0);
220            v.into_par_iter()
221                .fold_chunks_with(n, 0, sum)
222                .rev()
223                .collect_into_vec(&mut res);
224            assert_eq!(
225                expected.into_iter().rev().collect::<Vec<u32>>(),
226                res,
227                "Case {} reversed failed",
228                i
229            );
230        }
231    }
232}