paralight/core/
thread_pool.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! A thread pool implementing parallelism at a lightweight cost.
10
11use super::range::{
12    FixedRangeFactory, Range, RangeFactory, RangeOrchestrator, WorkStealingRangeFactory,
13};
14use super::sync::{make_lending_group, Borrower, Lender, WorkerState};
15use super::util::LifetimeParameterized;
16use crate::iter::Accumulator;
17use crate::macros::{log_debug, log_error, log_warn};
18use crossbeam_utils::CachePadded;
19// Platforms that support `libc::sched_setaffinity()`.
20#[cfg(all(
21    not(miri),
22    any(
23        target_os = "android",
24        target_os = "dragonfly",
25        target_os = "freebsd",
26        target_os = "linux"
27    )
28))]
29use nix::{
30    sched::{sched_setaffinity, CpuSet},
31    unistd::Pid,
32};
33use std::convert::TryFrom;
34use std::marker::PhantomData;
35use std::num::NonZeroUsize;
36use std::ops::ControlFlow;
37use std::sync::atomic::{AtomicUsize, Ordering};
38use std::sync::{Arc, Mutex};
39use std::thread::JoinHandle;
40
41/// Number of threads to spawn in a thread pool.
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum ThreadCount {
44    /// Spawn the number of threads returned by
45    /// [`std::thread::available_parallelism()`].
46    AvailableParallelism,
47    /// Spawn the given number of threads.
48    Count(NonZeroUsize),
49}
50
51impl TryFrom<usize> for ThreadCount {
52    type Error = <NonZeroUsize as TryFrom<usize>>::Error;
53
54    fn try_from(thread_count: usize) -> Result<Self, Self::Error> {
55        let count = NonZeroUsize::try_from(thread_count)?;
56        Ok(ThreadCount::Count(count))
57    }
58}
59
60/// Strategy to distribute ranges of work items among threads.
61#[derive(Clone, Copy)]
62pub enum RangeStrategy {
63    /// Each thread processes a fixed range of items.
64    Fixed,
65    /// Threads can steal work from each other.
66    WorkStealing,
67}
68
69/// Policy to pin worker threads to CPUs.
70#[derive(Clone, Copy)]
71pub enum CpuPinningPolicy {
72    /// Don't pin worker threads to CPUs.
73    No,
74    /// Pin each worker thread to a CPU, if CPU pinning is supported and
75    /// implemented on this platform.
76    IfSupported,
77    /// Pin each worker thread to a CPU. If CPU pinning isn't supported on this
78    /// platform (or not implemented), building a thread pool will panic.
79    Always,
80}
81
82/// A builder for [`ThreadPool`].
83pub struct ThreadPoolBuilder {
84    /// Number of worker threads to spawn in the pool.
85    pub num_threads: ThreadCount,
86    /// Strategy to distribute ranges of work items among threads.
87    pub range_strategy: RangeStrategy,
88    /// Policy to pin worker threads to CPUs.
89    pub cpu_pinning: CpuPinningPolicy,
90}
91
92impl ThreadPoolBuilder {
93    /// Spawns a thread pool.
94    ///
95    /// ```
96    /// # use paralight::iter::{IntoParallelRefSource, ParallelIteratorExt, ParallelSourceExt};
97    /// # use paralight::{CpuPinningPolicy, RangeStrategy, ThreadCount, ThreadPoolBuilder};
98    /// let pool_builder = ThreadPoolBuilder {
99    ///     num_threads: ThreadCount::AvailableParallelism,
100    ///     range_strategy: RangeStrategy::WorkStealing,
101    ///     cpu_pinning: CpuPinningPolicy::No,
102    /// };
103    /// let mut thread_pool = pool_builder.build();
104    ///
105    /// let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
106    /// let sum = input
107    ///     .par_iter()
108    ///     .with_thread_pool(&mut thread_pool)
109    ///     .sum::<i32>();
110    /// assert_eq!(sum, 5 * 11);
111    /// ```
112    pub fn build(&self) -> ThreadPool {
113        ThreadPool::new(self)
114    }
115}
116
117/// A thread pool that can execute parallel pipelines.
118///
119/// This type doesn't expose any public methods other than
120/// [`num_threads()`](Self::num_threads). You can interact with it via
121/// the [`ThreadPoolBuilder::build()`] function to create a thread pool, and the
122/// [`with_thread_pool()`](crate::iter::ParallelSourceExt::with_thread_pool)
123/// method to attach a thread pool to a parallel iterator.
124pub struct ThreadPool {
125    inner: ThreadPoolEnum,
126}
127
128impl ThreadPool {
129    /// Creates a new thread pool using the given parameters.
130    fn new(builder: &ThreadPoolBuilder) -> Self {
131        Self {
132            inner: ThreadPoolEnum::new(builder),
133        }
134    }
135
136    /// Returns the number of worker threads that have been spawned in this
137    /// thread pool.
138    pub fn num_threads(&self) -> NonZeroUsize {
139        self.inner.num_threads()
140    }
141
142    /// Processes an input of the given length in parallel and returns the
143    /// aggregated output.
144    ///
145    /// With this variant, the pipeline may skip processing items at larger
146    /// indices whenever a call to `process_item` returns
147    /// [`ControlFlow::Break`].
148    pub(crate) fn upper_bounded_pipeline<Output: Send, Accum>(
149        &mut self,
150        input_len: usize,
151        init: impl Fn() -> Accum + Sync,
152        process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
153        finalize: impl Fn(Accum) -> Output + Sync,
154        reduce: impl Fn(Output, Output) -> Output,
155    ) -> Output {
156        self.inner
157            .upper_bounded_pipeline(input_len, init, process_item, finalize, reduce)
158    }
159
160    /// Processes an input of the given length in parallel and returns the
161    /// aggregated output.
162    pub(crate) fn iter_pipeline<Output: Send>(
163        &mut self,
164        input_len: usize,
165        accum: impl Accumulator<usize, Output> + Sync,
166        reduce: impl Accumulator<Output, Output>,
167    ) -> Output {
168        self.inner.iter_pipeline(input_len, accum, reduce)
169    }
170}
171
172/// Underlying [`ThreadPool`] implementation, dispatching over the
173/// [`RangeStrategy`].
174enum ThreadPoolEnum {
175    Fixed(ThreadPoolImpl<FixedRangeFactory>),
176    WorkStealing(ThreadPoolImpl<WorkStealingRangeFactory>),
177}
178
179impl ThreadPoolEnum {
180    /// Creates a new thread pool using the given parameters.
181    fn new(builder: &ThreadPoolBuilder) -> Self {
182        let num_threads: NonZeroUsize = match builder.num_threads {
183            ThreadCount::AvailableParallelism => std::thread::available_parallelism()
184                .expect("Getting the available parallelism failed"),
185            ThreadCount::Count(count) => count,
186        };
187        let num_threads: usize = num_threads.into();
188        match builder.range_strategy {
189            RangeStrategy::Fixed => ThreadPoolEnum::Fixed(ThreadPoolImpl::new(
190                num_threads,
191                FixedRangeFactory::new(num_threads),
192                builder.cpu_pinning,
193            )),
194            RangeStrategy::WorkStealing => ThreadPoolEnum::WorkStealing(ThreadPoolImpl::new(
195                num_threads,
196                WorkStealingRangeFactory::new(num_threads),
197                builder.cpu_pinning,
198            )),
199        }
200    }
201
202    /// Returns the number of worker threads that have been spawned in this
203    /// thread pool.
204    fn num_threads(&self) -> NonZeroUsize {
205        match self {
206            ThreadPoolEnum::Fixed(inner) => inner.num_threads(),
207            ThreadPoolEnum::WorkStealing(inner) => inner.num_threads(),
208        }
209    }
210
211    /// Processes an input of the given length in parallel and returns the
212    /// aggregated output.
213    ///
214    /// With this variant, the pipeline may skip processing items at larger
215    /// indices whenever a call to `process_item` returns
216    /// [`ControlFlow::Break`].
217    fn upper_bounded_pipeline<Output: Send, Accum>(
218        &mut self,
219        input_len: usize,
220        init: impl Fn() -> Accum + Sync,
221        process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
222        finalize: impl Fn(Accum) -> Output + Sync,
223        reduce: impl Fn(Output, Output) -> Output,
224    ) -> Output {
225        match self {
226            ThreadPoolEnum::Fixed(inner) => {
227                inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce)
228            }
229            ThreadPoolEnum::WorkStealing(inner) => {
230                inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce)
231            }
232        }
233    }
234
235    /// Processes an input of the given length in parallel and returns the
236    /// aggregated output.
237    fn iter_pipeline<Output: Send>(
238        &mut self,
239        input_len: usize,
240        accum: impl Accumulator<usize, Output> + Sync,
241        reduce: impl Accumulator<Output, Output>,
242    ) -> Output {
243        match self {
244            ThreadPoolEnum::Fixed(inner) => inner.iter_pipeline(input_len, accum, reduce),
245            ThreadPoolEnum::WorkStealing(inner) => inner.iter_pipeline(input_len, accum, reduce),
246        }
247    }
248}
249
250/// Underlying [`ThreadPool`] implementation, specialized to a
251/// [`RangeStrategy`].
252struct ThreadPoolImpl<F: RangeFactory> {
253    /// Handles to all the worker threads in the pool.
254    threads: Vec<WorkerThreadHandle>,
255    /// Orchestrator for the work ranges distributed to the threads.
256    range_orchestrator: F::Orchestrator,
257    /// Pipeline to map and reduce inputs into an output.
258    pipeline: Lender<DynLifetimeSyncPipeline<F::Range>>,
259}
260
261/// Handle to a worker thread in a thread pool.
262struct WorkerThreadHandle {
263    /// Thread handle object.
264    handle: JoinHandle<()>,
265}
266
267impl<F: RangeFactory> ThreadPoolImpl<F> {
268    /// Creates a new thread pool using the given parameters.
269    fn new(num_threads: usize, range_factory: F, cpu_pinning: CpuPinningPolicy) -> Self
270    where
271        F::Range: Send + 'static,
272    {
273        let (lender, borrowers) = make_lending_group(num_threads);
274
275        #[cfg(any(
276            miri,
277            not(any(
278                target_os = "android",
279                target_os = "dragonfly",
280                target_os = "freebsd",
281                target_os = "linux"
282            ))
283        ))]
284        match cpu_pinning {
285            CpuPinningPolicy::No => (),
286            CpuPinningPolicy::IfSupported => {
287                log_warn!("Pinning threads to CPUs is not implemented on this platform.")
288            }
289            CpuPinningPolicy::Always => {
290                panic!("Pinning threads to CPUs is not implemented on this platform.")
291            }
292        }
293
294        let threads = borrowers
295            .into_iter()
296            .enumerate()
297            .map(|(id, borrower)| {
298                let mut context = ThreadContext {
299                    id,
300                    range: range_factory.range(id),
301                    pipeline: borrower,
302                };
303                WorkerThreadHandle {
304                    handle: std::thread::spawn(move || {
305                        #[cfg(all(
306                            not(miri),
307                            any(
308                                target_os = "android",
309                                target_os = "dragonfly",
310                                target_os = "freebsd",
311                                target_os = "linux"
312                            )
313                        ))]
314                        match cpu_pinning {
315                            CpuPinningPolicy::No => (),
316                            CpuPinningPolicy::IfSupported => {
317                                let mut cpu_set = CpuSet::new();
318                                if let Err(_e) = cpu_set.set(id) {
319                                    log_warn!("Failed to set CPU affinity for thread #{id}: {_e}");
320                                } else if let Err(_e) =
321                                    sched_setaffinity(Pid::from_raw(0), &cpu_set)
322                                {
323                                    log_warn!("Failed to set CPU affinity for thread #{id}: {_e}");
324                                } else {
325                                    log_debug!("Pinned thread #{id} to CPU #{id}");
326                                }
327                            }
328                            CpuPinningPolicy::Always => {
329                                let mut cpu_set = CpuSet::new();
330                                if let Err(e) = cpu_set.set(id) {
331                                    panic!("Failed to set CPU affinity for thread #{id}: {e}");
332                                } else if let Err(e) = sched_setaffinity(Pid::from_raw(0), &cpu_set)
333                                {
334                                    panic!("Failed to set CPU affinity for thread #{id}: {e}");
335                                } else {
336                                    log_debug!("Pinned thread #{id} to CPU #{id}");
337                                }
338                            }
339                        }
340                        context.run()
341                    }),
342                }
343            })
344            .collect();
345        log_debug!("[main thread] Spawned threads");
346
347        Self {
348            threads,
349            range_orchestrator: range_factory.orchestrator(),
350            pipeline: lender,
351        }
352    }
353
354    /// Returns the number of worker threads that have been spawned in this
355    /// thread pool.
356    fn num_threads(&self) -> NonZeroUsize {
357        self.threads.len().try_into().unwrap()
358    }
359
360    /// Processes an input of the given length in parallel and returns the
361    /// aggregated output.
362    ///
363    /// With this variant, the pipeline may skip processing items at larger
364    /// indices whenever a call to `process_item` returns
365    /// [`ControlFlow::Break`].
366    fn upper_bounded_pipeline<Output: Send, Accum>(
367        &mut self,
368        input_len: usize,
369        init: impl Fn() -> Accum + Sync,
370        process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
371        finalize: impl Fn(Accum) -> Output + Sync,
372        reduce: impl Fn(Output, Output) -> Output,
373    ) -> Output {
374        self.range_orchestrator.reset_ranges(input_len);
375
376        let num_threads = self.threads.len();
377        let outputs = (0..num_threads)
378            .map(|_| Mutex::new(None))
379            .collect::<Arc<[_]>>();
380        let bound = AtomicUsize::new(usize::MAX);
381
382        self.pipeline.lend(&UpperBoundedPipelineImpl {
383            bound: CachePadded::new(bound),
384            outputs: outputs.clone(),
385            init,
386            process_item,
387            finalize,
388        });
389
390        outputs
391            .iter()
392            .map(move |output| output.lock().unwrap().take().unwrap())
393            .reduce(reduce)
394            .unwrap()
395    }
396
397    /// Processes an input of the given length in parallel and returns the
398    /// aggregated output.
399    fn iter_pipeline<Output: Send>(
400        &mut self,
401        input_len: usize,
402        accum: impl Accumulator<usize, Output> + Sync,
403        reduce: impl Accumulator<Output, Output>,
404    ) -> Output {
405        self.range_orchestrator.reset_ranges(input_len);
406
407        let num_threads = self.threads.len();
408        let outputs = (0..num_threads)
409            .map(|_| Mutex::new(None))
410            .collect::<Arc<[_]>>();
411
412        self.pipeline.lend(&IterPipelineImpl {
413            outputs: outputs.clone(),
414            accum,
415        });
416
417        reduce.accumulate(
418            outputs
419                .iter()
420                .map(move |output| output.lock().unwrap().take().unwrap()),
421        )
422    }
423}
424
425impl<F: RangeFactory> Drop for ThreadPoolImpl<F> {
426    /// Joins all the threads in the pool.
427    #[allow(clippy::single_match, clippy::unused_enumerate_index)]
428    fn drop(&mut self) {
429        self.pipeline.finish_workers();
430
431        log_debug!("[main thread] Joining threads in the pool...");
432        for (_i, t) in self.threads.drain(..).enumerate() {
433            let result = t.handle.join();
434            match result {
435                Ok(_) => log_debug!("[main thread] Thread {_i} joined with result: {result:?}"),
436                Err(_) => log_error!("[main thread] Thread {_i} joined with result: {result:?}"),
437            }
438        }
439        log_debug!("[main thread] Joined threads.");
440
441        #[cfg(feature = "log_parallelism")]
442        self.range_orchestrator.print_statistics();
443    }
444}
445
446trait Pipeline<R: Range> {
447    fn run(&self, worker_id: usize, range: &R);
448}
449
450/// An intermediate struct representing a `dyn Pipeline<R> + Sync` with variable
451/// lifetime. Because Rust doesn't directly support higher-kinded types, we use
452/// the generic associated type of the [`LifetimeParameterized`] trait as a
453/// proxy.
454struct DynLifetimeSyncPipeline<R: Range>(PhantomData<R>);
455
456impl<R: Range> LifetimeParameterized for DynLifetimeSyncPipeline<R> {
457    type T<'a> = dyn Pipeline<R> + Sync + 'a;
458}
459
460struct UpperBoundedPipelineImpl<
461    Output,
462    Accum,
463    Init: Fn() -> Accum,
464    ProcessItem: Fn(Accum, usize) -> ControlFlow<Accum, Accum>,
465    Finalize: Fn(Accum) -> Output,
466> {
467    bound: CachePadded<AtomicUsize>,
468    outputs: Arc<[Mutex<Option<Output>>]>,
469    init: Init,
470    process_item: ProcessItem,
471    finalize: Finalize,
472}
473
474impl<R, Output, Accum, Init, ProcessItem, Finalize> Pipeline<R>
475    for UpperBoundedPipelineImpl<Output, Accum, Init, ProcessItem, Finalize>
476where
477    R: Range,
478    Init: Fn() -> Accum,
479    ProcessItem: Fn(Accum, usize) -> ControlFlow<Accum, Accum>,
480    Finalize: Fn(Accum) -> Output,
481{
482    fn run(&self, worker_id: usize, range: &R) {
483        let mut accumulator = (self.init)();
484        for i in range.upper_bounded_iter(&self.bound) {
485            let acc = (self.process_item)(accumulator, i);
486            accumulator = match acc {
487                ControlFlow::Continue(acc) => acc,
488                ControlFlow::Break(acc) => {
489                    self.bound.fetch_min(i, Ordering::Relaxed);
490                    acc
491                }
492            };
493        }
494        let output = (self.finalize)(accumulator);
495        *self.outputs[worker_id].lock().unwrap() = Some(output);
496    }
497}
498
499struct IterPipelineImpl<Output, Accum: Accumulator<usize, Output>> {
500    outputs: Arc<[Mutex<Option<Output>>]>,
501    accum: Accum,
502}
503
504impl<R, Output, Accum> Pipeline<R> for IterPipelineImpl<Output, Accum>
505where
506    R: Range,
507    Accum: Accumulator<usize, Output>,
508{
509    fn run(&self, worker_id: usize, range: &R) {
510        let output = self.accum.accumulate(range.iter());
511        *self.outputs[worker_id].lock().unwrap() = Some(output);
512    }
513}
514
515/// Context object owned by a worker thread.
516struct ThreadContext<R: Range> {
517    /// Thread index.
518    id: usize,
519    /// Range of items that this worker thread needs to process.
520    range: R,
521    /// Pipeline to map and reduce inputs into the output.
522    pipeline: Borrower<DynLifetimeSyncPipeline<R>>,
523}
524
525impl<R: Range> ThreadContext<R> {
526    /// Main function run by this thread.
527    fn run(&mut self) {
528        loop {
529            match self.pipeline.borrow(|pipeline| {
530                pipeline.run(self.id, &self.range);
531            }) {
532                WorkerState::Finished => break,
533                WorkerState::Ready => continue,
534            }
535        }
536    }
537}
538
539#[cfg(test)]
540mod test {
541    use super::*;
542    use crate::iter::{IntoParallelRefSource, ParallelIteratorExt, ParallelSourceExt};
543
544    #[test]
545    fn test_thread_count_try_from_usize() {
546        assert!(ThreadCount::try_from(0).is_err());
547        assert_eq!(
548            ThreadCount::try_from(1),
549            Ok(ThreadCount::Count(NonZeroUsize::try_from(1).unwrap()))
550        );
551    }
552
553    #[test]
554    fn test_build_thread_pool_available_parallelism() {
555        let mut thread_pool = ThreadPoolBuilder {
556            num_threads: ThreadCount::AvailableParallelism,
557            range_strategy: RangeStrategy::Fixed,
558            cpu_pinning: CpuPinningPolicy::No,
559        }
560        .build();
561
562        let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
563        let sum = input
564            .par_iter()
565            .with_thread_pool(&mut thread_pool)
566            .sum::<i32>();
567
568        assert_eq!(sum, 5 * 11);
569    }
570
571    #[test]
572    fn test_build_thread_pool_fixed_thread_count() {
573        let mut thread_pool = ThreadPoolBuilder {
574            num_threads: ThreadCount::try_from(4).unwrap(),
575            range_strategy: RangeStrategy::Fixed,
576            cpu_pinning: CpuPinningPolicy::No,
577        }
578        .build();
579
580        let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
581        let sum = input
582            .par_iter()
583            .with_thread_pool(&mut thread_pool)
584            .sum::<i32>();
585
586        assert_eq!(sum, 5 * 11);
587    }
588
589    #[test]
590    fn test_build_thread_pool_cpu_pinning_if_supported() {
591        let mut thread_pool = ThreadPoolBuilder {
592            num_threads: ThreadCount::AvailableParallelism,
593            range_strategy: RangeStrategy::Fixed,
594            cpu_pinning: CpuPinningPolicy::IfSupported,
595        }
596        .build();
597
598        let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
599        let sum = input
600            .par_iter()
601            .with_thread_pool(&mut thread_pool)
602            .sum::<i32>();
603
604        assert_eq!(sum, 5 * 11);
605    }
606
607    #[cfg(all(
608        not(miri),
609        any(
610            target_os = "android",
611            target_os = "dragonfly",
612            target_os = "freebsd",
613            target_os = "linux"
614        )
615    ))]
616    #[test]
617    fn test_build_thread_pool_cpu_pinning_always() {
618        let mut thread_pool = ThreadPoolBuilder {
619            num_threads: ThreadCount::AvailableParallelism,
620            range_strategy: RangeStrategy::Fixed,
621            cpu_pinning: CpuPinningPolicy::Always,
622        }
623        .build();
624
625        let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
626        let sum = input
627            .par_iter()
628            .with_thread_pool(&mut thread_pool)
629            .sum::<i32>();
630
631        assert_eq!(sum, 5 * 11);
632    }
633
634    #[cfg(any(
635        miri,
636        not(any(
637            target_os = "android",
638            target_os = "dragonfly",
639            target_os = "freebsd",
640            target_os = "linux"
641        ))
642    ))]
643    #[test]
644    #[should_panic = "Pinning threads to CPUs is not implemented on this platform."]
645    fn test_build_thread_pool_cpu_pinning_always_not_supported() {
646        ThreadPoolBuilder {
647            num_threads: ThreadCount::AvailableParallelism,
648            range_strategy: RangeStrategy::Fixed,
649            cpu_pinning: CpuPinningPolicy::Always,
650        }
651        .build();
652    }
653
654    #[test]
655    fn test_num_threads() {
656        let thread_pool = ThreadPoolBuilder {
657            num_threads: ThreadCount::AvailableParallelism,
658            range_strategy: RangeStrategy::Fixed,
659            cpu_pinning: CpuPinningPolicy::No,
660        }
661        .build();
662        assert_eq!(
663            thread_pool.num_threads(),
664            std::thread::available_parallelism().unwrap()
665        );
666
667        let thread_pool = ThreadPoolBuilder {
668            num_threads: ThreadCount::try_from(4).unwrap(),
669            range_strategy: RangeStrategy::Fixed,
670            cpu_pinning: CpuPinningPolicy::No,
671        }
672        .build();
673        assert_eq!(
674            thread_pool.num_threads(),
675            NonZeroUsize::try_from(4).unwrap()
676        );
677    }
678}