lamellar/array/iterator/distributed_iterator/consumer/
sum.rs1use crate::active_messaging::LamellarArcLocalAm;
2use crate::array::iterator::distributed_iterator::DistributedIterator;
3use crate::array::iterator::private::*;
4use crate::array::iterator::{consumer::*, IterLockFuture};
5use crate::array::r#unsafe::private::UnsafeArrayInner;
6use crate::array::{ArrayOps, Distribution, UnsafeArray};
7use crate::barrier::BarrierHandle;
8use crate::lamellar_request::LamellarRequest;
9use crate::lamellar_task_group::TaskGroupLocalAmHandle;
10use crate::lamellar_team::LamellarTeamRT;
11use crate::scheduler::LamellarTask;
12use crate::warnings::RuntimeWarning;
13use crate::Dist;
14use futures_util::{ready, Future};
15use pin_project::{pin_project, pinned_drop};
16use std::collections::VecDeque;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20
21#[derive(Clone, Debug)]
22pub(crate) struct Sum<I> {
23 pub(crate) iter: I,
24}
25
26impl<I: InnerIter> InnerIter for Sum<I> {
27 fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
28 None
29 }
30 fn iter_clone(&self, _s: Sealed) -> Self {
31 Sum {
32 iter: self.iter.iter_clone(Sealed),
33 }
34 }
35}
36
37impl<I> IterConsumer for Sum<I>
38where
39 I: DistributedIterator + 'static,
40 I::Item: Dist + ArrayOps + std::iter::Sum,
41{
42 type AmOutput = I::Item;
43 type Output = I::Item;
44 type Item = I::Item;
45 type Handle = InnerDistIterSumHandle<I::Item>;
46 fn init(&self, start: usize, cnt: usize) -> Self {
47 Sum {
48 iter: self.iter.init(start, cnt, Sealed),
49 }
50 }
51 fn next(&mut self) -> Option<Self::Item> {
52 self.iter.next()
53 }
54 fn into_am(&self, schedule: IterSchedule) -> LamellarArcLocalAm {
55 Arc::new(SumAm {
56 iter: self.iter_clone(Sealed),
57 schedule,
58 })
59 }
60 fn create_handle(
61 self,
62 team: Pin<Arc<LamellarTeamRT>>,
63 reqs: VecDeque<TaskGroupLocalAmHandle<Self::AmOutput>>,
64 ) -> Self::Handle {
65 InnerDistIterSumHandle {
66 reqs,
67 team,
68 state: InnerState::ReqsPending(None),
69 spawned: false,
70 }
71 }
72 fn max_elems(&self, in_elems: usize) -> usize {
73 self.iter.elems(in_elems)
74 }
75}
76
77#[pin_project]
79pub(crate) struct InnerDistIterSumHandle<T> {
80 pub(crate) reqs: VecDeque<TaskGroupLocalAmHandle<T>>,
81 pub(crate) team: Pin<Arc<LamellarTeamRT>>,
82 state: InnerState<T>,
83 spawned: bool,
84}
85
86enum InnerState<T> {
87 ReqsPending(Option<T>),
88 Summing(Pin<Box<dyn Future<Output = T> + Send>>),
89}
90
91impl<T> InnerDistIterSumHandle<T>
92where
93 T: Dist + ArrayOps + std::iter::Sum,
94{
95 async fn async_reduce_remote_vals(local_sum: T, team: Pin<Arc<LamellarTeamRT>>) -> T {
96 let local_sums = UnsafeArray::<T>::async_new(
97 &team,
98 team.num_pes,
99 Distribution::Block,
100 crate::darc::DarcMode::UnsafeArray,
101 )
102 .await;
103 unsafe {
104 local_sums.local_as_mut_slice()[0] = local_sum;
105 };
106 local_sums.async_barrier().await;
107 unsafe {
110 local_sums
111 .sum()
112 .await
113 .expect("array size is greater than zero")
114 }
115 }
116
117 }
132
133impl<T> Future for InnerDistIterSumHandle<T>
134where
135 T: Dist + ArrayOps + std::iter::Sum,
136{
137 type Output = T;
138 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139 if !self.spawned {
140 for req in self.reqs.iter_mut() {
141 req.ready_or_set_waker(cx.waker());
142 }
143 self.spawned = true;
144 }
145 let mut this = self.project();
146 match &mut this.state {
147 InnerState::ReqsPending(local_sum) => {
148 while let Some(mut req) = this.reqs.pop_front() {
149 if !req.ready_or_set_waker(cx.waker()) {
150 this.reqs.push_front(req);
151 return Poll::Pending;
152 }
153 match local_sum {
154 Some(sum) => {
155 *sum = [*sum, req.val()].into_iter().sum();
156 }
157 None => {
158 *local_sum = Some(req.val());
159 }
160 }
161 }
162 let mut sum = Box::pin(Self::async_reduce_remote_vals(
163 local_sum.unwrap(),
164 this.team.clone(),
165 ));
166 match Future::poll(sum.as_mut(), cx) {
167 Poll::Ready(local_sum) => Poll::Ready(local_sum),
168 Poll::Pending => {
169 *this.state = InnerState::Summing(sum);
170 Poll::Pending
171 }
172 }
173 }
174 InnerState::Summing(sum) => {
175 let local_sum = ready!(Future::poll(sum.as_mut(), cx));
176 Poll::Ready(local_sum)
177 }
178 }
179 }
180}
181
182
183
184#[pin_project(PinnedDrop)]
186pub struct DistIterSumHandle<T> {
187 array: UnsafeArrayInner,
188 launched: bool,
189 #[pin]
190 state: State<T>,
191}
192
193#[pinned_drop]
194impl<T> PinnedDrop for DistIterSumHandle<T> {
195 fn drop(self: Pin<&mut Self>) {
196 if !self.launched {
197 let mut this = self.project();
198 RuntimeWarning::disable_warnings();
199 *this.state = State::Dropped;
200 RuntimeWarning::enable_warnings();
201 RuntimeWarning::DroppedHandle("a DistIterSumHandle").print();
202 }
203 }
204}
205
206impl<T> DistIterSumHandle<T>
207where
208 T: Dist + ArrayOps + std::iter::Sum,
209{
210 pub(crate) fn new(
211 lock: Option<IterLockFuture>,
212 inner: Pin<Box<dyn Future<Output = InnerDistIterSumHandle<T>> + Send>>,
213 array: &UnsafeArrayInner,
214 ) -> Self {
215 let state = match lock {
216 Some(inner_lock) => State::Lock(inner_lock, Some(inner)),
217 None => State::Barrier(array.barrier_handle(), inner),
218 };
219 Self {
220 array: array.clone(),
221 launched: false,
222 state,
223 }
224 }
225
226 pub fn block(mut self) -> T {
228 self.launched = true;
229 RuntimeWarning::BlockingCall(
230 "DistIterSumHandle::block",
231 "<handle>.spawn() or <handle>.await",
232 )
233 .print();
234 self.array.clone().block_on(self)
235 }
236
237 #[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
242 pub fn spawn(mut self) -> LamellarTask<T> {
243 self.launched = true;
244 self.array.clone().spawn(self)
245 }
246}
247
248#[pin_project(project = StateProj)]
249enum State<T> {
250 Lock(
251 #[pin] IterLockFuture,
252 Option<Pin<Box<dyn Future<Output = InnerDistIterSumHandle<T>> + Send>>>,
253 ),
254 Barrier(
255 #[pin] BarrierHandle,
256 Pin<Box<dyn Future<Output = InnerDistIterSumHandle<T>> + Send>>,
257 ),
258 Reqs(#[pin] InnerDistIterSumHandle<T>),
259 Dropped,
260}
261impl<T> Future for DistIterSumHandle<T>
262where
263 T: Dist + ArrayOps + std::iter::Sum,
264{
265 type Output = T;
266 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267 self.launched = true;
268 let mut this = self.project();
269
270 match this.state.as_mut().project() {
271 StateProj::Lock(lock, inner) => {
272 ready!(lock.poll(cx));
273 let barrier = this.array.barrier_handle();
274 *this.state = State::Barrier(
275 barrier,
276 inner.take().expect("reqs should still be in this state"),
277 );
278 cx.waker().wake_by_ref();
279 Poll::Pending
280 }
281 StateProj::Barrier(barrier, inner) => {
282 ready!(barrier.poll(cx));
283 let mut inner = ready!(Future::poll(inner.as_mut(), cx));
284 match Pin::new(&mut inner).poll(cx) {
285 Poll::Ready(val) => Poll::Ready(val),
286 Poll::Pending => {
287 *this.state = State::Reqs(inner);
288 Poll::Pending
289 }
290 }
291 }
292 StateProj::Reqs(inner) => {
293 let val = ready!(inner.poll(cx));
294 Poll::Ready(val)
295 }
296 StateProj::Dropped => panic!("called `Future::poll()` on a future that was dropped"),
297 }
298 }
299}
300
301#[lamellar_impl::AmLocalDataRT(Clone)]
302pub(crate) struct SumAm<I> {
303 pub(crate) iter: Sum<I>,
304 pub(crate) schedule: IterSchedule,
305}
306
307impl<I: InnerIter> InnerIter for SumAm<I> {
308 fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
309 None
310 }
311 fn iter_clone(&self, _s: Sealed) -> Self {
312 SumAm {
313 iter: self.iter.iter_clone(Sealed),
314 schedule: self.schedule.clone(),
315 }
316 }
317}
318
319#[lamellar_impl::rt_am_local]
320impl<I> LamellarAm for SumAm<I>
321where
322 I: DistributedIterator + 'static,
323 I::Item: Dist + ArrayOps + std::iter::Sum,
324{
325 async fn exec(&self) -> I::Item {
326 let iter = self.schedule.init_iter(self.iter.iter_clone(Sealed));
327 iter.sum::<I::Item>()
328 }
329}