1use core::sync::atomic::{AtomicBool, Ordering};
9
10#[cfg(feature = "parallel")]
11use rayon::prelude::*;
12
13static PARALLELISM_DISABLED: AtomicBool = AtomicBool::new(false);
15
16pub fn disable_global_parallelism() {
21 PARALLELISM_DISABLED.store(true, Ordering::SeqCst);
22}
23
24pub fn enable_global_parallelism() {
26 PARALLELISM_DISABLED.store(false, Ordering::SeqCst);
27}
28
29pub fn is_parallelism_enabled() -> bool {
31 !PARALLELISM_DISABLED.load(Ordering::SeqCst)
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum Par {
37 Seq,
39 #[cfg(feature = "parallel")]
41 Rayon,
42 #[cfg(feature = "parallel")]
44 RayonWith(usize),
45}
46
47#[allow(clippy::derivable_impls)]
50impl Default for Par {
51 fn default() -> Self {
52 #[cfg(feature = "parallel")]
53 {
54 Par::Rayon
55 }
56 #[cfg(not(feature = "parallel"))]
57 {
58 Par::Seq
59 }
60 }
61}
62
63impl Par {
64 #[inline]
66 pub fn is_sequential(&self) -> bool {
67 match self {
68 Par::Seq => true,
69 #[cfg(feature = "parallel")]
70 _ => !is_parallelism_enabled(),
71 }
72 }
73
74 #[cfg(feature = "parallel")]
76 pub fn num_threads(&self) -> usize {
77 if !is_parallelism_enabled() {
78 return 1;
79 }
80
81 match self {
82 Par::Seq => 1,
83 Par::Rayon => rayon::current_num_threads(),
84 Par::RayonWith(n) => *n,
85 }
86 }
87
88 #[cfg(not(feature = "parallel"))]
90 pub fn num_threads(&self) -> usize {
91 1
92 }
93}
94
95#[derive(Debug, Clone, Copy)]
97pub struct ParThreshold {
98 pub min_elements: usize,
100 pub min_work_per_thread: usize,
102}
103
104impl Default for ParThreshold {
105 fn default() -> Self {
106 ParThreshold {
107 min_elements: 4096,
108 min_work_per_thread: 256,
109 }
110 }
111}
112
113impl ParThreshold {
114 pub const fn new(min_elements: usize, min_work_per_thread: usize) -> Self {
116 ParThreshold {
117 min_elements,
118 min_work_per_thread,
119 }
120 }
121
122 #[inline]
124 pub fn should_parallelize(&self, total_work: usize, par: Par) -> bool {
125 if par.is_sequential() {
126 return false;
127 }
128
129 if total_work < self.min_elements {
130 return false;
131 }
132
133 let threads = par.num_threads();
134 if threads <= 1 {
135 return false;
136 }
137
138 total_work / threads >= self.min_work_per_thread
139 }
140}
141
142#[derive(Debug, Clone, Copy)]
144pub struct WorkRange {
145 pub start: usize,
147 pub end: usize,
149}
150
151impl WorkRange {
152 #[inline]
154 pub const fn new(start: usize, end: usize) -> Self {
155 WorkRange { start, end }
156 }
157
158 #[inline]
160 pub const fn len(&self) -> usize {
161 self.end - self.start
162 }
163
164 #[inline]
166 pub const fn is_empty(&self) -> bool {
167 self.start >= self.end
168 }
169}
170
171pub fn partition_work(total: usize, num_threads: usize) -> Vec<WorkRange> {
173 if num_threads == 0 || total == 0 {
174 return vec![];
175 }
176
177 if num_threads == 1 {
178 return vec![WorkRange::new(0, total)];
179 }
180
181 let chunk_size = total.div_ceil(num_threads);
182 let mut ranges = Vec::with_capacity(num_threads);
183
184 let mut start = 0;
185 while start < total {
186 let end = (start + chunk_size).min(total);
187 ranges.push(WorkRange::new(start, end));
188 start = end;
189 }
190
191 ranges
192}
193
194#[inline]
198pub fn for_each_range<F>(total: usize, par: Par, threshold: &ParThreshold, f: F)
199where
200 F: Fn(WorkRange) + Send + Sync,
201{
202 if !threshold.should_parallelize(total, par) {
203 f(WorkRange::new(0, total));
204 return;
205 }
206
207 #[cfg(feature = "parallel")]
208 {
209 let ranges = partition_work(total, par.num_threads());
210 ranges.into_par_iter().for_each(|range| {
211 f(range);
212 });
213 }
214
215 #[cfg(not(feature = "parallel"))]
216 {
217 f(WorkRange::new(0, total));
218 }
219}
220
221#[allow(unused_variables)]
225pub fn map_reduce<T, Map, Reduce>(
226 total: usize,
227 par: Par,
228 threshold: &ParThreshold,
229 identity: T,
230 map: Map,
231 reduce: Reduce,
232) -> T
233where
234 T: Clone + Send + Sync,
235 Map: Fn(WorkRange) -> T + Send + Sync,
236 Reduce: Fn(T, T) -> T + Send + Sync,
237{
238 if !threshold.should_parallelize(total, par) {
239 return map(WorkRange::new(0, total));
240 }
241
242 #[cfg(feature = "parallel")]
243 {
244 let ranges = partition_work(total, par.num_threads());
245 ranges
246 .into_par_iter()
247 .map(map)
248 .reduce(|| identity.clone(), reduce)
249 }
250
251 #[cfg(not(feature = "parallel"))]
252 {
253 map(WorkRange::new(0, total))
254 }
255}
256
257pub fn for_each_indexed<F>(total: usize, par: Par, threshold: &ParThreshold, f: F)
259where
260 F: Fn(usize) + Send + Sync,
261{
262 if !threshold.should_parallelize(total, par) {
263 for i in 0..total {
264 f(i);
265 }
266 return;
267 }
268
269 #[cfg(feature = "parallel")]
270 {
271 (0..total).into_par_iter().for_each(f);
272 }
273
274 #[cfg(not(feature = "parallel"))]
275 {
276 for i in 0..total {
277 f(i);
278 }
279 }
280}
281
282pub trait ThreadPool: Send + Sync {
290 fn num_threads(&self) -> usize;
292
293 fn execute<F>(&self, f: F)
295 where
296 F: FnOnce() + Send + 'static;
297
298 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
300 where
301 A: FnOnce() -> RA + Send,
302 B: FnOnce() -> RB + Send,
303 RA: Send,
304 RB: Send;
305
306 fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
308 where
309 F: Fn(usize) + Send + Sync;
310
311 fn map_reduce<T, Map, Reduce>(
313 &self,
314 range: core::ops::Range<usize>,
315 identity: T,
316 map: Map,
317 reduce: Reduce,
318 ) -> T
319 where
320 T: Clone + Send + Sync,
321 Map: Fn(usize) -> T + Send + Sync,
322 Reduce: Fn(T, T) -> T + Send + Sync;
323}
324
325#[derive(Debug, Clone, Copy, Default)]
327pub struct SequentialPool;
328
329impl ThreadPool for SequentialPool {
330 #[inline]
331 fn num_threads(&self) -> usize {
332 1
333 }
334
335 #[inline]
336 fn execute<F>(&self, f: F)
337 where
338 F: FnOnce() + Send + 'static,
339 {
340 f();
341 }
342
343 #[inline]
344 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
345 where
346 A: FnOnce() -> RA + Send,
347 B: FnOnce() -> RB + Send,
348 RA: Send,
349 RB: Send,
350 {
351 (a(), b())
352 }
353
354 fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
355 where
356 F: Fn(usize) + Send + Sync,
357 {
358 for i in range {
359 f(i);
360 }
361 }
362
363 fn map_reduce<T, Map, Reduce>(
364 &self,
365 range: core::ops::Range<usize>,
366 identity: T,
367 map: Map,
368 reduce: Reduce,
369 ) -> T
370 where
371 T: Clone + Send + Sync,
372 Map: Fn(usize) -> T + Send + Sync,
373 Reduce: Fn(T, T) -> T + Send + Sync,
374 {
375 let mut acc = identity;
376 for i in range {
377 acc = reduce(acc, map(i));
378 }
379 acc
380 }
381}
382
383#[cfg(feature = "parallel")]
385#[derive(Debug, Clone, Copy, Default)]
386pub struct RayonGlobalPool;
387
388#[cfg(feature = "parallel")]
389impl ThreadPool for RayonGlobalPool {
390 #[inline]
391 fn num_threads(&self) -> usize {
392 rayon::current_num_threads()
393 }
394
395 #[inline]
396 fn execute<F>(&self, f: F)
397 where
398 F: FnOnce() + Send + 'static,
399 {
400 rayon::spawn(f);
401 }
402
403 #[inline]
404 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
405 where
406 A: FnOnce() -> RA + Send,
407 B: FnOnce() -> RB + Send,
408 RA: Send,
409 RB: Send,
410 {
411 rayon::join(a, b)
412 }
413
414 fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
415 where
416 F: Fn(usize) + Send + Sync,
417 {
418 range.into_par_iter().for_each(f);
419 }
420
421 fn map_reduce<T, Map, Reduce>(
422 &self,
423 range: core::ops::Range<usize>,
424 identity: T,
425 map: Map,
426 reduce: Reduce,
427 ) -> T
428 where
429 T: Clone + Send + Sync,
430 Map: Fn(usize) -> T + Send + Sync,
431 Reduce: Fn(T, T) -> T + Send + Sync,
432 {
433 range
434 .into_par_iter()
435 .map(map)
436 .reduce(|| identity.clone(), reduce)
437 }
438}
439
440#[cfg(feature = "parallel")]
442pub struct CustomRayonPool {
443 pool: rayon::ThreadPool,
444}
445
446#[cfg(feature = "parallel")]
447impl CustomRayonPool {
448 pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
450 let pool = rayon::ThreadPoolBuilder::new()
451 .num_threads(num_threads)
452 .build()?;
453 Ok(CustomRayonPool { pool })
454 }
455
456 pub fn with_builder<F>(configure: F) -> Result<Self, rayon::ThreadPoolBuildError>
458 where
459 F: FnOnce(rayon::ThreadPoolBuilder) -> rayon::ThreadPoolBuilder,
460 {
461 let builder = rayon::ThreadPoolBuilder::new();
462 let pool = configure(builder).build()?;
463 Ok(CustomRayonPool { pool })
464 }
465
466 pub fn install<R, F>(&self, f: F) -> R
468 where
469 F: FnOnce() -> R + Send,
470 R: Send,
471 {
472 self.pool.install(f)
473 }
474}
475
476#[cfg(feature = "parallel")]
477impl ThreadPool for CustomRayonPool {
478 #[inline]
479 fn num_threads(&self) -> usize {
480 self.pool.current_num_threads()
481 }
482
483 #[inline]
484 fn execute<F>(&self, f: F)
485 where
486 F: FnOnce() + Send + 'static,
487 {
488 self.pool.spawn(f);
489 }
490
491 #[inline]
492 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
493 where
494 A: FnOnce() -> RA + Send,
495 B: FnOnce() -> RB + Send,
496 RA: Send,
497 RB: Send,
498 {
499 self.pool.join(a, b)
500 }
501
502 fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
503 where
504 F: Fn(usize) + Send + Sync,
505 {
506 self.pool.install(|| {
507 range.into_par_iter().for_each(f);
508 });
509 }
510
511 fn map_reduce<T, Map, Reduce>(
512 &self,
513 range: core::ops::Range<usize>,
514 identity: T,
515 map: Map,
516 reduce: Reduce,
517 ) -> T
518 where
519 T: Clone + Send + Sync,
520 Map: Fn(usize) -> T + Send + Sync,
521 Reduce: Fn(T, T) -> T + Send + Sync,
522 {
523 self.pool.install(|| {
524 range
525 .into_par_iter()
526 .map(map)
527 .reduce(|| identity.clone(), reduce)
528 })
529 }
530}
531
532pub struct PoolScope<'a, P: ThreadPool> {
536 pool: &'a P,
537 threshold: ParThreshold,
538}
539
540impl<'a, P: ThreadPool> PoolScope<'a, P> {
541 pub fn new(pool: &'a P) -> Self {
543 PoolScope {
544 pool,
545 threshold: ParThreshold::default(),
546 }
547 }
548
549 pub fn with_threshold(pool: &'a P, threshold: ParThreshold) -> Self {
551 PoolScope { pool, threshold }
552 }
553
554 #[inline]
556 pub fn num_threads(&self) -> usize {
557 self.pool.num_threads()
558 }
559
560 #[inline]
562 pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
563 where
564 A: FnOnce() -> RA + Send,
565 B: FnOnce() -> RB + Send,
566 RA: Send,
567 RB: Send,
568 {
569 self.pool.join(a, b)
570 }
571
572 pub fn for_each<F>(&self, total: usize, f: F)
574 where
575 F: Fn(usize) + Send + Sync,
576 {
577 if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
578 for i in 0..total {
579 f(i);
580 }
581 } else {
582 self.pool.for_each(0..total, f);
583 }
584 }
585
586 pub fn for_each_range<F>(&self, total: usize, f: F)
588 where
589 F: Fn(WorkRange) + Send + Sync,
590 {
591 if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
592 f(WorkRange::new(0, total));
593 } else {
594 let ranges = partition_work(total, self.pool.num_threads());
595 for range in ranges {
596 f(range);
597 }
598 }
599 }
600
601 pub fn map_reduce<T, Map, Reduce>(
603 &self,
604 total: usize,
605 identity: T,
606 map: Map,
607 reduce: Reduce,
608 ) -> T
609 where
610 T: Clone + Send + Sync,
611 Map: Fn(usize) -> T + Send + Sync,
612 Reduce: Fn(T, T) -> T + Send + Sync,
613 {
614 if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
615 let mut acc = identity;
616 for i in 0..total {
617 acc = reduce(acc, map(i));
618 }
619 acc
620 } else {
621 self.pool.map_reduce(0..total, identity, map, reduce)
622 }
623 }
624}
625
626#[cfg(feature = "parallel")]
628pub fn default_pool() -> RayonGlobalPool {
629 RayonGlobalPool
630}
631
632#[cfg(not(feature = "parallel"))]
634pub fn default_pool() -> SequentialPool {
635 SequentialPool
636}
637
638#[cfg(feature = "parallel")]
642pub fn with_default_pool<R, F>(f: F) -> R
643where
644 F: FnOnce(PoolScope<'_, RayonGlobalPool>) -> R,
645{
646 let pool = RayonGlobalPool;
647 f(PoolScope::new(&pool))
648}
649
650#[cfg(not(feature = "parallel"))]
652pub fn with_default_pool<R, F>(f: F) -> R
653where
654 F: FnOnce(PoolScope<'_, SequentialPool>) -> R,
655{
656 let pool = SequentialPool;
657 f(PoolScope::new(&pool))
658}
659
660#[cfg(feature = "parallel")]
669pub struct ThreadLocalAccum<T> {
670 values: Vec<std::sync::Mutex<T>>,
671}
672
673#[cfg(feature = "parallel")]
674impl<T: Clone + Send> ThreadLocalAccum<T> {
675 pub fn new(identity: T) -> Self {
677 let num_threads = rayon::current_num_threads();
678 let values = (0..num_threads)
679 .map(|_| std::sync::Mutex::new(identity.clone()))
680 .collect();
681 ThreadLocalAccum { values }
682 }
683
684 pub fn get(&self) -> std::sync::MutexGuard<'_, T> {
686 let thread_idx = rayon::current_thread_index().unwrap_or(0) % self.values.len();
687 self.values[thread_idx].lock().unwrap()
688 }
689
690 pub fn reduce<F>(self, f: F) -> T
692 where
693 F: Fn(T, T) -> T,
694 {
695 self.values
696 .into_iter()
697 .map(|m| m.into_inner().unwrap())
698 .reduce(f)
699 .unwrap()
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706
707 #[test]
708 fn test_partition_work() {
709 let ranges = partition_work(100, 4);
710 assert_eq!(ranges.len(), 4);
711
712 let mut covered = [false; 100];
714 for range in &ranges {
715 for (offset, covered_elem) in covered[range.start..range.end].iter_mut().enumerate() {
716 let i = range.start + offset;
717 assert!(!*covered_elem, "Overlap at {}", i);
718 *covered_elem = true;
719 }
720 }
721 assert!(covered.iter().all(|&x| x), "Not all elements covered");
722 }
723
724 #[test]
725 fn test_partition_work_uneven() {
726 let ranges = partition_work(10, 4);
727
728 let total: usize = ranges.iter().map(|r| r.len()).sum();
730 assert_eq!(total, 10);
731 }
732
733 #[test]
734 fn test_partition_work_single() {
735 let ranges = partition_work(100, 1);
736 assert_eq!(ranges.len(), 1);
737 assert_eq!(ranges[0].start, 0);
738 assert_eq!(ranges[0].end, 100);
739 }
740
741 #[test]
742 fn test_threshold() {
743 let threshold = ParThreshold::new(100, 10);
744
745 assert!(!threshold.should_parallelize(50, Par::Seq));
746 assert!(!threshold.should_parallelize(50, Par::default()));
747
748 #[cfg(feature = "parallel")]
749 {
750 assert!(threshold.should_parallelize(1000, Par::Rayon));
752 }
753 }
754
755 #[test]
756 fn test_global_parallelism() {
757 let was_enabled = is_parallelism_enabled();
759
760 disable_global_parallelism();
761 assert!(!is_parallelism_enabled());
762
763 enable_global_parallelism();
764 assert!(is_parallelism_enabled());
765
766 if !was_enabled {
768 disable_global_parallelism();
769 }
770 }
771
772 #[test]
773 fn test_sequential_map_reduce() {
774 let result = map_reduce(
775 100,
776 Par::Seq,
777 &ParThreshold::default(),
778 0usize,
779 |range| range.len(),
780 |a, b| a + b,
781 );
782 assert_eq!(result, 100);
783 }
784
785 #[test]
787 fn test_sequential_pool() {
788 let pool = SequentialPool;
789
790 assert_eq!(pool.num_threads(), 1);
791
792 let (a, b) = pool.join(|| 1 + 1, || 2 + 2);
794 assert_eq!(a, 2);
795 assert_eq!(b, 4);
796
797 let sum = std::sync::atomic::AtomicUsize::new(0);
799 pool.for_each(0..10, |i| {
800 sum.fetch_add(i, std::sync::atomic::Ordering::SeqCst);
801 });
802 assert_eq!(sum.load(std::sync::atomic::Ordering::SeqCst), 45);
803
804 let result = pool.map_reduce(0..10, 0, |i| i, |a, b| a + b);
806 assert_eq!(result, 45);
807 }
808
809 #[test]
810 fn test_pool_scope() {
811 let pool = SequentialPool;
812 let scope = PoolScope::new(&pool);
813
814 assert_eq!(scope.num_threads(), 1);
815
816 let result = scope.map_reduce(100, 0usize, |i| i, |a, b| a + b);
818 assert_eq!(result, (0..100).sum::<usize>());
819
820 let sum = std::sync::atomic::AtomicUsize::new(0);
822 scope.for_each(10, |i| {
823 sum.fetch_add(i, std::sync::atomic::Ordering::SeqCst);
824 });
825 assert_eq!(sum.load(std::sync::atomic::Ordering::SeqCst), 45);
826 }
827
828 #[test]
829 fn test_pool_scope_with_threshold() {
830 let pool = SequentialPool;
831 let threshold = ParThreshold::new(50, 10);
832 let scope = PoolScope::with_threshold(&pool, threshold);
833
834 let result = scope.map_reduce(100, 0usize, |i| i, |a, b| a + b);
836 assert_eq!(result, (0..100).sum::<usize>());
837 }
838
839 #[test]
840 fn test_default_pool() {
841 let pool = default_pool();
842 assert!(pool.num_threads() >= 1);
844 }
845
846 #[test]
847 fn test_with_default_pool() {
848 let result = with_default_pool(|scope| scope.num_threads());
849 assert!(result >= 1);
850 }
851
852 #[cfg(feature = "parallel")]
853 #[test]
854 fn test_rayon_global_pool() {
855 let pool = RayonGlobalPool;
856
857 assert!(pool.num_threads() >= 1);
859
860 let (a, b) = pool.join(|| 1 + 1, || 2 + 2);
862 assert_eq!(a, 2);
863 assert_eq!(b, 4);
864
865 let result = pool.map_reduce(0..100, 0, |i| i, |a, b| a + b);
867 assert_eq!(result, (0..100).sum::<usize>());
868 }
869
870 #[cfg(feature = "parallel")]
871 #[test]
872 fn test_custom_rayon_pool() {
873 let pool = CustomRayonPool::new(2).expect("Failed to create pool");
874
875 assert_eq!(pool.num_threads(), 2);
876
877 let result = pool.map_reduce(0..100, 0, |i| i, |a, b| a + b);
879 assert_eq!(result, (0..100).sum::<usize>());
880
881 let result = pool.install(|| (0..100).into_par_iter().sum::<usize>());
883 assert_eq!(result, (0..100).sum());
884 }
885}