wnf/wait_async.rs
1//! Methods for asynchronously waiting for state updates
2
3#![deny(unsafe_code)]
4
5use std::borrow::Borrow;
6use std::future::Future;
7use std::io;
8use std::pin::Pin;
9use std::sync::{Arc, Mutex};
10use std::task::{Context, Poll, Waker};
11
12use crate::data::OpaqueData;
13use crate::predicate::{ChangedPredicate, Predicate, PredicateStage};
14use crate::read::Read;
15use crate::state::{BorrowedState, OwnedState, RawState};
16use crate::subscribe::{DataAccessor, SeenChangeStamp, StateListener, Subscription};
17
18impl<T> OwnedState<T>
19where
20 T: ?Sized,
21{
22 /// Waits until this state is updated
23 ///
24 /// This waits for *any* update to the state regardless of the value, even if the value is the same as the previous
25 /// one. In order to wait until the state data satisfy a certain condition, use
26 /// [`wait_until_async`](OwnedState::wait_until_async).
27 ///
28 /// Use this method if you want to wait for a state update *once*. In order to execute some logic on every state
29 /// update, use the [`subscribe`](OwnedState::subscribe) method.
30 ///
31 /// This is an async method. If you are in a sync context, use [`wait_blocking`](OwnedState::wait_blocking).
32 ///
33 /// This method does not make any assumptions on what async executor you use. Note that in contrast to
34 /// [`wait_blocking`](OwnedState::wait_blocking), it does not expect a timeout as an argument. In order to
35 /// implement a timeout, wrap it in the appropriate helper function provided by your executor. For instance,
36 /// with [`tokio`](https://docs.rs/tokio/1/tokio/), use
37 /// [`tokio::time::timeout`](https://docs.rs/tokio/1/tokio/time/fn.timeout.html):
38 /// ```
39 /// # #[tokio::main]
40 /// # async fn main() {
41 /// use std::io::{self, ErrorKind};
42 /// use std::time::Duration;
43 ///
44 /// use tokio::time;
45 /// use wnf::OwnedState;
46 ///
47 /// async fn wait() -> io::Result<()> {
48 /// let state = OwnedState::<u32>::create_temporary()?;
49 /// time::timeout(Duration::from_millis(100), state.wait_async()).await?
50 /// }
51 ///
52 /// let result = wait().await;
53 /// assert!(result.is_err());
54 /// assert_eq!(result.unwrap_err().kind(), ErrorKind::TimedOut);
55 /// # }
56 /// ```
57 ///
58 /// The returned future is [`Send`] and thus can be used with multi-threaded executors.
59 ///
60 /// # Errors
61 /// Returns an error if querying, subscribing to or unsubscribing from the state fails
62 pub fn wait_async(&self) -> Wait<'_> {
63 self.raw.wait_async()
64 }
65}
66
67impl<T> OwnedState<T>
68where
69 T: Read<T>,
70{
71 /// Waits until the data of this state satisfy a given predicate, returning the data
72 ///
73 /// This returns immediately if the current data already satisfy the predicate. Otherwise, it waits until the state
74 /// is updated with data that satisfy the predicate. If you want to unconditionally wait until the state is updated,
75 /// use [`wait_async`](OwnedState::wait_async).
76 ///
77 /// This returns the data for which the predicate returned `true`, causing the wait to finish. It produces an owned
78 /// `T` on the stack and hence requires `T: Sized`. In order to produce a `Box<T>` for `T: ?Sized`, use the
79 /// [`wait_until_boxed_async`](OwnedState::wait_until_boxed_async) method.
80 ///
81 /// For example, to wait until the value of a state reaches a given minimum:
82 /// ```
83 /// use std::error::Error;
84 /// use std::sync::Arc;
85 /// use std::time::Duration;
86 /// use std::{io, thread};
87 ///
88 /// use tokio::time;
89 /// use wnf::{AsState, OwnedState};
90 ///
91 /// async fn wait_until_at_least<S>(state: S, min_value: u32) -> io::Result<u32>
92 /// where
93 /// S: AsState<Data = u32>,
94 /// {
95 /// state.as_state().wait_until_async(|value| *value >= min_value).await
96 /// }
97 ///
98 /// #[tokio::main]
99 /// async fn main() -> Result<(), Box<dyn Error>> {
100 /// let state = Arc::new(OwnedState::create_temporary()?);
101 /// state.set(&0)?;
102 ///
103 /// {
104 /// let state = Arc::clone(&state);
105 /// tokio::spawn(async move {
106 /// loop {
107 /// state.apply(|value| value + 1).unwrap();
108 /// time::sleep(Duration::from_millis(10)).await;
109 /// }
110 /// });
111 /// }
112 ///
113 /// let value = wait_until_at_least(&state, 10).await?;
114 /// assert!(value >= 10);
115 ///
116 /// Ok(())
117 /// }
118 /// ```
119 ///
120 /// This is an async method. If you are in a sync context, use
121 /// [`wait_until_blocking`](OwnedState::wait_until_blocking).
122 ///
123 /// This method does not make any assumptions on what async executor you use. Note that in contrast to
124 /// [`wait_until_blocking`](OwnedState::wait_until_blocking), it does not expect a timeout as an argument. In order
125 /// to implement a timeout, wrap it in the appropriate helper function provided by your executor. For instance,
126 /// with [`tokio`](https://docs.rs/tokio/1/tokio/), use
127 /// [`tokio::time::timeout`](https://docs.rs/tokio/1/tokio/time/fn.timeout.html):
128 /// ```
129 /// # #[tokio::main]
130 /// # async fn main() {
131 /// use std::io::{self, ErrorKind};
132 /// use std::time::Duration;
133 ///
134 /// use tokio::time;
135 /// use wnf::OwnedState;
136 ///
137 /// async fn wait() -> io::Result<u32> {
138 /// let state = OwnedState::<u32>::create_temporary()?;
139 /// state.set(&42)?;
140 /// time::timeout(Duration::from_millis(100), state.wait_until_async(|_| false)).await?
141 /// }
142 ///
143 /// let result = wait().await;
144 /// assert!(result.is_err());
145 /// assert_eq!(result.unwrap_err().kind(), ErrorKind::TimedOut);
146 /// # }
147 /// ```
148 ///
149 /// If the predicate type `F` is [`Send`], the returned future is [`Send`] and thus can be used with multi-threaded
150 /// executors. Otherwise you may be able to use constructs such as tokio's
151 /// [`LocalSet`](https://docs.rs/tokio/1/tokio/task/struct.LocalSet.html).
152 ///
153 /// # Errors
154 /// Returns an error if querying, subscribing to or unsubscribing from the state fails
155 pub fn wait_until_async<F>(&self, predicate: F) -> WaitUntil<'_, T, F>
156 where
157 F: FnMut(&T) -> bool,
158 {
159 self.raw.wait_until_async(predicate)
160 }
161}
162
163impl<T> OwnedState<T>
164where
165 T: Read<Box<T>> + ?Sized,
166{
167 /// Waits until the data of this state satisfy a given predicate, returning the data as a box
168 ///
169 /// This returns immediately if the current data already satisfy the predicate. Otherwise, it waits until the state
170 /// is updated with data that satisfy the predicate. If you want to unconditionally wait until the state is updated,
171 /// use [`wait_async`](OwnedState::wait_async).
172 ///
173 /// This returns the data for which the predicate returned `true`, causing the wait to finish. It produces a
174 /// [`Box<T>`]. In order to produce an owned `T` on the stack (requiring `T: Sized`), use the
175 /// [`wait_until_async`](OwnedState::wait_until_async) method.
176 ///
177 /// For example, to wait until the length of a slice reaches a given minimum:
178 /// ```
179 /// use std::error::Error;
180 /// use std::sync::Arc;
181 /// use std::time::Duration;
182 /// use std::{io, thread};
183 ///
184 /// use tokio::time;
185 /// use wnf::{AsState, OwnedState};
186 ///
187 /// async fn wait_until_len_at_least<S>(state: S, min_len: usize) -> io::Result<usize>
188 /// where
189 /// S: AsState<Data = [u32]>,
190 /// {
191 /// state
192 /// .as_state()
193 /// .wait_until_boxed_async(|slice| slice.len() >= min_len)
194 /// .await
195 /// .map(|slice| slice.len())
196 /// }
197 ///
198 /// #[tokio::main]
199 /// async fn main() -> Result<(), Box<dyn Error>> {
200 /// let state = Arc::new(OwnedState::<[u32]>::create_temporary()?);
201 /// state.set(&[])?;
202 ///
203 /// {
204 /// let state = Arc::clone(&state);
205 /// tokio::spawn(async move {
206 /// loop {
207 /// state
208 /// .apply_boxed(|slice| {
209 /// let mut vec = slice.into_vec();
210 /// vec.push(0);
211 /// vec
212 /// })
213 /// .unwrap();
214 ///
215 /// time::sleep(Duration::from_millis(10)).await;
216 /// }
217 /// });
218 /// }
219 ///
220 /// let len = wait_until_len_at_least(&state, 10).await?;
221 /// assert!(len >= 10);
222 ///
223 /// Ok(())
224 /// }
225 /// ```
226 ///
227 /// This is an async method. If you are in a sync context, use
228 /// [`wait_until_boxed_blocking`](OwnedState::wait_until_boxed_blocking).
229 ///
230 /// This method does not make any assumptions on what async executor you use. Note that in contrast to
231 /// [`wait_until_boxed_blocking`](OwnedState::wait_until_boxed_blocking), it does not expect a timeout as an
232 /// argument. In order to implement a timeout, wrap it in the appropriate helper function provided by your
233 /// executor. For instance, with [`tokio`](https://docs.rs/tokio/1/tokio), use
234 /// [`tokio::time::timeout`](https://docs.rs/tokio/1/tokio/time/fn.timeout.html):
235 /// ```
236 /// # #[tokio::main]
237 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
238 /// use std::error::Error;
239 /// use std::io::{self, ErrorKind};
240 /// use std::time::Duration;
241 ///
242 /// use tokio::time;
243 /// use wnf::OwnedState;
244 ///
245 /// async fn wait() -> io::Result<Box<[u32]>> {
246 /// let state = OwnedState::<[u32]>::create_temporary()?;
247 /// state.set(&[])?;
248 /// time::timeout(Duration::from_millis(100), state.wait_until_boxed_async(|_| false)).await?
249 /// }
250 ///
251 /// let result = wait().await;
252 /// assert!(result.is_err());
253 /// assert_eq!(result.unwrap_err().kind(), ErrorKind::TimedOut);
254 /// # Ok(()) }
255 /// ```
256 ///
257 /// If the predicate type `F` is [`Send`], the returned future is [`Send`] and thus can be used with multi-threaded
258 /// executors. Otherwise you may be able to use constructs such as tokio's
259 /// [`LocalSet`](https://docs.rs/tokio/1/tokio/task/struct.LocalSet.html).
260 ///
261 /// # Errors
262 /// Returns an error if querying, subscribing to or unsubscribing from the state fails
263 pub fn wait_until_boxed_async<F>(&self, predicate: F) -> WaitUntilBoxed<'_, T, F>
264 where
265 F: FnMut(&T) -> bool,
266 {
267 self.raw.wait_until_boxed_async(predicate)
268 }
269}
270
271impl<'a, T> BorrowedState<'a, T>
272where
273 T: ?Sized,
274{
275 /// Waits until this state is updated
276 ///
277 /// See [`OwnedState::wait_async`]
278 pub fn wait_async(self) -> Wait<'a> {
279 self.raw.wait_async()
280 }
281}
282
283impl<'a, T> BorrowedState<'a, T>
284where
285 T: Read<T>,
286{
287 /// Waits until the data of this state satisfy a given predicate, returning the data
288 ///
289 /// See [`OwnedState::wait_until_async`]
290 pub fn wait_until_async<F>(self, predicate: F) -> WaitUntil<'a, T, F>
291 where
292 F: FnMut(&T) -> bool,
293 {
294 self.raw.wait_until_async(predicate)
295 }
296}
297
298impl<'a, T> BorrowedState<'a, T>
299where
300 T: Read<Box<T>> + ?Sized,
301{
302 /// Waits until the data of this state satisfy a given predicate, returning the data as a box
303 ///
304 /// See [`OwnedState::wait_until_boxed_async`]
305 pub fn wait_until_boxed_async<F>(self, predicate: F) -> WaitUntilBoxed<'a, T, F>
306 where
307 F: FnMut(&T) -> bool,
308 {
309 self.raw.wait_until_boxed_async(predicate)
310 }
311}
312
313impl<T> RawState<T>
314where
315 T: ?Sized,
316{
317 /// Waits until this state is updated
318 fn wait_async<'a>(self) -> Wait<'a> {
319 Wait::new(self)
320 }
321}
322
323impl<T> RawState<T>
324where
325 T: Read<T>,
326{
327 /// Waits until the data of this state satisfy a given predicate, returning the data
328 fn wait_until_async<'a, F>(self, predicate: F) -> WaitUntil<'a, T, F>
329 where
330 F: FnMut(&T) -> bool,
331 {
332 WaitUntil::new(self, predicate)
333 }
334}
335
336impl<T> RawState<T>
337where
338 T: Read<Box<T>> + ?Sized,
339{
340 /// Waits until the data of this state satisfy a given predicate, returning the data as a box
341 fn wait_until_boxed_async<'a, F>(self, predicate: F) -> WaitUntilBoxed<'a, T, F>
342 where
343 F: FnMut(&T) -> bool,
344 {
345 WaitUntilBoxed::new(self, predicate)
346 }
347}
348
349/// The future returned by [`wait_async`](`OwnedState::wait_async`) methods
350#[derive(Debug)]
351#[must_use = "futures do nothing unless you `.await` or poll them"]
352pub struct Wait<'a> {
353 inner: WaitUntilInternal<'a, OpaqueData, OpaqueData, ChangedPredicate>,
354}
355
356impl Wait<'_> {
357 /// Creates a new [`Wait<'_>`] future for the given raw state
358 const fn new<T>(state: RawState<T>) -> Self
359 where
360 T: ?Sized,
361 {
362 Self {
363 inner: WaitUntilInternal::new(state.cast(), ChangedPredicate),
364 }
365 }
366}
367
368impl Future for Wait<'_> {
369 type Output = io::Result<()>;
370
371 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
372 let inner_pinned = Pin::new(&mut self.get_mut().inner);
373 inner_pinned.poll(cx).map_ok(|_| ())
374 }
375}
376
377/// The future returned by [`wait_until_async`](`OwnedState::wait_until_async`) methods
378#[derive(Debug)]
379#[must_use = "futures do nothing unless you `.await` or poll them"]
380pub struct WaitUntil<'a, T, F> {
381 inner: WaitUntilInternal<'a, T, T, F>,
382}
383
384impl<F, T> WaitUntil<'_, T, F> {
385 /// Creates a new [`WaitUntil<'_, T, F>`] future for the given raw state and predicate
386 const fn new(state: RawState<T>, predicate: F) -> Self {
387 Self {
388 inner: WaitUntilInternal::new(state, predicate),
389 }
390 }
391}
392
393impl<F, T> Future for WaitUntil<'_, T, F>
394where
395 F: FnMut(&T) -> bool,
396 T: Read<T>,
397{
398 type Output = io::Result<T>;
399
400 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
401 let inner_pinned = Pin::new(&mut self.get_mut().inner);
402 inner_pinned.poll(cx)
403 }
404}
405
406/// The future returned by [`wait_until_boxed_async`](`OwnedState::wait_until_boxed_async`) methods
407#[derive(Debug)]
408#[must_use = "futures do nothing unless you `.await` or poll them"]
409pub struct WaitUntilBoxed<'a, T, F>
410where
411 T: ?Sized,
412{
413 inner: WaitUntilInternal<'a, T, Box<T>, F>,
414}
415
416impl<F, T> WaitUntilBoxed<'_, T, F>
417where
418 T: ?Sized,
419{
420 /// Creates a new [`WaitUntilBoxed<'_, T, F>`](WaitUntilBoxed) future for the given raw state and predicate
421 const fn new(state: RawState<T>, predicate: F) -> Self {
422 Self {
423 inner: WaitUntilInternal::new(state, predicate),
424 }
425 }
426}
427
428impl<F, T> Future for WaitUntilBoxed<'_, T, F>
429where
430 F: FnMut(&T) -> bool,
431 T: Read<Box<T>> + ?Sized,
432{
433 type Output = io::Result<Box<T>>;
434
435 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
436 let inner_pinned = Pin::new(&mut self.get_mut().inner);
437 inner_pinned.poll(cx)
438 }
439}
440
441/// Future generalizing the behavior of [`Wait<'_>`](Wait), [`WaitUntil<'_, T, F>`](WaitUntil) and [`WaitUntilBoxed<'_,
442/// T, F>`](WaitUntilBoxed)
443#[derive(Debug)]
444#[must_use = "futures do nothing unless you `.await` or poll them"]
445struct WaitUntilInternal<'a, T, D, F>
446where
447 T: ?Sized,
448{
449 future_state: Option<FutureState<'a, T, D, F>>,
450}
451
452// This is not auto-implemented because `F` might be `!Unpin`
453// We can implement it manually because `F` is never pinned, i.e. pinning is non-structural for `F`
454// See <https://doc.rust-lang.org/std/pin/index.html#pinning-is-not-structural-for-field>
455impl<D, F, T> Unpin for WaitUntilInternal<'_, T, D, F> where T: ?Sized {}
456
457/// State of the [`WaitUntilInternal<'a, T, D, F>`](WaitUntilInternal) future
458#[derive(Debug)]
459enum FutureState<'a, T, D, F>
460where
461 T: ?Sized,
462{
463 /// Future has not been polled
464 Initial { state: RawState<T>, predicate: F },
465
466 /// Future is waiting for state update
467 Waiting {
468 predicate: F,
469 shared_state: Arc<Mutex<SharedState<D>>>,
470 subscription: Subscription<'a, WaitListener<D>>,
471 },
472}
473
474/// Shared state between the polling thread and the waking thread
475#[derive(Debug)]
476struct SharedState<D> {
477 result: Option<io::Result<D>>,
478 waker: Waker,
479}
480
481impl<D> SharedState<D> {
482 /// Creates a new [`SharedState<D>`] from the given waker
483 const fn from_waker(waker: Waker) -> Self {
484 Self { result: None, waker }
485 }
486}
487
488impl<D, F, T> WaitUntilInternal<'_, T, D, F>
489where
490 T: ?Sized,
491{
492 /// Creates a new [`WaitUntilInternal<'_, T, D, F>`](WaitUntilInternal) future for the given raw state and predicate
493 const fn new(state: RawState<T>, predicate: F) -> Self {
494 Self {
495 future_state: Some(FutureState::Initial { state, predicate }),
496 }
497 }
498}
499
500impl<D, F, T> Future for WaitUntilInternal<'_, T, D, F>
501where
502 D: Borrow<T> + Send + 'static,
503 F: Predicate<T>,
504 T: Read<D> + ?Sized,
505{
506 type Output = io::Result<D>;
507
508 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
509 self.future_state = Some(
510 match self.future_state.take().expect("future polled after it has completed") {
511 FutureState::Initial { state, mut predicate } => {
512 let (data, change_stamp) = state.query_as()?.into_data_change_stamp();
513
514 if predicate.check(data.borrow(), PredicateStage::Initial) {
515 return Poll::Ready(Ok(data));
516 }
517
518 let shared_state = Arc::new(Mutex::new(SharedState::from_waker(cx.waker().clone())));
519 let subscription = state.subscribe(
520 WaitListener::new(Arc::clone(&shared_state)),
521 SeenChangeStamp::Value(change_stamp),
522 )?;
523
524 FutureState::Waiting {
525 predicate,
526 shared_state,
527 subscription,
528 }
529 }
530
531 FutureState::Waiting {
532 mut predicate,
533 shared_state,
534 subscription,
535 } => {
536 let mut guard = shared_state.lock().unwrap();
537 let SharedState { result, waker } = &mut *guard;
538
539 let ready_result = match result.take() {
540 Some(Ok(data)) if !predicate.check(data.borrow(), PredicateStage::Changed) => None,
541 None => None,
542 result => result,
543 };
544
545 match ready_result {
546 Some(result) => {
547 subscription.unsubscribe()?;
548 return Poll::Ready(Ok(result?));
549 }
550
551 None => {
552 if !waker.will_wake(cx.waker()) {
553 waker.clone_from(cx.waker());
554 }
555 }
556 }
557
558 drop(guard);
559
560 FutureState::Waiting {
561 predicate,
562 shared_state,
563 subscription,
564 }
565 }
566 },
567 );
568
569 Poll::Pending
570 }
571}
572
573/// State listener that saves the result of accessing the state data and wakes a waker
574///
575/// This is a type that can be named rather than an anonymous closure type so that it can be stored in a
576/// [`FutureState<'_, T, D, F>`](FutureState) without using a trait object.
577#[derive(Debug)]
578struct WaitListener<D> {
579 shared_state: Arc<Mutex<SharedState<D>>>,
580}
581
582impl<D> WaitListener<D> {
583 /// Creates a new [`WaitListener<D>`] with the given shared state
584 const fn new(shared_state: Arc<Mutex<SharedState<D>>>) -> Self {
585 Self { shared_state }
586 }
587}
588
589impl<D, T> StateListener<T> for WaitListener<D>
590where
591 D: Send + 'static,
592 T: Read<D> + ?Sized,
593{
594 fn call(&mut self, accessor: DataAccessor<'_, T>) {
595 let SharedState { result, ref waker } = &mut *self.shared_state.lock().unwrap();
596 *result = Some(accessor.get_as());
597 waker.wake_by_ref();
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 #![allow(dead_code)]
604
605 use std::cell::Cell;
606 use std::sync::MutexGuard;
607
608 use static_assertions::{assert_impl_all, assert_not_impl_any};
609
610 use super::*;
611
612 #[test]
613 fn wait_future_is_send_and_sync() {
614 assert_impl_all!(Wait<'_>: Send, Sync);
615 }
616
617 #[test]
618 fn wait_until_future_is_send_if_predicate_and_data_type_are_send() {
619 type SendNotSync = Cell<()>;
620 assert_impl_all!(SendNotSync: Send);
621 assert_not_impl_any!(SendNotSync: Sync);
622
623 assert_impl_all!(WaitUntil<'_, SendNotSync, SendNotSync>: Send);
624 }
625
626 #[test]
627 fn wait_until_future_is_sync_if_predicate_is_sync_and_data_type_is_send() {
628 type SyncNotSend = MutexGuard<'static, ()>;
629 assert_impl_all!(SyncNotSend: Sync);
630 assert_not_impl_any!(SyncNotSend: Send);
631
632 type SendNotSync = Cell<()>;
633 assert_impl_all!(SendNotSync: Send);
634 assert_not_impl_any!(SendNotSync: Sync);
635
636 assert_impl_all!(WaitUntil<'_, SendNotSync, SyncNotSend>: Sync);
637 }
638
639 #[test]
640 fn wait_until_boxed_future_is_send_if_predicate_and_data_type_are_send() {
641 type SendNotSync = Cell<()>;
642 assert_impl_all!(SendNotSync: Send);
643 assert_not_impl_any!(SendNotSync: Sync);
644
645 assert_impl_all!(WaitUntilBoxed<'_, SendNotSync, SendNotSync>: Send);
646 }
647
648 #[test]
649 fn wait_until_boxed_future_is_sync_if_predicate_is_sync_and_data_type_is_send() {
650 type SyncNotSend = MutexGuard<'static, ()>;
651 assert_impl_all!(SyncNotSend: Sync);
652 assert_not_impl_any!(SyncNotSend: Send);
653
654 type SendNotSync = Cell<()>;
655 assert_impl_all!(SendNotSync: Send);
656 assert_not_impl_any!(SendNotSync: Sync);
657
658 assert_impl_all!(WaitUntilBoxed<'_, SendNotSync, SyncNotSend>: Sync);
659 }
660}