1use std::fmt::{self, Debug};
2
3use super::{chunks::ChunkProducer, plumbing::*, *};
4use crate::math::div_round_up;
5
6#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
16#[derive(Clone)]
17pub struct FoldChunksWith<I, U, F>
18where
19 I: IndexedParallelIterator,
20{
21 base: I,
22 chunk_size: usize,
23 item: U,
24 fold_op: F,
25}
26
27impl<I: IndexedParallelIterator + Debug, U: Debug, F> Debug for FoldChunksWith<I, U, F> {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 f.debug_struct("Fold")
30 .field("base", &self.base)
31 .field("chunk_size", &self.chunk_size)
32 .field("item", &self.item)
33 .finish()
34 }
35}
36
37impl<I, U, F> FoldChunksWith<I, U, F>
38where
39 I: IndexedParallelIterator,
40 U: Send + Clone,
41 F: Fn(U, I::Item) -> U + Send + Sync,
42{
43 pub(super) fn new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self {
45 FoldChunksWith {
46 base,
47 chunk_size,
48 item,
49 fold_op,
50 }
51 }
52}
53
54impl<I, U, F> ParallelIterator for FoldChunksWith<I, U, F>
55where
56 I: IndexedParallelIterator,
57 U: Send + Clone,
58 F: Fn(U, I::Item) -> U + Send + Sync,
59{
60 type Item = U;
61
62 fn drive_unindexed<C>(self, consumer: C) -> C::Result
63 where
64 C: Consumer<U>,
65 {
66 bridge(self, consumer)
67 }
68
69 fn opt_len(&self) -> Option<usize> {
70 Some(self.len())
71 }
72}
73
74impl<I, U, F> IndexedParallelIterator for FoldChunksWith<I, U, F>
75where
76 I: IndexedParallelIterator,
77 U: Send + Clone,
78 F: Fn(U, I::Item) -> U + Send + Sync,
79{
80 fn len(&self) -> usize {
81 div_round_up(self.base.len(), self.chunk_size)
82 }
83
84 fn drive<C>(self, consumer: C) -> C::Result
85 where
86 C: Consumer<Self::Item>,
87 {
88 bridge(self, consumer)
89 }
90
91 fn with_producer<CB>(self, callback: CB) -> CB::Output
92 where
93 CB: ProducerCallback<Self::Item>,
94 {
95 let len = self.base.len();
96 return self.base.with_producer(Callback {
97 chunk_size: self.chunk_size,
98 len,
99 item: self.item,
100 fold_op: self.fold_op,
101 callback,
102 });
103
104 struct Callback<CB, T, F> {
105 chunk_size: usize,
106 len: usize,
107 item: T,
108 fold_op: F,
109 callback: CB,
110 }
111
112 impl<T, U, F, CB> ProducerCallback<T> for Callback<CB, U, F>
113 where
114 CB: ProducerCallback<U>,
115 U: Send + Clone,
116 F: Fn(U, T) -> U + Send + Sync,
117 {
118 type Output = CB::Output;
119
120 fn callback<P>(self, base: P) -> CB::Output
121 where
122 P: Producer<Item = T>,
123 {
124 let item = self.item;
125 let fold_op = &self.fold_op;
126 let fold_iter = move |iter: P::IntoIter| iter.fold(item.clone(), fold_op);
127 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
128 self.callback.callback(producer)
129 }
130 }
131 }
132}
133
134#[cfg(test)]
135mod test {
136 use std::ops::Add;
137
138 use super::*;
139
140 #[test]
141 fn check_fold_chunks_with() {
142 let words = "bishbashbosh!"
143 .chars()
144 .collect::<Vec<_>>()
145 .into_par_iter()
146 .fold_chunks_with(4, String::new(), |mut s, c| {
147 s.push(c);
148 s
149 })
150 .collect::<Vec<_>>();
151
152 assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
153 }
154
155 fn sum<T, U>(x: T, y: U) -> T
157 where
158 T: Add<U, Output = T>,
159 {
160 x + y
161 }
162
163 #[test]
164 #[should_panic(expected = "chunk_size must not be zero")]
165 fn check_fold_chunks_zero_size() {
166 let _: Vec<i32> = vec![1, 2, 3]
167 .into_par_iter()
168 .fold_chunks_with(0, 0, sum)
169 .collect();
170 }
171
172 #[test]
173 fn check_fold_chunks_even_size() {
174 assert_eq!(
175 vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
176 (1..10)
177 .into_par_iter()
178 .fold_chunks_with(3, 0, sum)
179 .collect::<Vec<i32>>()
180 );
181 }
182
183 #[test]
184 fn check_fold_chunks_with_empty() {
185 let v: Vec<i32> = vec![];
186 let expected: Vec<i32> = vec![];
187 assert_eq!(
188 expected,
189 v.into_par_iter()
190 .fold_chunks_with(2, 0, sum)
191 .collect::<Vec<i32>>()
192 );
193 }
194
195 #[test]
196 fn check_fold_chunks_len() {
197 assert_eq!(4, (0..8).into_par_iter().fold_chunks_with(2, 0, sum).len());
198 assert_eq!(3, (0..9).into_par_iter().fold_chunks_with(3, 0, sum).len());
199 assert_eq!(3, (0..8).into_par_iter().fold_chunks_with(3, 0, sum).len());
200 assert_eq!(1, [1].par_iter().fold_chunks_with(3, 0, sum).len());
201 assert_eq!(0, (0..0).into_par_iter().fold_chunks_with(3, 0, sum).len());
202 }
203
204 #[test]
205 fn check_fold_chunks_uneven() {
206 let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
207 ((0..5).collect(), 3, vec![1 + 2, 3 + 4]),
208 (vec![1], 5, vec![1]),
209 ((0..4).collect(), 3, vec![1 + 2, 3]),
210 ];
211
212 for (i, (v, n, expected)) in cases.into_iter().enumerate() {
213 let mut res: Vec<u32> = vec![];
214 v.par_iter()
215 .fold_chunks_with(n, 0, sum)
216 .collect_into_vec(&mut res);
217 assert_eq!(expected, res, "Case {} failed", i);
218
219 res.truncate(0);
220 v.into_par_iter()
221 .fold_chunks_with(n, 0, sum)
222 .rev()
223 .collect_into_vec(&mut res);
224 assert_eq!(
225 expected.into_iter().rev().collect::<Vec<u32>>(),
226 res,
227 "Case {} reversed failed",
228 i
229 );
230 }
231 }
232}