Skip to main content

laddu_core/execution/
thread_pool.rs

1//! Shared thread-pool manager for APIs that accept a per-call thread count.
2
3#[cfg(feature = "rayon")]
4use std::sync::Arc;
5use std::sync::{
6    atomic::{AtomicUsize, Ordering},
7    OnceLock,
8};
9
10#[cfg(feature = "rayon")]
11use parking_lot::RwLock;
12
13#[cfg(feature = "rayon")]
14use crate::LadduError;
15use crate::LadduResult;
16
17static GLOBAL_THREAD_COUNT: AtomicUsize = AtomicUsize::new(0);
18
19/// Shared thread-execution mode used by both [`ThreadPoolManager`] and
20/// [`ExecutionContext`](crate::execution::ExecutionContext).
21#[derive(Debug, Clone, Default)]
22pub(crate) enum ThreadExecutor {
23    /// Run work on the caller thread / ambient global Rayon context.
24    #[default]
25    Ambient,
26    /// Run work on a dedicated Rayon pool.
27    #[cfg(feature = "rayon")]
28    Dedicated(Arc<rayon::ThreadPool>),
29}
30
31impl ThreadExecutor {
32    /// Create a dedicated executor with `n_threads`.
33    #[cfg(feature = "rayon")]
34    pub(crate) fn dedicated(n_threads: usize) -> LadduResult<Self> {
35        if n_threads == 0 {
36            return Err(LadduError::ExecutionContextError {
37                reason: "Dedicated thread pool size must be >= 1".into(),
38            });
39        }
40
41        Ok(Self::Dedicated(Arc::new(
42            rayon::ThreadPoolBuilder::new()
43                .num_threads(n_threads)
44                .build()?,
45        )))
46    }
47
48    /// Execute work using this executor.
49    #[cfg(feature = "rayon")]
50    pub(crate) fn install<R: Send>(&self, op: impl FnOnce() -> R + Send) -> R {
51        match self {
52            Self::Ambient => op(),
53            Self::Dedicated(pool) => pool.install(op),
54        }
55    }
56
57    /// Execute work using this executor.
58    #[allow(dead_code)]
59    #[cfg(not(feature = "rayon"))]
60    pub(crate) fn install<R>(&self, op: impl FnOnce() -> R) -> R {
61        op()
62    }
63}
64
65/// Shared manager for per-call Rayon thread-pool reuse.
66///
67/// This manager is intended for APIs that accept an optional thread count on each call.
68/// Requests with `None` or `Some(0)` use the configured global default. When that default is `0`,
69/// work falls back to the ambient/global Rayon behavior. Positive thread counts reuse one cached
70/// dedicated pool for the most recently requested size.
71#[derive(Debug, Default)]
72pub struct ThreadPoolManager {
73    #[cfg(feature = "rayon")]
74    pub(crate) dedicated_pool: RwLock<Option<(usize, ThreadExecutor)>>,
75}
76
77impl ThreadPoolManager {
78    /// Return the process-wide shared pool manager.
79    pub fn shared() -> &'static Self {
80        static THREAD_POOL_MANAGER: OnceLock<ThreadPoolManager> = OnceLock::new();
81        THREAD_POOL_MANAGER.get_or_init(Self::default)
82    }
83
84    /// Set the process-global default thread count used by omitted or zero-valued requests.
85    ///
86    /// A value of `0` resets the default to the ambient/global Rayon behavior.
87    pub fn set_global_thread_count(n_threads: usize) {
88        GLOBAL_THREAD_COUNT.store(n_threads, Ordering::Relaxed);
89    }
90
91    /// Return the process-global default thread count used by omitted or zero-valued requests.
92    ///
93    /// Returns `None` when the default is the ambient/global Rayon behavior.
94    pub fn global_thread_count() -> Option<usize> {
95        Self::normalize_thread_request(Some(GLOBAL_THREAD_COUNT.load(Ordering::Relaxed)))
96    }
97
98    /// Resolve an optional thread request against the process-global default.
99    ///
100    /// `None` and `Some(0)` both use the configured global default. Positive thread counts bypass
101    /// the global default and are returned directly.
102    pub fn resolve_thread_request(requested_threads: Option<usize>) -> Option<usize> {
103        match requested_threads {
104            None | Some(0) => Self::global_thread_count(),
105            Some(n_threads) => Some(n_threads),
106        }
107    }
108
109    /// Execute work using the requested thread-count policy.
110    ///
111    /// `None` and `Some(0)` both use the configured global default. Positive thread counts reuse a
112    /// cached dedicated pool of that size.
113    #[cfg(feature = "rayon")]
114    pub fn install<R: Send>(
115        &self,
116        requested_threads: Option<usize>,
117        op: impl FnOnce() -> R + Send,
118    ) -> LadduResult<R> {
119        match Self::resolve_thread_request(requested_threads) {
120            Some(n_threads) => Ok(self.executor_for_threads(n_threads)?.install(op)),
121            None => Ok(ThreadExecutor::default().install(op)),
122        }
123    }
124
125    /// Execute work using the requested thread-count policy.
126    ///
127    /// Without Rayon, all work runs on the caller thread and the requested thread count is
128    /// ignored.
129    #[cfg(not(feature = "rayon"))]
130    pub fn install<R>(
131        &self,
132        _requested_threads: Option<usize>,
133        op: impl FnOnce() -> R,
134    ) -> LadduResult<R> {
135        Ok(op())
136    }
137
138    fn normalize_thread_request(requested_threads: Option<usize>) -> Option<usize> {
139        requested_threads.filter(|&n_threads| n_threads > 0)
140    }
141
142    #[cfg(feature = "rayon")]
143    pub(crate) fn executor_for_threads(&self, n_threads: usize) -> LadduResult<ThreadExecutor> {
144        if let Some((cached_threads, executor)) = &*self.dedicated_pool.read() {
145            if *cached_threads == n_threads {
146                return Ok(executor.clone());
147            }
148        }
149
150        let executor = ThreadExecutor::dedicated(n_threads)?;
151        let mut dedicated_pool = self.dedicated_pool.write();
152        *dedicated_pool = Some((n_threads, executor.clone()));
153        Ok(executor)
154    }
155}