lamellar/array/iterator/distributed_iterator/consumer/
sum.rs

1use crate::active_messaging::LamellarArcLocalAm;
2use crate::array::iterator::distributed_iterator::DistributedIterator;
3use crate::array::iterator::private::*;
4use crate::array::iterator::{consumer::*, IterLockFuture};
5use crate::array::r#unsafe::private::UnsafeArrayInner;
6use crate::array::{ArrayOps, Distribution, UnsafeArray};
7use crate::barrier::BarrierHandle;
8use crate::lamellar_request::LamellarRequest;
9use crate::lamellar_task_group::TaskGroupLocalAmHandle;
10use crate::lamellar_team::LamellarTeamRT;
11use crate::scheduler::LamellarTask;
12use crate::warnings::RuntimeWarning;
13use crate::Dist;
14use futures_util::{ready, Future};
15use pin_project::{pin_project, pinned_drop};
16use std::collections::VecDeque;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21#[derive(Clone, Debug)]
22pub(crate) struct Sum<I> {
23    pub(crate) iter: I,
24}
25
26impl<I: InnerIter> InnerIter for Sum<I> {
27    fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
28        None
29    }
30    fn iter_clone(&self, _s: Sealed) -> Self {
31        Sum {
32            iter: self.iter.iter_clone(Sealed),
33        }
34    }
35}
36
37impl<I> IterConsumer for Sum<I>
38where
39    I: DistributedIterator + 'static,
40    I::Item: Dist + ArrayOps + std::iter::Sum,
41{
42    type AmOutput = I::Item;
43    type Output = I::Item;
44    type Item = I::Item;
45    type Handle = InnerDistIterSumHandle<I::Item>;
46    fn init(&self, start: usize, cnt: usize) -> Self {
47        Sum {
48            iter: self.iter.init(start, cnt, Sealed),
49        }
50    }
51    fn next(&mut self) -> Option<Self::Item> {
52        self.iter.next()
53    }
54    fn into_am(&self, schedule: IterSchedule) -> LamellarArcLocalAm {
55        Arc::new(SumAm {
56            iter: self.iter_clone(Sealed),
57            schedule,
58        })
59    }
60    fn create_handle(
61        self,
62        team: Pin<Arc<LamellarTeamRT>>,
63        reqs: VecDeque<TaskGroupLocalAmHandle<Self::AmOutput>>,
64    ) -> Self::Handle {
65        InnerDistIterSumHandle {
66            reqs,
67            team,
68            state: InnerState::ReqsPending(None),
69            spawned: false,
70        }
71    }
72    fn max_elems(&self, in_elems: usize) -> usize {
73        self.iter.elems(in_elems)
74    }
75}
76
77//#[doc(hidden)]
78#[pin_project]
79pub(crate) struct InnerDistIterSumHandle<T> {
80    pub(crate) reqs: VecDeque<TaskGroupLocalAmHandle<T>>,
81    pub(crate) team: Pin<Arc<LamellarTeamRT>>,
82    state: InnerState<T>,
83    spawned: bool,
84}
85
86enum InnerState<T> {
87    ReqsPending(Option<T>),
88    Summing(Pin<Box<dyn Future<Output = T> + Send>>),
89}
90
91impl<T> InnerDistIterSumHandle<T>
92where
93    T: Dist + ArrayOps + std::iter::Sum,
94{
95    async fn async_reduce_remote_vals(local_sum: T, team: Pin<Arc<LamellarTeamRT>>) -> T {
96        let local_sums = UnsafeArray::<T>::async_new(
97            &team,
98            team.num_pes,
99            Distribution::Block,
100            crate::darc::DarcMode::UnsafeArray,
101        )
102        .await;
103        unsafe {
104            local_sums.local_as_mut_slice()[0] = local_sum;
105        };
106        local_sums.async_barrier().await;
107        // let buffered_iter = unsafe { local_sums.buffered_onesided_iter(self.team.num_pes) };
108        // buffered_iter.into_iter().map(|&e| e).sum()
109        unsafe {
110            local_sums
111                .sum()
112                .await
113                .expect("array size is greater than zero")
114        }
115    }
116
117    // fn reduce_remote_vals(&self, local_sum: T, local_sums: UnsafeArray<T>) -> T {
118    //     unsafe {
119    //         local_sums.local_as_mut_slice()[0] = local_sum;
120    //     };
121    //     local_sums.tasking_barrier();
122    //     // let buffered_iter = unsafe { local_sums.buffered_onesided_iter(self.team.num_pes) };
123    //     // buffered_iter.into_iter().map(|&e| e).sum()
124    //     unsafe {
125    //         local_sums
126    //             .sum()
127    //             .blocking_wait()
128    //             .expect("array size is greater than zero")
129    //     }
130    // }
131}
132
133impl<T> Future for InnerDistIterSumHandle<T>
134where
135    T: Dist + ArrayOps + std::iter::Sum,
136{
137    type Output = T;
138    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139        if !self.spawned {
140            for req in self.reqs.iter_mut() {
141                req.ready_or_set_waker(cx.waker());
142            }
143            self.spawned = true;
144        }
145        let mut this = self.project();
146        match &mut this.state {
147            InnerState::ReqsPending(local_sum) => {
148                while let Some(mut req) = this.reqs.pop_front() {
149                    if !req.ready_or_set_waker(cx.waker()) {
150                        this.reqs.push_front(req);
151                        return Poll::Pending;
152                    }
153                    match local_sum {
154                        Some(sum) => {
155                            *sum = [*sum, req.val()].into_iter().sum();
156                        }
157                        None => {
158                            *local_sum = Some(req.val());
159                        }
160                    }
161                }
162                let mut sum = Box::pin(Self::async_reduce_remote_vals(
163                    local_sum.unwrap(),
164                    this.team.clone(),
165                ));
166                match Future::poll(sum.as_mut(), cx) {
167                    Poll::Ready(local_sum) => Poll::Ready(local_sum),
168                    Poll::Pending => {
169                        *this.state = InnerState::Summing(sum);
170                        Poll::Pending
171                    }
172                }
173            }
174            InnerState::Summing(sum) => {
175                let local_sum = ready!(Future::poll(sum.as_mut(), cx));
176                Poll::Ready(local_sum)
177            }
178        }
179    }
180}
181
182
183
184/// This handle allows you to wait for the completion of a local iterator sum operation.
185#[pin_project(PinnedDrop)]
186pub struct DistIterSumHandle<T> {
187    array: UnsafeArrayInner,
188    launched: bool,
189    #[pin]
190    state: State<T>,
191}
192
193#[pinned_drop]
194impl<T> PinnedDrop for DistIterSumHandle<T> {
195    fn drop(self: Pin<&mut Self>) {
196        if !self.launched {
197            let mut this = self.project();
198            RuntimeWarning::disable_warnings();
199            *this.state = State::Dropped;
200            RuntimeWarning::enable_warnings();
201            RuntimeWarning::DroppedHandle("a DistIterSumHandle").print();
202        }
203    }
204}
205
206impl<T> DistIterSumHandle<T>
207where
208    T: Dist + ArrayOps + std::iter::Sum,
209{
210    pub(crate) fn new(
211        lock: Option<IterLockFuture>,
212        inner: Pin<Box<dyn Future<Output = InnerDistIterSumHandle<T>> + Send>>,
213        array: &UnsafeArrayInner,
214    ) -> Self {
215        let state = match lock {
216            Some(inner_lock) => State::Lock(inner_lock, Some(inner)),
217            None => State::Barrier(array.barrier_handle(), inner),
218        };
219        Self {
220            array: array.clone(),
221            launched: false,
222            state,
223        }
224    }
225
226    /// This method will block until the associated Sum operation completes and returns the result
227    pub fn block(mut self) -> T {
228        self.launched = true;
229        RuntimeWarning::BlockingCall(
230            "DistIterSumHandle::block",
231            "<handle>.spawn() or <handle>.await",
232        )
233        .print();
234        self.array.clone().block_on(self)
235    }
236
237    /// This method will spawn the associated Sum Operation on the work queue,
238    /// initiating the remote operation.
239    ///
240    /// This function returns a handle that can be used to wait for the operation to complete
241    #[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if  it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
242    pub fn spawn(mut self) -> LamellarTask<T> {
243        self.launched = true;
244        self.array.clone().spawn(self)
245    }
246}
247
248#[pin_project(project = StateProj)]
249enum State<T> {
250    Lock(
251        #[pin] IterLockFuture,
252        Option<Pin<Box<dyn Future<Output = InnerDistIterSumHandle<T>> + Send>>>,
253    ),
254    Barrier(
255        #[pin] BarrierHandle,
256        Pin<Box<dyn Future<Output = InnerDistIterSumHandle<T>> + Send>>,
257    ),
258    Reqs(#[pin] InnerDistIterSumHandle<T>),
259    Dropped,
260}
261impl<T> Future for DistIterSumHandle<T>
262where
263    T: Dist + ArrayOps + std::iter::Sum,
264{
265    type Output = T;
266    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267        self.launched = true;
268        let mut this = self.project();
269
270        match this.state.as_mut().project() {
271            StateProj::Lock(lock, inner) => {
272                ready!(lock.poll(cx));
273                let barrier = this.array.barrier_handle();
274                *this.state = State::Barrier(
275                    barrier,
276                    inner.take().expect("reqs should still be in this state"),
277                );
278                cx.waker().wake_by_ref();
279                Poll::Pending
280            }
281            StateProj::Barrier(barrier, inner) => {
282                ready!(barrier.poll(cx));
283                let mut inner = ready!(Future::poll(inner.as_mut(), cx));
284                match Pin::new(&mut inner).poll(cx) {
285                    Poll::Ready(val) => Poll::Ready(val),
286                    Poll::Pending => {
287                        *this.state = State::Reqs(inner);
288                        Poll::Pending
289                    }
290                }
291            }
292            StateProj::Reqs(inner) => {
293                let val = ready!(inner.poll(cx));
294                Poll::Ready(val)
295            }
296            StateProj::Dropped => panic!("called `Future::poll()` on a future that was dropped"),
297        }
298    }
299}
300
301#[lamellar_impl::AmLocalDataRT(Clone)]
302pub(crate) struct SumAm<I> {
303    pub(crate) iter: Sum<I>,
304    pub(crate) schedule: IterSchedule,
305}
306
307impl<I: InnerIter> InnerIter for SumAm<I> {
308    fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
309        None
310    }
311    fn iter_clone(&self, _s: Sealed) -> Self {
312        SumAm {
313            iter: self.iter.iter_clone(Sealed),
314            schedule: self.schedule.clone(),
315        }
316    }
317}
318
319#[lamellar_impl::rt_am_local]
320impl<I> LamellarAm for SumAm<I>
321where
322    I: DistributedIterator + 'static,
323    I::Item: Dist + ArrayOps + std::iter::Sum,
324{
325    async fn exec(&self) -> I::Item {
326        let iter = self.schedule.init_iter(self.iter.iter_clone(Sealed));
327        iter.sum::<I::Item>()
328    }
329}