1#[derive(Debug, Clone)]
12pub struct JxlThreadPool(JxlThreadPoolImpl);
13
14#[derive(Debug, Clone)]
15enum JxlThreadPoolImpl {
16 #[cfg(feature = "rayon")]
17 Rayon(std::sync::Arc<rayon_core::ThreadPool>),
18 #[cfg(feature = "rayon")]
19 RayonGlobal,
20 None,
21}
22
23#[derive(Debug, Copy, Clone)]
25pub struct JxlScope<'r, 'scope>(JxlScopeInner<'r, 'scope>);
26
27#[derive(Debug, Copy, Clone)]
28enum JxlScopeInner<'r, 'scope> {
29 #[cfg(feature = "rayon")]
30 Rayon(&'r rayon_core::Scope<'scope>),
31 None(std::marker::PhantomData<&'r &'scope ()>),
32}
33
34impl JxlThreadPool {
35 pub const fn none() -> Self {
39 Self(JxlThreadPoolImpl::None)
40 }
41
42 #[cfg(feature = "rayon")]
44 pub fn with_rayon_thread_pool(pool: std::sync::Arc<rayon_core::ThreadPool>) -> Self {
45 Self(JxlThreadPoolImpl::Rayon(pool))
46 }
47
48 #[cfg(feature = "rayon")]
53 pub fn rayon(num_threads_requested: Option<usize>) -> Self {
54 let num_threads_requested = num_threads_requested.unwrap_or(0);
55
56 let num_threads = if num_threads_requested == 0 {
57 let num_threads = std::thread::available_parallelism();
58 match num_threads {
59 Ok(num_threads) => num_threads.into(),
60 Err(e) => {
61 tracing::warn!(%e, "Failed to query available parallelism; falling back to single-threaded");
62 return Self::none();
63 }
64 }
65 } else {
66 num_threads_requested
67 };
68
69 let inner = rayon_core::ThreadPoolBuilder::new()
70 .num_threads(num_threads)
71 .build()
72 .map(|pool| JxlThreadPoolImpl::Rayon(std::sync::Arc::new(pool)));
73
74 match inner {
75 Ok(inner) => {
76 tracing::debug!(num_threads, "Initialized Rayon thread pool");
77 Self(inner)
78 }
79 Err(e) => {
80 tracing::warn!(%e, "Failed to initialize thread pool; falling back to single-threaded");
81 Self::none()
82 }
83 }
84 }
85
86 #[cfg(feature = "rayon")]
88 pub const fn rayon_global() -> Self {
89 Self(JxlThreadPoolImpl::RayonGlobal)
90 }
91
92 #[cfg(feature = "rayon")]
99 pub fn as_rayon_pool(&self) -> Option<&rayon_core::ThreadPool> {
100 match &self.0 {
101 JxlThreadPoolImpl::Rayon(pool) => Some(&**pool),
102 JxlThreadPoolImpl::RayonGlobal | JxlThreadPoolImpl::None => None,
103 }
104 }
105
106 pub fn is_multithreaded(&self) -> bool {
108 match self.0 {
109 #[cfg(feature = "rayon")]
110 JxlThreadPoolImpl::Rayon(_) | JxlThreadPoolImpl::RayonGlobal => true,
111 JxlThreadPoolImpl::None => false,
112 }
113 }
114}
115
116impl JxlThreadPool {
117 pub fn spawn(&self, op: impl FnOnce() + Send + 'static) {
119 match &self.0 {
120 #[cfg(feature = "rayon")]
121 JxlThreadPoolImpl::Rayon(pool) => pool.spawn(op),
122 #[cfg(feature = "rayon")]
123 JxlThreadPoolImpl::RayonGlobal => rayon_core::spawn(op),
124 JxlThreadPoolImpl::None => op(),
125 }
126 }
127
128 pub fn scope<'scope, R: Send>(
130 &'scope self,
131 op: impl for<'r> FnOnce(JxlScope<'r, 'scope>) -> R + Send,
132 ) -> R {
133 match &self.0 {
134 #[cfg(feature = "rayon")]
135 JxlThreadPoolImpl::Rayon(pool) => pool.scope(|scope| {
136 let scope = JxlScope(JxlScopeInner::Rayon(scope));
137 op(scope)
138 }),
139 #[cfg(feature = "rayon")]
140 JxlThreadPoolImpl::RayonGlobal => rayon_core::scope(|scope| {
141 let scope = JxlScope(JxlScopeInner::Rayon(scope));
142 op(scope)
143 }),
144 JxlThreadPoolImpl::None => op(JxlScope(JxlScopeInner::None(Default::default()))),
145 }
146 }
147
148 pub fn for_each_vec<T: Send>(&self, v: Vec<T>, op: impl Fn(T) + Send + Sync) {
150 match &self.0 {
151 #[cfg(feature = "rayon")]
152 JxlThreadPoolImpl::Rayon(pool) => pool.install(|| par_for_each(v, op)),
153 #[cfg(feature = "rayon")]
154 JxlThreadPoolImpl::RayonGlobal => par_for_each(v, op),
155 JxlThreadPoolImpl::None => v.into_iter().for_each(op),
156 }
157 }
158
159 pub fn for_each_vec_with<T: Send, U: Send + Clone>(
161 &self,
162 v: Vec<T>,
163 init: U,
164 op: impl Fn(&mut U, T) + Send + Sync,
165 ) {
166 match &self.0 {
167 #[cfg(feature = "rayon")]
168 JxlThreadPoolImpl::Rayon(pool) => pool.install(|| par_for_each_with(v, init, op)),
169 #[cfg(feature = "rayon")]
170 JxlThreadPoolImpl::RayonGlobal => par_for_each_with(v, init, op),
171 JxlThreadPoolImpl::None => {
172 let mut init = init;
173 v.into_iter().for_each(|item| op(&mut init, item))
174 }
175 }
176 }
177
178 pub fn for_each_mut_slice<'a, T: Send>(
180 &self,
181 v: &'a mut [T],
182 op: impl Fn(&'a mut T) + Send + Sync,
183 ) {
184 match &self.0 {
185 #[cfg(feature = "rayon")]
186 JxlThreadPoolImpl::Rayon(pool) => pool.install(|| par_for_each(v, op)),
187 #[cfg(feature = "rayon")]
188 JxlThreadPoolImpl::RayonGlobal => par_for_each(v, op),
189 JxlThreadPoolImpl::None => v.iter_mut().for_each(op),
190 }
191 }
192
193 pub fn for_each_mut_slice_with<'a, T: Send, U: Send + Clone>(
195 &self,
196 v: &'a mut [T],
197 init: U,
198 op: impl Fn(&mut U, &'a mut T) + Send + Sync,
199 ) {
200 match &self.0 {
201 #[cfg(feature = "rayon")]
202 JxlThreadPoolImpl::Rayon(pool) => pool.install(|| par_for_each_with(v, init, op)),
203 #[cfg(feature = "rayon")]
204 JxlThreadPoolImpl::RayonGlobal => par_for_each_with(v, init, op),
205 JxlThreadPoolImpl::None => {
206 let mut init = init;
207 v.iter_mut().for_each(|item| op(&mut init, item))
208 }
209 }
210 }
211}
212
213#[cfg(feature = "rayon")]
214fn par_for_each<T: Send>(
215 it: impl rayon::iter::IntoParallelIterator<Item = T>,
216 op: impl Fn(T) + Send + Sync,
217) {
218 use rayon::prelude::*;
219 it.into_par_iter().for_each(op);
220}
221
222#[cfg(feature = "rayon")]
223fn par_for_each_with<T: Send, U: Send + Clone>(
224 it: impl rayon::iter::IntoParallelIterator<Item = T>,
225 init: U,
226 op: impl Fn(&mut U, T) + Send + Sync,
227) {
228 use rayon::prelude::*;
229 it.into_par_iter().for_each_with(init, op);
230}
231
232impl<'scope> JxlScope<'_, 'scope> {
233 pub fn spawn(&self, op: impl for<'r> FnOnce(JxlScope<'r, 'scope>) + Send + 'scope) {
235 match self.0 {
236 #[cfg(feature = "rayon")]
237 JxlScopeInner::Rayon(scope) => scope.spawn(|scope| {
238 let scope = JxlScope(JxlScopeInner::Rayon(scope));
239 op(scope)
240 }),
241 JxlScopeInner::None(_) => op(JxlScope(JxlScopeInner::None(Default::default()))),
242 }
243 }
244}