Skip to main content

oxiblas_core/
parallel.rs

1//! Parallelization primitives for OxiBLAS.
2//!
3//! This module provides:
4//! - Parallel execution modes
5//! - Work partitioning utilities
6//! - Thread-local accumulation patterns
7
8#[cfg(not(feature = "std"))]
9use alloc::vec;
10#[cfg(not(feature = "std"))]
11use alloc::vec::Vec;
12
13use core::sync::atomic::{AtomicBool, Ordering};
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18/// Global flag to disable parallelism.
19static PARALLELISM_DISABLED: AtomicBool = AtomicBool::new(false);
20
21/// Disables global parallelism.
22///
23/// This can be useful for debugging or when running in environments
24/// where threading is problematic.
25pub fn disable_global_parallelism() {
26    PARALLELISM_DISABLED.store(true, Ordering::SeqCst);
27}
28
29/// Enables global parallelism.
30pub fn enable_global_parallelism() {
31    PARALLELISM_DISABLED.store(false, Ordering::SeqCst);
32}
33
34/// Returns true if parallelism is enabled.
35pub fn is_parallelism_enabled() -> bool {
36    !PARALLELISM_DISABLED.load(Ordering::SeqCst)
37}
38
39/// Parallelization mode.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Par {
42    /// Sequential execution.
43    Seq,
44    /// Parallel execution with the default thread pool.
45    #[cfg(feature = "parallel")]
46    Rayon,
47    /// Parallel execution with a specific number of threads.
48    #[cfg(feature = "parallel")]
49    RayonWith(usize),
50}
51
52// Manual impl because the default variant depends on feature flags
53// (Rayon when "parallel" is enabled, Seq otherwise)
54#[allow(clippy::derivable_impls)]
55impl Default for Par {
56    fn default() -> Self {
57        #[cfg(feature = "parallel")]
58        {
59            Par::Rayon
60        }
61        #[cfg(not(feature = "parallel"))]
62        {
63            Par::Seq
64        }
65    }
66}
67
68impl Par {
69    /// Returns true if this mode is sequential.
70    #[inline]
71    pub fn is_sequential(&self) -> bool {
72        match self {
73            Par::Seq => true,
74            #[cfg(feature = "parallel")]
75            _ => !is_parallelism_enabled(),
76        }
77    }
78
79    /// Returns the number of threads to use.
80    #[cfg(feature = "parallel")]
81    pub fn num_threads(&self) -> usize {
82        if !is_parallelism_enabled() {
83            return 1;
84        }
85
86        match self {
87            Par::Seq => 1,
88            Par::Rayon => rayon::current_num_threads(),
89            Par::RayonWith(n) => *n,
90        }
91    }
92
93    /// Returns the number of threads to use (always 1 without parallel feature).
94    #[cfg(not(feature = "parallel"))]
95    pub fn num_threads(&self) -> usize {
96        1
97    }
98}
99
100/// Threshold configuration for parallel operations.
101#[derive(Debug, Clone, Copy)]
102pub struct ParThreshold {
103    /// Minimum number of elements for parallelization.
104    pub min_elements: usize,
105    /// Minimum work per thread (elements).
106    pub min_work_per_thread: usize,
107}
108
109impl Default for ParThreshold {
110    fn default() -> Self {
111        ParThreshold {
112            min_elements: 4096,
113            min_work_per_thread: 256,
114        }
115    }
116}
117
118impl ParThreshold {
119    /// Creates a new threshold configuration.
120    pub const fn new(min_elements: usize, min_work_per_thread: usize) -> Self {
121        ParThreshold {
122            min_elements,
123            min_work_per_thread,
124        }
125    }
126
127    /// Returns true if parallelization should be used for the given work size.
128    #[inline]
129    pub fn should_parallelize(&self, total_work: usize, par: Par) -> bool {
130        if par.is_sequential() {
131            return false;
132        }
133
134        if total_work < self.min_elements {
135            return false;
136        }
137
138        let threads = par.num_threads();
139        if threads <= 1 {
140            return false;
141        }
142
143        total_work / threads >= self.min_work_per_thread
144    }
145}
146
147/// Work range for parallel iteration.
148#[derive(Debug, Clone, Copy)]
149pub struct WorkRange {
150    /// Start index (inclusive).
151    pub start: usize,
152    /// End index (exclusive).
153    pub end: usize,
154}
155
156impl WorkRange {
157    /// Creates a new work range.
158    #[inline]
159    pub const fn new(start: usize, end: usize) -> Self {
160        WorkRange { start, end }
161    }
162
163    /// Returns the length of the range.
164    #[inline]
165    pub const fn len(&self) -> usize {
166        self.end - self.start
167    }
168
169    /// Returns true if the range is empty.
170    #[inline]
171    pub const fn is_empty(&self) -> bool {
172        self.start >= self.end
173    }
174}
175
176/// Partitions work into chunks for parallel execution.
177pub fn partition_work(total: usize, num_threads: usize) -> Vec<WorkRange> {
178    if num_threads == 0 || total == 0 {
179        return vec![];
180    }
181
182    if num_threads == 1 {
183        return vec![WorkRange::new(0, total)];
184    }
185
186    let chunk_size = total.div_ceil(num_threads);
187    let mut ranges = Vec::with_capacity(num_threads);
188
189    let mut start = 0;
190    while start < total {
191        let end = (start + chunk_size).min(total);
192        ranges.push(WorkRange::new(start, end));
193        start = end;
194    }
195
196    ranges
197}
198
199/// Executes a closure in parallel over work ranges.
200///
201/// If parallelism is disabled or the work is too small, executes sequentially.
202#[inline]
203pub fn for_each_range<F>(total: usize, par: Par, threshold: &ParThreshold, f: F)
204where
205    F: Fn(WorkRange) + Send + Sync,
206{
207    if !threshold.should_parallelize(total, par) {
208        f(WorkRange::new(0, total));
209        return;
210    }
211
212    #[cfg(feature = "parallel")]
213    {
214        let ranges = partition_work(total, par.num_threads());
215        ranges.into_par_iter().for_each(|range| {
216            f(range);
217        });
218    }
219
220    #[cfg(not(feature = "parallel"))]
221    {
222        f(WorkRange::new(0, total));
223    }
224}
225
226/// Parallel map-reduce operation.
227///
228/// Maps each work range to a value, then reduces all values.
229#[allow(unused_variables)]
230pub fn map_reduce<T, Map, Reduce>(
231    total: usize,
232    par: Par,
233    threshold: &ParThreshold,
234    identity: T,
235    map: Map,
236    reduce: Reduce,
237) -> T
238where
239    T: Clone + Send + Sync,
240    Map: Fn(WorkRange) -> T + Send + Sync,
241    Reduce: Fn(T, T) -> T + Send + Sync,
242{
243    if !threshold.should_parallelize(total, par) {
244        return map(WorkRange::new(0, total));
245    }
246
247    #[cfg(feature = "parallel")]
248    {
249        let ranges = partition_work(total, par.num_threads());
250        ranges
251            .into_par_iter()
252            .map(map)
253            .reduce(|| identity.clone(), reduce)
254    }
255
256    #[cfg(not(feature = "parallel"))]
257    {
258        map(WorkRange::new(0, total))
259    }
260}
261
262/// Parallel for_each with index.
263pub fn for_each_indexed<F>(total: usize, par: Par, threshold: &ParThreshold, f: F)
264where
265    F: Fn(usize) + Send + Sync,
266{
267    if !threshold.should_parallelize(total, par) {
268        for i in 0..total {
269            f(i);
270        }
271        return;
272    }
273
274    #[cfg(feature = "parallel")]
275    {
276        (0..total).into_par_iter().for_each(f);
277    }
278
279    #[cfg(not(feature = "parallel"))]
280    {
281        for i in 0..total {
282            f(i);
283        }
284    }
285}
286
287// =============================================================================
288// Custom thread pool support
289// =============================================================================
290
291/// Trait for custom thread pool implementations.
292///
293/// This allows using thread pools other than rayon's global pool.
294pub trait ThreadPool: Send + Sync {
295    /// Returns the number of threads in the pool.
296    fn num_threads(&self) -> usize;
297
298    /// Executes a closure on the thread pool.
299    fn execute<F>(&self, f: F)
300    where
301        F: FnOnce() + Send + 'static;
302
303    /// Joins two closures, executing them potentially in parallel.
304    fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
305    where
306        A: FnOnce() -> RA + Send,
307        B: FnOnce() -> RB + Send,
308        RA: Send,
309        RB: Send;
310
311    /// Parallel for_each over a range.
312    fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
313    where
314        F: Fn(usize) + Send + Sync;
315
316    /// Parallel map-reduce over a range.
317    fn map_reduce<T, Map, Reduce>(
318        &self,
319        range: core::ops::Range<usize>,
320        identity: T,
321        map: Map,
322        reduce: Reduce,
323    ) -> T
324    where
325        T: Clone + Send + Sync,
326        Map: Fn(usize) -> T + Send + Sync,
327        Reduce: Fn(T, T) -> T + Send + Sync;
328}
329
330/// A single-threaded "pool" for sequential execution.
331#[derive(Debug, Clone, Copy, Default)]
332pub struct SequentialPool;
333
334impl ThreadPool for SequentialPool {
335    #[inline]
336    fn num_threads(&self) -> usize {
337        1
338    }
339
340    #[inline]
341    fn execute<F>(&self, f: F)
342    where
343        F: FnOnce() + Send + 'static,
344    {
345        f();
346    }
347
348    #[inline]
349    fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
350    where
351        A: FnOnce() -> RA + Send,
352        B: FnOnce() -> RB + Send,
353        RA: Send,
354        RB: Send,
355    {
356        (a(), b())
357    }
358
359    fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
360    where
361        F: Fn(usize) + Send + Sync,
362    {
363        for i in range {
364            f(i);
365        }
366    }
367
368    fn map_reduce<T, Map, Reduce>(
369        &self,
370        range: core::ops::Range<usize>,
371        identity: T,
372        map: Map,
373        reduce: Reduce,
374    ) -> T
375    where
376        T: Clone + Send + Sync,
377        Map: Fn(usize) -> T + Send + Sync,
378        Reduce: Fn(T, T) -> T + Send + Sync,
379    {
380        let mut acc = identity;
381        for i in range {
382            acc = reduce(acc, map(i));
383        }
384        acc
385    }
386}
387
388/// Wrapper for rayon's global thread pool.
389#[cfg(feature = "parallel")]
390#[derive(Debug, Clone, Copy, Default)]
391pub struct RayonGlobalPool;
392
393#[cfg(feature = "parallel")]
394impl ThreadPool for RayonGlobalPool {
395    #[inline]
396    fn num_threads(&self) -> usize {
397        rayon::current_num_threads()
398    }
399
400    #[inline]
401    fn execute<F>(&self, f: F)
402    where
403        F: FnOnce() + Send + 'static,
404    {
405        rayon::spawn(f);
406    }
407
408    #[inline]
409    fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
410    where
411        A: FnOnce() -> RA + Send,
412        B: FnOnce() -> RB + Send,
413        RA: Send,
414        RB: Send,
415    {
416        rayon::join(a, b)
417    }
418
419    fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
420    where
421        F: Fn(usize) + Send + Sync,
422    {
423        range.into_par_iter().for_each(f);
424    }
425
426    fn map_reduce<T, Map, Reduce>(
427        &self,
428        range: core::ops::Range<usize>,
429        identity: T,
430        map: Map,
431        reduce: Reduce,
432    ) -> T
433    where
434        T: Clone + Send + Sync,
435        Map: Fn(usize) -> T + Send + Sync,
436        Reduce: Fn(T, T) -> T + Send + Sync,
437    {
438        range
439            .into_par_iter()
440            .map(map)
441            .reduce(|| identity.clone(), reduce)
442    }
443}
444
445/// Wrapper for a custom rayon thread pool.
446#[cfg(feature = "parallel")]
447pub struct CustomRayonPool {
448    pool: rayon::ThreadPool,
449}
450
451#[cfg(feature = "parallel")]
452impl CustomRayonPool {
453    /// Creates a new custom rayon pool with the specified number of threads.
454    pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
455        let pool = rayon::ThreadPoolBuilder::new()
456            .num_threads(num_threads)
457            .build()?;
458        Ok(CustomRayonPool { pool })
459    }
460
461    /// Creates a new custom rayon pool with the specified number of threads.
462    ///
463    /// This is an alias for [`CustomRayonPool::new`] that matches the naming
464    /// convention used in the `OxiblasThreadConfig` builder API.
465    pub fn with_num_threads(n: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
466        Self::new(n)
467    }
468
469    /// Creates a new custom rayon pool with builder configuration.
470    pub fn with_builder<F>(configure: F) -> Result<Self, rayon::ThreadPoolBuildError>
471    where
472        F: FnOnce(rayon::ThreadPoolBuilder) -> rayon::ThreadPoolBuilder,
473    {
474        let builder = rayon::ThreadPoolBuilder::new();
475        let pool = configure(builder).build()?;
476        Ok(CustomRayonPool { pool })
477    }
478
479    /// Installs this pool for the duration of the closure.
480    pub fn install<R, F>(&self, f: F) -> R
481    where
482        F: FnOnce() -> R + Send,
483        R: Send,
484    {
485        self.pool.install(f)
486    }
487}
488
489#[cfg(feature = "parallel")]
490impl ThreadPool for CustomRayonPool {
491    #[inline]
492    fn num_threads(&self) -> usize {
493        self.pool.current_num_threads()
494    }
495
496    #[inline]
497    fn execute<F>(&self, f: F)
498    where
499        F: FnOnce() + Send + 'static,
500    {
501        self.pool.spawn(f);
502    }
503
504    #[inline]
505    fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
506    where
507        A: FnOnce() -> RA + Send,
508        B: FnOnce() -> RB + Send,
509        RA: Send,
510        RB: Send,
511    {
512        self.pool.join(a, b)
513    }
514
515    fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
516    where
517        F: Fn(usize) + Send + Sync,
518    {
519        self.pool.install(|| {
520            range.into_par_iter().for_each(f);
521        });
522    }
523
524    fn map_reduce<T, Map, Reduce>(
525        &self,
526        range: core::ops::Range<usize>,
527        identity: T,
528        map: Map,
529        reduce: Reduce,
530    ) -> T
531    where
532        T: Clone + Send + Sync,
533        Map: Fn(usize) -> T + Send + Sync,
534        Reduce: Fn(T, T) -> T + Send + Sync,
535    {
536        self.pool.install(|| {
537            range
538                .into_par_iter()
539                .map(map)
540                .reduce(|| identity.clone(), reduce)
541        })
542    }
543}
544
545/// Scoped execution context for a thread pool.
546///
547/// This provides a convenient way to run operations with a specific thread pool.
548pub struct PoolScope<'a, P: ThreadPool> {
549    pool: &'a P,
550    threshold: ParThreshold,
551}
552
553impl<'a, P: ThreadPool> PoolScope<'a, P> {
554    /// Creates a new pool scope with default threshold.
555    pub fn new(pool: &'a P) -> Self {
556        PoolScope {
557            pool,
558            threshold: ParThreshold::default(),
559        }
560    }
561
562    /// Creates a new pool scope with a custom threshold.
563    pub fn with_threshold(pool: &'a P, threshold: ParThreshold) -> Self {
564        PoolScope { pool, threshold }
565    }
566
567    /// Returns the number of threads in the pool.
568    #[inline]
569    pub fn num_threads(&self) -> usize {
570        self.pool.num_threads()
571    }
572
573    /// Joins two closures.
574    #[inline]
575    pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
576    where
577        A: FnOnce() -> RA + Send,
578        B: FnOnce() -> RB + Send,
579        RA: Send,
580        RB: Send,
581    {
582        self.pool.join(a, b)
583    }
584
585    /// Parallel for_each over a range.
586    pub fn for_each<F>(&self, total: usize, f: F)
587    where
588        F: Fn(usize) + Send + Sync,
589    {
590        if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
591            for i in 0..total {
592                f(i);
593            }
594        } else {
595            self.pool.for_each(0..total, f);
596        }
597    }
598
599    /// Parallel for_each over work ranges.
600    pub fn for_each_range<F>(&self, total: usize, f: F)
601    where
602        F: Fn(WorkRange) + Send + Sync,
603    {
604        if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
605            f(WorkRange::new(0, total));
606        } else {
607            let ranges = partition_work(total, self.pool.num_threads());
608            for range in ranges {
609                f(range);
610            }
611        }
612    }
613
614    /// Parallel map-reduce operation.
615    pub fn map_reduce<T, Map, Reduce>(
616        &self,
617        total: usize,
618        identity: T,
619        map: Map,
620        reduce: Reduce,
621    ) -> T
622    where
623        T: Clone + Send + Sync,
624        Map: Fn(usize) -> T + Send + Sync,
625        Reduce: Fn(T, T) -> T + Send + Sync,
626    {
627        if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
628            let mut acc = identity;
629            for i in 0..total {
630                acc = reduce(acc, map(i));
631            }
632            acc
633        } else {
634            self.pool.map_reduce(0..total, identity, map, reduce)
635        }
636    }
637}
638
639/// Gets the default thread pool based on feature flags.
640#[cfg(feature = "parallel")]
641pub fn default_pool() -> RayonGlobalPool {
642    RayonGlobalPool
643}
644
645/// Gets the default thread pool (sequential without parallel feature).
646#[cfg(not(feature = "parallel"))]
647pub fn default_pool() -> SequentialPool {
648    SequentialPool
649}
650
651/// Executes work with the default pool.
652///
653/// This is a convenience wrapper that creates a PoolScope with the default pool.
654#[cfg(feature = "parallel")]
655pub fn with_default_pool<R, F>(f: F) -> R
656where
657    F: FnOnce(PoolScope<'_, RayonGlobalPool>) -> R,
658{
659    let pool = RayonGlobalPool;
660    f(PoolScope::new(&pool))
661}
662
663/// Executes work with the default pool (sequential version).
664#[cfg(not(feature = "parallel"))]
665pub fn with_default_pool<R, F>(f: F) -> R
666where
667    F: FnOnce(PoolScope<'_, SequentialPool>) -> R,
668{
669    let pool = SequentialPool;
670    f(PoolScope::new(&pool))
671}
672
673// =============================================================================
674// Global thread pool management
675// =============================================================================
676
677/// Configuration for the OxiBLAS thread pool.
678///
679/// `OxiblasThreadConfig` gathers all knobs that influence how OxiBLAS
680/// chooses threads for parallel operations.  Build one with the fluent
681/// builder methods, then apply it via [`set_global_thread_pool`] or
682/// [`with_thread_count`].
683///
684/// # Example
685///
686/// ```rust
687/// use oxiblas_core::parallel::OxiblasThreadConfig;
688///
689/// let cfg = OxiblasThreadConfig::new()
690///     .num_threads(4)
691///     .stack_size(2 * 1024 * 1024);
692/// println!("threads: {}", cfg.num_threads);
693/// ```
694#[derive(Debug, Clone, Default)]
695pub struct OxiblasThreadConfig {
696    /// Number of worker threads.  `0` means "use all logical CPUs".
697    pub num_threads: usize,
698    /// Per-thread stack size in bytes.  `0` means "use OS default".
699    pub stack_size: usize,
700    /// Human-readable name prefix for spawned threads.
701    pub thread_name: Option<String>,
702}
703
704impl OxiblasThreadConfig {
705    /// Creates a new configuration with all defaults.
706    pub fn new() -> Self {
707        Self::default()
708    }
709
710    /// Sets the desired thread count.  Pass `0` for "all CPUs".
711    pub fn num_threads(mut self, n: usize) -> Self {
712        self.num_threads = n;
713        self
714    }
715
716    /// Sets the per-thread stack size.  Pass `0` for the OS default.
717    pub fn stack_size(mut self, bytes: usize) -> Self {
718        self.stack_size = bytes;
719        self
720    }
721
722    /// Sets a human-readable name prefix for spawned threads.
723    pub fn thread_name(mut self, name: impl Into<String>) -> Self {
724        self.thread_name = Some(name.into());
725        self
726    }
727
728    /// Returns the effective thread count, substituting the available
729    /// logical CPU count when `num_threads` is `0`.
730    pub fn effective_threads(&self) -> usize {
731        if self.num_threads == 0 {
732            std::thread::available_parallelism()
733                .map(|n| n.get())
734                .unwrap_or(1)
735        } else {
736            self.num_threads
737        }
738    }
739
740    /// Builds a [`CustomRayonPool`] from this configuration.
741    ///
742    /// Returns an error if rayon fails to construct the pool.
743    #[cfg(feature = "parallel")]
744    pub fn build_pool(&self) -> Result<CustomRayonPool, rayon::ThreadPoolBuildError> {
745        let mut builder = rayon::ThreadPoolBuilder::new().num_threads(self.effective_threads());
746        if self.stack_size > 0 {
747            builder = builder.stack_size(self.stack_size);
748        }
749        if let Some(name) = &self.thread_name {
750            let name = name.clone();
751            builder = builder.thread_name(move |i| format!("{name}-{i}"));
752        }
753        let pool = builder.build()?;
754        Ok(CustomRayonPool { pool })
755    }
756}
757
758// ---------------------------------------------------------------------------
759// Global pool registry
760// ---------------------------------------------------------------------------
761
762/// A type-erased, `Send + Sync` trait object for thread pools stored
763/// in the global registry.
764#[cfg(feature = "std")]
765trait AnyPool: Send + Sync {
766    fn num_threads_dyn(&self) -> usize;
767}
768
769#[cfg(all(feature = "std", feature = "parallel"))]
770impl AnyPool for CustomRayonPool {
771    fn num_threads_dyn(&self) -> usize {
772        self.num_threads()
773    }
774}
775
776#[cfg(feature = "std")]
777impl AnyPool for SequentialPool {
778    fn num_threads_dyn(&self) -> usize {
779        1
780    }
781}
782
783#[cfg(feature = "std")]
784static GLOBAL_POOL: std::sync::OnceLock<Box<dyn AnyPool>> = std::sync::OnceLock::new();
785
786/// Sets the global OxiBLAS thread pool.
787///
788/// The pool is stored in a `OnceLock` so it can only be set **once** per
789/// process.  Subsequent calls are silently ignored (the first writer wins).
790///
791/// # Arguments
792///
793/// * `pool` – Any value that implements [`ThreadPool`] and is
794///   `'static + Send + Sync`.  Typically a [`CustomRayonPool`] built via
795///   [`OxiblasThreadConfig::build_pool`] or
796///   [`CustomRayonPool::with_num_threads`].
797///
798/// # Example
799///
800/// ```rust
801/// # #[cfg(feature = "parallel")]
802/// # {
803/// use oxiblas_core::parallel::{CustomRayonPool, set_global_thread_pool};
804/// let pool = CustomRayonPool::with_num_threads(4).expect("build pool");
805/// set_global_thread_pool(pool);
806/// # }
807/// ```
808#[cfg(all(feature = "std", feature = "parallel"))]
809pub fn set_global_thread_pool(pool: CustomRayonPool) {
810    let _ = GLOBAL_POOL.set(Box::new(pool));
811}
812
813/// Sets the global OxiBLAS thread pool to a sequential (single-threaded)
814/// pool (available without the `parallel` feature).
815#[cfg(all(feature = "std", not(feature = "parallel")))]
816pub fn set_global_thread_pool(pool: SequentialPool) {
817    let _ = GLOBAL_POOL.set(Box::new(pool));
818}
819
820/// Returns the number of threads in the global pool, or `1` if no pool has
821/// been registered.
822#[cfg(feature = "std")]
823pub fn global_num_threads() -> usize {
824    GLOBAL_POOL.get().map(|p| p.num_threads_dyn()).unwrap_or(1)
825}
826
827/// Executes `f` inside a temporary rayon pool with exactly `n` threads.
828///
829/// This is useful for benchmarks or tests that need deterministic
830/// parallelism without replacing the global pool.  On platforms without
831/// the `parallel` feature the closure is called directly on the current
832/// thread.
833///
834/// # Example
835///
836/// ```rust
837/// use oxiblas_core::parallel::with_thread_count;
838///
839/// with_thread_count(2, || {
840///     // work here runs with (up to) 2 rayon threads
841/// });
842/// ```
843#[cfg(feature = "parallel")]
844pub fn with_thread_count(n: usize, f: impl FnOnce() + Send) {
845    let pool = rayon::ThreadPoolBuilder::new().num_threads(n).build();
846    match pool {
847        Ok(p) => p.install(f),
848        Err(_) => f(), // fallback: run sequentially if build fails
849    }
850}
851
852/// Sequential fallback when the `parallel` feature is disabled.
853#[cfg(not(feature = "parallel"))]
854pub fn with_thread_count(_n: usize, f: impl FnOnce()) {
855    f();
856}
857
858// =============================================================================
859// Thread-local accumulation
860// =============================================================================
861
862/// Thread-local accumulator for parallel reduction.
863///
864/// This is useful for operations like parallel summation where each thread
865/// maintains its own accumulator to avoid synchronization.
866///
867/// Requires the `parallel` feature (which implies `std`).
868#[cfg(feature = "parallel")]
869pub struct ThreadLocalAccum<T> {
870    values: Vec<std::sync::Mutex<T>>,
871}
872
873#[cfg(feature = "parallel")]
874impl<T: Clone + Send> ThreadLocalAccum<T> {
875    /// Creates a new thread-local accumulator.
876    pub fn new(identity: T) -> Self {
877        let num_threads = rayon::current_num_threads();
878        let values = (0..num_threads)
879            .map(|_| std::sync::Mutex::new(identity.clone()))
880            .collect();
881        ThreadLocalAccum { values }
882    }
883
884    /// Gets or initializes the accumulator for the current thread.
885    pub fn get(&self) -> std::sync::MutexGuard<'_, T> {
886        let thread_idx = rayon::current_thread_index().unwrap_or(0) % self.values.len();
887        self.values[thread_idx]
888            .lock()
889            .unwrap_or_else(|poisoned| poisoned.into_inner())
890    }
891
892    /// Reduces all thread-local values into a single result.
893    pub fn reduce<F>(self, f: F) -> T
894    where
895        F: Fn(T, T) -> T,
896    {
897        self.values
898            .into_iter()
899            .map(|m| {
900                m.into_inner()
901                    .unwrap_or_else(|poisoned| poisoned.into_inner())
902            })
903            .reduce(f)
904            .expect("ThreadLocalAccum should have at least one value")
905    }
906}
907
908#[cfg(test)]
909mod tests {
910    use super::*;
911
912    #[test]
913    fn test_partition_work() {
914        let ranges = partition_work(100, 4);
915        assert_eq!(ranges.len(), 4);
916
917        // Check that ranges cover everything
918        let mut covered = [false; 100];
919        for range in &ranges {
920            for (offset, covered_elem) in covered[range.start..range.end].iter_mut().enumerate() {
921                let i = range.start + offset;
922                assert!(!*covered_elem, "Overlap at {}", i);
923                *covered_elem = true;
924            }
925        }
926        assert!(covered.iter().all(|&x| x), "Not all elements covered");
927    }
928
929    #[test]
930    fn test_partition_work_uneven() {
931        let ranges = partition_work(10, 4);
932
933        // Total should equal original
934        let total: usize = ranges.iter().map(|r| r.len()).sum();
935        assert_eq!(total, 10);
936    }
937
938    #[test]
939    fn test_partition_work_single() {
940        let ranges = partition_work(100, 1);
941        assert_eq!(ranges.len(), 1);
942        assert_eq!(ranges[0].start, 0);
943        assert_eq!(ranges[0].end, 100);
944    }
945
946    #[test]
947    fn test_threshold() {
948        let threshold = ParThreshold::new(100, 10);
949
950        assert!(!threshold.should_parallelize(50, Par::Seq));
951        assert!(!threshold.should_parallelize(50, Par::default()));
952
953        #[cfg(feature = "parallel")]
954        {
955            // Only tests with parallel feature
956            assert!(threshold.should_parallelize(1000, Par::Rayon));
957        }
958    }
959
960    #[test]
961    fn test_global_parallelism() {
962        // Save current state
963        let was_enabled = is_parallelism_enabled();
964
965        disable_global_parallelism();
966        assert!(!is_parallelism_enabled());
967
968        enable_global_parallelism();
969        assert!(is_parallelism_enabled());
970
971        // Restore
972        if !was_enabled {
973            disable_global_parallelism();
974        }
975    }
976
977    #[test]
978    fn test_sequential_map_reduce() {
979        let result = map_reduce(
980            100,
981            Par::Seq,
982            &ParThreshold::default(),
983            0usize,
984            |range| range.len(),
985            |a, b| a + b,
986        );
987        assert_eq!(result, 100);
988    }
989
990    // Thread pool tests
991    #[test]
992    fn test_sequential_pool() {
993        let pool = SequentialPool;
994
995        assert_eq!(pool.num_threads(), 1);
996
997        // Test join
998        let (a, b) = pool.join(|| 1 + 1, || 2 + 2);
999        assert_eq!(a, 2);
1000        assert_eq!(b, 4);
1001
1002        // Test for_each
1003        let sum = std::sync::atomic::AtomicUsize::new(0);
1004        pool.for_each(0..10, |i| {
1005            sum.fetch_add(i, std::sync::atomic::Ordering::SeqCst);
1006        });
1007        assert_eq!(sum.load(std::sync::atomic::Ordering::SeqCst), 45);
1008
1009        // Test map_reduce
1010        let result = pool.map_reduce(0..10, 0, |i| i, |a, b| a + b);
1011        assert_eq!(result, 45);
1012    }
1013
1014    #[test]
1015    fn test_pool_scope() {
1016        let pool = SequentialPool;
1017        let scope = PoolScope::new(&pool);
1018
1019        assert_eq!(scope.num_threads(), 1);
1020
1021        // Test map_reduce
1022        let result = scope.map_reduce(100, 0usize, |i| i, |a, b| a + b);
1023        assert_eq!(result, (0..100).sum::<usize>());
1024
1025        // Test for_each
1026        let sum = std::sync::atomic::AtomicUsize::new(0);
1027        scope.for_each(10, |i| {
1028            sum.fetch_add(i, std::sync::atomic::Ordering::SeqCst);
1029        });
1030        assert_eq!(sum.load(std::sync::atomic::Ordering::SeqCst), 45);
1031    }
1032
1033    #[test]
1034    fn test_pool_scope_with_threshold() {
1035        let pool = SequentialPool;
1036        let threshold = ParThreshold::new(50, 10);
1037        let scope = PoolScope::with_threshold(&pool, threshold);
1038
1039        // Should work the same for sequential pool
1040        let result = scope.map_reduce(100, 0usize, |i| i, |a, b| a + b);
1041        assert_eq!(result, (0..100).sum::<usize>());
1042    }
1043
1044    #[test]
1045    fn test_default_pool() {
1046        let pool = default_pool();
1047        // Should have at least 1 thread
1048        assert!(pool.num_threads() >= 1);
1049    }
1050
1051    #[test]
1052    fn test_with_default_pool() {
1053        let result = with_default_pool(|scope| scope.num_threads());
1054        assert!(result >= 1);
1055    }
1056
1057    #[cfg(feature = "parallel")]
1058    #[test]
1059    fn test_rayon_global_pool() {
1060        let pool = RayonGlobalPool;
1061
1062        // Should have multiple threads on most systems
1063        assert!(pool.num_threads() >= 1);
1064
1065        // Test join
1066        let (a, b) = pool.join(|| 1 + 1, || 2 + 2);
1067        assert_eq!(a, 2);
1068        assert_eq!(b, 4);
1069
1070        // Test map_reduce
1071        let result = pool.map_reduce(0..100, 0, |i| i, |a, b| a + b);
1072        assert_eq!(result, (0..100).sum::<usize>());
1073    }
1074
1075    #[cfg(feature = "parallel")]
1076    #[test]
1077    fn test_custom_rayon_pool() {
1078        let pool = CustomRayonPool::new(2).expect("Failed to create pool");
1079
1080        assert_eq!(pool.num_threads(), 2);
1081
1082        // Test map_reduce
1083        let result = pool.map_reduce(0..100, 0, |i| i, |a, b| a + b);
1084        assert_eq!(result, (0..100).sum::<usize>());
1085
1086        // Test install
1087        let result = pool.install(|| (0..100).into_par_iter().sum::<usize>());
1088        assert_eq!(result, (0..100).sum());
1089    }
1090
1091    // ---- OxiblasThreadConfig tests ------------------------------------------
1092
1093    #[test]
1094    fn test_thread_config_default() {
1095        let cfg = OxiblasThreadConfig::default();
1096        assert_eq!(cfg.num_threads, 0);
1097        assert_eq!(cfg.stack_size, 0);
1098        assert!(cfg.thread_name.is_none());
1099    }
1100
1101    #[test]
1102    fn test_thread_config_builder() {
1103        let cfg = OxiblasThreadConfig::new()
1104            .num_threads(4)
1105            .stack_size(1024 * 1024)
1106            .thread_name("oxiblas-worker");
1107        assert_eq!(cfg.num_threads, 4);
1108        assert_eq!(cfg.stack_size, 1024 * 1024);
1109        assert_eq!(cfg.thread_name.as_deref(), Some("oxiblas-worker"));
1110    }
1111
1112    #[test]
1113    fn test_thread_config_effective_threads_zero() {
1114        let cfg = OxiblasThreadConfig::new().num_threads(0);
1115        // effective_threads should fall back to available parallelism (>= 1)
1116        assert!(cfg.effective_threads() >= 1);
1117    }
1118
1119    #[test]
1120    fn test_thread_config_effective_threads_explicit() {
1121        let cfg = OxiblasThreadConfig::new().num_threads(3);
1122        assert_eq!(cfg.effective_threads(), 3);
1123    }
1124
1125    #[cfg(feature = "parallel")]
1126    #[test]
1127    fn test_custom_rayon_pool_with_num_threads() {
1128        let pool = CustomRayonPool::with_num_threads(2).expect("build pool");
1129        assert_eq!(pool.num_threads(), 2);
1130        let sum: usize = pool.map_reduce(0..50, 0, |i| i, |a, b| a + b);
1131        assert_eq!(sum, (0..50).sum::<usize>());
1132    }
1133
1134    #[cfg(feature = "parallel")]
1135    #[test]
1136    fn test_oxiblas_thread_config_build_pool() {
1137        let cfg = OxiblasThreadConfig::new().num_threads(2);
1138        let pool = cfg.build_pool().expect("build pool");
1139        assert_eq!(pool.num_threads(), 2);
1140    }
1141
1142    #[cfg(feature = "parallel")]
1143    #[test]
1144    fn test_with_thread_count() {
1145        // Run inside a 2-thread pool and verify rayon sees 2 threads.
1146        with_thread_count(2, || {
1147            assert_eq!(rayon::current_num_threads(), 2);
1148        });
1149    }
1150
1151    #[cfg(not(feature = "parallel"))]
1152    #[test]
1153    fn test_with_thread_count_sequential() {
1154        // Without parallel feature, should just call the closure directly.
1155        let mut called = false;
1156        with_thread_count(4, || {
1157            called = true;
1158        });
1159        assert!(called);
1160    }
1161
1162    #[cfg(feature = "std")]
1163    #[test]
1164    fn test_global_num_threads_default() {
1165        // Before any pool is registered the answer must be at least 1.
1166        // (May be > 1 if a sibling test already set the global pool.)
1167        assert!(global_num_threads() >= 1);
1168    }
1169}