1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
use derive_new::new; use educe::Educe; use replace_with::replace_with_or_abort; use serde::{Deserialize, Serialize}; use std::{iter, marker::PhantomData, mem}; use super::{folder_par_sink, FolderSync, FolderSyncReducer, ParallelPipe, ParallelSink}; #[derive(new)] #[must_use] pub struct Sum<P, B> { pipe: P, marker: PhantomData<fn() -> B>, } impl_par_dist! { impl<P: ParallelPipe<Item>, Item, B> ParallelSink<Item> for Sum<P, B> where B: iter::Sum<P::Output> + iter::Sum<B> + Send + 'static, { folder_par_sink!( SumFolder<B>, SumFolder<B>, self, SumFolder::new(), SumFolder::new() ); } } #[derive(Educe, Serialize, Deserialize, new)] #[educe(Clone)] #[serde(bound = "")] pub struct SumFolder<B> { marker: PhantomData<fn() -> B>, } impl<Item, B> FolderSync<Item> for SumFolder<B> where B: iter::Sum<Item> + iter::Sum<B>, { type Done = B; #[inline(always)] fn zero(&mut self) -> Self::Done { iter::empty::<B>().sum() } #[inline(always)] fn push(&mut self, state: &mut Self::Done, item: Item) { let zero = iter::empty::<B>().sum(); let left = mem::replace(state, zero); let right = iter::once(item).sum::<B>(); *state = B::sum(iter::once(left).chain(iter::once(right))); } } #[derive(Clone, Serialize, Deserialize)] pub struct SumZeroFolder<B> { zero: Option<B>, } impl<B> SumZeroFolder<B> { #[inline(always)] pub(crate) fn new(zero: B) -> Self { Self { zero: Some(zero) } } } impl<Item> FolderSync<Item> for SumZeroFolder<Item> where Option<Item>: iter::Sum<Item>, { type Done = Item; #[inline(always)] fn zero(&mut self) -> Self::Done { self.zero.take().unwrap() } #[inline(always)] fn push(&mut self, state: &mut Self::Done, item: Item) { replace_with_or_abort(state, |left| { let right = iter::once(item).sum::<Option<Item>>().unwrap(); <Option<Item> as iter::Sum<Item>>::sum(iter::once(left).chain(iter::once(right))) .unwrap() }) } }