lamellar/array/iterator/local_iterator/consumer/
sum.rs1use crate::active_messaging::{LamellarArcLocalAm, SyncSend};
2use crate::array::iterator::local_iterator::LocalIterator;
3use crate::array::iterator::private::*;
4use crate::array::iterator::{consumer::*, IterLockFuture};
5use crate::array::r#unsafe::private::UnsafeArrayInner;
6use crate::lamellar_request::LamellarRequest;
7use crate::lamellar_task_group::TaskGroupLocalAmHandle;
8use crate::lamellar_team::LamellarTeamRT;
9use crate::scheduler::LamellarTask;
10use crate::warnings::RuntimeWarning;
11
12use futures_util::{ready, Future};
13use pin_project::{pin_project, pinned_drop};
14use std::collections::VecDeque;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18
19#[derive(Clone, Debug)]
20pub(crate) struct Sum<I> {
21 pub(crate) iter: I,
22}
23
24impl<I: InnerIter> InnerIter for Sum<I> {
25 fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
26 None
27 }
28 fn iter_clone(&self, _s: Sealed) -> Self {
29 Sum {
30 iter: self.iter.iter_clone(Sealed),
31 }
32 }
33}
34
35impl<I> IterConsumer for Sum<I>
36where
37 I: LocalIterator + 'static,
38 I::Item: SyncSend + for<'a> std::iter::Sum<&'a I::Item> + std::iter::Sum<I::Item>,
39{
40 type AmOutput = I::Item;
41 type Output = I::Item;
42 type Item = I::Item;
43 type Handle = InnerLocalIterSumHandle<I::Item>;
44 fn init(&self, start: usize, cnt: usize) -> Self {
45 Sum {
46 iter: self.iter.init(start, cnt, Sealed),
47 }
48 }
49 fn next(&mut self) -> Option<Self::Item> {
50 self.iter.next()
51 }
52 fn into_am(&self, schedule: IterSchedule) -> LamellarArcLocalAm {
53 Arc::new(SumAm {
54 iter: self.iter_clone(Sealed),
55 schedule,
56 })
57 }
58 fn create_handle(
59 self,
60 _team: Pin<Arc<LamellarTeamRT>>,
61 reqs: VecDeque<TaskGroupLocalAmHandle<Self::AmOutput>>,
62 ) -> Self::Handle {
63 InnerLocalIterSumHandle {
64 reqs,
65 state: InnerState::ReqsPending(None),
66 spawned: false,
67 }
68 }
69 fn max_elems(&self, in_elems: usize) -> usize {
70 self.iter.elems(in_elems)
71 }
72}
73
74#[pin_project]
76pub(crate) struct InnerLocalIterSumHandle<T> {
77 pub(crate) reqs: VecDeque<TaskGroupLocalAmHandle<T>>,
78 state: InnerState<T>,
79 spawned: bool,
80}
81
82enum InnerState<T> {
83 ReqsPending(Option<T>),
84}
85
86impl<T> Future for InnerLocalIterSumHandle<T>
87where
88 T: SyncSend + for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T> + 'static,
89{
90 type Output = T;
91 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
92 if !self.spawned {
93 for req in self.reqs.iter_mut() {
94 req.ready_or_set_waker(cx.waker());
95 }
96 self.spawned = true;
97 }
98 let mut this = self.project();
99 match &mut this.state {
100 InnerState::ReqsPending(local_sum) => {
101 while let Some(mut req) = this.reqs.pop_front() {
102 if !req.ready_or_set_waker(cx.waker()) {
103 this.reqs.push_front(req);
104 return Poll::Pending;
105 }
106 match local_sum {
107 Some(sum) => {
108 *sum = [sum, &req.val()].into_iter().sum::<T>();
109 }
110 None => {
111 *local_sum = Some(req.val());
112 }
113 }
114 }
115
116 Poll::Ready(local_sum.take().expect("Value should be Present"))
117 }
118 }
119 }
120}
121#[pin_project(PinnedDrop)]
123pub struct LocalIterSumHandle<T> {
124 array: UnsafeArrayInner,
125 launched: bool,
126 #[pin]
127 state: State<T>,
128}
129
130#[pinned_drop]
131impl<T> PinnedDrop for LocalIterSumHandle<T> {
132 fn drop(self: Pin<&mut Self>) {
133 if !self.launched {
134 let mut this = self.project();
135 RuntimeWarning::disable_warnings();
136 *this.state = State::Dropped;
137 RuntimeWarning::enable_warnings();
138 RuntimeWarning::DroppedHandle("a LocalIterSumHandle").print();
139 }
140 }
141}
142
143impl<T> LocalIterSumHandle<T>
144where
145 T: SyncSend + for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T> + 'static,
146{
147 pub(crate) fn new(
148 lock: Option<IterLockFuture>,
149 inner: Pin<Box<dyn Future<Output = InnerLocalIterSumHandle<T>> + Send>>,
150 array: &UnsafeArrayInner,
151 ) -> Self {
152 Self {
153 array: array.clone(),
154 launched: false,
155 state: State::Init(lock, inner),
156 }
157 }
158
159 pub fn block(mut self) -> T {
161 self.launched = true;
162 RuntimeWarning::BlockingCall(
163 "LocalIterSumHandle::block",
164 "<handle>.spawn() or <handle>.await",
165 )
166 .print();
167 self.array.clone().block_on(self)
168 }
169 #[must_use = "this function returns a future used to poll for completion and retrieve the result. Call '.await' on the future otherwise, if it is ignored (via ' let _ = *.spawn()') or dropped the only way to ensure completion is calling 'wait_all()' on the world or array. Alternatively it may be acceptable to call '.block()' instead of 'spawn()'"]
174 pub fn spawn(mut self) -> LamellarTask<T> {
175 self.launched = true;
176
177 self.array.clone().spawn(self)
178 }
179}
180
181#[pin_project(project = StateProj)]
182enum State<T> {
183 Init(
184 Option<IterLockFuture>,
185 Pin<Box<dyn Future<Output = InnerLocalIterSumHandle<T>> + Send>>,
186 ),
187 Reqs(#[pin] InnerLocalIterSumHandle<T>),
188 Dropped,
189}
190impl<T> Future for LocalIterSumHandle<T>
191where
192 T: SyncSend + for<'a> std::iter::Sum<&'a T> + std::iter::Sum<T> + 'static,
193{
194 type Output = T;
195 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
196 self.launched = true;
197 let mut this = self.project();
198 match this.state.as_mut().project() {
199 StateProj::Init(lock, inner) => {
200 if let Some(lock) = lock {
201 ready!(lock.as_mut().poll(cx));
202 }
203 let mut inner = ready!(Future::poll(inner.as_mut(), cx));
204 match Pin::new(&mut inner).poll(cx) {
205 Poll::Ready(val) => Poll::Ready(val),
206 Poll::Pending => {
207 *this.state = State::Reqs(inner);
208 Poll::Pending
209 }
210 }
211 }
212 StateProj::Reqs(inner) => {
213 let val = ready!(inner.poll(cx));
214 Poll::Ready(val)
215 }
216 StateProj::Dropped => panic!("called `Future::poll()` on a dropped future."),
217 }
218 }
219}
220
221#[lamellar_impl::AmLocalDataRT(Clone)]
222pub(crate) struct SumAm<I> {
223 pub(crate) iter: Sum<I>,
224 pub(crate) schedule: IterSchedule,
225}
226
227impl<I: InnerIter> InnerIter for SumAm<I> {
228 fn lock_if_needed(&self, _s: Sealed) -> Option<IterLockFuture> {
229 None
230 }
231 fn iter_clone(&self, _s: Sealed) -> Self {
232 SumAm {
233 iter: self.iter.iter_clone(Sealed),
234 schedule: self.schedule.clone(),
235 }
236 }
237}
238
239#[lamellar_impl::rt_am_local]
240impl<I> LamellarAm for SumAm<I>
241where
242 I: LocalIterator + 'static,
243 I::Item: SyncSend + for<'a> std::iter::Sum<&'a I::Item> + std::iter::Sum<I::Item>,
244{
245 async fn exec(&self) -> I::Item {
246 let iter = self.schedule.init_iter(self.iter.iter_clone(Sealed));
247 iter.sum::<I::Item>()
248 }
249}