flag_bearer_queue/
acquire.rs

1use core::{
2    fmt,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use lock_api::RawMutex;
8use pin_list::{Node, NodeData};
9
10use crate::closeable::{IsCloseable, Uncloseable};
11use crate::{SemaphoreQueue, SemaphoreState};
12
13use super::PinQueue;
14
15use crate::loom::Mutex;
16
17pin_project_lite::pin_project! {
18    /// A [`Future`] that acquires a permit from a [`SemaphoreQueue`].
19    pub struct Acquire<'a, S, C, R>
20    where
21        S: ?Sized,
22        S: SemaphoreState,
23        C: IsCloseable,
24        R: RawMutex,
25    {
26        #[pin]
27        node: Node<PinQueue<S::Params, S::Permit, C>>,
28        order: FairOrder,
29        state: &'a Mutex<R, SemaphoreQueue<S, C>>,
30        params: Option<S::Params>,
31    }
32
33    impl<S, C, R> PinnedDrop for Acquire<'_, S, C, R>
34    where
35        S: ?Sized,
36        S: SemaphoreState,
37        C: IsCloseable,
38        R: RawMutex,
39    {
40        fn drop(this: Pin<&mut Self>) {
41            let this = this.project();
42            let Some(node) = this.node.initialized_mut() else {
43                return;
44            };
45            let mut state = this.state.lock();
46            match &mut state.queue {
47                Ok(queue) => {
48                    let (data, _unprotected) = node.reset(queue);
49                    if let NodeData::Removed(Ok(permit)) = data {
50                        state.state.release(permit);
51                        state.check();
52                    }
53                }
54                Err(_closed) => {
55                    // Safety: If the semaphore is closed (meaning we have no queue)
56                    // then there's no way this node could be queued in the queue,
57                    // therefore it must be removed.
58                    let (permit, ()) = unsafe { node.take_removed_unchecked() };
59                    if let Ok(permit) = permit {
60                        state.state.release(permit);
61                    }
62                }
63            }
64        }
65    }
66}
67
68impl<S: SemaphoreState + ?Sized, C: IsCloseable, R: RawMutex> Future for Acquire<'_, S, C, R> {
69    type Output = Result<S::Permit, C::AcquireError<S::Params>>;
70
71    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72        let mut this = self.project();
73        let mut state = this.state.lock();
74
75        let Some(init) = this.node.as_mut().initialized_mut() else {
76            // first time polling.
77            let params = this.params.take().unwrap();
78            let node = this.node.as_mut();
79
80            return match state.try_acquire(params, Fairness::Fair(*this.order)) {
81                Ok(permit) => Poll::Ready(Ok(permit)),
82                Err(TryAcquireError::Closed(params)) => Poll::Ready(Err(params)),
83                Err(TryAcquireError::NoPermits(params)) => {
84                    let queue = match &mut state.queue {
85                        Ok(queue) => queue,
86                        // Safety: if the queue was closed, we would get a `Closed` error type.
87                        // It was not closed, thus it still isn't closed.
88                        Err(_closed) => unsafe { core::hint::unreachable_unchecked() },
89                    };
90
91                    // no permit or we are not the leader, so we register into the queue.
92                    let mut cursor = queue.cursor_ghost_mut();
93                    let protected = (Some(params), cx.waker().clone());
94                    let unprotected = ();
95                    match *this.order {
96                        FairOrder::Lifo => cursor.insert_after(node, protected, unprotected),
97                        FairOrder::Fifo => cursor.insert_before(node, protected, unprotected),
98                    };
99                    Poll::Pending
100                }
101            };
102        };
103
104        if let Ok(queue) = &mut state.queue {
105            if let Some((_, waker)) = init.protected_mut(queue) {
106                // spurious wakeup
107                waker.clone_from(cx.waker());
108                return Poll::Pending;
109            }
110        }
111
112        // Safety: Either there is no queue, then we are guaranteed to be removed from it
113        // Or there was a queue, but we were removed from it anyway (protected_mut returned None).
114        let (permit, ()) = unsafe { init.take_removed_unchecked() };
115        let permit = permit.map_err(|params| {
116            C::map_err(params, |params| {
117                params.expect(
118                    "params should be set. likely the SemaphoreState::acquire method panicked",
119                )
120            })
121        });
122        Poll::Ready(permit)
123    }
124}
125
126#[derive(Debug, Clone, Copy)]
127#[non_exhaustive]
128/// The order of which [`Acquire`] should enter the queue.
129pub enum FairOrder {
130    /// Last in, first out.
131    /// Increases tail latencies, but can have better average performance.
132    Lifo,
133    /// First in, first out.
134    /// Fairer option, but can have cascading failures if queue processing is slow.
135    Fifo,
136}
137
138#[derive(Debug, Clone, Copy)]
139#[non_exhaustive]
140/// Which fairness property [`SemaphoreQueue::try_acquire`] should respect
141pub enum Fairness {
142    /// [`SemaphoreQueue::try_acquire`] will be fair.
143    Fair(FairOrder),
144    /// [`SemaphoreQueue::try_acquire`] will be unfair.
145    Unfair,
146}
147
148impl<S: SemaphoreState + ?Sized, C: IsCloseable> SemaphoreQueue<S, C> {
149    /// Acquire a permit, or join the queue if not currently available.
150    ///
151    /// * If the order is [`FairOrder::Lifo`], then we enqueue at the front of the queue.
152    /// * If the order is [`FairOrder::Fifo`], then we enqueue at the back of the queue.
153    #[inline]
154    pub fn acquire<R: RawMutex>(
155        this: &Mutex<R, Self>,
156        params: S::Params,
157        order: FairOrder,
158    ) -> Acquire<'_, S, C, R> {
159        Acquire {
160            node: Node::new(),
161            order,
162            state: this,
163            params: Some(params),
164        }
165    }
166
167    /// Try acquire a permit without joining the queue.
168    ///
169    /// * If the fairness is [`Fairness::Unfair`], or [`Fairness::Fair(FairOrder::Lifo)`](FairOrder::Lifo), then we always try acquire a permit.
170    /// * If the fairness is [`Fairness::Fair(FairOrder::Fifo)`](FairOrder::Fifo), then we only try acquire a permit if the queue is empty.
171    ///
172    /// # Errors
173    ///
174    /// If there are currently not enough permits available for the given request,
175    /// then [`TryAcquireError::NoPermits`] is returned.
176    ///
177    /// If this is a [`Fairness::Fair(FairOrder::Fifo)`](FairOrder::Fifo) semaphore queue,
178    /// and there are other tasks waiting for permits,
179    /// then [`TryAcquireError::NoPermits`] is returned.
180    ///
181    /// If this semaphore [`is_closed`](SemaphoreQueue::is_closed), then [`TryAcquireError::Closed`] is returned.
182    #[inline]
183    pub fn try_acquire(
184        &mut self,
185        params: S::Params,
186        fairness: Fairness,
187    ) -> Result<S::Permit, TryAcquireError<S::Params, C>> {
188        let queue = match &mut self.queue {
189            Ok(queue) => queue,
190            Err(_closed) => {
191                return Err(TryAcquireError::Closed(C::new_err(params)));
192            }
193        };
194
195        let is_leader = match fairness {
196            // if first-in-first-out, we are only the leader if the queue is empty.
197            Fairness::Fair(FairOrder::Fifo) => queue.is_empty(),
198
199            // if unfair, then we don't care who the leader is.
200            // if last-in-first-out, we are the last in and thus the leader.
201            Fairness::Unfair | Fairness::Fair(FairOrder::Lifo) => true,
202        };
203
204        if !is_leader {
205            return Err(TryAcquireError::NoPermits(params));
206        }
207
208        match self.state.acquire(params) {
209            Ok(permit) => Ok(permit),
210            Err(p) => Err(TryAcquireError::NoPermits(p)),
211        }
212    }
213}
214
215/// The error returned by [`SemaphoreQueue::try_acquire`].
216#[derive(Debug, PartialEq, Eq)]
217pub enum TryAcquireError<P, C: IsCloseable> {
218    /// The semaphore had no permits to give out right now.
219    NoPermits(P),
220    /// The semaphore is closed.
221    Closed(C::AcquireError<P>),
222}
223
224impl<P, C: IsCloseable> fmt::Display for TryAcquireError<P, C> {
225    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
226        match self {
227            TryAcquireError::Closed(_) => write!(fmt, "semaphore closed"),
228            TryAcquireError::NoPermits(_) => write!(fmt, "no permits available"),
229        }
230    }
231}
232
233/// The error returned by [`Acquire`] if the semaphore queue was closed.
234///
235/// ```
236/// struct Counter(usize);
237///
238/// impl flag_bearer::SemaphoreState for Counter {
239///     type Params = ();
240///     type Permit = ();
241///
242///     fn acquire(&mut self, _: Self::Params) -> Result<Self::Permit, Self::Params> {
243///         if self.0 > 0 {
244///             self.0 -= 1;
245///             Ok(())
246///         } else {
247///             Err(())
248///         }
249///     }
250///
251///     fn release(&mut self, _: Self::Permit) {
252///         self.0 += 1;
253///     }
254/// }
255///
256/// # pollster::block_on(async move {
257/// let s = flag_bearer::new_fifo().closeable().with_state(Counter(1));
258///
259/// // closing the semaphore makes all current and new acquire() calls return an error.
260/// s.close();
261///
262/// let _err = s.acquire(()).await.unwrap_err();
263/// # });
264/// ```
265#[non_exhaustive]
266#[derive(Debug, PartialEq, Eq)]
267pub struct AcquireError<P> {
268    /// The params that was used in the acquire request
269    pub params: P,
270}
271
272impl AcquireError<Uncloseable> {
273    /// Since the [`SemaphoreQueue`] is [`Uncloseable`], there can
274    /// never be an acquire error. This allows for unwrapping with type-safety.
275    pub fn never(self) -> ! {
276        match self.params {}
277    }
278}
279
280impl<P> fmt::Display for AcquireError<P> {
281    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
282        write!(fmt, "semaphore closed")
283    }
284}