1use rayon::prelude::*;
30
31pub trait Mergeable: Send + Default {
36 fn merge(&mut self, other: Self);
38}
39
40pub fn fold_reduce<T, I, F, M>(items: I, fold_fn: F, merge_fn: M) -> T
62where
63 T: Send + Default,
64 I: ParallelIterator,
65 F: Fn(T, I::Item) -> T + Sync + Send,
66 M: Fn(T, T) -> T + Sync + Send,
67{
68 items.fold(T::default, fold_fn).reduce(T::default, merge_fn)
69}
70
71pub fn fold_reduce_with<T, I, Init, F, M>(items: I, init: Init, fold_fn: F, merge_fn: M) -> T
75where
76 T: Send,
77 I: ParallelIterator,
78 Init: Fn() -> T + Sync + Send + Clone,
79 F: Fn(T, I::Item) -> T + Sync + Send,
80 M: Fn(T, T) -> T + Sync + Send,
81{
82 items.fold(init.clone(), fold_fn).reduce(init, merge_fn)
83}
84
85pub fn parallel_count<T, I, P>(items: I, predicate: P) -> usize
101where
102 T: Send,
103 I: ParallelIterator<Item = T>,
104 P: Fn(&T) -> bool + Sync + Send,
105{
106 items
107 .fold(|| 0usize, |count, item| count + predicate(&item) as usize)
108 .reduce(|| 0, |a, b| a + b)
109}
110
111pub fn parallel_sum<T, I, F>(items: I, extract: F) -> f64
124where
125 T: Send,
126 I: ParallelIterator<Item = T>,
127 F: Fn(&T) -> f64 + Sync + Send,
128{
129 items
130 .fold(|| 0.0f64, |sum, item| sum + extract(&item))
131 .reduce(|| 0.0, |a, b| a + b)
132}
133
134pub fn parallel_sum_i64<T, I, F>(items: I, extract: F) -> i64
136where
137 T: Send,
138 I: ParallelIterator<Item = T>,
139 F: Fn(&T) -> i64 + Sync + Send,
140{
141 items
142 .fold(|| 0i64, |sum, item| sum + extract(&item))
143 .reduce(|| 0, |a, b| a + b)
144}
145
146pub fn parallel_min<T, I, F, V>(items: I, extract: F) -> Option<V>
150where
151 T: Send,
152 V: Send + Ord + Copy,
153 I: ParallelIterator<Item = T>,
154 F: Fn(&T) -> V + Sync + Send,
155{
156 items
157 .fold(
158 || None,
159 |min: Option<V>, item| {
160 let val = extract(&item);
161 Some(match min {
162 Some(m) if m < val => m,
163 _ => val,
164 })
165 },
166 )
167 .reduce(
168 || None,
169 |a, b| match (a, b) {
170 (Some(va), Some(vb)) => Some(if va < vb { va } else { vb }),
171 (Some(v), None) | (None, Some(v)) => Some(v),
172 (None, None) => None,
173 },
174 )
175}
176
177pub fn parallel_max<T, I, F, V>(items: I, extract: F) -> Option<V>
181where
182 T: Send,
183 V: Send + Ord + Copy,
184 I: ParallelIterator<Item = T>,
185 F: Fn(&T) -> V + Sync + Send,
186{
187 items
188 .fold(
189 || None,
190 |max: Option<V>, item| {
191 let val = extract(&item);
192 Some(match max {
193 Some(m) if m > val => m,
194 _ => val,
195 })
196 },
197 )
198 .reduce(
199 || None,
200 |a, b| match (a, b) {
201 (Some(va), Some(vb)) => Some(if va > vb { va } else { vb }),
202 (Some(v), None) | (None, Some(v)) => Some(v),
203 (None, None) => None,
204 },
205 )
206}
207
208pub fn parallel_try_collect<T, E, I, F, R>(items: I, process: F) -> (Vec<R>, Vec<E>)
230where
231 T: Send,
232 E: Send,
233 R: Send,
234 I: ParallelIterator<Item = T>,
235 F: Fn(T) -> Result<R, E> + Sync + Send,
236{
237 items
238 .fold(
239 || (Vec::new(), Vec::new()),
240 |(mut ok, mut err), item| {
241 match process(item) {
242 Ok(r) => ok.push(r),
243 Err(e) => err.push(e),
244 }
245 (ok, err)
246 },
247 )
248 .reduce(
249 || (Vec::new(), Vec::new()),
250 |(mut ok1, mut err1), (ok2, err2)| {
251 ok1.extend(ok2);
252 err1.extend(err2);
253 (ok1, err1)
254 },
255 )
256}
257
258pub fn parallel_stats<T, I, F>(items: I, extract: F) -> (usize, f64, Option<f64>, Option<f64>)
262where
263 T: Send,
264 I: ParallelIterator<Item = T>,
265 F: Fn(&T) -> f64 + Sync + Send,
266{
267 items
268 .fold(
269 || (0usize, 0.0f64, None::<f64>, None::<f64>),
270 |(count, sum, min, max), item| {
271 let val = extract(&item);
272 (
273 count + 1,
274 sum + val,
275 Some(match min {
276 Some(m) if m < val => m,
277 _ => val,
278 }),
279 Some(match max {
280 Some(m) if m > val => m,
281 _ => val,
282 }),
283 )
284 },
285 )
286 .reduce(
287 || (0, 0.0, None, None),
288 |(c1, s1, min1, max1), (c2, s2, min2, max2)| {
289 let min = match (min1, min2) {
290 (Some(a), Some(b)) => Some(a.min(b)),
291 (Some(v), None) | (None, Some(v)) => Some(v),
292 (None, None) => None,
293 };
294 let max = match (max1, max2) {
295 (Some(a), Some(b)) => Some(a.max(b)),
296 (Some(v), None) | (None, Some(v)) => Some(v),
297 (None, None) => None,
298 };
299 (c1 + c2, s1 + s2, min, max)
300 },
301 )
302}
303
304pub fn parallel_partition<T, I, K, V, KeyFn, ValFn>(
322 items: I,
323 key_fn: KeyFn,
324 val_fn: ValFn,
325) -> std::collections::HashMap<K, Vec<V>>
326where
327 T: Send,
328 K: Send + Eq + std::hash::Hash + Clone,
329 V: Send,
330 I: ParallelIterator<Item = T>,
331 KeyFn: Fn(&T) -> K + Sync + Send,
332 ValFn: Fn(T) -> V + Sync + Send,
333{
334 items
335 .fold(std::collections::HashMap::new, |mut map, item| {
336 let key = key_fn(&item);
337 let val = val_fn(item);
338 map.entry(key).or_insert_with(Vec::new).push(val);
339 map
340 })
341 .reduce(std::collections::HashMap::new, |mut map1, map2| {
342 for (key, mut values) in map2 {
343 map1.entry(key).or_insert_with(Vec::new).append(&mut values);
344 }
345 map1
346 })
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_parallel_count() {
355 let numbers: Vec<i32> = (0..1000).collect();
356 let even_count = parallel_count(numbers.par_iter(), |n| *n % 2 == 0);
357 assert_eq!(even_count, 500);
358 }
359
360 #[test]
361 fn test_parallel_sum() {
362 let numbers: Vec<i32> = (1..=100).collect();
363 let total = parallel_sum(numbers.par_iter(), |n| f64::from(**n));
364 assert!((total - 5050.0).abs() < 0.001);
365 }
366
367 #[test]
368 fn test_parallel_sum_i64() {
369 let numbers: Vec<i32> = (1..=100).collect();
370 let total = parallel_sum_i64(numbers.par_iter(), |n| i64::from(**n));
371 assert_eq!(total, 5050);
372 }
373
374 #[test]
375 fn test_parallel_min() {
376 let numbers: Vec<i32> = vec![5, 3, 8, 1, 9, 2];
377 let min = parallel_min(numbers.par_iter(), |n| *n);
378 assert_eq!(min, Some(&1));
379
380 let empty: Vec<i32> = vec![];
381 let min_empty: Option<&i32> = parallel_min(empty.par_iter(), |n| *n);
382 assert_eq!(min_empty, None);
383 }
384
385 #[test]
386 fn test_parallel_max() {
387 let numbers: Vec<i32> = vec![5, 3, 8, 1, 9, 2];
388 let max = parallel_max(numbers.par_iter(), |n| *n);
389 assert_eq!(max, Some(&9));
390 }
391
392 #[test]
393 fn test_parallel_try_collect() {
394 let items = vec!["1", "two", "3", "four", "5"];
395 let (successes, errors): (Vec<i32>, Vec<String>) =
396 parallel_try_collect(items.into_par_iter(), |s| {
397 s.parse::<i32>().map_err(|e| e.to_string())
398 });
399
400 assert_eq!(successes.len(), 3);
401 assert!(successes.contains(&1));
402 assert!(successes.contains(&3));
403 assert!(successes.contains(&5));
404 assert_eq!(errors.len(), 2);
405 }
406
407 #[test]
408 fn test_parallel_stats() {
409 let numbers: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
410 let (count, sum, min, max) = parallel_stats(numbers.into_par_iter(), |n| *n);
411
412 assert_eq!(count, 5);
413 assert!((sum - 15.0).abs() < 0.001);
414 assert!((min.unwrap() - 1.0).abs() < 0.001);
415 assert!((max.unwrap() - 5.0).abs() < 0.001);
416 }
417
418 #[test]
419 fn test_parallel_partition() {
420 let items: Vec<(i32, &str)> = vec![(1, "a"), (2, "b"), (1, "c"), (2, "d"), (1, "e")];
421 let groups = parallel_partition(items.into_par_iter(), |(k, _)| *k, |(_, v)| v);
422
423 assert_eq!(groups.get(&1).map(|v| v.len()), Some(3));
424 assert_eq!(groups.get(&2).map(|v| v.len()), Some(2));
425 }
426
427 #[test]
428 fn test_fold_reduce() {
429 let items: Vec<i32> = (1..=10).collect();
430 let sum: i32 = fold_reduce(items.into_par_iter(), |acc, item| acc + item, |a, b| a + b);
431 assert_eq!(sum, 55);
432 }
433
434 #[test]
435 fn test_fold_reduce_with_custom_init() {
436 let items: Vec<i32> = (1..=10).collect();
437 let sum: i32 = fold_reduce_with(
438 items.into_par_iter(),
439 || 100, |acc, item| acc + item,
441 |a, b| a + b - 100, );
443 assert!(sum >= 55);
446 }
447}