agnostic_lite/async_std/
after.rs1use core::{
2 pin::Pin,
3 sync::atomic::Ordering,
4 task::{Context, Poll},
5};
6
7use std::sync::Arc;
8
9use async_std::channel::{
10 mpsc::{unbounded, UnboundedSender},
11 oneshot::{channel, Sender},
12};
13use atomic_time::AtomicOptionDuration;
14use futures_util::{FutureExt, StreamExt};
15
16use crate::{
17 spawner::{AfterHandle, AfterHandleSignals, Canceled},
18 time::AsyncSleep,
19 AfterHandleError, AsyncAfterSpawner,
20};
21
22use super::{super::RuntimeLite, *};
23
24pub(crate) struct Resetter {
25 duration: Arc<AtomicOptionDuration>,
26 tx: UnboundedSender<()>,
27}
28
29impl Resetter {
30 pub(crate) fn new(duration: Arc<AtomicOptionDuration>, tx: UnboundedSender<()>) -> Self {
31 Self { duration, tx }
32 }
33
34 pub(crate) fn reset(&self, duration: Duration) {
35 self.duration.store(Some(duration), Ordering::Release);
36 }
37}
38
39macro_rules! spawn_after {
40 ($spawn:ident, $sleep:ident($trait:ident) -> ($instant:ident, $future:ident)) => {{
41 let (tx, rx) = channel::<()>();
42 let (abort_tx, abort_rx) = channel::<()>();
43 let signals = Arc::new(AfterHandleSignals::new());
44 let (reset_tx, mut reset_rx) = unbounded::<()>();
45 let duration = Arc::new(AtomicOptionDuration::none());
46 let resetter = Resetter::new(duration.clone(), reset_tx);
47 let s1 = signals.clone();
48 let h = AsyncStdRuntime::$spawn(async move {
49 let delay = AsyncStdRuntime::$sleep($instant);
50 let future = $future.fuse();
51 futures_util::pin_mut!(delay);
52 futures_util::pin_mut!(rx);
53 futures_util::pin_mut!(abort_rx);
54 futures_util::pin_mut!(future);
55 loop {
56 futures_util::select_biased! {
57 res = abort_rx => {
58 if res.is_ok() {
59 return Err(Canceled);
60 }
61 delay.await;
62 let res = future.await;
63 s1.set_finished();
64 return Ok(res);
65 }
66 res = rx => {
67 if res.is_ok() {
68 return Err(Canceled);
69 }
70
71 delay.await;
72 let res = future.await;
73 s1.set_finished();
74 return Ok(res);
75 }
76 res = reset_rx.next() => {
77 if res.is_none() {
78 delay.await;
79 let res = future.await;
80 s1.set_finished();
81 return Ok(res);
82 }
83
84 if let Some(d) = duration.load(Ordering::Acquire) {
85 if $instant.checked_sub(d).is_some() {
86 s1.set_expired();
87
88 futures_util::select_biased! {
89 res = &mut future => {
90 s1.set_finished();
91 return Ok(res);
92 }
93 canceled = &mut rx => {
94 if canceled.is_ok() {
95 return Err(Canceled);
96 }
97 delay.await;
98 s1.set_expired();
99 let res = future.await;
100 s1.set_finished();
101 return Ok(res);
102 }
103 }
104 }
105
106 match $instant.checked_sub(d) {
107 Some(v) => {
108 $trait::reset(delay.as_mut(), v);
109 },
110 None => {
111 match d.checked_sub($instant.elapsed()) {
112 Some(v) => {
113 $trait::reset(delay.as_mut(), Instant::now() + v);
114 },
115 None => {
116 s1.set_expired();
117
118 futures_util::select_biased! {
119 res = &mut future => {
120 s1.set_finished();
121 return Ok(res);
122 }
123 canceled = &mut rx => {
124 if canceled.is_ok() {
125 return Err(Canceled);
126 }
127 delay.await;
128 s1.set_expired();
129 let res = future.await;
130 s1.set_finished();
131 return Ok(res);
132 }
133 }
134 },
135 }
136 },
137 }
138 }
139 }
140 _ = delay.as_mut().fuse() => {
141 s1.set_expired();
142 futures_util::select_biased! {
143 res = abort_rx => {
144 if res.is_ok() {
145 return Err(Canceled);
146 }
147 let res = future.await;
148 s1.set_finished();
149 return Ok(res);
150 }
151 res = rx => {
152 if res.is_ok() {
153 return Err(Canceled);
154 }
155 let res = future.await;
156 s1.set_finished();
157 return Ok(res);
158 }
159 res = future => {
160 s1.set_finished();
161 return Ok(res);
162 }
163 }
164 }
165 }
166 }
167 });
168
169 AsyncStdAfterHandle {
170 handle: h,
171 resetter,
172 signals,
173 abort_tx,
174 tx,
175 }
176 }};
177}
178
179#[pin_project::pin_project]
181pub struct AsyncStdAfterHandle<O>
182where
183 O: 'static,
184{
185 #[pin]
186 handle: JoinHandle<Result<O, Canceled>>,
187 signals: Arc<AfterHandleSignals>,
188 resetter: Resetter,
189 abort_tx: Sender<()>,
190 tx: Sender<()>,
191}
192
193impl<O: 'static> Future for AsyncStdAfterHandle<O> {
194 type Output = Result<O, AfterHandleError<JoinError>>;
195
196 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197 let this = self.project();
198 match this.handle.poll(cx) {
199 Poll::Ready(v) => match v {
200 Ok(v) => Poll::Ready(v.map_err(|_| AfterHandleError::Canceled)),
201 Err(_) => Poll::Ready(Err(AfterHandleError::Canceled)),
202 },
203 Poll::Pending => Poll::Pending,
204 }
205 }
206}
207
208impl<O> AfterHandle<O> for AsyncStdAfterHandle<O>
209where
210 O: Send + 'static,
211{
212 type JoinError = AfterHandleError<JoinError>;
213
214 async fn cancel(self) -> Option<Result<O, Self::JoinError>> {
215 if AfterHandle::is_finished(&self) {
216 return Some(self.handle.await.map_err(AfterHandleError::Join)
217 .and_then(|v| v.map_err(|_| AfterHandleError::Canceled)));
218 }
219
220 let _ = self.tx.send(());
221 None
222 }
223
224 fn reset(&self, duration: Duration) {
225 self.resetter.reset(duration);
226 let _ = self.resetter.tx.unbounded_send(());
227 }
228
229 #[inline]
230 fn abort(self) {
231 let _ = self.tx.send(());
232 }
233
234 #[inline]
235 fn is_expired(&self) -> bool {
236 self.signals.is_expired()
237 }
238
239 #[inline]
240 fn is_finished(&self) -> bool {
241 self.signals.is_finished()
242 }
243}
244
245impl AsyncAfterSpawner for AsyncStdSpawner {
246 type Instant = Instant;
247 type JoinHandle<F>
248 = AsyncStdAfterHandle<F>
249 where
250 F: Send + 'static;
251
252 fn spawn_after<F>(duration: core::time::Duration, future: F) -> Self::JoinHandle<F::Output>
253 where
254 F::Output: Send + 'static,
255 F: Future + Send + 'static,
256 {
257 Self::spawn_after_at(Instant::now() + duration, future)
258 }
259
260 fn spawn_after_at<F>(instant: Instant, future: F) -> Self::JoinHandle<F::Output>
261 where
262 F::Output: Send + 'static,
263 F: Future + Send + 'static,
264 {
265 spawn_after!(spawn, sleep_until(AsyncSleep) -> (instant, future))
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_after_handle() {
275 futures::executor::block_on(async {
276 crate::tests::spawn_after_unittest::<AsyncStdRuntime>().await;
277 });
278 }
279
280 #[test]
281 fn test_after_drop() {
282 futures::executor::block_on(async {
283 crate::tests::spawn_after_drop_unittest::<AsyncStdRuntime>().await;
284 });
285 }
286
287 #[test]
288 fn test_after_cancel() {
289 futures::executor::block_on(async {
290 crate::tests::spawn_after_cancel_unittest::<AsyncStdRuntime>().await;
291 });
292 }
293
294 #[test]
295 fn test_after_abort() {
296 futures::executor::block_on(async {
297 crate::tests::spawn_after_abort_unittest::<AsyncStdRuntime>().await;
298 });
299 }
300
301 #[test]
302 fn test_after_reset_to_pass() {
303 futures::executor::block_on(async {
304 crate::tests::spawn_after_reset_to_pass_unittest::<AsyncStdRuntime>().await;
305 });
306 }
307
308 #[test]
309 fn test_after_reset_to_future() {
310 futures::executor::block_on(async {
311 crate::tests::spawn_after_reset_to_future_unittest::<AsyncStdRuntime>().await;
312 });
313 }
314}