use std::{
cell::RefCell,
collections::{HashMap, VecDeque},
num::NonZero,
sync::{
Once,
atomic::{AtomicU64, Ordering}
},
thread::available_parallelism
};
use rayon::{
ThreadPoolBuilder, broadcast,
iter::{
ParallelIterator,
plumbing::{Folder, Reducer, UnindexedConsumer, UnindexedProducer}
},
join, scope
};
use super::{CanBuildHistogram, CanIterate, EvaluationState, Histogram};
use crate::{EvaluationError, Evaluator, Function};
impl Histogram
{
fn merge(&mut self, other: Self)
{
for (key, value) in other.into_iter()
{
*self.entry(key).or_insert(0) += value;
}
}
}
#[derive(Debug, Clone)]
pub struct HistogramBuilder
{
function: Function,
environment: HashMap<usize, i32>,
generation: u64
}
static THREAD_POOL_INITIALIZER: Once = Once::new();
impl HistogramBuilder
{
pub fn use_available_parallelism()
{
THREAD_POOL_INITIALIZER.call_once(|| {
let _ = ThreadPoolBuilder::new()
.num_threads(available_parallelism().map_or(1, NonZero::get))
.build_global();
});
}
pub fn set_parallelism(&self, num_threads: usize)
{
THREAD_POOL_INITIALIZER.call_once(|| {
let _ = ThreadPoolBuilder::new()
.num_threads(num_threads)
.build_global();
});
}
}
fn increment(generation: u64, outcome: i32)
{
HISTOGRAM.with_borrow_mut(|histogram| {
*histogram
.entry(generation)
.or_default()
.entry(outcome)
.or_insert(0) += 1;
});
}
fn finish(generation: u64) -> Histogram
{
let partials = broadcast(|_| {
HISTOGRAM
.with_borrow_mut(|histogram| histogram.remove(&generation))
.unwrap_or_default()
});
partials
.into_iter()
.reduce(|mut a, b| {
a.merge(b);
a
})
.unwrap()
}
impl<'inst> super::HistogramBuilder<'inst, EvaluationStateIterator<'inst>>
for HistogramBuilder
{
fn new(evaluator: Evaluator) -> Self
{
HistogramBuilder {
function: evaluator.function,
environment: evaluator.environment,
generation: GENERATION.fetch_add(1, Ordering::Relaxed)
}
}
fn build_while(
&self,
args: impl IntoIterator<Item = i32> + Send,
condition: impl Fn(&i32) -> bool + Send + Sync
) -> Result<Histogram, EvaluationError<'_>>
{
let generation = self.generation;
scope(|_| {
CanBuildHistogram::iter(self, args)?
.flat_map(|state| state.result)
.take_any_while(condition)
.for_each(|outcome| increment(generation, outcome));
Ok(finish(generation))
})
}
#[inline]
fn iter(
&'inst self,
args: impl IntoIterator<Item = i32>
) -> Result<EvaluationStateIterator<'inst>, EvaluationError<'inst>>
{
CanBuildHistogram::iter(self, args)
}
}
impl<'inst> CanBuildHistogram<'inst, EvaluationStateIterator<'inst>>
for HistogramBuilder
{
#[inline]
fn function(&self) -> &Function { &self.function }
#[inline]
fn environment(&self) -> &HashMap<usize, i32> { &self.environment }
fn create_iterator(
&'inst self,
initial_state: EvaluationState<'inst>
) -> EvaluationStateIterator<'inst>
{
HistogramBuilder::use_available_parallelism();
EvaluationStateIterator {
generation: self.generation,
states: [initial_state].into(),
completed: VecDeque::new()
}
}
}
static GENERATION: AtomicU64 = AtomicU64::new(0);
thread_local! {
pub static HISTOGRAM: RefCell<HashMap<u64, Histogram>> =
RefCell::new(Default::default());
}
#[derive(Debug, Clone)]
pub struct EvaluationStateIterator<'inst>
{
generation: u64,
states: VecDeque<EvaluationState<'inst>>,
completed: VecDeque<EvaluationState<'inst>>
}
impl<'inst> CanIterate<'inst> for EvaluationStateIterator<'inst>
{
fn next_state(&mut self) -> Option<EvaluationState<'inst>>
{
self.states.pop_front()
}
fn inject_successors(&mut self, state: &mut EvaluationState<'inst>)
{
if let Some(ref mut successors) = state.successors
{
self.states.extend(successors.drain(..));
}
}
}
impl<'inst> ParallelIterator for EvaluationStateIterator<'inst>
{
type Item = EvaluationState<'inst>;
fn drive_unindexed<C>(mut self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>
{
loop
{
if consumer.full()
{
return consumer.into_folder().complete()
}
else
{
while self.states.len() <= EvaluationState::MAX_SUCCESSORS
{
match self.next_state()
{
Some(mut state) =>
{
let outcome = state.evaluate();
self.inject_successors(&mut state);
if outcome.is_some()
{
self.completed.push_back(state);
}
},
None =>
{
return UnindexedProducer::fold_with(
self,
consumer.into_folder()
)
.complete()
}
}
}
match self.split()
{
(left_producer, Some(right_producer)) =>
{
let (reducer, left_consumer, right_consumer) = (
consumer.to_reducer(),
consumer.split_off_left(),
consumer
);
let (left_result, right_result) = join(
|| left_producer.drive_unindexed(left_consumer),
|| right_producer.drive_unindexed(right_consumer)
);
return reducer.reduce(left_result, right_result)
},
(producer, None) =>
{
self = producer;
}
}
}
}
}
}
impl<'inst> UnindexedProducer for EvaluationStateIterator<'inst>
{
type Item = EvaluationState<'inst>;
fn split(mut self) -> (Self, Option<Self>)
{
match self.states.len()
{
0..=EvaluationState::MAX_SUCCESSORS =>
{
(self, None)
},
n =>
{
let right = self.states.split_off(n / 2);
let right = EvaluationStateIterator {
generation: self.generation,
states: right,
completed: VecDeque::new()
};
(self, Some(right))
}
}
}
fn fold_with<F>(mut self, mut folder: F) -> F
where
F: Folder<Self::Item>
{
let completed = self.completed.drain(..).collect::<Vec<_>>();
for state in completed
{
folder = folder.consume(state);
if folder.full()
{
return folder
}
}
folder
}
}