mod sync;
mod util;
use super::{RangeStrategy, ThreadCount};
use crate::core::pipeline::{IterPipelineImpl, Pipeline, UpperBoundedPipelineImpl};
use crate::core::range::{
FixedRangeFactory, Range, RangeFactory, RangeOrchestrator, WorkStealingRangeFactory,
};
use crate::iter::{Accumulator, ExactSizeAccumulator, GenericThreadPool, SourceCleanup};
use crate::macros::{log_debug, log_error, log_warn};
use crossbeam_utils::CachePadded;
use sync::{make_lending_group, Borrower, Lender, WorkerState};
use util::LifetimeParameterized;
#[cfg(all(
not(miri),
any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux"
)
))]
use nix::{
sched::{sched_setaffinity, CpuSet},
unistd::Pid,
};
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::ops::ControlFlow;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
#[cfg(all(not(miri), target_os = "windows"))]
use windows_sys::Win32::{
Foundation::GetLastError,
System::Threading::{GetCurrentThread, SetThreadAffinityMask},
};
#[derive(Clone, Copy)]
pub enum CpuPinningPolicy {
No,
IfSupported,
Always,
}
pub struct ThreadPoolBuilder {
pub num_threads: ThreadCount,
pub range_strategy: RangeStrategy,
pub cpu_pinning: CpuPinningPolicy,
}
impl ThreadPoolBuilder {
pub fn build(&self) -> ThreadPool {
ThreadPool::new(self)
}
}
pub struct ThreadPool {
inner: ThreadPoolEnum,
}
impl ThreadPool {
fn new(builder: &ThreadPoolBuilder) -> Self {
Self {
inner: ThreadPoolEnum::new(builder),
}
}
pub fn num_threads(&self) -> NonZeroUsize {
self.inner.num_threads()
}
}
unsafe impl GenericThreadPool for &mut ThreadPool {
fn upper_bounded_pipeline<Output: Send, Accum>(
self,
input_len: usize,
init: impl Fn() -> Accum + Sync,
process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
finalize: impl Fn(Accum) -> Output + Sync,
reduce: impl Fn(Output, Output) -> Output,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
self.inner
.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce, cleanup)
}
fn iter_pipeline<Output, Accum: Send>(
self,
input_len: usize,
accum: impl Accumulator<usize, Accum> + Sync,
reduce: impl ExactSizeAccumulator<Accum, Output>,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
self.inner.iter_pipeline(input_len, accum, reduce, cleanup)
}
}
enum ThreadPoolEnum {
Fixed(ThreadPoolImpl<FixedRangeFactory>),
WorkStealing(ThreadPoolImpl<WorkStealingRangeFactory>),
}
impl ThreadPoolEnum {
fn new(builder: &ThreadPoolBuilder) -> Self {
let num_threads: NonZeroUsize = builder.num_threads.count();
let num_threads: usize = num_threads.into();
match builder.range_strategy {
RangeStrategy::Fixed => ThreadPoolEnum::Fixed(ThreadPoolImpl::new(
num_threads,
FixedRangeFactory::new(num_threads),
builder.cpu_pinning,
)),
RangeStrategy::WorkStealing => ThreadPoolEnum::WorkStealing(ThreadPoolImpl::new(
num_threads,
WorkStealingRangeFactory::new(num_threads),
builder.cpu_pinning,
)),
}
}
fn num_threads(&self) -> NonZeroUsize {
match self {
ThreadPoolEnum::Fixed(inner) => inner.num_threads(),
ThreadPoolEnum::WorkStealing(inner) => inner.num_threads(),
}
}
fn upper_bounded_pipeline<Output: Send, Accum>(
&mut self,
input_len: usize,
init: impl Fn() -> Accum + Sync,
process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
finalize: impl Fn(Accum) -> Output + Sync,
reduce: impl Fn(Output, Output) -> Output,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
match self {
ThreadPoolEnum::Fixed(inner) => inner.upper_bounded_pipeline(
input_len,
init,
process_item,
finalize,
reduce,
cleanup,
),
ThreadPoolEnum::WorkStealing(inner) => inner.upper_bounded_pipeline(
input_len,
init,
process_item,
finalize,
reduce,
cleanup,
),
}
}
fn iter_pipeline<Output, Accum: Send>(
&mut self,
input_len: usize,
accum: impl Accumulator<usize, Accum> + Sync,
reduce: impl ExactSizeAccumulator<Accum, Output>,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
match self {
ThreadPoolEnum::Fixed(inner) => inner.iter_pipeline(input_len, accum, reduce, cleanup),
ThreadPoolEnum::WorkStealing(inner) => {
inner.iter_pipeline(input_len, accum, reduce, cleanup)
}
}
}
}
struct ThreadPoolImpl<F: RangeFactory> {
threads: Vec<WorkerThreadHandle>,
range_orchestrator: F::Orchestrator,
pipeline: Lender<DynLifetimeSyncPipeline<F::Range>>,
}
struct WorkerThreadHandle {
handle: JoinHandle<()>,
}
impl<F: RangeFactory> ThreadPoolImpl<F> {
fn new(num_threads: usize, range_factory: F, cpu_pinning: CpuPinningPolicy) -> Self
where
F::Range: Send + 'static,
{
let (lender, borrowers) = make_lending_group(num_threads);
#[cfg(any(
miri,
not(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux",
target_os = "windows"
))
))]
match cpu_pinning {
CpuPinningPolicy::No => (),
CpuPinningPolicy::IfSupported => {
log_warn!("Pinning threads to CPUs is not implemented on this platform.")
}
CpuPinningPolicy::Always => {
panic!("Pinning threads to CPUs is not implemented on this platform.")
}
}
let threads = borrowers
.into_iter()
.enumerate()
.map(|(id, borrower)| {
let mut context = ThreadContext {
id,
range: range_factory.range(id),
pipeline: borrower,
};
WorkerThreadHandle {
handle: std::thread::spawn(move || {
#[cfg(all(
not(miri),
any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux"
)
))]
match cpu_pinning {
CpuPinningPolicy::No => (),
CpuPinningPolicy::IfSupported => {
let mut cpu_set = CpuSet::new();
if let Err(_e) = cpu_set.set(id) {
log_warn!("Failed to set CPU affinity for thread #{id}: {_e}");
} else if let Err(_e) =
sched_setaffinity(Pid::from_raw(0), &cpu_set)
{
log_warn!("Failed to set CPU affinity for thread #{id}: {_e}");
} else {
log_debug!("Pinned thread #{id} to CPU #{id}");
}
}
CpuPinningPolicy::Always => {
let mut cpu_set = CpuSet::new();
if let Err(e) = cpu_set.set(id) {
panic!("Failed to set CPU affinity for thread #{id}: {e}");
} else if let Err(e) = sched_setaffinity(Pid::from_raw(0), &cpu_set)
{
panic!("Failed to set CPU affinity for thread #{id}: {e}");
} else {
log_debug!("Pinned thread #{id} to CPU #{id}");
}
}
}
#[cfg(all(not(miri), target_os = "windows"))]
match cpu_pinning {
CpuPinningPolicy::No => (),
CpuPinningPolicy::IfSupported => {
let affinity_mask = 1usize << id;
let thread = unsafe { GetCurrentThread() };
let result = unsafe { SetThreadAffinityMask(thread, affinity_mask) };
if result == 0 {
let _last_error = unsafe { GetLastError() };
log_warn!("Failed to set CPU affinity for thread #{id}: error code {_last_error}");
} else {
log_debug!("Pinned thread #{id} to CPU #{id}");
}
}
CpuPinningPolicy::Always => {
let affinity_mask = 1usize << id;
let thread = unsafe { GetCurrentThread() };
let result = unsafe { SetThreadAffinityMask(thread, affinity_mask) };
if result == 0 {
let last_error = unsafe { GetLastError() };
panic!("Failed to set CPU affinity for thread #{id}: error code {last_error}");
} else {
log_debug!("Pinned thread #{id} to CPU #{id}");
}
}
}
context.run()
}),
}
})
.collect();
log_debug!("[main thread] Spawned threads");
Self {
threads,
range_orchestrator: range_factory.orchestrator(),
pipeline: lender,
}
}
fn num_threads(&self) -> NonZeroUsize {
self.threads.len().try_into().unwrap()
}
fn upper_bounded_pipeline<Output: Send, Accum>(
&mut self,
input_len: usize,
init: impl Fn() -> Accum + Sync,
process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
finalize: impl Fn(Accum) -> Output + Sync,
reduce: impl Fn(Output, Output) -> Output,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
self.range_orchestrator.reset_ranges(input_len);
let num_threads = self.threads.len();
let outputs = (0..num_threads)
.map(|_| Mutex::new(None))
.collect::<Arc<[_]>>();
let bound = AtomicUsize::new(usize::MAX);
self.pipeline.lend(&UpperBoundedPipelineImpl {
bound: CachePadded::new(bound),
outputs: outputs.clone(),
init,
process_item,
finalize,
cleanup,
});
outputs
.iter()
.map(move |output| output.lock().unwrap().take().unwrap())
.reduce(reduce)
.unwrap()
}
fn iter_pipeline<Output, Accum: Send>(
&mut self,
input_len: usize,
accum: impl Accumulator<usize, Accum> + Sync,
reduce: impl ExactSizeAccumulator<Accum, Output>,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
self.range_orchestrator.reset_ranges(input_len);
let num_threads = self.threads.len();
let outputs = (0..num_threads)
.map(|_| Mutex::new(None))
.collect::<Arc<[_]>>();
self.pipeline.lend(&IterPipelineImpl {
outputs: outputs.clone(),
accum,
cleanup,
});
reduce.accumulate_exact(
outputs
.iter()
.map(move |output| output.lock().unwrap().take().unwrap()),
)
}
}
impl<F: RangeFactory> Drop for ThreadPoolImpl<F> {
#[allow(clippy::single_match, clippy::unused_enumerate_index)]
fn drop(&mut self) {
self.pipeline.finish_workers();
log_debug!("[main thread] Joining threads in the pool...");
for (_i, t) in self.threads.drain(..).enumerate() {
let result = t.handle.join();
match result {
Ok(_) => log_debug!("[main thread] Thread {_i} joined with result: {result:?}"),
Err(_) => log_error!("[main thread] Thread {_i} joined with result: {result:?}"),
}
}
log_debug!("[main thread] Joined threads.");
#[cfg(feature = "log_parallelism")]
self.range_orchestrator.print_statistics();
}
}
struct DynLifetimeSyncPipeline<R: Range>(PhantomData<R>);
impl<R: Range> LifetimeParameterized for DynLifetimeSyncPipeline<R> {
type T<'a> = dyn Pipeline<R> + Sync + 'a;
}
struct ThreadContext<R: Range> {
id: usize,
range: R,
pipeline: Borrower<DynLifetimeSyncPipeline<R>>,
}
impl<R: Range> ThreadContext<R> {
fn run(&mut self) {
loop {
match self.pipeline.borrow(|pipeline| {
pipeline.run(self.id, &self.range);
}) {
WorkerState::Finished => break,
WorkerState::Ready => continue,
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::iter::{ExactParallelSourceExt, IntoExactParallelRefSource, ParallelIteratorExt};
#[test]
fn test_build_thread_pool_available_parallelism() {
let mut thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
range_strategy: RangeStrategy::Fixed,
cpu_pinning: CpuPinningPolicy::No,
}
.build();
let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let sum = input
.par_iter()
.with_thread_pool(&mut thread_pool)
.sum::<i32>();
assert_eq!(sum, 5 * 11);
}
#[test]
fn test_build_thread_pool_fixed_thread_count() {
let mut thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::try_from(4).unwrap(),
range_strategy: RangeStrategy::Fixed,
cpu_pinning: CpuPinningPolicy::No,
}
.build();
let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let sum = input
.par_iter()
.with_thread_pool(&mut thread_pool)
.sum::<i32>();
assert_eq!(sum, 5 * 11);
}
#[test]
fn test_build_thread_pool_cpu_pinning_if_supported() {
let mut thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
range_strategy: RangeStrategy::Fixed,
cpu_pinning: CpuPinningPolicy::IfSupported,
}
.build();
let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let sum = input
.par_iter()
.with_thread_pool(&mut thread_pool)
.sum::<i32>();
assert_eq!(sum, 5 * 11);
}
#[cfg(all(
not(miri),
any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux",
target_os = "windows"
)
))]
#[test]
fn test_build_thread_pool_cpu_pinning_always() {
let mut thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
range_strategy: RangeStrategy::Fixed,
cpu_pinning: CpuPinningPolicy::Always,
}
.build();
let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let sum = input
.par_iter()
.with_thread_pool(&mut thread_pool)
.sum::<i32>();
assert_eq!(sum, 5 * 11);
}
#[cfg(any(
miri,
not(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux",
target_os = "windows"
))
))]
#[test]
#[should_panic = "Pinning threads to CPUs is not implemented on this platform."]
fn test_build_thread_pool_cpu_pinning_always_not_supported() {
ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
range_strategy: RangeStrategy::Fixed,
cpu_pinning: CpuPinningPolicy::Always,
}
.build();
}
#[test]
fn test_num_threads() {
for range_strategy in [RangeStrategy::Fixed, RangeStrategy::WorkStealing] {
let thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::AvailableParallelism,
range_strategy,
cpu_pinning: CpuPinningPolicy::No,
}
.build();
assert_eq!(
thread_pool.num_threads(),
std::thread::available_parallelism().unwrap()
);
let thread_pool = ThreadPoolBuilder {
num_threads: ThreadCount::try_from(4).unwrap(),
range_strategy,
cpu_pinning: CpuPinningPolicy::No,
}
.build();
assert_eq!(
thread_pool.num_threads(),
NonZeroUsize::try_from(4).unwrap()
);
}
}
}