1use rayon::iter::{
29 IntoParallelIterator, ParallelIterator,
30};
31
32#[derive(Debug, Clone, Default)]
45pub struct ParConfig {
46 num_threads: Option<usize>,
47 chunk_size: Option<usize>,
48}
49
50impl ParConfig {
51 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn threads(mut self, n: usize) -> Self {
60 self.num_threads = Some(n);
61 self
62 }
63
64 pub fn chunk_size(mut self, n: usize) -> Self {
68 self.chunk_size = Some(n);
69 self
70 }
71
72 fn run<F, R>(&self, f: F) -> R
77 where
78 F: FnOnce() -> R + Send,
79 R: Send,
80 {
81 if let Some(threads) = self.num_threads {
82 let pool = rayon::ThreadPoolBuilder::new()
83 .num_threads(threads)
84 .build()
85 .expect("failed to build rayon thread pool");
86 pool.install(f)
87 } else {
88 f()
89 }
90 }
91}
92
93pub trait ParIter: IntoIterator {
99 fn par_map<F, R>(self, f: F) -> Vec<R>
110 where
111 F: Fn(Self::Item) -> R + Sync + Send,
112 R: Send;
113
114 fn par_filter<F>(self, f: F) -> Vec<Self::Item>
125 where
126 F: Fn(&Self::Item) -> bool + Sync + Send,
127 Self::Item: Send;
128
129 fn par_for_each<F>(self, f: F)
142 where
143 F: Fn(Self::Item) + Sync + Send;
144
145 fn par_map_results<F, T, E>(self, f: F) -> Result<Vec<T>, Vec<E>>
160 where
161 F: Fn(Self::Item) -> Result<T, E> + Sync + Send,
162 T: Send,
163 E: Send;
164
165 fn par_flat_map<F, I, R>(self, f: F) -> Vec<R>
176 where
177 F: Fn(Self::Item) -> I + Sync + Send,
178 I: IntoIterator<Item = R>,
179 R: Send;
180
181 fn par_any<F>(self, f: F) -> bool
192 where
193 F: Fn(&Self::Item) -> bool + Sync + Send;
194
195 fn par_all<F>(self, f: F) -> bool
206 where
207 F: Fn(&Self::Item) -> bool + Sync + Send;
208
209 fn par_count<F>(self, f: F) -> usize
219 where
220 F: Fn(&Self::Item) -> bool + Sync + Send;
221}
222
223impl<I> ParIter for I
224where
225 I: IntoIterator,
226 I::Item: Send,
227{
228 fn par_map<F, R>(self, f: F) -> Vec<R>
229 where
230 F: Fn(Self::Item) -> R + Sync + Send,
231 R: Send,
232 {
233 let items: Vec<_> = self.into_iter().collect();
234 items.into_par_iter().map(f).collect()
235 }
236
237 fn par_filter<F>(self, f: F) -> Vec<Self::Item>
238 where
239 F: Fn(&Self::Item) -> bool + Sync + Send,
240 Self::Item: Send,
241 {
242 let items: Vec<_> = self.into_iter().collect();
243 items.into_par_iter().filter(|x| f(x)).collect()
244 }
245
246 fn par_for_each<F>(self, f: F)
247 where
248 F: Fn(Self::Item) + Sync + Send,
249 {
250 let items: Vec<_> = self.into_iter().collect();
251 items.into_par_iter().for_each(f);
252 }
253
254 fn par_map_results<F, T, E>(self, f: F) -> Result<Vec<T>, Vec<E>>
255 where
256 F: Fn(Self::Item) -> Result<T, E> + Sync + Send,
257 T: Send,
258 E: Send,
259 {
260 let items: Vec<_> = self.into_iter().collect();
261 let results: Vec<Result<T, E>> = items.into_par_iter().map(f).collect();
262
263 let mut oks = Vec::new();
264 let mut errs = Vec::new();
265
266 for result in results {
267 match result {
268 Ok(v) => oks.push(v),
269 Err(e) => errs.push(e),
270 }
271 }
272
273 if errs.is_empty() {
274 Ok(oks)
275 } else {
276 Err(errs)
277 }
278 }
279
280 fn par_flat_map<F, II, R>(self, f: F) -> Vec<R>
281 where
282 F: Fn(Self::Item) -> II + Sync + Send,
283 II: IntoIterator<Item = R>,
284 R: Send,
285 {
286 let items: Vec<_> = self.into_iter().collect();
287 items.into_par_iter().flat_map_iter(f).collect()
288 }
289
290 fn par_any<F>(self, f: F) -> bool
291 where
292 F: Fn(&Self::Item) -> bool + Sync + Send,
293 {
294 let items: Vec<_> = self.into_iter().collect();
295 items.into_par_iter().any(|x| f(&x))
296 }
297
298 fn par_all<F>(self, f: F) -> bool
299 where
300 F: Fn(&Self::Item) -> bool + Sync + Send,
301 {
302 let items: Vec<_> = self.into_iter().collect();
303 items.into_par_iter().all(|x| f(&x))
304 }
305
306 fn par_count<F>(self, f: F) -> usize
307 where
308 F: Fn(&Self::Item) -> bool + Sync + Send,
309 {
310 let items: Vec<_> = self.into_iter().collect();
311 items.into_par_iter().filter(|x| f(x)).count()
312 }
313}
314
315pub trait ParIterWith: IntoIterator {
320 fn par_map_with<F, R>(self, config: &ParConfig, f: F) -> Vec<R>
322 where
323 F: Fn(Self::Item) -> R + Sync + Send,
324 R: Send;
325
326 fn par_filter_with<F>(self, config: &ParConfig, f: F) -> Vec<Self::Item>
328 where
329 F: Fn(&Self::Item) -> bool + Sync + Send,
330 Self::Item: Send;
331
332 fn par_for_each_with<F>(self, config: &ParConfig, f: F)
334 where
335 F: Fn(Self::Item) + Sync + Send;
336
337 fn par_map_results_with<F, T, E>(self, config: &ParConfig, f: F) -> Result<Vec<T>, Vec<E>>
339 where
340 F: Fn(Self::Item) -> Result<T, E> + Sync + Send,
341 T: Send,
342 E: Send;
343
344 fn par_flat_map_with<F, I, R>(self, config: &ParConfig, f: F) -> Vec<R>
346 where
347 F: Fn(Self::Item) -> I + Sync + Send,
348 I: IntoIterator<Item = R>,
349 R: Send;
350
351 fn par_any_with<F>(self, config: &ParConfig, f: F) -> bool
353 where
354 F: Fn(&Self::Item) -> bool + Sync + Send;
355
356 fn par_all_with<F>(self, config: &ParConfig, f: F) -> bool
358 where
359 F: Fn(&Self::Item) -> bool + Sync + Send;
360
361 fn par_count_with<F>(self, config: &ParConfig, f: F) -> usize
363 where
364 F: Fn(&Self::Item) -> bool + Sync + Send;
365}
366
367impl<I> ParIterWith for I
368where
369 I: IntoIterator,
370 I::Item: Send,
371{
372 fn par_map_with<F, R>(self, config: &ParConfig, f: F) -> Vec<R>
373 where
374 F: Fn(Self::Item) -> R + Sync + Send,
375 R: Send,
376 {
377 let items: Vec<_> = self.into_iter().collect();
378 config.run(|| items.into_par_iter().map(f).collect())
379 }
380
381 fn par_filter_with<F>(self, config: &ParConfig, f: F) -> Vec<Self::Item>
382 where
383 F: Fn(&Self::Item) -> bool + Sync + Send,
384 Self::Item: Send,
385 {
386 let items: Vec<_> = self.into_iter().collect();
387 config.run(|| items.into_par_iter().filter(|x| f(x)).collect())
388 }
389
390 fn par_for_each_with<F>(self, config: &ParConfig, f: F)
391 where
392 F: Fn(Self::Item) + Sync + Send,
393 {
394 let items: Vec<_> = self.into_iter().collect();
395 config.run(|| items.into_par_iter().for_each(f));
396 }
397
398 fn par_map_results_with<F, T, E>(self, config: &ParConfig, f: F) -> Result<Vec<T>, Vec<E>>
399 where
400 F: Fn(Self::Item) -> Result<T, E> + Sync + Send,
401 T: Send,
402 E: Send,
403 {
404 let items: Vec<_> = self.into_iter().collect();
405 let results: Vec<Result<T, E>> =
406 config.run(|| items.into_par_iter().map(f).collect());
407
408 let mut oks = Vec::new();
409 let mut errs = Vec::new();
410
411 for result in results {
412 match result {
413 Ok(v) => oks.push(v),
414 Err(e) => errs.push(e),
415 }
416 }
417
418 if errs.is_empty() {
419 Ok(oks)
420 } else {
421 Err(errs)
422 }
423 }
424
425 fn par_flat_map_with<F, II, R>(self, config: &ParConfig, f: F) -> Vec<R>
426 where
427 F: Fn(Self::Item) -> II + Sync + Send,
428 II: IntoIterator<Item = R>,
429 R: Send,
430 {
431 let items: Vec<_> = self.into_iter().collect();
432 config.run(|| items.into_par_iter().flat_map_iter(f).collect())
433 }
434
435 fn par_any_with<F>(self, config: &ParConfig, f: F) -> bool
436 where
437 F: Fn(&Self::Item) -> bool + Sync + Send,
438 {
439 let items: Vec<_> = self.into_iter().collect();
440 config.run(|| items.into_par_iter().any(|x| f(&x)))
441 }
442
443 fn par_all_with<F>(self, config: &ParConfig, f: F) -> bool
444 where
445 F: Fn(&Self::Item) -> bool + Sync + Send,
446 {
447 let items: Vec<_> = self.into_iter().collect();
448 config.run(|| items.into_par_iter().all(|x| f(&x)))
449 }
450
451 fn par_count_with<F>(self, config: &ParConfig, f: F) -> usize
452 where
453 F: Fn(&Self::Item) -> bool + Sync + Send,
454 {
455 let items: Vec<_> = self.into_iter().collect();
456 config.run(|| items.into_par_iter().filter(|x| f(x)).count())
457 }
458}
459
460pub fn par_map<T, F, R>(items: impl IntoIterator<Item = T>, f: F) -> Vec<R>
475where
476 T: Send,
477 F: Fn(T) -> R + Sync + Send,
478 R: Send,
479{
480 let items: Vec<T> = items.into_iter().collect();
481 items.into_par_iter().map(f).collect()
482}
483
484pub fn par_filter<T, F>(items: impl IntoIterator<Item = T>, f: F) -> Vec<T>
495where
496 T: Send,
497 F: Fn(&T) -> bool + Sync + Send,
498{
499 let items: Vec<T> = items.into_iter().collect();
500 items.into_par_iter().filter(|x| f(x)).collect()
501}
502
503pub fn par_for_each<T, F>(items: impl IntoIterator<Item = T>, f: F)
519where
520 T: Send,
521 F: Fn(T) + Sync + Send,
522{
523 let items: Vec<T> = items.into_iter().collect();
524 items.into_par_iter().for_each(f);
525}
526
527pub fn par_map_results<T, F, R, E>(
539 items: impl IntoIterator<Item = T>,
540 f: F,
541) -> Result<Vec<R>, Vec<E>>
542where
543 T: Send,
544 F: Fn(T) -> Result<R, E> + Sync + Send,
545 R: Send,
546 E: Send,
547{
548 let items: Vec<T> = items.into_iter().collect();
549 let results: Vec<Result<R, E>> = items.into_par_iter().map(f).collect();
550
551 let mut oks = Vec::new();
552 let mut errs = Vec::new();
553
554 for result in results {
555 match result {
556 Ok(v) => oks.push(v),
557 Err(e) => errs.push(e),
558 }
559 }
560
561 if errs.is_empty() {
562 Ok(oks)
563 } else {
564 Err(errs)
565 }
566}
567
568pub fn par_chunks<T, F, R>(
582 items: impl IntoIterator<Item = T>,
583 chunk_size: usize,
584 f: F,
585) -> Vec<R>
586where
587 T: Send,
588 F: Fn(Vec<T>) -> R + Sync + Send,
589 R: Send,
590{
591 let mut items: Vec<T> = items.into_iter().collect();
592 let mut chunks: Vec<Vec<T>> = Vec::new();
593 while !items.is_empty() {
594 let at = chunk_size.min(items.len());
595 let rest = items.split_off(at);
596 chunks.push(items);
597 items = rest;
598 }
599 chunks.into_par_iter().map(f).collect()
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605 use std::sync::atomic::{AtomicUsize, Ordering};
606
607 #[test]
608 fn test_par_map_squares() {
609 let input: Vec<i32> = (0..100).collect();
610 let result: Vec<i32> = input.par_map(|x| x * x);
611 let expected: Vec<i32> = (0..100).map(|x| x * x).collect();
612 assert_eq!(result, expected);
613 }
614
615 #[test]
616 fn test_par_filter() {
617 let evens: Vec<i32> = (0..20).par_filter(|x| x % 2 == 0);
618 let expected: Vec<i32> = (0..20).step_by(2).collect();
619 assert_eq!(evens, expected);
620 }
621
622 #[test]
623 fn test_par_for_each() {
624 let counter = AtomicUsize::new(0);
625 let items: Vec<usize> = (0..50).collect();
626 items.par_for_each(|_| {
627 counter.fetch_add(1, Ordering::Relaxed);
628 });
629 assert_eq!(counter.load(Ordering::Relaxed), 50);
630 }
631
632 #[test]
633 fn test_par_map_results_all_ok() {
634 let result: Result<Vec<i32>, Vec<String>> =
635 vec![1, 2, 3].par_map_results(|x| Ok(x * 2));
636 assert_eq!(result, Ok(vec![2, 4, 6]));
637 }
638
639 #[test]
640 fn test_par_map_results_some_err() {
641 let result: Result<Vec<i32>, Vec<String>> =
642 vec![1, -1, 2, -2].par_map_results(|x| {
643 if x > 0 {
644 Ok(x)
645 } else {
646 Err(format!("negative: {}", x))
647 }
648 });
649 assert!(result.is_err());
650 let errs = result.unwrap_err();
651 assert_eq!(errs.len(), 2);
652 }
653
654 #[test]
655 fn test_par_flat_map() {
656 let result: Vec<i32> = vec![1, 2, 3].par_flat_map(|x| vec![x, x * 10]);
657 assert_eq!(result, vec![1, 10, 2, 20, 3, 30]);
658 }
659
660 #[test]
661 fn test_par_any() {
662 assert!(vec![1, 2, 3, 4, 5].par_any(|x| *x == 3));
663 assert!(!vec![1, 2, 3, 4, 5].par_any(|x| *x == 10));
664 }
665
666 #[test]
667 fn test_par_all() {
668 assert!(vec![2, 4, 6, 8].par_all(|x| *x % 2 == 0));
669 assert!(!vec![2, 4, 5, 8].par_all(|x| *x % 2 == 0));
670 }
671
672 #[test]
673 fn test_par_count() {
674 let count = vec![1, 2, 3, 4, 5, 6].par_count(|x| *x % 2 == 0);
675 assert_eq!(count, 3);
676 }
677
678 #[test]
679 fn test_par_chunks() {
680 let sums = par_chunks(vec![1, 2, 3, 4, 5, 6], 2, |chunk| {
681 chunk.into_iter().sum::<i32>()
682 });
683 assert_eq!(sums, vec![3, 7, 11]);
684 }
685
686 #[test]
687 fn test_empty_collection() {
688 let result: Vec<i32> = Vec::<i32>::new().par_map(|x| x * x);
689 assert!(result.is_empty());
690
691 let filtered: Vec<i32> = Vec::<i32>::new().par_filter(|_| true);
692 assert!(filtered.is_empty());
693
694 assert!(!Vec::<i32>::new().par_any(|_| true));
695 assert!(Vec::<i32>::new().par_all(|_| false));
696 assert_eq!(Vec::<i32>::new().par_count(|_| true), 0);
697 }
698
699 #[test]
700 fn test_single_element() {
701 let result: Vec<i32> = vec![42].par_map(|x| x * 2);
702 assert_eq!(result, vec![84]);
703
704 assert!(vec![42].par_any(|x| *x == 42));
705 assert!(vec![42].par_all(|x| *x == 42));
706 assert_eq!(vec![42].par_count(|x| *x == 42), 1);
707 }
708
709 #[test]
710 fn test_ordering_preserved_par_map() {
711 let input: Vec<i32> = (0..1000).collect();
712 let result: Vec<i32> = input.par_map(|x| x * 2);
713 let expected: Vec<i32> = (0..1000).map(|x| x * 2).collect();
714 assert_eq!(result, expected);
715 }
716
717 #[test]
718 fn test_par_config_with_threads() {
719 let config = ParConfig::new().threads(2);
720 let result: Vec<i32> = (0..100).par_map_with(&config, |x| x * x);
721 let expected: Vec<i32> = (0..100).map(|x| x * x).collect();
722 assert_eq!(result, expected);
723 }
724
725 #[test]
726 fn test_par_config_filter_with() {
727 let config = ParConfig::new().threads(2);
728 let result: Vec<i32> = (0..20).par_filter_with(&config, |x| x % 2 == 0);
729 let expected: Vec<i32> = (0..20).step_by(2).collect();
730 assert_eq!(result, expected);
731 }
732
733 #[test]
734 fn test_standalone_par_map() {
735 let result = par_map(0..5, |x| x * x);
736 assert_eq!(result, vec![0, 1, 4, 9, 16]);
737 }
738
739 #[test]
740 fn test_standalone_par_filter() {
741 let result = par_filter(0..10, |x| x % 2 == 0);
742 assert_eq!(result, vec![0, 2, 4, 6, 8]);
743 }
744
745 #[test]
746 fn test_standalone_par_for_each() {
747 let counter = AtomicUsize::new(0);
748 par_for_each(0..10usize, |_| {
749 counter.fetch_add(1, Ordering::Relaxed);
750 });
751 assert_eq!(counter.load(Ordering::Relaxed), 10);
752 }
753
754 #[test]
755 fn test_standalone_par_map_results() {
756 let result: Result<Vec<i32>, Vec<&str>> =
757 par_map_results(vec![1, 2, 3], |x| Ok(x * 2));
758 assert_eq!(result, Ok(vec![2, 4, 6]));
759 }
760
761 #[test]
762 fn test_par_chunks_single_chunk() {
763 let result = par_chunks(vec![1, 2, 3], 10, |chunk| chunk.len());
764 assert_eq!(result, vec![3]);
765 }
766
767 #[test]
768 fn test_par_chunks_empty() {
769 let result = par_chunks(Vec::<i32>::new(), 5, |chunk| chunk.len());
770 assert!(result.is_empty());
771 }
772}