agnostic_lite/async_std/
after.rs

1use core::{
2  pin::Pin,
3  sync::atomic::Ordering,
4  task::{Context, Poll},
5};
6
7use std::sync::Arc;
8
9use async_std::channel::{
10  mpsc::{unbounded, UnboundedSender},
11  oneshot::{channel, Sender},
12};
13use atomic_time::AtomicOptionDuration;
14use futures_util::{FutureExt, StreamExt};
15
16use crate::{
17  spawner::{AfterHandle, AfterHandleSignals, Canceled},
18  time::AsyncSleep,
19  AfterHandleError, AsyncAfterSpawner,
20};
21
22use super::{super::RuntimeLite, *};
23
24pub(crate) struct Resetter {
25  duration: Arc<AtomicOptionDuration>,
26  tx: UnboundedSender<()>,
27}
28
29impl Resetter {
30  pub(crate) fn new(duration: Arc<AtomicOptionDuration>, tx: UnboundedSender<()>) -> Self {
31    Self { duration, tx }
32  }
33
34  pub(crate) fn reset(&self, duration: Duration) {
35    self.duration.store(Some(duration), Ordering::Release);
36  }
37}
38
39macro_rules! spawn_after {
40  ($spawn:ident, $sleep:ident($trait:ident) -> ($instant:ident, $future:ident)) => {{
41    let (tx, rx) = channel::<()>();
42    let (abort_tx, abort_rx) = channel::<()>();
43    let signals = Arc::new(AfterHandleSignals::new());
44    let (reset_tx, mut reset_rx) = unbounded::<()>();
45    let duration = Arc::new(AtomicOptionDuration::none());
46    let resetter = Resetter::new(duration.clone(), reset_tx);
47    let s1 = signals.clone();
48    let h = AsyncStdRuntime::$spawn(async move {
49      let delay = AsyncStdRuntime::$sleep($instant);
50      let future = $future.fuse();
51      futures_util::pin_mut!(delay);
52      futures_util::pin_mut!(rx);
53      futures_util::pin_mut!(abort_rx);
54      futures_util::pin_mut!(future);
55      loop {
56        futures_util::select_biased! {
57          res = abort_rx => {
58            if res.is_ok() {
59              return Err(Canceled);
60            }
61            delay.await;
62            let res = future.await;
63            s1.set_finished();
64            return Ok(res);
65          }
66          res = rx => {
67            if res.is_ok() {
68              return Err(Canceled);
69            }
70
71            delay.await;
72            let res = future.await;
73            s1.set_finished();
74            return Ok(res);
75          }
76          res = reset_rx.next() => {
77            if res.is_none() {
78              delay.await;
79              let res = future.await;
80              s1.set_finished();
81              return Ok(res);
82            }
83
84            if let Some(d) = duration.load(Ordering::Acquire) {
85              if $instant.checked_sub(d).is_some() {
86                s1.set_expired();
87
88                futures_util::select_biased! {
89                  res = &mut future => {
90                    s1.set_finished();
91                    return Ok(res);
92                  }
93                  canceled = &mut rx => {
94                    if canceled.is_ok() {
95                      return Err(Canceled);
96                    }
97                    delay.await;
98                    s1.set_expired();
99                    let res = future.await;
100                    s1.set_finished();
101                    return Ok(res);
102                  }
103                }
104              }
105
106              match $instant.checked_sub(d) {
107                Some(v) => {
108                  $trait::reset(delay.as_mut(), v);
109                },
110                None => {
111                  match d.checked_sub($instant.elapsed()) {
112                    Some(v) => {
113                      $trait::reset(delay.as_mut(), Instant::now() + v);
114                    },
115                    None => {
116                      s1.set_expired();
117
118                      futures_util::select_biased! {
119                        res = &mut future => {
120                          s1.set_finished();
121                          return Ok(res);
122                        }
123                        canceled = &mut rx => {
124                          if canceled.is_ok() {
125                            return Err(Canceled);
126                          }
127                          delay.await;
128                          s1.set_expired();
129                          let res = future.await;
130                          s1.set_finished();
131                          return Ok(res);
132                        }
133                      }
134                    },
135                  }
136                },
137              }
138            }
139          }
140          _ = delay.as_mut().fuse() => {
141            s1.set_expired();
142            futures_util::select_biased! {
143              res = abort_rx => {
144                if res.is_ok() {
145                  return Err(Canceled);
146                }
147                let res = future.await;
148                s1.set_finished();
149                return Ok(res);
150              }
151              res = rx => {
152                if res.is_ok() {
153                  return Err(Canceled);
154                }
155                let res = future.await;
156                s1.set_finished();
157                return Ok(res);
158              }
159              res = future => {
160                s1.set_finished();
161                return Ok(res);
162              }
163            }
164          }
165        }
166      }
167    });
168
169    AsyncStdAfterHandle {
170      handle: h,
171      resetter,
172      signals,
173      abort_tx,
174      tx,
175    }
176  }};
177}
178
179/// The handle return by [`RuntimeLite::spawn_after`] or [`RuntimeLite::spawn_after_at`]
180#[pin_project::pin_project]
181pub struct AsyncStdAfterHandle<O>
182where
183  O: 'static,
184{
185  #[pin]
186  handle: JoinHandle<Result<O, Canceled>>,
187  signals: Arc<AfterHandleSignals>,
188  resetter: Resetter,
189  abort_tx: Sender<()>,
190  tx: Sender<()>,
191}
192
193impl<O: 'static> Future for AsyncStdAfterHandle<O> {
194  type Output = Result<O, AfterHandleError<JoinError>>;
195
196  fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197    let this = self.project();
198    match this.handle.poll(cx) {
199      Poll::Ready(v) => match v {
200        Ok(v) => Poll::Ready(v.map_err(|_| AfterHandleError::Canceled)),
201        Err(_) => Poll::Ready(Err(AfterHandleError::Canceled)),
202      },
203      Poll::Pending => Poll::Pending,
204    }
205  }
206}
207
208impl<O> AfterHandle<O> for AsyncStdAfterHandle<O>
209where
210  O: Send + 'static,
211{
212  type JoinError = AfterHandleError<JoinError>;
213
214  async fn cancel(self) -> Option<Result<O, Self::JoinError>> {
215    if AfterHandle::is_finished(&self) {
216      return Some(self.handle.await.map_err(AfterHandleError::Join)
217      .and_then(|v| v.map_err(|_| AfterHandleError::Canceled)));
218    }
219
220    let _ = self.tx.send(());
221    None
222  }
223
224  fn reset(&self, duration: Duration) {
225    self.resetter.reset(duration);
226    let _ = self.resetter.tx.unbounded_send(());
227  }
228
229  #[inline]
230  fn abort(self) {
231    let _ = self.tx.send(());
232  }
233
234  #[inline]
235  fn is_expired(&self) -> bool {
236    self.signals.is_expired()
237  }
238
239  #[inline]
240  fn is_finished(&self) -> bool {
241    self.signals.is_finished()
242  }
243}
244
245impl AsyncAfterSpawner for AsyncStdSpawner {
246  type Instant = Instant;
247  type JoinHandle<F>
248    = AsyncStdAfterHandle<F>
249  where
250    F: Send + 'static;
251
252  fn spawn_after<F>(duration: core::time::Duration, future: F) -> Self::JoinHandle<F::Output>
253  where
254    F::Output: Send + 'static,
255    F: Future + Send + 'static,
256  {
257    Self::spawn_after_at(Instant::now() + duration, future)
258  }
259
260  fn spawn_after_at<F>(instant: Instant, future: F) -> Self::JoinHandle<F::Output>
261  where
262    F::Output: Send + 'static,
263    F: Future + Send + 'static,
264  {
265    spawn_after!(spawn, sleep_until(AsyncSleep) -> (instant, future))
266  }
267}
268
269#[cfg(test)]
270mod tests {
271  use super::*;
272
273  #[test]
274  fn test_after_handle() {
275    futures::executor::block_on(async {
276      crate::tests::spawn_after_unittest::<AsyncStdRuntime>().await;
277    });
278  }
279
280  #[test]
281  fn test_after_drop() {
282    futures::executor::block_on(async {
283      crate::tests::spawn_after_drop_unittest::<AsyncStdRuntime>().await;
284    });
285  }
286
287  #[test]
288  fn test_after_cancel() {
289    futures::executor::block_on(async {
290      crate::tests::spawn_after_cancel_unittest::<AsyncStdRuntime>().await;
291    });
292  }
293
294  #[test]
295  fn test_after_abort() {
296    futures::executor::block_on(async {
297      crate::tests::spawn_after_abort_unittest::<AsyncStdRuntime>().await;
298    });
299  }
300
301  #[test]
302  fn test_after_reset_to_pass() {
303    futures::executor::block_on(async {
304      crate::tests::spawn_after_reset_to_pass_unittest::<AsyncStdRuntime>().await;
305    });
306  }
307
308  #[test]
309  fn test_after_reset_to_future() {
310    futures::executor::block_on(async {
311      crate::tests::spawn_after_reset_to_future_unittest::<AsyncStdRuntime>().await;
312    });
313  }
314}