async_weighted_semaphore/
acquire.rs

1use crate::state::{Waiter, AcquireStep, Permits};
2use std::cell::UnsafeCell;
3use std::marker::{PhantomPinned, PhantomData};
4use crate::{Semaphore, SemaphoreGuard, SemaphoreGuardArc};
5use std::sync::Arc;
6use std::panic::{UnwindSafe, RefUnwindSafe};
7use std::fmt::{Debug, Formatter};
8use std::{fmt};
9use std::task::{Context, Poll};
10use crate::state::AcquireState::{Available, Queued};
11use std::ptr::null;
12use std::pin::Pin;
13use std::future::Future;
14use crate::waker::{WakerResult};
15use crate::errors::PoisonError;
16use std::sync::atomic::Ordering::Acquire;
17
18/// A [`Future`] returned by [`Semaphore::acquire`] that produces a [`SemaphoreGuard`].
19pub struct AcquireFuture<'a>(pub(crate) UnsafeCell<Waiter>, pub(crate) PhantomData<&'a Semaphore>, pub(crate) PhantomPinned);
20
21/// A [`Future`] returned by [`Semaphore::acquire_arc`] that produces a [`SemaphoreGuardArc`].
22pub struct AcquireFutureArc {
23    pub(crate) arc: Arc<Semaphore>,
24    pub(crate) inner: AcquireFuture<'static>,
25}
26
27unsafe impl<'a> Sync for AcquireFuture<'a> {}
28
29unsafe impl<'a> Send for AcquireFuture<'a> {}
30
31impl<'a> UnwindSafe for AcquireFuture<'a> {}
32
33impl<'a> RefUnwindSafe for AcquireFuture<'a> {}
34
35impl<'a> Debug for AcquireFuture<'a> {
36    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
37        f.debug_tuple("AcquireFuture").field(&unsafe { self.waiter() }.amount).finish()
38    }
39}
40
41impl Debug for AcquireFutureArc {
42    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
43        f.debug_tuple("AcquireFutureArc").field(&unsafe { self.inner.waiter() }.amount).finish()
44    }
45}
46
47impl<'a> AcquireFuture<'a> {
48    unsafe fn waiter(&self) -> &Waiter {
49        &*self.0.get()
50    }
51    // Try to acquire or add to queue.
52    unsafe fn poll_enter(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<SemaphoreGuard<'a>, PoisonError>> {
53        let acquire = &(*self.waiter().semaphore).acquire;
54        let mut current = acquire.load(Acquire);
55        loop {
56            let (available, back) = match current {
57                Queued(back) => (0, back),
58                Available(available) => {
59                    let available = match available.into_usize() {
60                        None => {
61                            *self.waiter().step.get() = AcquireStep::Done;
62                            return Poll::Ready(Err(PoisonError));
63                        }
64                        Some(available) => available,
65                    };
66                    if self.waiter().amount <= available {
67                        if !acquire.cmpxchg_weak_acqrel(&mut current, Available(Permits::new(available - self.waiter().amount))) { continue; }
68                        *self.waiter().step.get() = AcquireStep::Done;
69                        return Poll::Ready(Ok(SemaphoreGuard::new(
70                            &*self.waiter().semaphore, self.waiter().amount)));
71                    } else {
72                        (available, null())
73                    }
74                }
75            };
76            assert!(self.waiter().waker.poll(cx).is_pending());
77            *self.waiter().prev.get() = back;
78            if !acquire.cmpxchg_weak_acqrel(&mut current, Queued(self.0.get())) { continue; }
79            *self.waiter().step.get() = AcquireStep::Waiting;
80            // Even if available==0, this is necessary to set release to LockedDirty.
81            (*self.waiter().semaphore).release(available);
82            return Poll::Pending;
83        }
84    }
85
86    unsafe fn poll_waiting(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<SemaphoreGuard<'a>, PoisonError>> {
87        match self.waiter().waker.poll(cx) {
88            Poll::Pending => Poll::Pending,
89            Poll::Ready(poisoned) => {
90                *self.waiter().step.get() = AcquireStep::Done;
91                if poisoned {
92                    Poll::Ready(Err(PoisonError))
93                } else {
94                    Poll::Ready(Ok(SemaphoreGuard::new(&*self.waiter().semaphore, self.waiter().amount)))
95                }
96            }
97        }
98    }
99}
100
101impl<'a> Future for AcquireFuture<'a> {
102    type Output = Result<SemaphoreGuard<'a>, PoisonError>;
103
104    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105        unsafe {
106            match *(*self.0.get()).step.get() {
107                AcquireStep::Entering => {
108                    self.poll_enter(cx)
109                }
110                AcquireStep::Waiting => {
111                    self.poll_waiting(cx)
112                }
113                AcquireStep::Done => panic!("Polling completed future.")
114            }
115        }
116    }
117}
118
119impl Future for AcquireFutureArc {
120    type Output = Result<SemaphoreGuardArc, PoisonError>;
121
122    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123        unsafe {
124            let this = self.get_unchecked_mut();
125            match Pin::new_unchecked(&mut this.inner).poll(cx) {
126                Poll::Ready(guard) => {
127                    let result =
128                        SemaphoreGuardArc::new(this.arc.clone(), guard?.forget());
129                    Poll::Ready(Ok(result))
130                }
131                Poll::Pending => Poll::Pending,
132            }
133        }
134    }
135}
136
137impl<'a> Drop for AcquireFuture<'a> {
138    fn drop(&mut self) {
139        unsafe {
140            match *self.waiter().step.get() {
141                AcquireStep::Waiting => {
142                    // Decide whether the finish or cancel wins if there is a race.
143                    match self.waiter().waker.start_cancel() {
144                        WakerResult::Cancelling => {
145                            // Push onto the cancel queue.
146                            let next_cancel = &(*self.waiter().semaphore).next_cancel;
147                            let mut current = next_cancel.load(Acquire);
148                            loop {
149                                *self.waiter().next_cancel.get() = current;
150                                if next_cancel.cmpxchg_weak_acqrel(&mut current, self.0.get()) { break; }
151                            }
152                            // Ensure a flush of the cancel queue is completed or at least scheduled.
153                            (*self.waiter().semaphore).release(0);
154                            // Wait for a notification that the node can be dropped
155                            self.waiter().waker.wait_cancel();
156                        }
157                        WakerResult::Finished { poisoned } => {
158                            // The acquire finished before it could be cancelled. Pretend like
159                            // nothing happened and release the acquired permits.
160                            if !poisoned {
161                                (*self.waiter().semaphore).release(self.waiter().amount);
162                            }
163                        }
164                    }
165                }
166                AcquireStep::Entering { .. } => {}
167                AcquireStep::Done => {}
168            }
169        }
170    }
171}