use crate::AssignElem;
use crate::{Array, ArrayRef, Dimension, IntoNdProducer, NdProducer, Zip};
use super::send_producer::SendProducer;
use crate::parallel::par::ParallelSplits;
use crate::parallel::prelude::*;
use crate::partial::Partial;
impl<A, D> ArrayRef<A, D>
where
D: Dimension,
A: Send + Sync,
{
pub fn par_map_inplace<F>(&mut self, f: F)
where F: Fn(&mut A) + Sync + Send
{
self.view_mut().into_par_iter().for_each(f)
}
pub fn par_mapv_inplace<F>(&mut self, f: F)
where
F: Fn(A) -> A + Sync + Send,
A: Clone,
{
self.view_mut()
.into_par_iter()
.for_each(move |x| *x = f(x.clone()))
}
}
const COLLECT_MAX_SPLITS: usize = 10;
macro_rules! zip_impl {
($([$notlast:ident $($p:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<D, $($p),*> Zip<($($p,)*), D>
where $($p::Item : Send , )*
$($p : Send , )*
D: Dimension,
$($p: NdProducer<Dim=D> ,)*
{
pub fn par_for_each<F>(self, function: F)
where F: Fn($($p::Item),*) + Sync + Send
{
self.into_par_iter().for_each(move |($($p,)*)| function($($p),*))
}
expand_if!(@bool [$notlast]
pub fn par_map_collect<R>(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send)
-> Array<R, D>
where R: Send
{
let mut output = self.uninitialized_for_current_layout::<R>();
let total_len = output.len();
let splits = unsafe {
ParallelSplits {
iter: self.and(SendProducer::new(output.raw_view_mut().cast::<R>())),
max_splits: COLLECT_MAX_SPLITS,
}
};
let collect_result = splits.map(move |zip| {
unsafe {
zip.collect_with_partial(&f)
}
})
.reduce(Partial::stub, Partial::try_merge);
if std::mem::needs_drop::<R>() {
debug_assert_eq!(total_len, collect_result.len,
"collect len is not correct, expected {}", total_len);
assert!(collect_result.len == total_len,
"Collect: Expected number of writes not completed");
}
collect_result.release_ownership();
unsafe {
output.assume_init()
}
}
pub fn par_map_assign_into<R, Q>(self, into: Q, f: impl Fn($($p::Item,)* ) -> R + Sync + Send)
where Q: IntoNdProducer<Dim=D>,
Q::Item: AssignElem<R> + Send,
Q::Output: Send,
{
self.and(into)
.par_for_each(move |$($p, )* output_| {
output_.assign_elem(f($($p ),*));
});
}
pub fn par_fold<ID, F, R, T>(self, identity: ID, fold: F, reduce: R) -> T
where
ID: Fn() -> T + Send + Sync + Clone,
F: Fn(T, $($p::Item),*) -> T + Send + Sync,
R: Fn(T, T) -> T + Send + Sync,
T: Send
{
self.into_par_iter()
.fold(identity.clone(), move |accumulator, ($($p,)*)| {
fold(accumulator, $($p),*)
})
.reduce(identity, reduce)
}
);
}
)+
};
}
zip_impl! {
[true P1],
[true P1 P2],
[true P1 P2 P3],
[true P1 P2 P3 P4],
[true P1 P2 P3 P4 P5],
[false P1 P2 P3 P4 P5 P6],
}