drop_awaiter/
lib.rs

1use std::{
2    pin::Pin,
3    sync::{
4        atomic::{AtomicUsize, Ordering},
5        Arc,
6    },
7    task::{Context, Poll},
8};
9
10use futures::{task::AtomicWaker, Future};
11
12/// Function that spawns pair [`DropNotifier`] and [`DropAwaiter`]
13/// Clone and pass Notifier further into your code.
14/// Once all Notifiers will be dropped, they will notify the Awaiter that it is time to wake up
15///
16/// Usage:
17/// ```
18/// async fn foo() {
19///     let (notifier_1, awaiter) = drop_awaiter::new();
20///     
21///     let notifier_2 = notifier_1.clone();
22///
23///     std::thread::spawn(move || {
24///         // Perform task 1
25///         // ...
26///         drop(notifier_1);    
27///     });
28///     
29///     std::thread::spawn(move || {
30///         // Perform task 2    
31///         // ...
32///         drop(notifier_2);    
33///     });
34///     
35///
36///     awaiter.await
37/// }
38
39pub fn new() -> (DropNotifier, DropAwaiter) {
40    let state = Arc::new(State {
41        awaiter_waker: AtomicWaker::new(),
42        notifiers_count: AtomicUsize::new(1),
43    });
44
45    (
46        DropNotifier {
47            state: state.clone(),
48        },
49        DropAwaiter { state },
50    )
51}
52#[derive(Debug)]
53pub struct DropAwaiter {
54    state: Arc<State>,
55}
56
57#[derive(Debug)]
58pub struct DropNotifier {
59    state: Arc<State>,
60}
61
62impl Future for DropAwaiter {
63    type Output = ();
64
65    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
66        if self.state.notifiers_count.load(Ordering::SeqCst) == 0 {
67            Poll::Ready(())
68        } else {
69            self.state.awaiter_waker.register(cx.waker());
70            Poll::Pending
71        }
72    }
73}
74
75impl Clone for DropNotifier {
76    fn clone(&self) -> Self {
77        self.state.notifiers_count.fetch_add(1, Ordering::Relaxed);
78
79        Self {
80            state: self.state.clone(),
81        }
82    }
83}
84
85impl Drop for DropNotifier {
86    fn drop(&mut self) {
87        if self.state.notifiers_count.fetch_sub(1, Ordering::AcqRel) != 1 {
88            return;
89        }
90
91        self.state.awaiter_waker.wake();
92    }
93}
94
95#[derive(Debug)]
96struct State {
97    notifiers_count: AtomicUsize,
98    awaiter_waker: AtomicWaker,
99}
100
101#[cfg(test)]
102mod tests {
103    use futures::future::{self};
104    use std::{sync::atomic::Ordering, time::Duration};
105    use tokio::pin;
106
107    #[tokio::test]
108    async fn test_awaiter() {
109        let (notifier_1, awaiter) = crate::new();
110        let notifier_2 = notifier_1.clone();
111        let notifier_3 = notifier_2.clone();
112
113        assert_eq!(3, awaiter.state.notifiers_count.load(Ordering::SeqCst));
114
115        drop(notifier_1);
116        drop(notifier_3);
117
118        assert_eq!(1, awaiter.state.notifiers_count.load(Ordering::SeqCst));
119
120        let sleep_fut = tokio::time::sleep(Duration::from_millis(1000));
121        pin!(sleep_fut);
122
123        match future::select(awaiter, sleep_fut).await {
124            future::Either::Left((_, _)) => panic!("Awaiter must not complete before sleep"),
125            future::Either::Right((_, awaiter)) => {
126                drop(notifier_2);
127                awaiter.await
128            }
129        };
130    }
131}