oxiblas_core/
parallel.rs

1//! Parallelization primitives for OxiBLAS.
2//!
3//! This module provides:
4//! - Parallel execution modes
5//! - Work partitioning utilities
6//! - Thread-local accumulation patterns
7
8use core::sync::atomic::{AtomicBool, Ordering};
9
10#[cfg(feature = "parallel")]
11use rayon::prelude::*;
12
13/// Global flag to disable parallelism.
14static PARALLELISM_DISABLED: AtomicBool = AtomicBool::new(false);
15
16/// Disables global parallelism.
17///
18/// This can be useful for debugging or when running in environments
19/// where threading is problematic.
20pub fn disable_global_parallelism() {
21    PARALLELISM_DISABLED.store(true, Ordering::SeqCst);
22}
23
24/// Enables global parallelism.
25pub fn enable_global_parallelism() {
26    PARALLELISM_DISABLED.store(false, Ordering::SeqCst);
27}
28
29/// Returns true if parallelism is enabled.
30pub fn is_parallelism_enabled() -> bool {
31    !PARALLELISM_DISABLED.load(Ordering::SeqCst)
32}
33
34/// Parallelization mode.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum Par {
37    /// Sequential execution.
38    Seq,
39    /// Parallel execution with the default thread pool.
40    #[cfg(feature = "parallel")]
41    Rayon,
42    /// Parallel execution with a specific number of threads.
43    #[cfg(feature = "parallel")]
44    RayonWith(usize),
45}
46
47// Manual impl because the default variant depends on feature flags
48// (Rayon when "parallel" is enabled, Seq otherwise)
49#[allow(clippy::derivable_impls)]
50impl Default for Par {
51    fn default() -> Self {
52        #[cfg(feature = "parallel")]
53        {
54            Par::Rayon
55        }
56        #[cfg(not(feature = "parallel"))]
57        {
58            Par::Seq
59        }
60    }
61}
62
63impl Par {
64    /// Returns true if this mode is sequential.
65    #[inline]
66    pub fn is_sequential(&self) -> bool {
67        match self {
68            Par::Seq => true,
69            #[cfg(feature = "parallel")]
70            _ => !is_parallelism_enabled(),
71        }
72    }
73
74    /// Returns the number of threads to use.
75    #[cfg(feature = "parallel")]
76    pub fn num_threads(&self) -> usize {
77        if !is_parallelism_enabled() {
78            return 1;
79        }
80
81        match self {
82            Par::Seq => 1,
83            Par::Rayon => rayon::current_num_threads(),
84            Par::RayonWith(n) => *n,
85        }
86    }
87
88    /// Returns the number of threads to use (always 1 without parallel feature).
89    #[cfg(not(feature = "parallel"))]
90    pub fn num_threads(&self) -> usize {
91        1
92    }
93}
94
95/// Threshold configuration for parallel operations.
96#[derive(Debug, Clone, Copy)]
97pub struct ParThreshold {
98    /// Minimum number of elements for parallelization.
99    pub min_elements: usize,
100    /// Minimum work per thread (elements).
101    pub min_work_per_thread: usize,
102}
103
104impl Default for ParThreshold {
105    fn default() -> Self {
106        ParThreshold {
107            min_elements: 4096,
108            min_work_per_thread: 256,
109        }
110    }
111}
112
113impl ParThreshold {
114    /// Creates a new threshold configuration.
115    pub const fn new(min_elements: usize, min_work_per_thread: usize) -> Self {
116        ParThreshold {
117            min_elements,
118            min_work_per_thread,
119        }
120    }
121
122    /// Returns true if parallelization should be used for the given work size.
123    #[inline]
124    pub fn should_parallelize(&self, total_work: usize, par: Par) -> bool {
125        if par.is_sequential() {
126            return false;
127        }
128
129        if total_work < self.min_elements {
130            return false;
131        }
132
133        let threads = par.num_threads();
134        if threads <= 1 {
135            return false;
136        }
137
138        total_work / threads >= self.min_work_per_thread
139    }
140}
141
142/// Work range for parallel iteration.
143#[derive(Debug, Clone, Copy)]
144pub struct WorkRange {
145    /// Start index (inclusive).
146    pub start: usize,
147    /// End index (exclusive).
148    pub end: usize,
149}
150
151impl WorkRange {
152    /// Creates a new work range.
153    #[inline]
154    pub const fn new(start: usize, end: usize) -> Self {
155        WorkRange { start, end }
156    }
157
158    /// Returns the length of the range.
159    #[inline]
160    pub const fn len(&self) -> usize {
161        self.end - self.start
162    }
163
164    /// Returns true if the range is empty.
165    #[inline]
166    pub const fn is_empty(&self) -> bool {
167        self.start >= self.end
168    }
169}
170
171/// Partitions work into chunks for parallel execution.
172pub fn partition_work(total: usize, num_threads: usize) -> Vec<WorkRange> {
173    if num_threads == 0 || total == 0 {
174        return vec![];
175    }
176
177    if num_threads == 1 {
178        return vec![WorkRange::new(0, total)];
179    }
180
181    let chunk_size = total.div_ceil(num_threads);
182    let mut ranges = Vec::with_capacity(num_threads);
183
184    let mut start = 0;
185    while start < total {
186        let end = (start + chunk_size).min(total);
187        ranges.push(WorkRange::new(start, end));
188        start = end;
189    }
190
191    ranges
192}
193
194/// Executes a closure in parallel over work ranges.
195///
196/// If parallelism is disabled or the work is too small, executes sequentially.
197#[inline]
198pub fn for_each_range<F>(total: usize, par: Par, threshold: &ParThreshold, f: F)
199where
200    F: Fn(WorkRange) + Send + Sync,
201{
202    if !threshold.should_parallelize(total, par) {
203        f(WorkRange::new(0, total));
204        return;
205    }
206
207    #[cfg(feature = "parallel")]
208    {
209        let ranges = partition_work(total, par.num_threads());
210        ranges.into_par_iter().for_each(|range| {
211            f(range);
212        });
213    }
214
215    #[cfg(not(feature = "parallel"))]
216    {
217        f(WorkRange::new(0, total));
218    }
219}
220
221/// Parallel map-reduce operation.
222///
223/// Maps each work range to a value, then reduces all values.
224#[allow(unused_variables)]
225pub fn map_reduce<T, Map, Reduce>(
226    total: usize,
227    par: Par,
228    threshold: &ParThreshold,
229    identity: T,
230    map: Map,
231    reduce: Reduce,
232) -> T
233where
234    T: Clone + Send + Sync,
235    Map: Fn(WorkRange) -> T + Send + Sync,
236    Reduce: Fn(T, T) -> T + Send + Sync,
237{
238    if !threshold.should_parallelize(total, par) {
239        return map(WorkRange::new(0, total));
240    }
241
242    #[cfg(feature = "parallel")]
243    {
244        let ranges = partition_work(total, par.num_threads());
245        ranges
246            .into_par_iter()
247            .map(map)
248            .reduce(|| identity.clone(), reduce)
249    }
250
251    #[cfg(not(feature = "parallel"))]
252    {
253        map(WorkRange::new(0, total))
254    }
255}
256
257/// Parallel for_each with index.
258pub fn for_each_indexed<F>(total: usize, par: Par, threshold: &ParThreshold, f: F)
259where
260    F: Fn(usize) + Send + Sync,
261{
262    if !threshold.should_parallelize(total, par) {
263        for i in 0..total {
264            f(i);
265        }
266        return;
267    }
268
269    #[cfg(feature = "parallel")]
270    {
271        (0..total).into_par_iter().for_each(f);
272    }
273
274    #[cfg(not(feature = "parallel"))]
275    {
276        for i in 0..total {
277            f(i);
278        }
279    }
280}
281
282// =============================================================================
283// Custom thread pool support
284// =============================================================================
285
286/// Trait for custom thread pool implementations.
287///
288/// This allows using thread pools other than rayon's global pool.
289pub trait ThreadPool: Send + Sync {
290    /// Returns the number of threads in the pool.
291    fn num_threads(&self) -> usize;
292
293    /// Executes a closure on the thread pool.
294    fn execute<F>(&self, f: F)
295    where
296        F: FnOnce() + Send + 'static;
297
298    /// Joins two closures, executing them potentially in parallel.
299    fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
300    where
301        A: FnOnce() -> RA + Send,
302        B: FnOnce() -> RB + Send,
303        RA: Send,
304        RB: Send;
305
306    /// Parallel for_each over a range.
307    fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
308    where
309        F: Fn(usize) + Send + Sync;
310
311    /// Parallel map-reduce over a range.
312    fn map_reduce<T, Map, Reduce>(
313        &self,
314        range: core::ops::Range<usize>,
315        identity: T,
316        map: Map,
317        reduce: Reduce,
318    ) -> T
319    where
320        T: Clone + Send + Sync,
321        Map: Fn(usize) -> T + Send + Sync,
322        Reduce: Fn(T, T) -> T + Send + Sync;
323}
324
325/// A single-threaded "pool" for sequential execution.
326#[derive(Debug, Clone, Copy, Default)]
327pub struct SequentialPool;
328
329impl ThreadPool for SequentialPool {
330    #[inline]
331    fn num_threads(&self) -> usize {
332        1
333    }
334
335    #[inline]
336    fn execute<F>(&self, f: F)
337    where
338        F: FnOnce() + Send + 'static,
339    {
340        f();
341    }
342
343    #[inline]
344    fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
345    where
346        A: FnOnce() -> RA + Send,
347        B: FnOnce() -> RB + Send,
348        RA: Send,
349        RB: Send,
350    {
351        (a(), b())
352    }
353
354    fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
355    where
356        F: Fn(usize) + Send + Sync,
357    {
358        for i in range {
359            f(i);
360        }
361    }
362
363    fn map_reduce<T, Map, Reduce>(
364        &self,
365        range: core::ops::Range<usize>,
366        identity: T,
367        map: Map,
368        reduce: Reduce,
369    ) -> T
370    where
371        T: Clone + Send + Sync,
372        Map: Fn(usize) -> T + Send + Sync,
373        Reduce: Fn(T, T) -> T + Send + Sync,
374    {
375        let mut acc = identity;
376        for i in range {
377            acc = reduce(acc, map(i));
378        }
379        acc
380    }
381}
382
383/// Wrapper for rayon's global thread pool.
384#[cfg(feature = "parallel")]
385#[derive(Debug, Clone, Copy, Default)]
386pub struct RayonGlobalPool;
387
388#[cfg(feature = "parallel")]
389impl ThreadPool for RayonGlobalPool {
390    #[inline]
391    fn num_threads(&self) -> usize {
392        rayon::current_num_threads()
393    }
394
395    #[inline]
396    fn execute<F>(&self, f: F)
397    where
398        F: FnOnce() + Send + 'static,
399    {
400        rayon::spawn(f);
401    }
402
403    #[inline]
404    fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
405    where
406        A: FnOnce() -> RA + Send,
407        B: FnOnce() -> RB + Send,
408        RA: Send,
409        RB: Send,
410    {
411        rayon::join(a, b)
412    }
413
414    fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
415    where
416        F: Fn(usize) + Send + Sync,
417    {
418        range.into_par_iter().for_each(f);
419    }
420
421    fn map_reduce<T, Map, Reduce>(
422        &self,
423        range: core::ops::Range<usize>,
424        identity: T,
425        map: Map,
426        reduce: Reduce,
427    ) -> T
428    where
429        T: Clone + Send + Sync,
430        Map: Fn(usize) -> T + Send + Sync,
431        Reduce: Fn(T, T) -> T + Send + Sync,
432    {
433        range
434            .into_par_iter()
435            .map(map)
436            .reduce(|| identity.clone(), reduce)
437    }
438}
439
440/// Wrapper for a custom rayon thread pool.
441#[cfg(feature = "parallel")]
442pub struct CustomRayonPool {
443    pool: rayon::ThreadPool,
444}
445
446#[cfg(feature = "parallel")]
447impl CustomRayonPool {
448    /// Creates a new custom rayon pool with the specified number of threads.
449    pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
450        let pool = rayon::ThreadPoolBuilder::new()
451            .num_threads(num_threads)
452            .build()?;
453        Ok(CustomRayonPool { pool })
454    }
455
456    /// Creates a new custom rayon pool with builder configuration.
457    pub fn with_builder<F>(configure: F) -> Result<Self, rayon::ThreadPoolBuildError>
458    where
459        F: FnOnce(rayon::ThreadPoolBuilder) -> rayon::ThreadPoolBuilder,
460    {
461        let builder = rayon::ThreadPoolBuilder::new();
462        let pool = configure(builder).build()?;
463        Ok(CustomRayonPool { pool })
464    }
465
466    /// Installs this pool for the duration of the closure.
467    pub fn install<R, F>(&self, f: F) -> R
468    where
469        F: FnOnce() -> R + Send,
470        R: Send,
471    {
472        self.pool.install(f)
473    }
474}
475
476#[cfg(feature = "parallel")]
477impl ThreadPool for CustomRayonPool {
478    #[inline]
479    fn num_threads(&self) -> usize {
480        self.pool.current_num_threads()
481    }
482
483    #[inline]
484    fn execute<F>(&self, f: F)
485    where
486        F: FnOnce() + Send + 'static,
487    {
488        self.pool.spawn(f);
489    }
490
491    #[inline]
492    fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
493    where
494        A: FnOnce() -> RA + Send,
495        B: FnOnce() -> RB + Send,
496        RA: Send,
497        RB: Send,
498    {
499        self.pool.join(a, b)
500    }
501
502    fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
503    where
504        F: Fn(usize) + Send + Sync,
505    {
506        self.pool.install(|| {
507            range.into_par_iter().for_each(f);
508        });
509    }
510
511    fn map_reduce<T, Map, Reduce>(
512        &self,
513        range: core::ops::Range<usize>,
514        identity: T,
515        map: Map,
516        reduce: Reduce,
517    ) -> T
518    where
519        T: Clone + Send + Sync,
520        Map: Fn(usize) -> T + Send + Sync,
521        Reduce: Fn(T, T) -> T + Send + Sync,
522    {
523        self.pool.install(|| {
524            range
525                .into_par_iter()
526                .map(map)
527                .reduce(|| identity.clone(), reduce)
528        })
529    }
530}
531
532/// Scoped execution context for a thread pool.
533///
534/// This provides a convenient way to run operations with a specific thread pool.
535pub struct PoolScope<'a, P: ThreadPool> {
536    pool: &'a P,
537    threshold: ParThreshold,
538}
539
540impl<'a, P: ThreadPool> PoolScope<'a, P> {
541    /// Creates a new pool scope with default threshold.
542    pub fn new(pool: &'a P) -> Self {
543        PoolScope {
544            pool,
545            threshold: ParThreshold::default(),
546        }
547    }
548
549    /// Creates a new pool scope with a custom threshold.
550    pub fn with_threshold(pool: &'a P, threshold: ParThreshold) -> Self {
551        PoolScope { pool, threshold }
552    }
553
554    /// Returns the number of threads in the pool.
555    #[inline]
556    pub fn num_threads(&self) -> usize {
557        self.pool.num_threads()
558    }
559
560    /// Joins two closures.
561    #[inline]
562    pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
563    where
564        A: FnOnce() -> RA + Send,
565        B: FnOnce() -> RB + Send,
566        RA: Send,
567        RB: Send,
568    {
569        self.pool.join(a, b)
570    }
571
572    /// Parallel for_each over a range.
573    pub fn for_each<F>(&self, total: usize, f: F)
574    where
575        F: Fn(usize) + Send + Sync,
576    {
577        if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
578            for i in 0..total {
579                f(i);
580            }
581        } else {
582            self.pool.for_each(0..total, f);
583        }
584    }
585
586    /// Parallel for_each over work ranges.
587    pub fn for_each_range<F>(&self, total: usize, f: F)
588    where
589        F: Fn(WorkRange) + Send + Sync,
590    {
591        if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
592            f(WorkRange::new(0, total));
593        } else {
594            let ranges = partition_work(total, self.pool.num_threads());
595            for range in ranges {
596                f(range);
597            }
598        }
599    }
600
601    /// Parallel map-reduce operation.
602    pub fn map_reduce<T, Map, Reduce>(
603        &self,
604        total: usize,
605        identity: T,
606        map: Map,
607        reduce: Reduce,
608    ) -> T
609    where
610        T: Clone + Send + Sync,
611        Map: Fn(usize) -> T + Send + Sync,
612        Reduce: Fn(T, T) -> T + Send + Sync,
613    {
614        if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
615            let mut acc = identity;
616            for i in 0..total {
617                acc = reduce(acc, map(i));
618            }
619            acc
620        } else {
621            self.pool.map_reduce(0..total, identity, map, reduce)
622        }
623    }
624}
625
626/// Gets the default thread pool based on feature flags.
627#[cfg(feature = "parallel")]
628pub fn default_pool() -> RayonGlobalPool {
629    RayonGlobalPool
630}
631
632/// Gets the default thread pool (sequential without parallel feature).
633#[cfg(not(feature = "parallel"))]
634pub fn default_pool() -> SequentialPool {
635    SequentialPool
636}
637
638/// Executes work with the default pool.
639///
640/// This is a convenience wrapper that creates a PoolScope with the default pool.
641#[cfg(feature = "parallel")]
642pub fn with_default_pool<R, F>(f: F) -> R
643where
644    F: FnOnce(PoolScope<'_, RayonGlobalPool>) -> R,
645{
646    let pool = RayonGlobalPool;
647    f(PoolScope::new(&pool))
648}
649
650/// Executes work with the default pool (sequential version).
651#[cfg(not(feature = "parallel"))]
652pub fn with_default_pool<R, F>(f: F) -> R
653where
654    F: FnOnce(PoolScope<'_, SequentialPool>) -> R,
655{
656    let pool = SequentialPool;
657    f(PoolScope::new(&pool))
658}
659
660// =============================================================================
661// Thread-local accumulation
662// =============================================================================
663
664/// Thread-local accumulator for parallel reduction.
665///
666/// This is useful for operations like parallel summation where each thread
667/// maintains its own accumulator to avoid synchronization.
668#[cfg(feature = "parallel")]
669pub struct ThreadLocalAccum<T> {
670    values: Vec<std::sync::Mutex<T>>,
671}
672
673#[cfg(feature = "parallel")]
674impl<T: Clone + Send> ThreadLocalAccum<T> {
675    /// Creates a new thread-local accumulator.
676    pub fn new(identity: T) -> Self {
677        let num_threads = rayon::current_num_threads();
678        let values = (0..num_threads)
679            .map(|_| std::sync::Mutex::new(identity.clone()))
680            .collect();
681        ThreadLocalAccum { values }
682    }
683
684    /// Gets or initializes the accumulator for the current thread.
685    pub fn get(&self) -> std::sync::MutexGuard<'_, T> {
686        let thread_idx = rayon::current_thread_index().unwrap_or(0) % self.values.len();
687        self.values[thread_idx].lock().unwrap()
688    }
689
690    /// Reduces all thread-local values into a single result.
691    pub fn reduce<F>(self, f: F) -> T
692    where
693        F: Fn(T, T) -> T,
694    {
695        self.values
696            .into_iter()
697            .map(|m| m.into_inner().unwrap())
698            .reduce(f)
699            .unwrap()
700    }
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706
707    #[test]
708    fn test_partition_work() {
709        let ranges = partition_work(100, 4);
710        assert_eq!(ranges.len(), 4);
711
712        // Check that ranges cover everything
713        let mut covered = [false; 100];
714        for range in &ranges {
715            for (offset, covered_elem) in covered[range.start..range.end].iter_mut().enumerate() {
716                let i = range.start + offset;
717                assert!(!*covered_elem, "Overlap at {}", i);
718                *covered_elem = true;
719            }
720        }
721        assert!(covered.iter().all(|&x| x), "Not all elements covered");
722    }
723
724    #[test]
725    fn test_partition_work_uneven() {
726        let ranges = partition_work(10, 4);
727
728        // Total should equal original
729        let total: usize = ranges.iter().map(|r| r.len()).sum();
730        assert_eq!(total, 10);
731    }
732
733    #[test]
734    fn test_partition_work_single() {
735        let ranges = partition_work(100, 1);
736        assert_eq!(ranges.len(), 1);
737        assert_eq!(ranges[0].start, 0);
738        assert_eq!(ranges[0].end, 100);
739    }
740
741    #[test]
742    fn test_threshold() {
743        let threshold = ParThreshold::new(100, 10);
744
745        assert!(!threshold.should_parallelize(50, Par::Seq));
746        assert!(!threshold.should_parallelize(50, Par::default()));
747
748        #[cfg(feature = "parallel")]
749        {
750            // Only tests with parallel feature
751            assert!(threshold.should_parallelize(1000, Par::Rayon));
752        }
753    }
754
755    #[test]
756    fn test_global_parallelism() {
757        // Save current state
758        let was_enabled = is_parallelism_enabled();
759
760        disable_global_parallelism();
761        assert!(!is_parallelism_enabled());
762
763        enable_global_parallelism();
764        assert!(is_parallelism_enabled());
765
766        // Restore
767        if !was_enabled {
768            disable_global_parallelism();
769        }
770    }
771
772    #[test]
773    fn test_sequential_map_reduce() {
774        let result = map_reduce(
775            100,
776            Par::Seq,
777            &ParThreshold::default(),
778            0usize,
779            |range| range.len(),
780            |a, b| a + b,
781        );
782        assert_eq!(result, 100);
783    }
784
785    // Thread pool tests
786    #[test]
787    fn test_sequential_pool() {
788        let pool = SequentialPool;
789
790        assert_eq!(pool.num_threads(), 1);
791
792        // Test join
793        let (a, b) = pool.join(|| 1 + 1, || 2 + 2);
794        assert_eq!(a, 2);
795        assert_eq!(b, 4);
796
797        // Test for_each
798        let sum = std::sync::atomic::AtomicUsize::new(0);
799        pool.for_each(0..10, |i| {
800            sum.fetch_add(i, std::sync::atomic::Ordering::SeqCst);
801        });
802        assert_eq!(sum.load(std::sync::atomic::Ordering::SeqCst), 45);
803
804        // Test map_reduce
805        let result = pool.map_reduce(0..10, 0, |i| i, |a, b| a + b);
806        assert_eq!(result, 45);
807    }
808
809    #[test]
810    fn test_pool_scope() {
811        let pool = SequentialPool;
812        let scope = PoolScope::new(&pool);
813
814        assert_eq!(scope.num_threads(), 1);
815
816        // Test map_reduce
817        let result = scope.map_reduce(100, 0usize, |i| i, |a, b| a + b);
818        assert_eq!(result, (0..100).sum::<usize>());
819
820        // Test for_each
821        let sum = std::sync::atomic::AtomicUsize::new(0);
822        scope.for_each(10, |i| {
823            sum.fetch_add(i, std::sync::atomic::Ordering::SeqCst);
824        });
825        assert_eq!(sum.load(std::sync::atomic::Ordering::SeqCst), 45);
826    }
827
828    #[test]
829    fn test_pool_scope_with_threshold() {
830        let pool = SequentialPool;
831        let threshold = ParThreshold::new(50, 10);
832        let scope = PoolScope::with_threshold(&pool, threshold);
833
834        // Should work the same for sequential pool
835        let result = scope.map_reduce(100, 0usize, |i| i, |a, b| a + b);
836        assert_eq!(result, (0..100).sum::<usize>());
837    }
838
839    #[test]
840    fn test_default_pool() {
841        let pool = default_pool();
842        // Should have at least 1 thread
843        assert!(pool.num_threads() >= 1);
844    }
845
846    #[test]
847    fn test_with_default_pool() {
848        let result = with_default_pool(|scope| scope.num_threads());
849        assert!(result >= 1);
850    }
851
852    #[cfg(feature = "parallel")]
853    #[test]
854    fn test_rayon_global_pool() {
855        let pool = RayonGlobalPool;
856
857        // Should have multiple threads on most systems
858        assert!(pool.num_threads() >= 1);
859
860        // Test join
861        let (a, b) = pool.join(|| 1 + 1, || 2 + 2);
862        assert_eq!(a, 2);
863        assert_eq!(b, 4);
864
865        // Test map_reduce
866        let result = pool.map_reduce(0..100, 0, |i| i, |a, b| a + b);
867        assert_eq!(result, (0..100).sum::<usize>());
868    }
869
870    #[cfg(feature = "parallel")]
871    #[test]
872    fn test_custom_rayon_pool() {
873        let pool = CustomRayonPool::new(2).expect("Failed to create pool");
874
875        assert_eq!(pool.num_threads(), 2);
876
877        // Test map_reduce
878        let result = pool.map_reduce(0..100, 0, |i| i, |a, b| a + b);
879        assert_eq!(result, (0..100).sum::<usize>());
880
881        // Test install
882        let result = pool.install(|| (0..100).into_par_iter().sum::<usize>());
883        assert_eq!(result, (0..100).sum());
884    }
885}