1use std::fmt::{self, Debug};
2
3use super::{chunks::ChunkProducer, plumbing::*, *};
4use crate::math::div_round_up;
5
6#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
15#[derive(Clone)]
16pub struct FoldChunks<I, ID, F>
17where
18 I: IndexedParallelIterator,
19{
20 base: I,
21 chunk_size: usize,
22 fold_op: F,
23 identity: ID,
24}
25
26impl<I: IndexedParallelIterator + Debug, ID, F> Debug for FoldChunks<I, ID, F> {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("Fold")
29 .field("base", &self.base)
30 .field("chunk_size", &self.chunk_size)
31 .finish()
32 }
33}
34
35impl<I, ID, U, F> FoldChunks<I, ID, F>
36where
37 I: IndexedParallelIterator,
38 ID: Fn() -> U + Send + Sync,
39 F: Fn(U, I::Item) -> U + Send + Sync,
40 U: Send,
41{
42 pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self {
44 FoldChunks {
45 base,
46 chunk_size,
47 identity,
48 fold_op,
49 }
50 }
51}
52
53impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F>
54where
55 I: IndexedParallelIterator,
56 ID: Fn() -> U + Send + Sync,
57 F: Fn(U, I::Item) -> U + Send + Sync,
58 U: Send,
59{
60 type Item = U;
61
62 fn drive_unindexed<C>(self, consumer: C) -> C::Result
63 where
64 C: Consumer<U>,
65 {
66 bridge(self, consumer)
67 }
68
69 fn opt_len(&self) -> Option<usize> {
70 Some(self.len())
71 }
72}
73
74impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F>
75where
76 I: IndexedParallelIterator,
77 ID: Fn() -> U + Send + Sync,
78 F: Fn(U, I::Item) -> U + Send + Sync,
79 U: Send,
80{
81 fn len(&self) -> usize {
82 div_round_up(self.base.len(), self.chunk_size)
83 }
84
85 fn drive<C>(self, consumer: C) -> C::Result
86 where
87 C: Consumer<Self::Item>,
88 {
89 bridge(self, consumer)
90 }
91
92 fn with_producer<CB>(self, callback: CB) -> CB::Output
93 where
94 CB: ProducerCallback<Self::Item>,
95 {
96 let len = self.base.len();
97 return self.base.with_producer(Callback {
98 chunk_size: self.chunk_size,
99 len,
100 identity: self.identity,
101 fold_op: self.fold_op,
102 callback,
103 });
104
105 struct Callback<CB, ID, F> {
106 chunk_size: usize,
107 len: usize,
108 identity: ID,
109 fold_op: F,
110 callback: CB,
111 }
112
113 impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F>
114 where
115 CB: ProducerCallback<U>,
116 ID: Fn() -> U + Send + Sync,
117 F: Fn(U, T) -> U + Send + Sync,
118 {
119 type Output = CB::Output;
120
121 fn callback<P>(self, base: P) -> CB::Output
122 where
123 P: Producer<Item = T>,
124 {
125 let identity = &self.identity;
126 let fold_op = &self.fold_op;
127 let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op);
128 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
129 self.callback.callback(producer)
130 }
131 }
132 }
133}
134
135#[cfg(test)]
136mod test {
137 use std::ops::Add;
138
139 use super::*;
140
141 #[test]
142 fn check_fold_chunks() {
143 let words = "bishbashbosh!"
144 .chars()
145 .collect::<Vec<_>>()
146 .into_par_iter()
147 .fold_chunks(4, String::new, |mut s, c| {
148 s.push(c);
149 s
150 })
151 .collect::<Vec<_>>();
152
153 assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
154 }
155
156 fn id() -> i32 {
158 0
159 }
160 fn sum<T, U>(x: T, y: U) -> T
161 where
162 T: Add<U, Output = T>,
163 {
164 x + y
165 }
166
167 #[test]
168 #[should_panic(expected = "chunk_size must not be zero")]
169 fn check_fold_chunks_zero_size() {
170 let _: Vec<i32> = vec![1, 2, 3]
171 .into_par_iter()
172 .fold_chunks(0, id, sum)
173 .collect();
174 }
175
176 #[test]
177 fn check_fold_chunks_even_size() {
178 assert_eq!(
179 vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
180 (1..10)
181 .into_par_iter()
182 .fold_chunks(3, id, sum)
183 .collect::<Vec<i32>>()
184 );
185 }
186
187 #[test]
188 fn check_fold_chunks_empty() {
189 let v: Vec<i32> = vec![];
190 let expected: Vec<i32> = vec![];
191 assert_eq!(
192 expected,
193 v.into_par_iter()
194 .fold_chunks(2, id, sum)
195 .collect::<Vec<i32>>()
196 );
197 }
198
199 #[test]
200 fn check_fold_chunks_len() {
201 assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len());
202 assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len());
203 assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len());
204 assert_eq!(1, [1].par_iter().fold_chunks(3, id, sum).len());
205 assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len());
206 }
207
208 #[test]
209 fn check_fold_chunks_uneven() {
210 let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
211 ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
212 (vec![1], 5, vec![1]),
213 ((0..4).collect(), 3, vec![1 + 2, 3]),
214 ];
215
216 for (i, (v, n, expected)) in cases.into_iter().enumerate() {
217 let mut res: Vec<u32> = vec![];
218 v.par_iter()
219 .fold_chunks(n, || 0, sum)
220 .collect_into_vec(&mut res);
221 assert_eq!(expected, res, "Case {} failed", i);
222
223 res.truncate(0);
224 v.into_par_iter()
225 .fold_chunks(n, || 0, sum)
226 .rev()
227 .collect_into_vec(&mut res);
228 assert_eq!(
229 expected.into_iter().rev().collect::<Vec<u32>>(),
230 res,
231 "Case {} reversed failed",
232 i
233 );
234 }
235 }
236}