use super::{RangeStrategy, ThreadCount};
use crate::core::pipeline::{IterPipelineImpl, Pipeline, UpperBoundedPipelineImpl};
use crate::core::range::{
FixedRangeFactory, RangeFactory, RangeOrchestrator, WorkStealingRangeFactory,
};
use crate::iter::{Accumulator, ExactSizeAccumulator, GenericThreadPool, SourceCleanup};
use crossbeam_utils::CachePadded;
use rayon_core::{Scope, ThreadPool};
use std::num::NonZeroUsize;
use std::ops::ControlFlow;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
pub struct RayonThreadPool<'a> {
thread_pool: Option<&'a ThreadPool>,
num_tasks: ThreadCount,
range_strategy: RangeStrategy,
}
impl RayonThreadPool<'static> {
pub fn new_global(num_tasks: ThreadCount, range_strategy: RangeStrategy) -> Self {
Self {
thread_pool: None,
num_tasks,
range_strategy,
}
}
}
impl<'a> RayonThreadPool<'a> {
pub fn new(
thread_pool: &'a ThreadPool,
num_tasks: ThreadCount,
range_strategy: RangeStrategy,
) -> Self {
Self {
thread_pool: Some(thread_pool),
num_tasks,
range_strategy,
}
}
}
impl RayonThreadPool<'_> {
pub fn num_tasks(&self) -> NonZeroUsize {
self.num_tasks.count()
}
}
unsafe impl GenericThreadPool for &RayonThreadPool<'_> {
fn upper_bounded_pipeline<Output: Send, Accum>(
self,
input_len: usize,
init: impl Fn() -> Accum + Sync,
process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
finalize: impl Fn(Accum) -> Output + Sync,
reduce: impl Fn(Output, Output) -> Output,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
let mut inner =
RayonThreadPoolEnum::new(self.thread_pool, self.num_tasks, self.range_strategy);
inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce, cleanup)
}
fn iter_pipeline<Output, Accum: Send>(
self,
input_len: usize,
accum: impl Accumulator<usize, Accum> + Sync,
reduce: impl ExactSizeAccumulator<Accum, Output>,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
let mut inner =
RayonThreadPoolEnum::new(self.thread_pool, self.num_tasks, self.range_strategy);
inner.iter_pipeline(input_len, accum, reduce, cleanup)
}
}
enum RayonThreadPoolEnum<'a> {
Fixed(RayonThreadPoolImpl<'a, FixedRangeFactory>),
WorkStealing(RayonThreadPoolImpl<'a, WorkStealingRangeFactory>),
}
impl<'a> RayonThreadPoolEnum<'a> {
fn new(
thread_pool: Option<&'a ThreadPool>,
num_tasks: ThreadCount,
range_strategy: RangeStrategy,
) -> Self {
let num_tasks: NonZeroUsize = num_tasks.count();
let num_tasks: usize = num_tasks.into();
match range_strategy {
RangeStrategy::Fixed => RayonThreadPoolEnum::Fixed(RayonThreadPoolImpl::new(
thread_pool,
num_tasks,
FixedRangeFactory::new(num_tasks),
)),
RangeStrategy::WorkStealing => {
RayonThreadPoolEnum::WorkStealing(RayonThreadPoolImpl::new(
thread_pool,
num_tasks,
WorkStealingRangeFactory::new(num_tasks),
))
}
}
}
fn upper_bounded_pipeline<Output: Send, Accum>(
&mut self,
input_len: usize,
init: impl Fn() -> Accum + Sync,
process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
finalize: impl Fn(Accum) -> Output + Sync,
reduce: impl Fn(Output, Output) -> Output,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
match self {
RayonThreadPoolEnum::Fixed(inner) => inner.upper_bounded_pipeline(
input_len,
init,
process_item,
finalize,
reduce,
cleanup,
),
RayonThreadPoolEnum::WorkStealing(inner) => inner.upper_bounded_pipeline(
input_len,
init,
process_item,
finalize,
reduce,
cleanup,
),
}
}
fn iter_pipeline<Output, Accum: Send>(
&mut self,
input_len: usize,
accum: impl Accumulator<usize, Accum> + Sync,
reduce: impl ExactSizeAccumulator<Accum, Output>,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
match self {
RayonThreadPoolEnum::Fixed(inner) => {
inner.iter_pipeline(input_len, accum, reduce, cleanup)
}
RayonThreadPoolEnum::WorkStealing(inner) => {
inner.iter_pipeline(input_len, accum, reduce, cleanup)
}
}
}
}
struct RayonThreadPoolImpl<'a, F: RangeFactory> {
thread_pool: Option<&'a ThreadPool>,
range_orchestrator: F::Orchestrator,
ranges: Box<[F::Range]>,
}
impl<'a, F: RangeFactory> RayonThreadPoolImpl<'a, F> {
fn new(thread_pool: Option<&'a ThreadPool>, num_tasks: usize, range_factory: F) -> Self {
let ranges = (0..num_tasks).map(|id| range_factory.range(id)).collect();
Self {
thread_pool,
range_orchestrator: range_factory.orchestrator(),
ranges,
}
}
}
impl<F: RangeFactory> RayonThreadPoolImpl<'_, F> {
fn scope<'scope, OP, R>(&self, op: OP) -> R
where
OP: FnOnce(&Scope<'scope>) -> R + Send,
R: Send,
{
match self.thread_pool {
None => rayon_core::scope(op),
Some(thread_pool) => thread_pool.scope(op),
}
}
}
impl<F: RangeFactory> RayonThreadPoolImpl<'_, F>
where
F::Range: Sync,
{
fn upper_bounded_pipeline<Output: Send, Accum>(
&mut self,
input_len: usize,
init: impl Fn() -> Accum + Sync,
process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
finalize: impl Fn(Accum) -> Output + Sync,
reduce: impl Fn(Output, Output) -> Output,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
self.range_orchestrator.reset_ranges(input_len);
let num_tasks = self.ranges.len();
let outputs = (0..num_tasks)
.map(|_| Mutex::new(None))
.collect::<Arc<[_]>>();
let bound = AtomicUsize::new(usize::MAX);
let pipeline = &UpperBoundedPipelineImpl {
bound: CachePadded::new(bound),
outputs: outputs.clone(),
init,
process_item,
finalize,
cleanup,
};
let ranges = &self.ranges;
self.scope({
|scope| {
for (id, range) in ranges.iter().enumerate() {
scope.spawn(move |_| {
pipeline.run(id, range);
});
}
}
});
outputs
.iter()
.map(move |output| output.lock().unwrap().take().unwrap())
.reduce(reduce)
.unwrap()
}
fn iter_pipeline<Output, Accum: Send>(
&mut self,
input_len: usize,
accum: impl Accumulator<usize, Accum> + Sync,
reduce: impl ExactSizeAccumulator<Accum, Output>,
cleanup: &(impl SourceCleanup + Sync),
) -> Output {
self.range_orchestrator.reset_ranges(input_len);
let num_tasks = self.ranges.len();
let outputs = (0..num_tasks)
.map(|_| Mutex::new(None))
.collect::<Arc<[_]>>();
let pipeline = &IterPipelineImpl {
outputs: outputs.clone(),
accum,
cleanup,
};
let ranges = &self.ranges;
self.scope({
|scope| {
for (id, range) in ranges.iter().enumerate() {
scope.spawn(move |_| {
pipeline.run(id, range);
});
}
}
});
reduce.accumulate_exact(
outputs
.iter()
.map(move |output| output.lock().unwrap().take().unwrap()),
)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_num_tasks() {
for range_strategy in [RangeStrategy::Fixed, RangeStrategy::WorkStealing] {
let thread_pool =
RayonThreadPool::new_global(ThreadCount::AvailableParallelism, range_strategy);
assert_eq!(
thread_pool.num_tasks(),
std::thread::available_parallelism().unwrap()
);
let thread_pool =
RayonThreadPool::new_global(ThreadCount::try_from(4).unwrap(), range_strategy);
assert_eq!(thread_pool.num_tasks(), NonZeroUsize::try_from(4).unwrap());
}
}
#[cfg(not(miri))]
mod not_miri {
use super::*;
use crate::iter::{ExactParallelSourceExt, IntoExactParallelSource, ParallelIteratorExt};
use std::ops::Range;
enum Tree<T> {
Leaf(T),
Node(Vec<Tree<T>>),
}
fn build_tree<T>(
arity: usize,
range: Range<usize>,
build: &impl Fn(usize) -> T,
) -> Tree<T> {
assert!(!range.is_empty());
let len = range.end - range.start;
if len == 1 {
Tree::Leaf(build(range.start))
} else if len <= arity {
Tree::Node(range.map(|i| Tree::Leaf(build(i))).collect())
} else {
Tree::Node(
(0..arity)
.map(|i| {
let start = range.start + i * len / arity;
let end = range.start + (i + 1) * len / arity;
build_tree(arity, start..end, build)
})
.collect(),
)
}
}
fn reduce_tree<T: Send, U: Default + Send>(
thread_pool: &RayonThreadPool,
tree: Tree<T>,
convert: &(impl Fn(T) -> U + Sync),
reduce_op: &(impl Fn(U, U) -> U + Sync),
) -> U {
match tree {
Tree::Leaf(t) => convert(t),
Tree::Node(children) => children
.into_par_iter()
.with_thread_pool(thread_pool)
.map(|child| reduce_tree(thread_pool, child, convert, reduce_op))
.reduce(U::default, reduce_op),
}
}
const INPUT_LEN: u64 = 100_000;
#[test]
fn test_recursion() {
let thread_pool = RayonThreadPool::new_global(
ThreadCount::AvailableParallelism,
RangeStrategy::WorkStealing,
);
let tree: Tree<u64> = build_tree(10, 0..INPUT_LEN as usize, &|i| i as u64);
let sum: u64 = reduce_tree(&thread_pool, tree, &|x| x, &|x, y| x + y);
assert_eq!(sum, INPUT_LEN * (INPUT_LEN - 1) / 2);
let tree: Tree<Box<u64>> =
build_tree(10, 0..INPUT_LEN as usize, &|i| Box::new(i as u64));
let sum: u64 = reduce_tree(&thread_pool, tree, &|x| *x, &|x, y| x + y);
assert_eq!(sum, INPUT_LEN * (INPUT_LEN - 1) / 2);
}
}
}