async_resource/util/
dropshot.rs1use std::cell::UnsafeCell;
2use std::fmt;
3use std::future::Future;
4use std::mem::MaybeUninit;
5use std::pin::Pin;
6use std::sync::{
7 atomic::{AtomicU8, Ordering},
8 Arc,
9};
10use std::task::{Context, Poll, Waker};
11use std::thread;
12
13use option_lock::OptionLock;
14
15use super::thread_waker;
16
17const INIT: u8 = 0;
23const LOAD: u8 = 1;
24const READY: u8 = 2;
25const SENT: u8 = 3;
26const CANCEL: u8 = 4;
27
28#[derive(Clone, Copy, PartialEq, Eq, Debug)]
29pub struct Canceled;
30
31impl fmt::Display for Canceled {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 write!(f, "dropshot canceled")
34 }
35}
36
37impl std::error::Error for Canceled {}
38
39pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
40 let inner = Arc::new(Inner::new());
41 let receiver = Receiver {
42 inner: inner.clone(),
43 };
44 let sender = Sender { inner };
45 (sender, receiver)
46}
47
48struct Inner<T> {
49 data: UnsafeCell<MaybeUninit<T>>,
50 recv_waker: OptionLock<Waker>,
51 state: AtomicU8,
52}
53
54unsafe impl<T> Sync for Inner<T> {}
55
56impl<T> Inner<T> {
57 pub const fn new() -> Self {
58 Self {
59 data: UnsafeCell::new(MaybeUninit::uninit()),
60 recv_waker: OptionLock::new(),
61 state: AtomicU8::new(INIT),
62 }
63 }
64
65 pub fn cancel_recv(&self) -> Option<T> {
66 match self.state.swap(CANCEL, Ordering::SeqCst) {
67 READY => Some(self.take()),
68 _ => None,
69 }
70 }
71
72 pub fn cancel_send(&self) -> bool {
73 if self.state.compare_and_swap(INIT, CANCEL, Ordering::SeqCst) == INIT {
74 if let Ok(waker) = self.recv_waker.try_take() {
75 waker.wake();
76 }
77 true
78 } else {
79 false
80 }
81 }
82
83 pub fn is_canceled(&self) -> bool {
84 self.state.load(Ordering::Acquire) == CANCEL
85 }
86
87 pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, Canceled>> {
88 loop {
89 match self.try_recv() {
90 Ok(Some(val)) => return Poll::Ready(Ok(val)),
91 Ok(None) => {
92 let waker = cx.waker().clone();
93 if let Ok(mut guard) = self.recv_waker.try_lock() {
94 guard.replace(waker);
95 } else {
96 continue;
99 }
100
101 match self.state.load(Ordering::Acquire) {
104 INIT => {
105 return Poll::Pending;
106 }
107 CANCEL => {
108 return Poll::Ready(Err(Canceled));
110 }
111 LOAD => {
112 thread::yield_now();
114 continue;
115 }
116 READY => {
117 continue;
119 }
120 _ => {
121 panic!("Invalid state for dropshot");
122 }
123 }
124 }
125 Err(err) => return Poll::Ready(Err(err)),
126 }
127 }
128 }
129
130 pub fn try_recv(&self) -> Result<Option<T>, Canceled> {
131 loop {
132 match self
133 .state
134 .compare_exchange_weak(READY, SENT, Ordering::AcqRel, Ordering::Acquire)
135 {
136 Ok(_) => {
137 return Ok(Some(self.take()));
138 }
139 Err(INIT) => {
140 return Ok(None);
141 }
142 Err(CANCEL) => {
143 return Err(Canceled);
145 }
146 Err(LOAD) => {
147 thread::yield_now();
149 continue;
150 }
151 Err(READY) => {
152 continue;
154 }
155 Err(SENT) => {
156 return Err(Canceled);
158 }
159 Err(_) => {
160 panic!("Invalid state for dropshot");
161 }
162 }
163 }
164 }
165
166 pub fn send(&self, value: T) -> Result<(), T> {
167 loop {
168 match self
169 .state
170 .compare_exchange_weak(INIT, LOAD, Ordering::AcqRel, Ordering::Acquire)
171 {
172 Ok(_) => {
173 unsafe { self.data.get().write(MaybeUninit::new(value)) };
174 match self.state.compare_exchange(
175 LOAD,
176 READY,
177 Ordering::AcqRel,
178 Ordering::Acquire,
179 ) {
180 Ok(_) => {
181 if let Ok(waker) = self.recv_waker.try_take() {
182 waker.wake();
183 }
184 return Ok(());
185 }
186 Err(CANCEL) => {
187 return Err(self.take());
189 }
190 _ => panic!("Invalid state for dropshot"),
191 }
192 }
193 Err(INIT) => {
194 continue;
196 }
197 Err(CANCEL) | Err(LOAD) | Err(READY) | Err(SENT) => {
198 return Err(value);
200 }
201 Err(_) => {
202 panic!("Invalid state for dropshot");
203 }
204 }
205 }
206 }
207
208 #[inline]
209 fn take(&self) -> T {
210 unsafe { self.data.get().read().assume_init() }
211 }
212}
213
214pub struct Receiver<T> {
215 inner: Arc<Inner<T>>,
216}
217
218impl<T> Receiver<T> {
219 pub fn cancel(&mut self) -> Option<T> {
220 self.inner.cancel_recv()
221 }
222
223 pub fn recv(&mut self) -> Result<T, Canceled> {
224 for _ in 0..20 {
225 match self.inner.try_recv() {
226 Ok(Some(value)) => return Ok(value),
227 Ok(None) => {
228 thread::yield_now();
229 }
230 Err(err) => return Err(err),
231 }
232 }
233 loop {
234 let (waker, waiter) = thread_waker::pair();
235 let task_waker = waker.task_waker();
236 let mut context = Context::from_waker(&task_waker);
237 match self.inner.poll_recv(&mut context) {
238 Poll::Ready(result) => return result,
239 Poll::Pending => {
240 waiter.wait();
241 }
242 }
243 }
244 }
245
246 pub fn try_recv(&mut self) -> Result<Option<T>, Canceled> {
247 self.inner.try_recv()
248 }
249}
250
251impl<T> Future for Receiver<T> {
252 type Output = Result<T, Canceled>;
253
254 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, Canceled>> {
255 self.inner.poll_recv(cx)
256 }
257}
258
259impl<T> Drop for Receiver<T> {
260 fn drop(&mut self) {
261 self.inner.cancel_recv();
262 }
263}
264
265pub struct Sender<T> {
266 inner: Arc<Inner<T>>,
267}
268
269impl<T> Sender<T> {
270 pub fn cancel(&self) -> bool {
271 self.inner.cancel_send()
272 }
273
274 pub fn is_canceled(&self) -> bool {
275 self.inner.is_canceled()
276 }
277
278 pub fn send(&self, data: T) -> Result<(), T> {
279 self.inner.send(data)
280 }
281}
282
283impl<T> Drop for Sender<T> {
284 fn drop(&mut self) {
285 self.inner.cancel_send();
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use futures_util::task::{waker_ref, ArcWake};
293 use std::sync::atomic::AtomicUsize;
294
295 struct TestWaker {
296 calls: AtomicUsize,
297 }
298
299 impl TestWaker {
300 pub fn new() -> Self {
301 Self {
302 calls: AtomicUsize::new(0),
303 }
304 }
305
306 pub fn count(&self) -> usize {
307 return self.calls.load(Ordering::Acquire);
308 }
309 }
310
311 impl ArcWake for TestWaker {
312 fn wake_by_ref(arc_self: &Arc<Self>) {
313 arc_self.calls.fetch_add(1, Ordering::SeqCst);
314 }
315 }
316
317 #[test]
318 fn dropshot_send_normal() {
319 let (sender, mut receiver) = channel();
320 let waker = Arc::new(TestWaker::new());
321 let wr = waker_ref(&waker);
322 let mut cx = Context::from_waker(&wr);
323 assert_eq!(Pin::new(&mut receiver).poll(&mut cx), Poll::Pending);
324 assert_eq!(waker.count(), 0);
325 assert!(sender.send(1u32).is_ok());
326 assert_eq!(waker.count(), 1);
327 assert_eq!(Pin::new(&mut receiver).poll(&mut cx), Poll::Ready(Ok(1u32)));
328 drop(sender);
329 assert_eq!(waker.count(), 1);
330 assert_eq!(
331 Pin::new(&mut receiver).poll(&mut cx),
332 Poll::Ready(Err(Canceled))
333 );
334 assert_eq!(waker.count(), 1);
335 }
336
337 #[test]
338 fn dropshot_sender_dropped() {
339 let (sender, mut receiver) = channel::<u32>();
340 let waker = Arc::new(TestWaker::new());
341 let wr = waker_ref(&waker);
342 let mut cx = Context::from_waker(&wr);
343 assert_eq!(Pin::new(&mut receiver).poll(&mut cx), Poll::Pending);
344 drop(sender);
345 assert_eq!(waker.count(), 1);
346 assert_eq!(
347 Pin::new(&mut receiver).poll(&mut cx),
348 Poll::Ready(Err(Canceled))
349 );
350 assert_eq!(waker.count(), 1);
351 }
352
353 #[test]
354 fn dropshot_receiver_dropped() {
355 let (sender, receiver) = channel();
356 drop(receiver);
357 assert_eq!(sender.send(1u32), Err(1u32));
358 }
359
360 #[test]
361 fn dropshot_test_future() {
362 use futures_executor::block_on;
363 let (sender, receiver) = channel::<u32>();
364 sender.send(5).unwrap();
365 assert_eq!(block_on(receiver), Ok(5));
366 }
367}