jxl_threadpool/
lib.rs

1//! Abstraction of thread pools, intended to be used in jxl-oxide.
2//!
3//! [`JxlThreadPool`] is re-exported by `jxl-oxide`.
4
5/// Thread pool wrapper.
6///
7/// This struct wraps internal thread pool implementation and provides interfaces to access it. If
8/// `rayon` feature is enabled, users can create an actual thread pool backed by Rayon, or use
9/// global Rayon thread pool; if not, this struct won't have any multithreading capability, and
10/// every spawn operation will just run the given closure in place.
11#[derive(Debug, Clone)]
12pub struct JxlThreadPool(JxlThreadPoolImpl);
13
14#[derive(Debug, Clone)]
15enum JxlThreadPoolImpl {
16    #[cfg(feature = "rayon")]
17    Rayon(std::sync::Arc<rayon_core::ThreadPool>),
18    #[cfg(feature = "rayon")]
19    RayonGlobal,
20    None,
21}
22
23/// Fork-join scope created by thread pool.
24#[derive(Debug, Copy, Clone)]
25pub struct JxlScope<'r, 'scope>(JxlScopeInner<'r, 'scope>);
26
27#[derive(Debug, Copy, Clone)]
28enum JxlScopeInner<'r, 'scope> {
29    #[cfg(feature = "rayon")]
30    Rayon(&'r rayon_core::Scope<'scope>),
31    None(std::marker::PhantomData<&'r &'scope ()>),
32}
33
34impl JxlThreadPool {
35    /// Creates a "fake" thread pool without any multithreading capability.
36    ///
37    /// Every spawn operation on this thread poll will just run the closure in current thread.
38    pub const fn none() -> Self {
39        Self(JxlThreadPoolImpl::None)
40    }
41
42    /// Creates a thread pool backed by Rayon [`ThreadPool`][rayon_core::ThreadPool].
43    #[cfg(feature = "rayon")]
44    pub fn with_rayon_thread_pool(pool: std::sync::Arc<rayon_core::ThreadPool>) -> Self {
45        Self(JxlThreadPoolImpl::Rayon(pool))
46    }
47
48    /// Creates a thread pool backed by Rayon.
49    ///
50    /// If `num_threads_requested` is `None` or zero, this method queries available paralleism and
51    /// uses it.
52    #[cfg(feature = "rayon")]
53    pub fn rayon(num_threads_requested: Option<usize>) -> Self {
54        let num_threads_requested = num_threads_requested.unwrap_or(0);
55
56        let num_threads = if num_threads_requested == 0 {
57            let num_threads = std::thread::available_parallelism();
58            match num_threads {
59                Ok(num_threads) => num_threads.into(),
60                Err(e) => {
61                    tracing::warn!(%e, "Failed to query available parallelism; falling back to single-threaded");
62                    return Self::none();
63                }
64            }
65        } else {
66            num_threads_requested
67        };
68
69        let inner = rayon_core::ThreadPoolBuilder::new()
70            .num_threads(num_threads)
71            .build()
72            .map(|pool| JxlThreadPoolImpl::Rayon(std::sync::Arc::new(pool)));
73
74        match inner {
75            Ok(inner) => {
76                tracing::debug!(num_threads, "Initialized Rayon thread pool");
77                Self(inner)
78            }
79            Err(e) => {
80                tracing::warn!(%e, "Failed to initialize thread pool; falling back to single-threaded");
81                Self::none()
82            }
83        }
84    }
85
86    /// Creates a `JxlThreadPool` backed by global Rayon thread pool.
87    #[cfg(feature = "rayon")]
88    pub const fn rayon_global() -> Self {
89        Self(JxlThreadPoolImpl::RayonGlobal)
90    }
91
92    /// Returns the reference to Rayon thread pool, if exists.
93    ///
94    /// Returns `None` for thread pools created using [`rayon_global`], as they don't have Rayon
95    /// thread pool references.
96    ///
97    /// [`rayon_global`]: JxlThreadPool::rayon_global
98    #[cfg(feature = "rayon")]
99    pub fn as_rayon_pool(&self) -> Option<&rayon_core::ThreadPool> {
100        match &self.0 {
101            JxlThreadPoolImpl::Rayon(pool) => Some(&**pool),
102            JxlThreadPoolImpl::RayonGlobal | JxlThreadPoolImpl::None => None,
103        }
104    }
105
106    /// Returns if the thread pool is capable of multithreading.
107    pub fn is_multithreaded(&self) -> bool {
108        match self.0 {
109            #[cfg(feature = "rayon")]
110            JxlThreadPoolImpl::Rayon(_) | JxlThreadPoolImpl::RayonGlobal => true,
111            JxlThreadPoolImpl::None => false,
112        }
113    }
114}
115
116impl JxlThreadPool {
117    /// Runs the given closure on the thread pool.
118    pub fn spawn(&self, op: impl FnOnce() + Send + 'static) {
119        match &self.0 {
120            #[cfg(feature = "rayon")]
121            JxlThreadPoolImpl::Rayon(pool) => pool.spawn(op),
122            #[cfg(feature = "rayon")]
123            JxlThreadPoolImpl::RayonGlobal => rayon_core::spawn(op),
124            JxlThreadPoolImpl::None => op(),
125        }
126    }
127
128    /// Creates a fork-join scope of tasks.
129    pub fn scope<'scope, R: Send>(
130        &'scope self,
131        op: impl for<'r> FnOnce(JxlScope<'r, 'scope>) -> R + Send,
132    ) -> R {
133        match &self.0 {
134            #[cfg(feature = "rayon")]
135            JxlThreadPoolImpl::Rayon(pool) => pool.scope(|scope| {
136                let scope = JxlScope(JxlScopeInner::Rayon(scope));
137                op(scope)
138            }),
139            #[cfg(feature = "rayon")]
140            JxlThreadPoolImpl::RayonGlobal => rayon_core::scope(|scope| {
141                let scope = JxlScope(JxlScopeInner::Rayon(scope));
142                op(scope)
143            }),
144            JxlThreadPoolImpl::None => op(JxlScope(JxlScopeInner::None(Default::default()))),
145        }
146    }
147
148    /// Consumes the `Vec`, and runs a job for each element of the `Vec`.
149    pub fn for_each_vec<T: Send>(&self, v: Vec<T>, op: impl Fn(T) + Send + Sync) {
150        match &self.0 {
151            #[cfg(feature = "rayon")]
152            JxlThreadPoolImpl::Rayon(pool) => pool.install(|| par_for_each(v, op)),
153            #[cfg(feature = "rayon")]
154            JxlThreadPoolImpl::RayonGlobal => par_for_each(v, op),
155            JxlThreadPoolImpl::None => v.into_iter().for_each(op),
156        }
157    }
158
159    /// Consumes the `Vec`, and runs a job for each element of the `Vec`.
160    pub fn for_each_vec_with<T: Send, U: Send + Clone>(
161        &self,
162        v: Vec<T>,
163        init: U,
164        op: impl Fn(&mut U, T) + Send + Sync,
165    ) {
166        match &self.0 {
167            #[cfg(feature = "rayon")]
168            JxlThreadPoolImpl::Rayon(pool) => pool.install(|| par_for_each_with(v, init, op)),
169            #[cfg(feature = "rayon")]
170            JxlThreadPoolImpl::RayonGlobal => par_for_each_with(v, init, op),
171            JxlThreadPoolImpl::None => {
172                let mut init = init;
173                v.into_iter().for_each(|item| op(&mut init, item))
174            }
175        }
176    }
177
178    /// Runs a job for each element of the mutable slice.
179    pub fn for_each_mut_slice<'a, T: Send>(
180        &self,
181        v: &'a mut [T],
182        op: impl Fn(&'a mut T) + Send + Sync,
183    ) {
184        match &self.0 {
185            #[cfg(feature = "rayon")]
186            JxlThreadPoolImpl::Rayon(pool) => pool.install(|| par_for_each(v, op)),
187            #[cfg(feature = "rayon")]
188            JxlThreadPoolImpl::RayonGlobal => par_for_each(v, op),
189            JxlThreadPoolImpl::None => v.iter_mut().for_each(op),
190        }
191    }
192
193    /// Runs a job for each element of the mutable slice.
194    pub fn for_each_mut_slice_with<'a, T: Send, U: Send + Clone>(
195        &self,
196        v: &'a mut [T],
197        init: U,
198        op: impl Fn(&mut U, &'a mut T) + Send + Sync,
199    ) {
200        match &self.0 {
201            #[cfg(feature = "rayon")]
202            JxlThreadPoolImpl::Rayon(pool) => pool.install(|| par_for_each_with(v, init, op)),
203            #[cfg(feature = "rayon")]
204            JxlThreadPoolImpl::RayonGlobal => par_for_each_with(v, init, op),
205            JxlThreadPoolImpl::None => {
206                let mut init = init;
207                v.iter_mut().for_each(|item| op(&mut init, item))
208            }
209        }
210    }
211}
212
213#[cfg(feature = "rayon")]
214fn par_for_each<T: Send>(
215    it: impl rayon::iter::IntoParallelIterator<Item = T>,
216    op: impl Fn(T) + Send + Sync,
217) {
218    use rayon::prelude::*;
219    it.into_par_iter().for_each(op);
220}
221
222#[cfg(feature = "rayon")]
223fn par_for_each_with<T: Send, U: Send + Clone>(
224    it: impl rayon::iter::IntoParallelIterator<Item = T>,
225    init: U,
226    op: impl Fn(&mut U, T) + Send + Sync,
227) {
228    use rayon::prelude::*;
229    it.into_par_iter().for_each_with(init, op);
230}
231
232impl<'scope> JxlScope<'_, 'scope> {
233    /// Spanws the given closure in current fork-join scope.
234    pub fn spawn(&self, op: impl for<'r> FnOnce(JxlScope<'r, 'scope>) + Send + 'scope) {
235        match self.0 {
236            #[cfg(feature = "rayon")]
237            JxlScopeInner::Rayon(scope) => scope.spawn(|scope| {
238                let scope = JxlScope(JxlScopeInner::Rayon(scope));
239                op(scope)
240            }),
241            JxlScopeInner::None(_) => op(JxlScope(JxlScopeInner::None(Default::default()))),
242        }
243    }
244}