use std::ops::Add;
use num_traits::Zero;
#[cfg(feature = "parallel")]
use rayon::iter::ParallelIterator;
#[cfg(feature = "parallel")]
use super::DotConsumer;
pub trait DotAccumulator<F>: Add<(F, F), Output = Self> + From<F> + Clone {
fn zero() -> Self
where
F: Zero,
{
Self::from(F::zero())
}
fn dot(self) -> F;
fn absorb<I>(self, it: I) -> Self
where
I: IntoIterator<Item = (F, F)>,
{
it.into_iter().fold(self, |acc, x| acc + x)
}
}
pub trait DotWithAccumulator<F> {
fn dot_with_accumulator<Acc>(self) -> F
where
Acc: DotAccumulator<F>,
F: Zero;
}
impl<I, F> DotWithAccumulator<F> for I
where
I: IntoIterator<Item = (F, F)>,
{
fn dot_with_accumulator<Acc>(self) -> F
where
Acc: DotAccumulator<F>,
F: Zero,
{
Acc::zero().absorb(self).dot()
}
}
#[cfg(feature = "parallel")]
pub trait ParallelDotAccumulator<F>:
DotAccumulator<F> + Add<Self, Output = Self> + Send + Sized
{
#[inline]
fn into_consumer(self) -> DotConsumer<Self> {
DotConsumer(self)
}
}
#[cfg(feature = "parallel")]
impl<Acc, F> ParallelDotAccumulator<F> for Acc where
Acc: DotAccumulator<F> + Add<Acc, Output = Acc> + Send + Sized
{
}
#[cfg(feature = "parallel")]
pub trait ParallelDotWithAccumulator<F>: ParallelIterator<Item = (F, F)>
where
F: Send,
{
fn parallel_dot_with_accumulator<Acc>(self) -> F
where
Acc: ParallelDotAccumulator<F>,
F: Zero,
{
self.drive_unindexed(Acc::zero().into_consumer()).dot()
}
}
#[cfg(feature = "parallel")]
impl<T, F> ParallelDotWithAccumulator<F> for T
where
T: ParallelIterator<Item = (F, F)>,
F: Zero + Send,
{
}