async_resource/util/
dropshot.rs

1use std::cell::UnsafeCell;
2use std::fmt;
3use std::future::Future;
4use std::mem::MaybeUninit;
5use std::pin::Pin;
6use std::sync::{
7    atomic::{AtomicU8, Ordering},
8    Arc,
9};
10use std::task::{Context, Poll, Waker};
11use std::thread;
12
13use option_lock::OptionLock;
14
15use super::thread_waker;
16
17/// Alternative version of futures::oneshot
18/// In this case poll_cancelled is not available. It could be added
19/// at the expense of another waker per message. This could also be used
20/// to confirm delivery of a message and pull it back out on failure.
21
22const INIT: u8 = 0;
23const LOAD: u8 = 1;
24const READY: u8 = 2;
25const SENT: u8 = 3;
26const CANCEL: u8 = 4;
27
28#[derive(Clone, Copy, PartialEq, Eq, Debug)]
29pub struct Canceled;
30
31impl fmt::Display for Canceled {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        write!(f, "dropshot canceled")
34    }
35}
36
37impl std::error::Error for Canceled {}
38
39pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
40    let inner = Arc::new(Inner::new());
41    let receiver = Receiver {
42        inner: inner.clone(),
43    };
44    let sender = Sender { inner };
45    (sender, receiver)
46}
47
48struct Inner<T> {
49    data: UnsafeCell<MaybeUninit<T>>,
50    recv_waker: OptionLock<Waker>,
51    state: AtomicU8,
52}
53
54unsafe impl<T> Sync for Inner<T> {}
55
56impl<T> Inner<T> {
57    pub const fn new() -> Self {
58        Self {
59            data: UnsafeCell::new(MaybeUninit::uninit()),
60            recv_waker: OptionLock::new(),
61            state: AtomicU8::new(INIT),
62        }
63    }
64
65    pub fn cancel_recv(&self) -> Option<T> {
66        match self.state.swap(CANCEL, Ordering::SeqCst) {
67            READY => Some(self.take()),
68            _ => None,
69        }
70    }
71
72    pub fn cancel_send(&self) -> bool {
73        if self.state.compare_and_swap(INIT, CANCEL, Ordering::SeqCst) == INIT {
74            if let Ok(waker) = self.recv_waker.try_take() {
75                waker.wake();
76            }
77            true
78        } else {
79            false
80        }
81    }
82
83    pub fn is_canceled(&self) -> bool {
84        self.state.load(Ordering::Acquire) == CANCEL
85    }
86
87    pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, Canceled>> {
88        loop {
89            match self.try_recv() {
90                Ok(Some(val)) => return Poll::Ready(Ok(val)),
91                Ok(None) => {
92                    let waker = cx.waker().clone();
93                    if let Ok(mut guard) = self.recv_waker.try_lock() {
94                        guard.replace(waker);
95                    } else {
96                        // the sender is already trying to wake us, so the
97                        // value has been stored
98                        continue;
99                    }
100
101                    // check the state again, in case the sender
102                    // failed to get a lock on the waker because we were storing it
103                    match self.state.load(Ordering::Acquire) {
104                        INIT => {
105                            return Poll::Pending;
106                        }
107                        CANCEL => {
108                            // sender dropped
109                            return Poll::Ready(Err(Canceled));
110                        }
111                        LOAD => {
112                            // sender was interrupted while setting the value, spin
113                            thread::yield_now();
114                            continue;
115                        }
116                        READY => {
117                            // sender completed concurrently
118                            continue;
119                        }
120                        _ => {
121                            panic!("Invalid state for dropshot");
122                        }
123                    }
124                }
125                Err(err) => return Poll::Ready(Err(err)),
126            }
127        }
128    }
129
130    pub fn try_recv(&self) -> Result<Option<T>, Canceled> {
131        loop {
132            match self
133                .state
134                .compare_exchange_weak(READY, SENT, Ordering::AcqRel, Ordering::Acquire)
135            {
136                Ok(_) => {
137                    return Ok(Some(self.take()));
138                }
139                Err(INIT) => {
140                    return Ok(None);
141                }
142                Err(CANCEL) => {
143                    // sender dropped
144                    return Err(Canceled);
145                }
146                Err(LOAD) => {
147                    // sender was interrupted while setting the value, spin
148                    thread::yield_now();
149                    continue;
150                }
151                Err(READY) => {
152                    // spurious failure
153                    continue;
154                }
155                Err(SENT) => {
156                    // receive was called after taking the value
157                    return Err(Canceled);
158                }
159                Err(_) => {
160                    panic!("Invalid state for dropshot");
161                }
162            }
163        }
164    }
165
166    pub fn send(&self, value: T) -> Result<(), T> {
167        loop {
168            match self
169                .state
170                .compare_exchange_weak(INIT, LOAD, Ordering::AcqRel, Ordering::Acquire)
171            {
172                Ok(_) => {
173                    unsafe { self.data.get().write(MaybeUninit::new(value)) };
174                    match self.state.compare_exchange(
175                        LOAD,
176                        READY,
177                        Ordering::AcqRel,
178                        Ordering::Acquire,
179                    ) {
180                        Ok(_) => {
181                            if let Ok(waker) = self.recv_waker.try_take() {
182                                waker.wake();
183                            }
184                            return Ok(());
185                        }
186                        Err(CANCEL) => {
187                            // receiver dropped mid-send
188                            return Err(self.take());
189                        }
190                        _ => panic!("Invalid state for dropshot"),
191                    }
192                }
193                Err(INIT) => {
194                    // spurious failure
195                    continue;
196                }
197                Err(CANCEL) | Err(LOAD) | Err(READY) | Err(SENT) => {
198                    // receiver hung up, or send was called repeatedly
199                    return Err(value);
200                }
201                Err(_) => {
202                    panic!("Invalid state for dropshot");
203                }
204            }
205        }
206    }
207
208    #[inline]
209    fn take(&self) -> T {
210        unsafe { self.data.get().read().assume_init() }
211    }
212}
213
214pub struct Receiver<T> {
215    inner: Arc<Inner<T>>,
216}
217
218impl<T> Receiver<T> {
219    pub fn cancel(&mut self) -> Option<T> {
220        self.inner.cancel_recv()
221    }
222
223    pub fn recv(&mut self) -> Result<T, Canceled> {
224        for _ in 0..20 {
225            match self.inner.try_recv() {
226                Ok(Some(value)) => return Ok(value),
227                Ok(None) => {
228                    thread::yield_now();
229                }
230                Err(err) => return Err(err),
231            }
232        }
233        loop {
234            let (waker, waiter) = thread_waker::pair();
235            let task_waker = waker.task_waker();
236            let mut context = Context::from_waker(&task_waker);
237            match self.inner.poll_recv(&mut context) {
238                Poll::Ready(result) => return result,
239                Poll::Pending => {
240                    waiter.wait();
241                }
242            }
243        }
244    }
245
246    pub fn try_recv(&mut self) -> Result<Option<T>, Canceled> {
247        self.inner.try_recv()
248    }
249}
250
251impl<T> Future for Receiver<T> {
252    type Output = Result<T, Canceled>;
253
254    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, Canceled>> {
255        self.inner.poll_recv(cx)
256    }
257}
258
259impl<T> Drop for Receiver<T> {
260    fn drop(&mut self) {
261        self.inner.cancel_recv();
262    }
263}
264
265pub struct Sender<T> {
266    inner: Arc<Inner<T>>,
267}
268
269impl<T> Sender<T> {
270    pub fn cancel(&self) -> bool {
271        self.inner.cancel_send()
272    }
273
274    pub fn is_canceled(&self) -> bool {
275        self.inner.is_canceled()
276    }
277
278    pub fn send(&self, data: T) -> Result<(), T> {
279        self.inner.send(data)
280    }
281}
282
283impl<T> Drop for Sender<T> {
284    fn drop(&mut self) {
285        self.inner.cancel_send();
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use futures_util::task::{waker_ref, ArcWake};
293    use std::sync::atomic::AtomicUsize;
294
295    struct TestWaker {
296        calls: AtomicUsize,
297    }
298
299    impl TestWaker {
300        pub fn new() -> Self {
301            Self {
302                calls: AtomicUsize::new(0),
303            }
304        }
305
306        pub fn count(&self) -> usize {
307            return self.calls.load(Ordering::Acquire);
308        }
309    }
310
311    impl ArcWake for TestWaker {
312        fn wake_by_ref(arc_self: &Arc<Self>) {
313            arc_self.calls.fetch_add(1, Ordering::SeqCst);
314        }
315    }
316
317    #[test]
318    fn dropshot_send_normal() {
319        let (sender, mut receiver) = channel();
320        let waker = Arc::new(TestWaker::new());
321        let wr = waker_ref(&waker);
322        let mut cx = Context::from_waker(&wr);
323        assert_eq!(Pin::new(&mut receiver).poll(&mut cx), Poll::Pending);
324        assert_eq!(waker.count(), 0);
325        assert!(sender.send(1u32).is_ok());
326        assert_eq!(waker.count(), 1);
327        assert_eq!(Pin::new(&mut receiver).poll(&mut cx), Poll::Ready(Ok(1u32)));
328        drop(sender);
329        assert_eq!(waker.count(), 1);
330        assert_eq!(
331            Pin::new(&mut receiver).poll(&mut cx),
332            Poll::Ready(Err(Canceled))
333        );
334        assert_eq!(waker.count(), 1);
335    }
336
337    #[test]
338    fn dropshot_sender_dropped() {
339        let (sender, mut receiver) = channel::<u32>();
340        let waker = Arc::new(TestWaker::new());
341        let wr = waker_ref(&waker);
342        let mut cx = Context::from_waker(&wr);
343        assert_eq!(Pin::new(&mut receiver).poll(&mut cx), Poll::Pending);
344        drop(sender);
345        assert_eq!(waker.count(), 1);
346        assert_eq!(
347            Pin::new(&mut receiver).poll(&mut cx),
348            Poll::Ready(Err(Canceled))
349        );
350        assert_eq!(waker.count(), 1);
351    }
352
353    #[test]
354    fn dropshot_receiver_dropped() {
355        let (sender, receiver) = channel();
356        drop(receiver);
357        assert_eq!(sender.send(1u32), Err(1u32));
358    }
359
360    #[test]
361    fn dropshot_test_future() {
362        use futures_executor::block_on;
363        let (sender, receiver) = channel::<u32>();
364        sender.send(5).unwrap();
365        assert_eq!(block_on(receiver), Ok(5));
366    }
367}