wnf/
wait_async.rs

1//! Methods for asynchronously waiting for state updates
2
3#![deny(unsafe_code)]
4
5use std::borrow::Borrow;
6use std::future::Future;
7use std::io;
8use std::pin::Pin;
9use std::sync::{Arc, Mutex};
10use std::task::{Context, Poll, Waker};
11
12use crate::data::OpaqueData;
13use crate::predicate::{ChangedPredicate, Predicate, PredicateStage};
14use crate::read::Read;
15use crate::state::{BorrowedState, OwnedState, RawState};
16use crate::subscribe::{DataAccessor, SeenChangeStamp, StateListener, Subscription};
17
18impl<T> OwnedState<T>
19where
20    T: ?Sized,
21{
22    /// Waits until this state is updated
23    ///
24    /// This waits for *any* update to the state regardless of the value, even if the value is the same as the previous
25    /// one. In order to wait until the state data satisfy a certain condition, use
26    /// [`wait_until_async`](OwnedState::wait_until_async).
27    ///
28    /// Use this method if you want to wait for a state update *once*. In order to execute some logic on every state
29    /// update, use the [`subscribe`](OwnedState::subscribe) method.
30    ///
31    /// This is an async method. If you are in a sync context, use [`wait_blocking`](OwnedState::wait_blocking).
32    ///
33    /// This method does not make any assumptions on what async executor you use. Note that in contrast to
34    /// [`wait_blocking`](OwnedState::wait_blocking), it does not expect a timeout as an argument. In order to
35    /// implement a timeout, wrap it in the appropriate helper function provided by your executor. For instance,
36    /// with [`tokio`](https://docs.rs/tokio/1/tokio/), use
37    /// [`tokio::time::timeout`](https://docs.rs/tokio/1/tokio/time/fn.timeout.html):
38    /// ```
39    /// # #[tokio::main]
40    /// # async fn main() {
41    /// use std::io::{self, ErrorKind};
42    /// use std::time::Duration;
43    ///
44    /// use tokio::time;
45    /// use wnf::OwnedState;
46    ///
47    /// async fn wait() -> io::Result<()> {
48    ///     let state = OwnedState::<u32>::create_temporary()?;
49    ///     time::timeout(Duration::from_millis(100), state.wait_async()).await?
50    /// }
51    ///
52    /// let result = wait().await;
53    /// assert!(result.is_err());
54    /// assert_eq!(result.unwrap_err().kind(), ErrorKind::TimedOut);
55    /// # }
56    /// ```
57    ///
58    /// The returned future is [`Send`] and thus can be used with multi-threaded executors.
59    ///
60    /// # Errors
61    /// Returns an error if querying, subscribing to or unsubscribing from the state fails
62    pub fn wait_async(&self) -> Wait<'_> {
63        self.raw.wait_async()
64    }
65}
66
67impl<T> OwnedState<T>
68where
69    T: Read<T>,
70{
71    /// Waits until the data of this state satisfy a given predicate, returning the data
72    ///
73    /// This returns immediately if the current data already satisfy the predicate. Otherwise, it waits until the state
74    /// is updated with data that satisfy the predicate. If you want to unconditionally wait until the state is updated,
75    /// use [`wait_async`](OwnedState::wait_async).
76    ///
77    /// This returns the data for which the predicate returned `true`, causing the wait to finish. It produces an owned
78    /// `T` on the stack and hence requires `T: Sized`. In order to produce a `Box<T>` for `T: ?Sized`, use the
79    /// [`wait_until_boxed_async`](OwnedState::wait_until_boxed_async) method.
80    ///
81    /// For example, to wait until the value of a state reaches a given minimum:
82    /// ```
83    /// use std::error::Error;
84    /// use std::sync::Arc;
85    /// use std::time::Duration;
86    /// use std::{io, thread};
87    ///
88    /// use tokio::time;
89    /// use wnf::{AsState, OwnedState};
90    ///
91    /// async fn wait_until_at_least<S>(state: S, min_value: u32) -> io::Result<u32>
92    /// where
93    ///     S: AsState<Data = u32>,
94    /// {
95    ///     state.as_state().wait_until_async(|value| *value >= min_value).await
96    /// }
97    ///
98    /// #[tokio::main]
99    /// async fn main() -> Result<(), Box<dyn Error>> {
100    ///     let state = Arc::new(OwnedState::create_temporary()?);
101    ///     state.set(&0)?;
102    ///
103    ///     {
104    ///         let state = Arc::clone(&state);
105    ///         tokio::spawn(async move {
106    ///             loop {
107    ///                 state.apply(|value| value + 1).unwrap();
108    ///                 time::sleep(Duration::from_millis(10)).await;
109    ///             }
110    ///         });
111    ///     }
112    ///
113    ///     let value = wait_until_at_least(&state, 10).await?;
114    ///     assert!(value >= 10);
115    ///
116    ///     Ok(())
117    /// }
118    /// ```
119    ///
120    /// This is an async method. If you are in a sync context, use
121    /// [`wait_until_blocking`](OwnedState::wait_until_blocking).
122    ///
123    /// This method does not make any assumptions on what async executor you use. Note that in contrast to
124    /// [`wait_until_blocking`](OwnedState::wait_until_blocking), it does not expect a timeout as an argument. In order
125    /// to implement a timeout, wrap it in the appropriate helper function provided by your executor. For instance,
126    /// with [`tokio`](https://docs.rs/tokio/1/tokio/), use
127    /// [`tokio::time::timeout`](https://docs.rs/tokio/1/tokio/time/fn.timeout.html):
128    /// ```
129    /// # #[tokio::main]
130    /// # async fn main() {
131    /// use std::io::{self, ErrorKind};
132    /// use std::time::Duration;
133    ///
134    /// use tokio::time;
135    /// use wnf::OwnedState;
136    ///
137    /// async fn wait() -> io::Result<u32> {
138    ///     let state = OwnedState::<u32>::create_temporary()?;
139    ///     state.set(&42)?;
140    ///     time::timeout(Duration::from_millis(100), state.wait_until_async(|_| false)).await?
141    /// }
142    ///
143    /// let result = wait().await;
144    /// assert!(result.is_err());
145    /// assert_eq!(result.unwrap_err().kind(), ErrorKind::TimedOut);
146    /// # }
147    /// ```
148    ///
149    /// If the predicate type `F` is [`Send`], the returned future is [`Send`] and thus can be used with multi-threaded
150    /// executors. Otherwise you may be able to use constructs such as tokio's
151    /// [`LocalSet`](https://docs.rs/tokio/1/tokio/task/struct.LocalSet.html).
152    ///
153    /// # Errors
154    /// Returns an error if querying, subscribing to or unsubscribing from the state fails
155    pub fn wait_until_async<F>(&self, predicate: F) -> WaitUntil<'_, T, F>
156    where
157        F: FnMut(&T) -> bool,
158    {
159        self.raw.wait_until_async(predicate)
160    }
161}
162
163impl<T> OwnedState<T>
164where
165    T: Read<Box<T>> + ?Sized,
166{
167    /// Waits until the data of this state satisfy a given predicate, returning the data as a box
168    ///
169    /// This returns immediately if the current data already satisfy the predicate. Otherwise, it waits until the state
170    /// is updated with data that satisfy the predicate. If you want to unconditionally wait until the state is updated,
171    /// use [`wait_async`](OwnedState::wait_async).
172    ///
173    /// This returns the data for which the predicate returned `true`, causing the wait to finish. It produces a
174    /// [`Box<T>`]. In order to produce an owned `T` on the stack (requiring `T: Sized`), use the
175    /// [`wait_until_async`](OwnedState::wait_until_async) method.
176    ///
177    /// For example, to wait until the length of a slice reaches a given minimum:
178    /// ```
179    /// use std::error::Error;
180    /// use std::sync::Arc;
181    /// use std::time::Duration;
182    /// use std::{io, thread};
183    ///
184    /// use tokio::time;
185    /// use wnf::{AsState, OwnedState};
186    ///
187    /// async fn wait_until_len_at_least<S>(state: S, min_len: usize) -> io::Result<usize>
188    /// where
189    ///     S: AsState<Data = [u32]>,
190    /// {
191    ///     state
192    ///         .as_state()
193    ///         .wait_until_boxed_async(|slice| slice.len() >= min_len)
194    ///         .await
195    ///         .map(|slice| slice.len())
196    /// }
197    ///
198    /// #[tokio::main]
199    /// async fn main() -> Result<(), Box<dyn Error>> {
200    ///     let state = Arc::new(OwnedState::<[u32]>::create_temporary()?);
201    ///     state.set(&[])?;
202    ///
203    ///     {
204    ///         let state = Arc::clone(&state);
205    ///         tokio::spawn(async move {
206    ///             loop {
207    ///                 state
208    ///                     .apply_boxed(|slice| {
209    ///                         let mut vec = slice.into_vec();
210    ///                         vec.push(0);
211    ///                         vec
212    ///                     })
213    ///                     .unwrap();
214    ///
215    ///                 time::sleep(Duration::from_millis(10)).await;
216    ///             }
217    ///         });
218    ///     }
219    ///
220    ///     let len = wait_until_len_at_least(&state, 10).await?;
221    ///     assert!(len >= 10);
222    ///
223    ///     Ok(())
224    /// }
225    /// ```
226    ///
227    /// This is an async method. If you are in a sync context, use
228    /// [`wait_until_boxed_blocking`](OwnedState::wait_until_boxed_blocking).
229    ///
230    /// This method does not make any assumptions on what async executor you use. Note that in contrast to
231    /// [`wait_until_boxed_blocking`](OwnedState::wait_until_boxed_blocking), it does not expect a timeout as an
232    /// argument. In order to implement a timeout, wrap it in the appropriate helper function provided by your
233    /// executor. For instance, with [`tokio`](https://docs.rs/tokio/1/tokio), use
234    /// [`tokio::time::timeout`](https://docs.rs/tokio/1/tokio/time/fn.timeout.html):
235    /// ```
236    /// # #[tokio::main]
237    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
238    /// use std::error::Error;
239    /// use std::io::{self, ErrorKind};
240    /// use std::time::Duration;
241    ///
242    /// use tokio::time;
243    /// use wnf::OwnedState;
244    ///
245    /// async fn wait() -> io::Result<Box<[u32]>> {
246    ///     let state = OwnedState::<[u32]>::create_temporary()?;
247    ///     state.set(&[])?;
248    ///     time::timeout(Duration::from_millis(100), state.wait_until_boxed_async(|_| false)).await?
249    /// }
250    ///
251    /// let result = wait().await;
252    /// assert!(result.is_err());
253    /// assert_eq!(result.unwrap_err().kind(), ErrorKind::TimedOut);
254    /// # Ok(()) }
255    /// ```
256    ///
257    /// If the predicate type `F` is [`Send`], the returned future is [`Send`] and thus can be used with multi-threaded
258    /// executors. Otherwise you may be able to use constructs such as tokio's
259    /// [`LocalSet`](https://docs.rs/tokio/1/tokio/task/struct.LocalSet.html).
260    ///
261    /// # Errors
262    /// Returns an error if querying, subscribing to or unsubscribing from the state fails
263    pub fn wait_until_boxed_async<F>(&self, predicate: F) -> WaitUntilBoxed<'_, T, F>
264    where
265        F: FnMut(&T) -> bool,
266    {
267        self.raw.wait_until_boxed_async(predicate)
268    }
269}
270
271impl<'a, T> BorrowedState<'a, T>
272where
273    T: ?Sized,
274{
275    /// Waits until this state is updated
276    ///
277    /// See [`OwnedState::wait_async`]
278    pub fn wait_async(self) -> Wait<'a> {
279        self.raw.wait_async()
280    }
281}
282
283impl<'a, T> BorrowedState<'a, T>
284where
285    T: Read<T>,
286{
287    /// Waits until the data of this state satisfy a given predicate, returning the data
288    ///
289    /// See [`OwnedState::wait_until_async`]
290    pub fn wait_until_async<F>(self, predicate: F) -> WaitUntil<'a, T, F>
291    where
292        F: FnMut(&T) -> bool,
293    {
294        self.raw.wait_until_async(predicate)
295    }
296}
297
298impl<'a, T> BorrowedState<'a, T>
299where
300    T: Read<Box<T>> + ?Sized,
301{
302    /// Waits until the data of this state satisfy a given predicate, returning the data as a box
303    ///
304    /// See [`OwnedState::wait_until_boxed_async`]
305    pub fn wait_until_boxed_async<F>(self, predicate: F) -> WaitUntilBoxed<'a, T, F>
306    where
307        F: FnMut(&T) -> bool,
308    {
309        self.raw.wait_until_boxed_async(predicate)
310    }
311}
312
313impl<T> RawState<T>
314where
315    T: ?Sized,
316{
317    /// Waits until this state is updated
318    fn wait_async<'a>(self) -> Wait<'a> {
319        Wait::new(self)
320    }
321}
322
323impl<T> RawState<T>
324where
325    T: Read<T>,
326{
327    /// Waits until the data of this state satisfy a given predicate, returning the data
328    fn wait_until_async<'a, F>(self, predicate: F) -> WaitUntil<'a, T, F>
329    where
330        F: FnMut(&T) -> bool,
331    {
332        WaitUntil::new(self, predicate)
333    }
334}
335
336impl<T> RawState<T>
337where
338    T: Read<Box<T>> + ?Sized,
339{
340    /// Waits until the data of this state satisfy a given predicate, returning the data as a box
341    fn wait_until_boxed_async<'a, F>(self, predicate: F) -> WaitUntilBoxed<'a, T, F>
342    where
343        F: FnMut(&T) -> bool,
344    {
345        WaitUntilBoxed::new(self, predicate)
346    }
347}
348
349/// The future returned by [`wait_async`](`OwnedState::wait_async`) methods
350#[derive(Debug)]
351#[must_use = "futures do nothing unless you `.await` or poll them"]
352pub struct Wait<'a> {
353    inner: WaitUntilInternal<'a, OpaqueData, OpaqueData, ChangedPredicate>,
354}
355
356impl Wait<'_> {
357    /// Creates a new [`Wait<'_>`] future for the given raw state
358    const fn new<T>(state: RawState<T>) -> Self
359    where
360        T: ?Sized,
361    {
362        Self {
363            inner: WaitUntilInternal::new(state.cast(), ChangedPredicate),
364        }
365    }
366}
367
368impl Future for Wait<'_> {
369    type Output = io::Result<()>;
370
371    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
372        let inner_pinned = Pin::new(&mut self.get_mut().inner);
373        inner_pinned.poll(cx).map_ok(|_| ())
374    }
375}
376
377/// The future returned by [`wait_until_async`](`OwnedState::wait_until_async`) methods
378#[derive(Debug)]
379#[must_use = "futures do nothing unless you `.await` or poll them"]
380pub struct WaitUntil<'a, T, F> {
381    inner: WaitUntilInternal<'a, T, T, F>,
382}
383
384impl<F, T> WaitUntil<'_, T, F> {
385    /// Creates a new [`WaitUntil<'_, T, F>`] future for the given raw state and predicate
386    const fn new(state: RawState<T>, predicate: F) -> Self {
387        Self {
388            inner: WaitUntilInternal::new(state, predicate),
389        }
390    }
391}
392
393impl<F, T> Future for WaitUntil<'_, T, F>
394where
395    F: FnMut(&T) -> bool,
396    T: Read<T>,
397{
398    type Output = io::Result<T>;
399
400    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
401        let inner_pinned = Pin::new(&mut self.get_mut().inner);
402        inner_pinned.poll(cx)
403    }
404}
405
406/// The future returned by [`wait_until_boxed_async`](`OwnedState::wait_until_boxed_async`) methods
407#[derive(Debug)]
408#[must_use = "futures do nothing unless you `.await` or poll them"]
409pub struct WaitUntilBoxed<'a, T, F>
410where
411    T: ?Sized,
412{
413    inner: WaitUntilInternal<'a, T, Box<T>, F>,
414}
415
416impl<F, T> WaitUntilBoxed<'_, T, F>
417where
418    T: ?Sized,
419{
420    /// Creates a new [`WaitUntilBoxed<'_, T, F>`](WaitUntilBoxed) future for the given raw state and predicate
421    const fn new(state: RawState<T>, predicate: F) -> Self {
422        Self {
423            inner: WaitUntilInternal::new(state, predicate),
424        }
425    }
426}
427
428impl<F, T> Future for WaitUntilBoxed<'_, T, F>
429where
430    F: FnMut(&T) -> bool,
431    T: Read<Box<T>> + ?Sized,
432{
433    type Output = io::Result<Box<T>>;
434
435    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
436        let inner_pinned = Pin::new(&mut self.get_mut().inner);
437        inner_pinned.poll(cx)
438    }
439}
440
441/// Future generalizing the behavior of [`Wait<'_>`](Wait), [`WaitUntil<'_, T, F>`](WaitUntil) and [`WaitUntilBoxed<'_,
442/// T, F>`](WaitUntilBoxed)
443#[derive(Debug)]
444#[must_use = "futures do nothing unless you `.await` or poll them"]
445struct WaitUntilInternal<'a, T, D, F>
446where
447    T: ?Sized,
448{
449    future_state: Option<FutureState<'a, T, D, F>>,
450}
451
452// This is not auto-implemented because `F` might be `!Unpin`
453// We can implement it manually because `F` is never pinned, i.e. pinning is non-structural for `F`
454// See <https://doc.rust-lang.org/std/pin/index.html#pinning-is-not-structural-for-field>
455impl<D, F, T> Unpin for WaitUntilInternal<'_, T, D, F> where T: ?Sized {}
456
457/// State of the [`WaitUntilInternal<'a, T, D, F>`](WaitUntilInternal) future
458#[derive(Debug)]
459enum FutureState<'a, T, D, F>
460where
461    T: ?Sized,
462{
463    /// Future has not been polled
464    Initial { state: RawState<T>, predicate: F },
465
466    /// Future is waiting for state update
467    Waiting {
468        predicate: F,
469        shared_state: Arc<Mutex<SharedState<D>>>,
470        subscription: Subscription<'a, WaitListener<D>>,
471    },
472}
473
474/// Shared state between the polling thread and the waking thread
475#[derive(Debug)]
476struct SharedState<D> {
477    result: Option<io::Result<D>>,
478    waker: Waker,
479}
480
481impl<D> SharedState<D> {
482    /// Creates a new [`SharedState<D>`] from the given waker
483    const fn from_waker(waker: Waker) -> Self {
484        Self { result: None, waker }
485    }
486}
487
488impl<D, F, T> WaitUntilInternal<'_, T, D, F>
489where
490    T: ?Sized,
491{
492    /// Creates a new [`WaitUntilInternal<'_, T, D, F>`](WaitUntilInternal) future for the given raw state and predicate
493    const fn new(state: RawState<T>, predicate: F) -> Self {
494        Self {
495            future_state: Some(FutureState::Initial { state, predicate }),
496        }
497    }
498}
499
500impl<D, F, T> Future for WaitUntilInternal<'_, T, D, F>
501where
502    D: Borrow<T> + Send + 'static,
503    F: Predicate<T>,
504    T: Read<D> + ?Sized,
505{
506    type Output = io::Result<D>;
507
508    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
509        self.future_state = Some(
510            match self.future_state.take().expect("future polled after it has completed") {
511                FutureState::Initial { state, mut predicate } => {
512                    let (data, change_stamp) = state.query_as()?.into_data_change_stamp();
513
514                    if predicate.check(data.borrow(), PredicateStage::Initial) {
515                        return Poll::Ready(Ok(data));
516                    }
517
518                    let shared_state = Arc::new(Mutex::new(SharedState::from_waker(cx.waker().clone())));
519                    let subscription = state.subscribe(
520                        WaitListener::new(Arc::clone(&shared_state)),
521                        SeenChangeStamp::Value(change_stamp),
522                    )?;
523
524                    FutureState::Waiting {
525                        predicate,
526                        shared_state,
527                        subscription,
528                    }
529                }
530
531                FutureState::Waiting {
532                    mut predicate,
533                    shared_state,
534                    subscription,
535                } => {
536                    let mut guard = shared_state.lock().unwrap();
537                    let SharedState { result, waker } = &mut *guard;
538
539                    let ready_result = match result.take() {
540                        Some(Ok(data)) if !predicate.check(data.borrow(), PredicateStage::Changed) => None,
541                        None => None,
542                        result => result,
543                    };
544
545                    match ready_result {
546                        Some(result) => {
547                            subscription.unsubscribe()?;
548                            return Poll::Ready(Ok(result?));
549                        }
550
551                        None => {
552                            if !waker.will_wake(cx.waker()) {
553                                waker.clone_from(cx.waker());
554                            }
555                        }
556                    }
557
558                    drop(guard);
559
560                    FutureState::Waiting {
561                        predicate,
562                        shared_state,
563                        subscription,
564                    }
565                }
566            },
567        );
568
569        Poll::Pending
570    }
571}
572
573/// State listener that saves the result of accessing the state data and wakes a waker
574///
575/// This is a type that can be named rather than an anonymous closure type so that it can be stored in a
576/// [`FutureState<'_, T, D, F>`](FutureState) without using a trait object.
577#[derive(Debug)]
578struct WaitListener<D> {
579    shared_state: Arc<Mutex<SharedState<D>>>,
580}
581
582impl<D> WaitListener<D> {
583    /// Creates a new [`WaitListener<D>`] with the given shared state
584    const fn new(shared_state: Arc<Mutex<SharedState<D>>>) -> Self {
585        Self { shared_state }
586    }
587}
588
589impl<D, T> StateListener<T> for WaitListener<D>
590where
591    D: Send + 'static,
592    T: Read<D> + ?Sized,
593{
594    fn call(&mut self, accessor: DataAccessor<'_, T>) {
595        let SharedState { result, ref waker } = &mut *self.shared_state.lock().unwrap();
596        *result = Some(accessor.get_as());
597        waker.wake_by_ref();
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    #![allow(dead_code)]
604
605    use std::cell::Cell;
606    use std::sync::MutexGuard;
607
608    use static_assertions::{assert_impl_all, assert_not_impl_any};
609
610    use super::*;
611
612    #[test]
613    fn wait_future_is_send_and_sync() {
614        assert_impl_all!(Wait<'_>: Send, Sync);
615    }
616
617    #[test]
618    fn wait_until_future_is_send_if_predicate_and_data_type_are_send() {
619        type SendNotSync = Cell<()>;
620        assert_impl_all!(SendNotSync: Send);
621        assert_not_impl_any!(SendNotSync: Sync);
622
623        assert_impl_all!(WaitUntil<'_, SendNotSync, SendNotSync>: Send);
624    }
625
626    #[test]
627    fn wait_until_future_is_sync_if_predicate_is_sync_and_data_type_is_send() {
628        type SyncNotSend = MutexGuard<'static, ()>;
629        assert_impl_all!(SyncNotSend: Sync);
630        assert_not_impl_any!(SyncNotSend: Send);
631
632        type SendNotSync = Cell<()>;
633        assert_impl_all!(SendNotSync: Send);
634        assert_not_impl_any!(SendNotSync: Sync);
635
636        assert_impl_all!(WaitUntil<'_, SendNotSync, SyncNotSend>: Sync);
637    }
638
639    #[test]
640    fn wait_until_boxed_future_is_send_if_predicate_and_data_type_are_send() {
641        type SendNotSync = Cell<()>;
642        assert_impl_all!(SendNotSync: Send);
643        assert_not_impl_any!(SendNotSync: Sync);
644
645        assert_impl_all!(WaitUntilBoxed<'_, SendNotSync, SendNotSync>: Send);
646    }
647
648    #[test]
649    fn wait_until_boxed_future_is_sync_if_predicate_is_sync_and_data_type_is_send() {
650        type SyncNotSend = MutexGuard<'static, ()>;
651        assert_impl_all!(SyncNotSend: Sync);
652        assert_not_impl_any!(SyncNotSend: Send);
653
654        type SendNotSync = Cell<()>;
655        assert_impl_all!(SendNotSync: Send);
656        assert_not_impl_any!(SendNotSync: Sync);
657
658        assert_impl_all!(WaitUntilBoxed<'_, SendNotSync, SyncNotSend>: Sync);
659    }
660}