lamellar/array/iterator/local_iterator/consumer/
sum.rs

1use crate::active_messaging::{LamellarArcLocalAm, SyncSend};
2use crate::array::iterator::local_iterator::LocalIterator;
3use crate::array::iterator::private::*;
4use crate::array::iterator::{consumer::*, IterLockFuture};
5use crate::array::r#unsafe::private::UnsafeArrayInner;
6use crate::lamellar_request::LamellarRequest;
7use crate::lamellar_task_group::TaskGroupLocalAmHandle;
8use crate::lamellar_team::LamellarTeamRT;
9use crate::scheduler::LamellarTask;
10use crate::warnings::RuntimeWarning;
11
12use futures_util::{ready, Future};
13use pin_project::{pin_project, pinned_drop};
14use std::collections::VecDeque;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18
19#[derive(Clone, Debug)]
20pub(crate) struct Sum<I> {
21    pub(crate) iter: I,
22}
23
24impl<I: InnerIter> InnerIter for Sum<I> {
25    fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
26        None
27    }
28    fn iter_clone(&self, _s: Sealed) -> Self {
29        Sum {
30            iter: self.iter.iter_clone(Sealed),
31        }
32    }
33}
34
35impl<I> IterConsumer for Sum<I>
36where
37    I: LocalIterator + 'static,
38    I::Item: SyncSend + for<'a> std::iter::Sum<&'a I::Item> + std::iter::Sum<I::Item>,
39{
40    type AmOutput = I::Item;
41    type Output = I::Item;
42    type Item = I::Item;
43    type Handle = InnerLocalIterSumHandle<I::Item>;
44    fn init(&self, start: usize, cnt: usize) -> Self {
45        Sum {
46            iter: self.iter.init(start, cnt, Sealed),
47        }
48    }
49    fn next(&mut self) -> Option<Self::Item> {
50        self.iter.next()
51    }
52    fn into_am(&self, schedule: IterSchedule) -> LamellarArcLocalAm {
53        Arc::new(SumAm {
54            iter: self.iter_clone(Sealed),
55            schedule,
56        })
57    }
58    fn create_handle(
59        self,
60        _team: Pin<Arc<LamellarTeamRT>>,
61        reqs: VecDeque<TaskGroupLocalAmHandle<Self::AmOutput>>,
62    ) -> Self::Handle {
63        InnerLocalIterSumHandle {
64            reqs,
65            state: InnerState::ReqsPending(None),
66            spawned: false,
67        }
68    }
69    fn max_elems(&self, in_elems: usize) -> usize {
70        self.iter.elems(in_elems)
71    }
72}
73
74//#[doc(hidden)]
75#[pin_project]
76pub(crate) struct InnerLocalIterSumHandle<T> {
77    pub(crate) reqs: VecDeque<TaskGroupLocalAmHandle<T>>,
78    state: InnerState<T>,
79    spawned: bool,
80}
81
82enum InnerState<T> {
83    ReqsPending(Option<T>),
84}
85
86impl<T> Future for InnerLocalIterSumHandle<T>
87where
88    T: SyncSend + for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T> + 'static,
89{
90    type Output = T;
91    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
92        if !self.spawned {
93            for req in self.reqs.iter_mut() {
94                req.ready_or_set_waker(cx.waker());
95            }
96            self.spawned = true;
97        }
98        let mut this = self.project();
99        match &mut this.state {
100            InnerState::ReqsPending(local_sum) => {
101                while let Some(mut req) = this.reqs.pop_front() {
102                    if !req.ready_or_set_waker(cx.waker()) {
103                        this.reqs.push_front(req);
104                        return Poll::Pending;
105                    }
106                    match local_sum {
107                        Some(sum) => {
108                            *sum = [sum, &req.val()].into_iter().sum::<T>();
109                        }
110                        None => {
111                            *local_sum = Some(req.val());
112                        }
113                    }
114                }
115
116                Poll::Ready(local_sum.take().expect("Value should be Present"))
117            }
118        }
119    }
120}
121/// This handle allows you to wait for the completion of a local iterator sum operation.
122#[pin_project(PinnedDrop)]
123pub struct LocalIterSumHandle<T> {
124    array: UnsafeArrayInner,
125    launched: bool,
126    #[pin]
127    state: State<T>,
128}
129
130#[pinned_drop]
131impl<T> PinnedDrop for LocalIterSumHandle<T> {
132    fn drop(self: Pin<&mut Self>) {
133        if !self.launched {
134            let mut this = self.project();
135            RuntimeWarning::disable_warnings();
136            *this.state = State::Dropped;
137            RuntimeWarning::enable_warnings();
138            RuntimeWarning::DroppedHandle("a LocalIterSumHandle").print();
139        }
140    }
141}
142
143impl<T> LocalIterSumHandle<T>
144where
145    T: SyncSend + for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T> + 'static,
146{
147    pub(crate) fn new(
148        lock: Option<IterLockFuture>,
149        inner: Pin<Box<dyn Future<Output = InnerLocalIterSumHandle<T>> + Send>>,
150        array: &UnsafeArrayInner,
151    ) -> Self {
152        Self {
153            array: array.clone(),
154            launched: false,
155            state: State::Init(lock, inner),
156        }
157    }
158
159    /// This method will block until the associated Sumoperation completes and returns the result
160    pub fn block(mut self) -> T {
161        self.launched = true;
162        RuntimeWarning::BlockingCall(
163            "LocalIterSumHandle::block",
164            "<handle>.spawn() or <handle>.await",
165        )
166        .print();
167        self.array.clone().block_on(self)
168    }
169    /// This method will spawn the associated Sum Operation on the work queue,
170    /// initiating the remote operation.
171    ///
172    /// This function returns a handle that can be used to wait for the operation to complete
173    #[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if  it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
174    pub fn spawn(mut self) -> LamellarTask<T> {
175        self.launched = true;
176
177        self.array.clone().spawn(self)
178    }
179}
180
181#[pin_project(project = StateProj)]
182enum State<T> {
183    Init(
184        Option<IterLockFuture>,
185        Pin<Box<dyn Future<Output = InnerLocalIterSumHandle<T>> + Send>>,
186    ),
187    Reqs(#[pin] InnerLocalIterSumHandle<T>),
188    Dropped,
189}
190impl<T> Future for LocalIterSumHandle<T>
191where
192    T: SyncSend + for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T> + 'static,
193{
194    type Output = T;
195    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
196        self.launched = true;
197        let mut this = self.project();
198        match this.state.as_mut().project() {
199            StateProj::Init(lock, inner) => {
200                if let Some(lock) = lock {
201                    ready!(lock.as_mut().poll(cx));
202                }
203                let mut inner = ready!(Future::poll(inner.as_mut(), cx));
204                match Pin::new(&mut inner).poll(cx) {
205                    Poll::Ready(val) => Poll::Ready(val),
206                    Poll::Pending => {
207                        *this.state = State::Reqs(inner);
208                        Poll::Pending
209                    }
210                }
211            }
212            StateProj::Reqs(inner) => {
213                let val = ready!(inner.poll(cx));
214                Poll::Ready(val)
215            }
216            StateProj::Dropped => panic!("called `Future::poll()` on a dropped future."),
217        }
218    }
219}
220
221#[lamellar_impl::AmLocalDataRT(Clone)]
222pub(crate) struct SumAm<I> {
223    pub(crate) iter: Sum<I>,
224    pub(crate) schedule: IterSchedule,
225}
226
227impl<I: InnerIter> InnerIter for SumAm<I> {
228    fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
229        None
230    }
231    fn iter_clone(&self, _s: Sealed) -> Self {
232        SumAm {
233            iter: self.iter.iter_clone(Sealed),
234            schedule: self.schedule.clone(),
235        }
236    }
237}
238
239#[lamellar_impl::rt_am_local]
240impl<I> LamellarAm for SumAm<I>
241where
242    I: LocalIterator + 'static,
243    I::Item: SyncSend + for<'a> std::iter::Sum<&'a I::Item> + std::iter::Sum<I::Item>,
244{
245    async fn exec(&self) -> I::Item {
246        let iter = self.schedule.init_iter(self.iter.iter_clone(Sealed));
247        iter.sum::<I::Item>()
248    }
249}