lamellar/array/iterator/local_iterator/consumer/
count.rs

1use crate::active_messaging::LamellarArcLocalAm;
2use crate::array::iterator::local_iterator::LocalIterator;
3use crate::array::iterator::{consumer::*, private::*, IterLockFuture};
4use crate::array::r#unsafe::private::UnsafeArrayInner;
5use crate::lamellar_request::LamellarRequest;
6use crate::lamellar_task_group::TaskGroupLocalAmHandle;
7use crate::lamellar_team::LamellarTeamRT;
8use crate::scheduler::LamellarTask;
9use crate::warnings::RuntimeWarning;
10
11use futures_util::{ready, Future};
12use pin_project::{pin_project, pinned_drop};
13use std::collections::VecDeque;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18#[derive(Clone, Debug)]
19pub(crate) struct Count<I> {
20    pub(crate) iter: I,
21}
22
23impl<I: InnerIter> InnerIter for Count<I> {
24    fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
25        None
26    }
27    fn iter_clone(&self, _s: Sealed) -> Self {
28        Count {
29            iter: self.iter.iter_clone(Sealed),
30        }
31    }
32}
33
34impl<I> IterConsumer for Count<I>
35where
36    I: LocalIterator,
37{
38    type AmOutput = usize;
39    type Output = usize;
40    type Item = I::Item;
41    type Handle = InnerLocalIterCountHandle;
42    fn init(&self, start: usize, cnt: usize) -> Self {
43        Count {
44            iter: self.iter.init(start, cnt, Sealed),
45        }
46    }
47    fn next(&mut self) -> Option<Self::Item> {
48        self.iter.next()
49    }
50    fn into_am(&self, schedule: IterSchedule) -> LamellarArcLocalAm {
51        Arc::new(CountAm {
52            iter: self.iter_clone(Sealed),
53            schedule,
54        })
55    }
56    fn create_handle(
57        self,
58        _team: Pin<Arc<LamellarTeamRT>>,
59        reqs: VecDeque<TaskGroupLocalAmHandle<Self::AmOutput>>,
60    ) -> InnerLocalIterCountHandle {
61        InnerLocalIterCountHandle {
62            reqs,
63            state: InnerState::ReqsPending(0),
64            spawned: false,
65        }
66    }
67    fn max_elems(&self, in_elems: usize) -> usize {
68        self.iter.elems(in_elems)
69    }
70}
71
72//#[doc(hidden)]
73#[pin_project]
74pub(crate) struct InnerLocalIterCountHandle {
75    pub(crate) reqs: VecDeque<TaskGroupLocalAmHandle<usize>>,
76    state: InnerState,
77    spawned: bool,
78}
79
80enum InnerState {
81    ReqsPending(usize),
82}
83
84impl Future for InnerLocalIterCountHandle {
85    type Output = usize;
86    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
87        if !self.spawned {
88            for req in self.reqs.iter_mut() {
89                req.ready_or_set_waker(cx.waker());
90            }
91            self.spawned = true;
92        }
93        let mut this = self.project();
94        match &mut this.state {
95            InnerState::ReqsPending(cnt) => {
96                while let Some(mut req) = this.reqs.pop_front() {
97                    if !req.ready_or_set_waker(cx.waker()) {
98                        this.reqs.push_front(req);
99                        return Poll::Pending;
100                    }
101                    *cnt += req.val();
102                }
103                Poll::Ready(*cnt)
104            }
105        }
106    }
107}
108
109/// This handle allows you to wait for the completion of a local iterator count operation.
110#[pin_project(PinnedDrop)]
111pub struct LocalIterCountHandle {
112    array: UnsafeArrayInner,
113    launched: bool,
114    #[pin]
115    state: State,
116}
117
118#[pinned_drop]
119impl PinnedDrop for LocalIterCountHandle {
120    fn drop(self: Pin<&mut Self>) {
121        if !self.launched {
122            let mut this = self.project();
123            RuntimeWarning::disable_warnings();
124            *this.state = State::Dropped;
125            RuntimeWarning::enable_warnings();
126            RuntimeWarning::DroppedHandle("a LocalIterCountHandle").print();
127        }
128    }
129}
130
131impl LocalIterCountHandle {
132    pub(crate) fn new(
133        lock: Option<IterLockFuture>,
134        inner: Pin<Box<dyn Future<Output = InnerLocalIterCountHandle> + Send>>,
135        array: &UnsafeArrayInner,
136    ) -> Self {
137        Self {
138            array: array.clone(),
139            launched: false,
140            state: State::Init(lock, inner),
141        }
142    }
143
144    /// This method will block until the associated Count operation completes and returns the result
145    pub fn block(mut self) -> usize {
146        self.launched = true;
147        RuntimeWarning::BlockingCall(
148            "LocalIterCountHandle::block",
149            "<handle>.spawn() or <handle>.await",
150        )
151        .print();
152        self.array.clone().block_on(self)
153    }
154
155    /// This method will spawn the associated Count Operation on the work queue,
156    /// initiating the remote operation.
157    ///
158    /// This function returns a handle that can be used to wait for the operation to complete
159    #[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if  it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
160    pub fn spawn(mut self) -> LamellarTask<usize> {
161        self.launched = true;
162        self.array.clone().spawn(self)
163    }
164}
165
166#[pin_project(project = StateProj)]
167enum State {
168    Init(
169        Option<IterLockFuture>,
170        Pin<Box<dyn Future<Output = InnerLocalIterCountHandle> + Send>>,
171    ),
172    Reqs(#[pin] InnerLocalIterCountHandle),
173    Dropped,
174}
175impl Future for LocalIterCountHandle {
176    type Output = usize;
177    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
178        self.launched = true;
179        let mut this = self.project();
180        match this.state.as_mut().project() {
181            StateProj::Init(lock, inner) => {
182                if let Some(lock) = lock {
183                    ready!(lock.as_mut().poll(cx));
184                }
185                let mut inner = ready!(Future::poll(inner.as_mut(), cx));
186                match Pin::new(&mut inner).poll(cx) {
187                    Poll::Ready(val) => Poll::Ready(val),
188                    Poll::Pending => {
189                        *this.state = State::Reqs(inner);
190                        Poll::Pending
191                    }
192                }
193            }
194            StateProj::Reqs(inner) => {
195                let val = ready!(inner.poll(cx));
196                Poll::Ready(val)
197            }
198            StateProj::Dropped => panic!("called `Future::poll()` on a dropped future."),
199        }
200    }
201}
202
203#[lamellar_impl::AmLocalDataRT(Clone)]
204pub(crate) struct CountAm<I> {
205    pub(crate) iter: Count<I>,
206    pub(crate) schedule: IterSchedule,
207}
208
209impl<I> InnerIter for CountAm<I>
210where
211    I: InnerIter,
212{
213    fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
214        None
215    }
216    fn iter_clone(&self, _s: Sealed) -> Self {
217        CountAm {
218            iter: self.iter.iter_clone(Sealed),
219            schedule: self.schedule.clone(),
220        }
221    }
222}
223
224#[lamellar_impl::rt_am_local]
225impl<I> LamellarAm for CountAm<I>
226where
227    I: LocalIterator + 'static,
228{
229    async fn exec(&self) -> usize {
230        let mut iter = self.schedule.init_iter(self.iter.iter_clone(Sealed));
231        let mut count: usize = 0;
232        while let Some(_) = iter.next() {
233            count += 1;
234        }
235        // println!("count: {} {:?}", count, std::thread::current().id());
236        count
237    }
238}