lamellar/array/iterator/distributed_iterator/consumer/
reduce.rs1use crate::active_messaging::{LamellarArcLocalAm, SyncSend};
2use crate::array::iterator::distributed_iterator::DistributedIterator;
3use crate::array::iterator::one_sided_iterator::OneSidedIterator;
4use crate::array::iterator::private::*;
5use crate::array::iterator::{consumer::*, IterLockFuture};
6use crate::array::r#unsafe::private::UnsafeArrayInner;
7use crate::array::{ArrayOps, Distribution, UnsafeArray};
8use crate::barrier::BarrierHandle;
9use crate::lamellar_request::LamellarRequest;
10use crate::lamellar_task_group::TaskGroupLocalAmHandle;
11use crate::lamellar_team::LamellarTeamRT;
12use crate::scheduler::LamellarTask;
13use crate::warnings::RuntimeWarning;
14use crate::Dist;
15
16use futures_util::{ready, Future, StreamExt};
17use pin_project::{pin_project, pinned_drop};
18use std::collections::VecDeque;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23
24
25#[derive(Clone, Debug)]
26pub(crate) struct Reduce<I, F> {
27 pub(crate) iter: I,
28 pub(crate) op: F,
29}
30
31impl<I: InnerIter, F: Clone> InnerIter for Reduce<I, F> {
32 fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
33 None
34 }
35 fn iter_clone(&self, _s: Sealed) -> Self {
36 Reduce {
37 iter: self.iter.iter_clone(Sealed),
38 op: self.op.clone(),
39 }
40 }
41}
42
43impl<I, F> IterConsumer for Reduce<I, F>
44where
45 I: DistributedIterator + 'static,
46 I::Item: Dist + ArrayOps,
47 F: Fn(I::Item, I::Item) -> I::Item + SyncSend + Clone + 'static,
48{
49 type AmOutput = Option<I::Item>;
50 type Output = Option<I::Item>;
51 type Item = I::Item;
52 type Handle = InnerDistIterReduceHandle<I::Item, F>;
53 fn init(&self, start: usize, cnt: usize) -> Self {
54 Reduce {
55 iter: self.iter.init(start, cnt, Sealed),
56 op: self.op.clone(),
57 }
58 }
59 fn next(&mut self) -> Option<Self::Item> {
60 self.iter.next()
61 }
62 fn into_am(&self, schedule: IterSchedule) -> LamellarArcLocalAm {
63 Arc::new(ReduceAm {
64 iter: self.iter_clone(Sealed),
65 op: self.op.clone(),
66 schedule,
67 })
68 }
69 fn create_handle(
70 self,
71 team: Pin<Arc<LamellarTeamRT>>,
72 reqs: VecDeque<TaskGroupLocalAmHandle<Self::AmOutput>>,
73 ) -> Self::Handle {
74 InnerDistIterReduceHandle {
75 op: self.op,
76 reqs,
77 team,
78 state: InnerState::ReqsPending(None),
79 spawned: false,
80 }
81 }
82 fn max_elems(&self, in_elems: usize) -> usize {
83 self.iter.elems(in_elems)
84 }
85}
86#[pin_project]
88pub(crate) struct InnerDistIterReduceHandle<T, F> {
89 pub(crate) reqs: VecDeque<TaskGroupLocalAmHandle<Option<T>>>,
90 pub(crate) op: F,
91 pub(crate) team: Pin<Arc<LamellarTeamRT>>,
92 state: InnerState<T>,
93 spawned: bool,
94}
95
96enum InnerState<T> {
97 ReqsPending(Option<T>),
98 Reducing(Pin<Box<dyn Future<Output = Option<T>> + Send + 'static>>),
99}
100
101impl<T, F> InnerDistIterReduceHandle<T, F>
102where
103 T: Dist + ArrayOps,
104 F: Fn(T, T) -> T + SyncSend + Clone + 'static,
105{
106 async fn async_reduce_remote_vals(
107 local_val: T,
108 team: Pin<Arc<LamellarTeamRT>>,
109 op: F,
110 ) -> Option<T> {
111 let local_vals = UnsafeArray::<T>::async_new(
112 &team,
113 team.num_pes,
114 Distribution::Block,
115 crate::darc::DarcMode::UnsafeArray,
116 )
117 .await;
118 unsafe {
119 local_vals.local_as_mut_slice()[0] = local_val;
120 };
121 local_vals.async_barrier().await;
122 let buffered_iter = unsafe { local_vals.buffered_onesided_iter(team.num_pes) };
123 let mut stream = buffered_iter.into_stream();
124 let first = stream.next().await?;
125
126 Some(
127 stream
128 .fold(*first, |a, &b| {
129 let val = op(a, b);
130 async move { val }
131 })
132 .await,
133 )
134 }
135
136 }
151
152impl<T, F> Future for InnerDistIterReduceHandle<T, F>
153where
154 T: Dist + ArrayOps,
155 F: Fn(T, T) -> T + SyncSend + Clone + 'static,
156{
157 type Output = Option<T>;
158 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159 if !self.spawned {
160 for req in self.reqs.iter_mut() {
161 req.ready_or_set_waker(cx.waker());
162 }
163 self.spawned = true;
164 }
165 let mut this = self.project();
166 match &mut this.state {
167 InnerState::ReqsPending(mut val) => {
168 while let Some(mut req) = this.reqs.pop_front() {
169 if !req.ready_or_set_waker(cx.waker()) {
170 this.reqs.push_front(req);
171 return Poll::Pending;
172 }
173 match val {
174 None => val = req.val(),
175 Some(val1) => {
176 if let Some(val2) = req.val() {
177 val = Some((this.op)(val1, val2));
178 }
179 }
180 }
181 }
182 if let Some(val) = val {
183 let mut reducing = Box::pin(Self::async_reduce_remote_vals(
184 val.clone(),
185 this.team.clone(),
186 this.op.clone(),
187 ));
188 match Future::poll(reducing.as_mut(), cx) {
189 Poll::Ready(val) => Poll::Ready(val),
190 Poll::Pending => {
191 *this.state = InnerState::Reducing(reducing);
192 Poll::Pending
193 }
194 }
195 } else {
196 Poll::Ready(None)
197 }
198 }
199 InnerState::Reducing(reducing) => {
200 let val = ready!(Future::poll(reducing.as_mut(), cx));
201 Poll::Ready(val)
202 }
203 }
204 }
205}
206
207#[pin_project(PinnedDrop)]
209pub struct DistIterReduceHandle<T, F> {
210 array: UnsafeArrayInner,
211 launched: bool,
212 #[pin]
213 state: State<T, F>,
214}
215
216#[pinned_drop]
217impl<T, F> PinnedDrop for DistIterReduceHandle<T, F> {
218 fn drop(self: Pin<&mut Self>) {
219 if !self.launched {
220 let mut this = self.project();
221 RuntimeWarning::disable_warnings();
222 *this.state = State::Dropped;
223 RuntimeWarning::enable_warnings();
224 RuntimeWarning::DroppedHandle("a DistIterReduceHandle").print();
225 }
226 }
227}
228
229impl<T, F> DistIterReduceHandle<T, F>
230where
231 T: Dist + ArrayOps,
232 F: Fn(T, T) -> T + SyncSend + Clone + 'static,
233{
234 pub(crate) fn new(
235 lock: Option<IterLockFuture>,
236 reqs: Pin<Box<dyn Future<Output = InnerDistIterReduceHandle<T, F>> + Send>>,
237 array: &UnsafeArrayInner,
238 ) -> Self {
239 let state = match lock {
240 Some(inner_lock) => State::Lock(inner_lock, Some(reqs)),
241 None => State::Barrier(array.barrier_handle(), reqs),
242 };
243 Self {
244 array: array.clone(),
245 launched: false,
246 state,
247 }
248 }
249
250 pub fn block(mut self) -> Option<T> {
252 self.launched = true;
253 RuntimeWarning::BlockingCall(
254 "DistIterReduceHandle::block",
255 "<handle>.spawn() or <handle>.await",
256 )
257 .print();
258 self.array.clone().block_on(self)
259 }
260
261 #[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
266 pub fn spawn(mut self) -> LamellarTask<Option<T>> {
267 self.launched = true;
268 self.array.clone().spawn(self)
269 }
270}
271
272#[pin_project(project = StateProj)]
273enum State<T, F> {
274 Lock(
275 #[pin] IterLockFuture,
276 Option<Pin<Box<dyn Future<Output = InnerDistIterReduceHandle<T, F>> + Send>>>,
277 ),
278 Barrier(
279 #[pin] BarrierHandle,
280 Pin<Box<dyn Future<Output = InnerDistIterReduceHandle<T, F>> + Send>>,
281 ),
282 Reqs(#[pin] InnerDistIterReduceHandle<T, F>),
283 Dropped,
284}
285impl<T, F> Future for DistIterReduceHandle<T, F>
286where
287 T: Dist + ArrayOps,
288 F: Fn(T, T) -> T + SyncSend + Clone + 'static,
289{
290 type Output = Option<T>;
291 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292 self.launched = true;
293 let mut this = self.project();
294 match this.state.as_mut().project() {
295 StateProj::Lock(lock, inner) => {
296 ready!(lock.poll(cx));
297 let barrier = this.array.barrier_handle();
298 *this.state = State::Barrier(
299 barrier,
300 inner.take().expect("reqs should still be in this state"),
301 );
302 cx.waker().wake_by_ref();
303 Poll::Pending
304 }
305 StateProj::Barrier(barrier, inner) => {
306 ready!(barrier.poll(cx));
307 let mut inner = ready!(Future::poll(inner.as_mut(), cx));
308 match Pin::new(&mut inner).poll(cx) {
309 Poll::Ready(val) => Poll::Ready(val),
310 Poll::Pending => {
311 *this.state = State::Reqs(inner);
312 Poll::Pending
313 }
314 }
315 }
316 StateProj::Reqs(inner) => {
317 let val = ready!(inner.poll(cx));
318 Poll::Ready(val)
319 }
320 StateProj::Dropped => panic!("called `Future::poll()` on a dropped future."),
321 }
322 }
323}
324
325#[lamellar_impl::AmLocalDataRT(Clone)]
326pub(crate) struct ReduceAm<I, F> {
327 pub(crate) op: F,
328 pub(crate) iter: Reduce<I, F>,
329 pub(crate) schedule: IterSchedule,
330}
331
332impl<I: InnerIter, F: Clone> InnerIter for ReduceAm<I, F> {
333 fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
334 None
335 }
336 fn iter_clone(&self, _s: Sealed) -> Self {
337 ReduceAm {
338 op: self.op.clone(),
339 iter: self.iter.iter_clone(Sealed),
340 schedule: self.schedule.clone(),
341 }
342 }
343}
344
345#[lamellar_impl::rt_am_local]
346impl<I, F> LamellarAm for ReduceAm<I, F>
347where
348 I: DistributedIterator + 'static,
349 I::Item: Dist + ArrayOps,
350 F: Fn(I::Item, I::Item) -> I::Item + SyncSend + Clone + 'static,
351{
352 async fn exec(&self) -> Option<I::Item> {
353 let mut iter = self.schedule.init_iter(self.iter.iter_clone(Sealed));
354 match iter.next() {
355 Some(mut accum) => {
356 while let Some(elem) = iter.next() {
357 accum = (self.op)(accum, elem);
358 }
359 Some(accum)
360 }
361 None => None,
362 }
363 }
364}