Skip to main content

async_flag/
lib.rs

1use std::{
2    pin::Pin,
3    sync::{
4        mpsc::{self, Receiver, Sender},
5        Arc, Mutex, TryLockError,
6    },
7};
8
9use atomic::{Atomic, Ordering};
10use futures::{
11    future::Future,
12    task::{Context, Poll, Waker},
13};
14use pin_project::pin_project;
15
16type WrappedWaker = Arc<Mutex<Option<Waker>>>;
17
18/// The error type returned when the setter is dropped.
19#[derive(Debug, thiserror::Error, PartialEq)]
20#[error("Setter dropped without setting the flag")]
21pub struct SetterDropped;
22
23#[derive(Copy, Clone, PartialEq)]
24enum State {
25    NotSet,
26    Set,
27    Dropped,
28}
29
30struct SetterInner {
31    f: Arc<Atomic<State>>,
32    waiters: Receiver<WrappedWaker>,
33}
34
35/// The setting half of the flag.  Setting the flag will wake all `Waiter`'s.
36pub struct Setter {
37    i: Option<SetterInner>,
38}
39
40/// A cloneable waiter implementing `Future` with an `Output` type of `Result<(), SetterDropped>`
41/// that will become ready when the associated `Setter` is set or dropped.
42#[pin_project]
43pub struct Waiter {
44    f: Arc<Atomic<State>>,
45    wait_sender: Sender<WrappedWaker>,
46    waiter: Option<WrappedWaker>,
47}
48
49/// Create a `Setter`, `Waiter` pair.  The `Waiter` can be cloned any number of times.
50pub fn flag() -> (Setter, Waiter) {
51    let f = Arc::new(Atomic::new(State::NotSet));
52    let (wait_sender, waiters) = mpsc::channel();
53    (
54        Setter {
55            i: Some(SetterInner {
56                f: f.clone(),
57                waiters,
58            }),
59        },
60        Waiter {
61            f,
62            wait_sender,
63            waiter: None,
64        },
65    )
66}
67
68impl State {
69    fn to_poll(self) -> Poll<Result<(), SetterDropped>> {
70        match self {
71            State::NotSet => Poll::Pending,
72            State::Set => Poll::Ready(Ok(())),
73            State::Dropped => Poll::Ready(Err(SetterDropped {})),
74        }
75    }
76}
77
78impl Setter {
79    /// Set the flag and wake all `Waiter`s that are waiting on it.
80    pub fn set(mut self) {
81        self.i
82            .take()
83            .expect("Inner missing, should be impossible")
84            .set_state(State::Set)
85    }
86}
87
88impl SetterInner {
89    fn set_state(self, state: State) {
90        self.f.store(state, Ordering::Release);
91        for waiter in self.waiters.try_iter() {
92            match waiter.try_lock() {
93                Ok(mut w) => w.take().expect("Empty option, should be impossible").wake(),
94                Err(TryLockError::WouldBlock) => (), // They'll check state again before returning
95                Err(TryLockError::Poisoned(_)) => panic!("Lock was poisoned, should be impossible"),
96            }
97        }
98    }
99}
100
101impl Waiter {
102    /// Check if the flag is currently set.  This only returns the current value, and if it's not
103    /// set there's no guarantees for how long it will stay unset, it may even be set by the time
104    /// the function returns.
105    pub fn is_set(&self) -> bool {
106        self.f.load(Ordering::Acquire) == State::Set
107    }
108
109    /// Check if the flag was dropped.  This only returns the current value, and if it's not
110    /// set there's no guarantees for how long it will stay unset, it may even be set by the time
111    /// the function returns.
112    pub fn is_dropped(&self) -> bool {
113        self.f.load(Ordering::Acquire) == State::Dropped
114    }
115
116    /// Check if the flag was set to some value, either by being set or by being dropped.  This
117    /// only returns the current value, and if it's not set there's no guarantees for how long it
118    /// will stay unset, it may even be set by the time the function returns.
119    pub fn is_finished(&self) -> bool {
120        self.f.load(Ordering::Acquire) != State::NotSet
121    }
122}
123
124impl Drop for Setter {
125    fn drop(&mut self) {
126        if let Some(i) = self.i.take() {
127            i.set_state(State::Dropped)
128        }
129    }
130}
131
132impl Future for Waiter {
133    type Output = Result<(), SetterDropped>;
134
135    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
136        let this = self.project();
137        match this.f.load(Ordering::Acquire).to_poll() {
138            Poll::Ready(r) => return Poll::Ready(r),
139            Poll::Pending => (),
140        }
141
142        if let Some(waiter) = this.waiter {
143            match waiter.try_lock() {
144                Ok(mut w) => *w = Some(cx.waker().clone()),
145                Err(TryLockError::WouldBlock) => (), // We've raced with the Setter, check the state again.
146                Err(TryLockError::Poisoned(_)) => panic!("Lock was poisoned, should be impossible"),
147            }
148        } else {
149            let waiter = Arc::new(Mutex::new(Some(cx.waker().clone())));
150            *this.waiter = Some(waiter.clone());
151            let _ = this.wait_sender.send(waiter);
152        }
153
154        this.f.load(Ordering::Acquire).to_poll()
155    }
156}
157
158impl Clone for Waiter {
159    fn clone(&self) -> Self {
160        Waiter {
161            f: self.f.clone(),
162            wait_sender: self.wait_sender.clone(),
163            waiter: None,
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    use std::time::Duration;
173
174    use futures::pin_mut;
175
176    #[tokio::test(core_threads = 4)]
177    async fn test_simple() {
178        let (set, wait) = flag();
179
180        set.set();
181
182        assert_eq!(wait.await, Ok(()));
183    }
184
185    #[tokio::test(core_threads = 4)]
186    async fn test_dropped() {
187        let (set, wait) = flag();
188
189        drop(set);
190
191        assert_eq!(wait.await, Err(SetterDropped {}));
192    }
193
194    #[tokio::test(core_threads = 4)]
195    async fn test_multiple() {
196        let (set, wait) = flag();
197
198        let handles: Vec<_> = (0..10)
199            .map(|_| {
200                let w = wait.clone();
201                tokio::spawn(async move { w.await.unwrap() })
202            })
203            .collect();
204
205        tokio::time::delay_for(Duration::from_millis(100)).await;
206
207        set.set();
208
209        for h in handles.into_iter() {
210            pin_mut!(h);
211            h.await.unwrap()
212        }
213
214        assert_eq!(wait.await, Ok(()));
215    }
216
217    #[pin_project]
218    struct AlwaysWake<T> {
219        #[pin]
220        t: T,
221    }
222
223    impl<T: Future> Future for AlwaysWake<T> {
224        type Output = T::Output;
225        fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
226            let this = self.project();
227            let r = this.t.poll(cx);
228            if r.is_pending() {
229                cx.waker().wake_by_ref();
230            }
231            r
232        }
233    }
234
235    #[tokio::test(core_threads = 4)]
236    async fn test_racing() {
237        for _ in 0..10 {
238            let (set, wait) = flag();
239
240            let handles: Vec<_> = (0..50)
241                .map(|_| {
242                    let w = wait.clone();
243                    tokio::spawn(AlwaysWake {
244                        t: async move { w.await.unwrap() },
245                    })
246                })
247                .collect();
248
249            tokio::time::delay_for(Duration::from_millis(10)).await;
250
251            set.set();
252
253            for h in handles.into_iter() {
254                pin_mut!(h);
255                h.await.unwrap()
256            }
257
258            assert_eq!(wait.await, Ok(()));
259        }
260    }
261}