1use std::{
2 pin::Pin,
3 sync::{
4 mpsc::{self, Receiver, Sender},
5 Arc, Mutex, TryLockError,
6 },
7};
8
9use atomic::{Atomic, Ordering};
10use futures::{
11 future::Future,
12 task::{Context, Poll, Waker},
13};
14use pin_project::pin_project;
15
16type WrappedWaker = Arc<Mutex<Option<Waker>>>;
17
18#[derive(Debug, thiserror::Error, PartialEq)]
20#[error("Setter dropped without setting the flag")]
21pub struct SetterDropped;
22
23#[derive(Copy, Clone, PartialEq)]
24enum State {
25 NotSet,
26 Set,
27 Dropped,
28}
29
30struct SetterInner {
31 f: Arc<Atomic<State>>,
32 waiters: Receiver<WrappedWaker>,
33}
34
35pub struct Setter {
37 i: Option<SetterInner>,
38}
39
40#[pin_project]
43pub struct Waiter {
44 f: Arc<Atomic<State>>,
45 wait_sender: Sender<WrappedWaker>,
46 waiter: Option<WrappedWaker>,
47}
48
49pub fn flag() -> (Setter, Waiter) {
51 let f = Arc::new(Atomic::new(State::NotSet));
52 let (wait_sender, waiters) = mpsc::channel();
53 (
54 Setter {
55 i: Some(SetterInner {
56 f: f.clone(),
57 waiters,
58 }),
59 },
60 Waiter {
61 f,
62 wait_sender,
63 waiter: None,
64 },
65 )
66}
67
68impl State {
69 fn to_poll(self) -> Poll<Result<(), SetterDropped>> {
70 match self {
71 State::NotSet => Poll::Pending,
72 State::Set => Poll::Ready(Ok(())),
73 State::Dropped => Poll::Ready(Err(SetterDropped {})),
74 }
75 }
76}
77
78impl Setter {
79 pub fn set(mut self) {
81 self.i
82 .take()
83 .expect("Inner missing, should be impossible")
84 .set_state(State::Set)
85 }
86}
87
88impl SetterInner {
89 fn set_state(self, state: State) {
90 self.f.store(state, Ordering::Release);
91 for waiter in self.waiters.try_iter() {
92 match waiter.try_lock() {
93 Ok(mut w) => w.take().expect("Empty option, should be impossible").wake(),
94 Err(TryLockError::WouldBlock) => (), Err(TryLockError::Poisoned(_)) => panic!("Lock was poisoned, should be impossible"),
96 }
97 }
98 }
99}
100
101impl Waiter {
102 pub fn is_set(&self) -> bool {
106 self.f.load(Ordering::Acquire) == State::Set
107 }
108
109 pub fn is_dropped(&self) -> bool {
113 self.f.load(Ordering::Acquire) == State::Dropped
114 }
115
116 pub fn is_finished(&self) -> bool {
120 self.f.load(Ordering::Acquire) != State::NotSet
121 }
122}
123
124impl Drop for Setter {
125 fn drop(&mut self) {
126 if let Some(i) = self.i.take() {
127 i.set_state(State::Dropped)
128 }
129 }
130}
131
132impl Future for Waiter {
133 type Output = Result<(), SetterDropped>;
134
135 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
136 let this = self.project();
137 match this.f.load(Ordering::Acquire).to_poll() {
138 Poll::Ready(r) => return Poll::Ready(r),
139 Poll::Pending => (),
140 }
141
142 if let Some(waiter) = this.waiter {
143 match waiter.try_lock() {
144 Ok(mut w) => *w = Some(cx.waker().clone()),
145 Err(TryLockError::WouldBlock) => (), Err(TryLockError::Poisoned(_)) => panic!("Lock was poisoned, should be impossible"),
147 }
148 } else {
149 let waiter = Arc::new(Mutex::new(Some(cx.waker().clone())));
150 *this.waiter = Some(waiter.clone());
151 let _ = this.wait_sender.send(waiter);
152 }
153
154 this.f.load(Ordering::Acquire).to_poll()
155 }
156}
157
158impl Clone for Waiter {
159 fn clone(&self) -> Self {
160 Waiter {
161 f: self.f.clone(),
162 wait_sender: self.wait_sender.clone(),
163 waiter: None,
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 use std::time::Duration;
173
174 use futures::pin_mut;
175
176 #[tokio::test(core_threads = 4)]
177 async fn test_simple() {
178 let (set, wait) = flag();
179
180 set.set();
181
182 assert_eq!(wait.await, Ok(()));
183 }
184
185 #[tokio::test(core_threads = 4)]
186 async fn test_dropped() {
187 let (set, wait) = flag();
188
189 drop(set);
190
191 assert_eq!(wait.await, Err(SetterDropped {}));
192 }
193
194 #[tokio::test(core_threads = 4)]
195 async fn test_multiple() {
196 let (set, wait) = flag();
197
198 let handles: Vec<_> = (0..10)
199 .map(|_| {
200 let w = wait.clone();
201 tokio::spawn(async move { w.await.unwrap() })
202 })
203 .collect();
204
205 tokio::time::delay_for(Duration::from_millis(100)).await;
206
207 set.set();
208
209 for h in handles.into_iter() {
210 pin_mut!(h);
211 h.await.unwrap()
212 }
213
214 assert_eq!(wait.await, Ok(()));
215 }
216
217 #[pin_project]
218 struct AlwaysWake<T> {
219 #[pin]
220 t: T,
221 }
222
223 impl<T: Future> Future for AlwaysWake<T> {
224 type Output = T::Output;
225 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
226 let this = self.project();
227 let r = this.t.poll(cx);
228 if r.is_pending() {
229 cx.waker().wake_by_ref();
230 }
231 r
232 }
233 }
234
235 #[tokio::test(core_threads = 4)]
236 async fn test_racing() {
237 for _ in 0..10 {
238 let (set, wait) = flag();
239
240 let handles: Vec<_> = (0..50)
241 .map(|_| {
242 let w = wait.clone();
243 tokio::spawn(AlwaysWake {
244 t: async move { w.await.unwrap() },
245 })
246 })
247 .collect();
248
249 tokio::time::delay_for(Duration::from_millis(10)).await;
250
251 set.set();
252
253 for h in handles.into_iter() {
254 pin_mut!(h);
255 h.await.unwrap()
256 }
257
258 assert_eq!(wait.await, Ok(()));
259 }
260 }
261}