lamellar/array/iterator/distributed_iterator/consumer/
reduce.rs

1use crate::active_messaging::{LamellarArcLocalAm, SyncSend};
2use crate::array::iterator::distributed_iterator::DistributedIterator;
3use crate::array::iterator::one_sided_iterator::OneSidedIterator;
4use crate::array::iterator::private::*;
5use crate::array::iterator::{consumer::*, IterLockFuture};
6use crate::array::r#unsafe::private::UnsafeArrayInner;
7use crate::array::{ArrayOps, Distribution, UnsafeArray};
8use crate::barrier::BarrierHandle;
9use crate::lamellar_request::LamellarRequest;
10use crate::lamellar_task_group::TaskGroupLocalAmHandle;
11use crate::lamellar_team::LamellarTeamRT;
12use crate::scheduler::LamellarTask;
13use crate::warnings::RuntimeWarning;
14use crate::Dist;
15
16use futures_util::{ready, Future, StreamExt};
17use pin_project::{pin_project, pinned_drop};
18use std::collections::VecDeque;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23
24
25#[derive(Clone, Debug)]
26pub(crate) struct Reduce<I, F> {
27    pub(crate) iter: I,
28    pub(crate) op: F,
29}
30
31impl<I: InnerIter, F: Clone> InnerIter for Reduce<I, F> {
32    fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
33        None
34    }
35    fn iter_clone(&self, _s: Sealed) -> Self {
36        Reduce {
37            iter: self.iter.iter_clone(Sealed),
38            op: self.op.clone(),
39        }
40    }
41}
42
43impl<I, F> IterConsumer for Reduce<I, F>
44where
45    I: DistributedIterator + 'static,
46    I::Item: Dist + ArrayOps,
47    F: Fn(I::Item, I::Item) -> I::Item + SyncSend + Clone + 'static,
48{
49    type AmOutput = Option<I::Item>;
50    type Output = Option<I::Item>;
51    type Item = I::Item;
52    type Handle = InnerDistIterReduceHandle<I::Item, F>;
53    fn init(&self, start: usize, cnt: usize) -> Self {
54        Reduce {
55            iter: self.iter.init(start, cnt, Sealed),
56            op: self.op.clone(),
57        }
58    }
59    fn next(&mut self) -> Option<Self::Item> {
60        self.iter.next()
61    }
62    fn into_am(&self, schedule: IterSchedule) -> LamellarArcLocalAm {
63        Arc::new(ReduceAm {
64            iter: self.iter_clone(Sealed),
65            op: self.op.clone(),
66            schedule,
67        })
68    }
69    fn create_handle(
70        self,
71        team: Pin<Arc<LamellarTeamRT>>,
72        reqs: VecDeque<TaskGroupLocalAmHandle<Self::AmOutput>>,
73    ) -> Self::Handle {
74        InnerDistIterReduceHandle {
75            op: self.op,
76            reqs,
77            team,
78            state: InnerState::ReqsPending(None),
79            spawned: false,
80        }
81    }
82    fn max_elems(&self, in_elems: usize) -> usize {
83        self.iter.elems(in_elems)
84    }
85}
86//#[doc(hidden)]
87#[pin_project]
88pub(crate) struct InnerDistIterReduceHandle<T, F> {
89    pub(crate) reqs: VecDeque<TaskGroupLocalAmHandle<Option<T>>>,
90    pub(crate) op: F,
91    pub(crate) team: Pin<Arc<LamellarTeamRT>>,
92    state: InnerState<T>,
93    spawned: bool,
94}
95
96enum InnerState<T> {
97    ReqsPending(Option<T>),
98    Reducing(Pin<Box<dyn Future<Output = Option<T>> + Send + 'static>>),
99}
100
101impl<T, F> InnerDistIterReduceHandle<T, F>
102where
103    T: Dist + ArrayOps,
104    F: Fn(T, T) -> T + SyncSend + Clone + 'static,
105{
106    async fn async_reduce_remote_vals(
107        local_val: T,
108        team: Pin<Arc<LamellarTeamRT>>,
109        op: F,
110    ) -> Option<T> {
111        let local_vals = UnsafeArray::<T>::async_new(
112            &team,
113            team.num_pes,
114            Distribution::Block,
115            crate::darc::DarcMode::UnsafeArray,
116        )
117        .await;
118        unsafe {
119            local_vals.local_as_mut_slice()[0] = local_val;
120        };
121        local_vals.async_barrier().await;
122        let buffered_iter = unsafe { local_vals.buffered_onesided_iter(team.num_pes) };
123        let mut stream = buffered_iter.into_stream();
124        let first = stream.next().await?;
125
126        Some(
127            stream
128                .fold(*first, |a, &b| {
129                    let val = op(a, b);
130                    async move { val }
131                })
132                .await,
133        )
134    }
135
136    // fn reduce_remote_vals(&self, local_val: T) -> Option<T> {
137    //     // self.team.tasking_barrier();
138    //     let local_vals =
139    //         UnsafeArray::<T>::new(&self.team, self.team.num_pes, Distribution::Block).block();
140    //     unsafe {
141    //         local_vals.local_as_mut_slice()[0] = local_val;
142    //     };
143    //     local_vals.tasking_barrier();
144    //     let buffered_iter = unsafe { local_vals.buffered_onesided_iter(self.team.num_pes) };
145    //     buffered_iter
146    //         .into_iter()
147    //         .map(|&x| x)
148    //         .reduce(self.op.clone())
149    // }
150}
151
152impl<T, F> Future for InnerDistIterReduceHandle<T, F>
153where
154    T: Dist + ArrayOps,
155    F: Fn(T, T) -> T + SyncSend + Clone + 'static,
156{
157    type Output = Option<T>;
158    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159        if !self.spawned {
160            for req in self.reqs.iter_mut() {
161                req.ready_or_set_waker(cx.waker());
162            }
163            self.spawned = true;
164        }
165        let mut this = self.project();
166        match &mut this.state {
167            InnerState::ReqsPending(mut val) => {
168                while let Some(mut req) = this.reqs.pop_front() {
169                    if !req.ready_or_set_waker(cx.waker()) {
170                        this.reqs.push_front(req);
171                        return Poll::Pending;
172                    }
173                    match val {
174                        None => val = req.val(),
175                        Some(val1) => {
176                            if let Some(val2) = req.val() {
177                                val = Some((this.op)(val1, val2));
178                            }
179                        }
180                    }
181                }
182                if let Some(val) = val {
183                    let mut reducing = Box::pin(Self::async_reduce_remote_vals(
184                        val.clone(),
185                        this.team.clone(),
186                        this.op.clone(),
187                    ));
188                    match Future::poll(reducing.as_mut(), cx) {
189                        Poll::Ready(val) => Poll::Ready(val),
190                        Poll::Pending => {
191                            *this.state = InnerState::Reducing(reducing);
192                            Poll::Pending
193                        }
194                    }
195                } else {
196                    Poll::Ready(None)
197                }
198            }
199            InnerState::Reducing(reducing) => {
200                let val = ready!(Future::poll(reducing.as_mut(), cx));
201                Poll::Ready(val)
202            }
203        }
204    }
205}
206
207/// This handle allows you to wait for the completion of a local iterator reduce operation
208#[pin_project(PinnedDrop)]
209pub struct DistIterReduceHandle<T, F> {
210    array: UnsafeArrayInner,
211    launched: bool,
212    #[pin]
213    state: State<T, F>,
214}
215
216#[pinned_drop]
217impl<T, F> PinnedDrop for DistIterReduceHandle<T, F> {
218    fn drop(self: Pin<&mut Self>) {
219        if !self.launched {
220            let mut this = self.project();
221            RuntimeWarning::disable_warnings();
222            *this.state = State::Dropped;
223            RuntimeWarning::enable_warnings();
224            RuntimeWarning::DroppedHandle("a DistIterReduceHandle").print();
225        }
226    }
227}
228
229impl<T, F> DistIterReduceHandle<T, F>
230where
231    T: Dist + ArrayOps,
232    F: Fn(T, T) -> T + SyncSend + Clone + 'static,
233{
234    pub(crate) fn new(
235        lock: Option<IterLockFuture>,
236        reqs: Pin<Box<dyn Future<Output = InnerDistIterReduceHandle<T, F>> + Send>>,
237        array: &UnsafeArrayInner,
238    ) -> Self {
239        let state = match lock {
240            Some(inner_lock) => State::Lock(inner_lock, Some(reqs)),
241            None => State::Barrier(array.barrier_handle(), reqs),
242        };
243        Self {
244            array: array.clone(),
245            launched: false,
246            state,
247        }
248    }
249
250    /// This method will block until the associated Reduce operation completes and returns the result
251    pub fn block(mut self) -> Option<T> {
252        self.launched = true;
253        RuntimeWarning::BlockingCall(
254            "DistIterReduceHandle::block",
255            "<handle>.spawn() or <handle>.await",
256        )
257        .print();
258        self.array.clone().block_on(self)
259    }
260
261    /// This method will spawn the associated Reduce Operation on the work queue,
262    /// initiating the remote operation.
263    ///
264    /// This function returns a handle that can be used to wait for the operation to complete
265    #[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if  it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
266    pub fn spawn(mut self) -> LamellarTask<Option<T>> {
267        self.launched = true;
268        self.array.clone().spawn(self)
269    }
270}
271
272#[pin_project(project = StateProj)]
273enum State<T, F> {
274    Lock(
275        #[pin] IterLockFuture,
276        Option<Pin<Box<dyn Future<Output = InnerDistIterReduceHandle<T, F>> + Send>>>,
277    ),
278    Barrier(
279        #[pin] BarrierHandle,
280        Pin<Box<dyn Future<Output = InnerDistIterReduceHandle<T, F>> + Send>>,
281    ),
282    Reqs(#[pin] InnerDistIterReduceHandle<T, F>),
283    Dropped,
284}
285impl<T, F> Future for DistIterReduceHandle<T, F>
286where
287    T: Dist + ArrayOps,
288    F: Fn(T, T) -> T + SyncSend + Clone + 'static,
289{
290    type Output = Option<T>;
291    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292        self.launched = true;
293        let mut this = self.project();
294        match this.state.as_mut().project() {
295            StateProj::Lock(lock, inner) => {
296                ready!(lock.poll(cx));
297                let barrier = this.array.barrier_handle();
298                *this.state = State::Barrier(
299                    barrier,
300                    inner.take().expect("reqs should still be in this state"),
301                );
302                cx.waker().wake_by_ref();
303                Poll::Pending
304            }
305            StateProj::Barrier(barrier, inner) => {
306                ready!(barrier.poll(cx));
307                let mut inner = ready!(Future::poll(inner.as_mut(), cx));
308                match Pin::new(&mut inner).poll(cx) {
309                    Poll::Ready(val) => Poll::Ready(val),
310                    Poll::Pending => {
311                        *this.state = State::Reqs(inner);
312                        Poll::Pending
313                    }
314                }
315            }
316            StateProj::Reqs(inner) => {
317                let val = ready!(inner.poll(cx));
318                Poll::Ready(val)
319            }
320            StateProj::Dropped => panic!("called `Future::poll()` on a dropped future."),
321        }
322    }
323}
324
325#[lamellar_impl::AmLocalDataRT(Clone)]
326pub(crate) struct ReduceAm<I, F> {
327    pub(crate) op: F,
328    pub(crate) iter: Reduce<I, F>,
329    pub(crate) schedule: IterSchedule,
330}
331
332impl<I: InnerIter, F: Clone> InnerIter for ReduceAm<I, F> {
333    fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
334        None
335    }
336    fn iter_clone(&self, _s: Sealed) -> Self {
337        ReduceAm {
338            op: self.op.clone(),
339            iter: self.iter.iter_clone(Sealed),
340            schedule: self.schedule.clone(),
341        }
342    }
343}
344
345#[lamellar_impl::rt_am_local]
346impl<I, F> LamellarAm for ReduceAm<I, F>
347where
348    I: DistributedIterator + 'static,
349    I::Item: Dist + ArrayOps,
350    F: Fn(I::Item, I::Item) -> I::Item + SyncSend + Clone + 'static,
351{
352    async fn exec(&self) -> Option<I::Item> {
353        let mut iter = self.schedule.init_iter(self.iter.iter_clone(Sealed));
354        match iter.next() {
355            Some(mut accum) => {
356                while let Some(elem) = iter.next() {
357                    accum = (self.op)(accum, elem);
358                }
359                Some(accum)
360            }
361            None => None,
362        }
363    }
364}