lamellar/array/iterator/local_iterator/consumer/
count.rs1use crate::active_messaging::LamellarArcLocalAm;
2use crate::array::iterator::local_iterator::LocalIterator;
3use crate::array::iterator::{consumer::*, private::*, IterLockFuture};
4use crate::array::r#unsafe::private::UnsafeArrayInner;
5use crate::lamellar_request::LamellarRequest;
6use crate::lamellar_task_group::TaskGroupLocalAmHandle;
7use crate::lamellar_team::LamellarTeamRT;
8use crate::scheduler::LamellarTask;
9use crate::warnings::RuntimeWarning;
10
11use futures_util::{ready, Future};
12use pin_project::{pin_project, pinned_drop};
13use std::collections::VecDeque;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18#[derive(Clone, Debug)]
19pub(crate) struct Count<I> {
20 pub(crate) iter: I,
21}
22
23impl<I: InnerIter> InnerIter for Count<I> {
24 fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
25 None
26 }
27 fn iter_clone(&self, _s: Sealed) -> Self {
28 Count {
29 iter: self.iter.iter_clone(Sealed),
30 }
31 }
32}
33
34impl<I> IterConsumer for Count<I>
35where
36 I: LocalIterator,
37{
38 type AmOutput = usize;
39 type Output = usize;
40 type Item = I::Item;
41 type Handle = InnerLocalIterCountHandle;
42 fn init(&self, start: usize, cnt: usize) -> Self {
43 Count {
44 iter: self.iter.init(start, cnt, Sealed),
45 }
46 }
47 fn next(&mut self) -> Option<Self::Item> {
48 self.iter.next()
49 }
50 fn into_am(&self, schedule: IterSchedule) -> LamellarArcLocalAm {
51 Arc::new(CountAm {
52 iter: self.iter_clone(Sealed),
53 schedule,
54 })
55 }
56 fn create_handle(
57 self,
58 _team: Pin<Arc<LamellarTeamRT>>,
59 reqs: VecDeque<TaskGroupLocalAmHandle<Self::AmOutput>>,
60 ) -> InnerLocalIterCountHandle {
61 InnerLocalIterCountHandle {
62 reqs,
63 state: InnerState::ReqsPending(0),
64 spawned: false,
65 }
66 }
67 fn max_elems(&self, in_elems: usize) -> usize {
68 self.iter.elems(in_elems)
69 }
70}
71
72#[pin_project]
74pub(crate) struct InnerLocalIterCountHandle {
75 pub(crate) reqs: VecDeque<TaskGroupLocalAmHandle<usize>>,
76 state: InnerState,
77 spawned: bool,
78}
79
80enum InnerState {
81 ReqsPending(usize),
82}
83
84impl Future for InnerLocalIterCountHandle {
85 type Output = usize;
86 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
87 if !self.spawned {
88 for req in self.reqs.iter_mut() {
89 req.ready_or_set_waker(cx.waker());
90 }
91 self.spawned = true;
92 }
93 let mut this = self.project();
94 match &mut this.state {
95 InnerState::ReqsPending(cnt) => {
96 while let Some(mut req) = this.reqs.pop_front() {
97 if !req.ready_or_set_waker(cx.waker()) {
98 this.reqs.push_front(req);
99 return Poll::Pending;
100 }
101 *cnt += req.val();
102 }
103 Poll::Ready(*cnt)
104 }
105 }
106 }
107}
108
109#[pin_project(PinnedDrop)]
111pub struct LocalIterCountHandle {
112 array: UnsafeArrayInner,
113 launched: bool,
114 #[pin]
115 state: State,
116}
117
118#[pinned_drop]
119impl PinnedDrop for LocalIterCountHandle {
120 fn drop(self: Pin<&mut Self>) {
121 if !self.launched {
122 let mut this = self.project();
123 RuntimeWarning::disable_warnings();
124 *this.state = State::Dropped;
125 RuntimeWarning::enable_warnings();
126 RuntimeWarning::DroppedHandle("a LocalIterCountHandle").print();
127 }
128 }
129}
130
131impl LocalIterCountHandle {
132 pub(crate) fn new(
133 lock: Option<IterLockFuture>,
134 inner: Pin<Box<dyn Future<Output = InnerLocalIterCountHandle> + Send>>,
135 array: &UnsafeArrayInner,
136 ) -> Self {
137 Self {
138 array: array.clone(),
139 launched: false,
140 state: State::Init(lock, inner),
141 }
142 }
143
144 pub fn block(mut self) -> usize {
146 self.launched = true;
147 RuntimeWarning::BlockingCall(
148 "LocalIterCountHandle::block",
149 "<handle>.spawn() or <handle>.await",
150 )
151 .print();
152 self.array.clone().block_on(self)
153 }
154
155 #[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
160 pub fn spawn(mut self) -> LamellarTask<usize> {
161 self.launched = true;
162 self.array.clone().spawn(self)
163 }
164}
165
166#[pin_project(project = StateProj)]
167enum State {
168 Init(
169 Option<IterLockFuture>,
170 Pin<Box<dyn Future<Output = InnerLocalIterCountHandle> + Send>>,
171 ),
172 Reqs(#[pin] InnerLocalIterCountHandle),
173 Dropped,
174}
175impl Future for LocalIterCountHandle {
176 type Output = usize;
177 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
178 self.launched = true;
179 let mut this = self.project();
180 match this.state.as_mut().project() {
181 StateProj::Init(lock, inner) => {
182 if let Some(lock) = lock {
183 ready!(lock.as_mut().poll(cx));
184 }
185 let mut inner = ready!(Future::poll(inner.as_mut(), cx));
186 match Pin::new(&mut inner).poll(cx) {
187 Poll::Ready(val) => Poll::Ready(val),
188 Poll::Pending => {
189 *this.state = State::Reqs(inner);
190 Poll::Pending
191 }
192 }
193 }
194 StateProj::Reqs(inner) => {
195 let val = ready!(inner.poll(cx));
196 Poll::Ready(val)
197 }
198 StateProj::Dropped => panic!("called `Future::poll()` on a dropped future."),
199 }
200 }
201}
202
203#[lamellar_impl::AmLocalDataRT(Clone)]
204pub(crate) struct CountAm<I> {
205 pub(crate) iter: Count<I>,
206 pub(crate) schedule: IterSchedule,
207}
208
209impl<I> InnerIter for CountAm<I>
210where
211 I: InnerIter,
212{
213 fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
214 None
215 }
216 fn iter_clone(&self, _s: Sealed) -> Self {
217 CountAm {
218 iter: self.iter.iter_clone(Sealed),
219 schedule: self.schedule.clone(),
220 }
221 }
222}
223
224#[lamellar_impl::rt_am_local]
225impl<I> LamellarAm for CountAm<I>
226where
227 I: LocalIterator + 'static,
228{
229 async fn exec(&self) -> usize {
230 let mut iter = self.schedule.init_iter(self.iter.iter_clone(Sealed));
231 let mut count: usize = 0;
232 while let Some(_) = iter.next() {
233 count += 1;
234 }
235 count
237 }
238}