flag_bearer_queue/
lib.rs

1#![no_std]
2#![warn(
3    unsafe_op_in_unsafe_fn,
4    clippy::missing_safety_doc,
5    clippy::multiple_unsafe_ops_per_block,
6    clippy::undocumented_unsafe_blocks
7)]
8
9#[cfg(test)]
10extern crate std;
11
12use core::{hint::unreachable_unchecked, task::Waker};
13
14use closeable::{Closeable, IsCloseable};
15use flag_bearer_core::SemaphoreState;
16use pin_list::PinList;
17
18pub mod acquire;
19pub mod closeable;
20
21mod loom;
22
23/// A queue that manages the acquisition of permits from a [`SemaphoreState`], or queues tasks
24/// if no permits are available.
25// don't question the weird bounds here...
26pub struct SemaphoreQueue<
27    S: SemaphoreState<Params = Params, Permit = Permit> + ?Sized,
28    C: IsCloseable,
29    Params = <S as SemaphoreState>::Params,
30    Permit = <S as SemaphoreState>::Permit,
31> {
32    #[allow(clippy::type_complexity)]
33    queue: Result<PinList<PinQueue<Params, Permit, C>>, C::Closed<()>>,
34    state: S,
35}
36
37impl<S: SemaphoreState + core::fmt::Debug, C: IsCloseable> core::fmt::Debug
38    for SemaphoreQueue<S, C>
39{
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        let mut d = f.debug_struct("SemaphoreQueue");
42        d.field("state", &self.state);
43        d.finish_non_exhaustive()
44    }
45}
46
47type PinQueue<Params, Permit, C> = dyn pin_list::Types<
48        Id = pin_list::id::DebugChecked,
49        Protected = (
50            // Some(params) -> Pending
51            // None -> Invalid state.
52            Option<Params>,
53            Waker,
54        ),
55        Removed = Result<
56            // Ok(permit) -> Ready
57            Permit,
58            // Err(Some(params)) -> Closed
59            // Err(None) -> Closed, Invalid state
60            <C as closeable::private::Sealed>::Closed<Option<Params>>,
61        >,
62        Unprotected = (),
63    >;
64
65impl<S: SemaphoreState, C: IsCloseable> SemaphoreQueue<S, C> {
66    /// Construct a new semaphore queue, with the given [`SemaphoreState`].
67    pub fn new(state: S) -> Self {
68        Self {
69            state,
70            // Safety: during acquire, we ensure that nodes in this queue
71            // will never attempt to use a different queue to read the nodes.
72            queue: Ok(PinList::new(unsafe { pin_list::id::DebugChecked::new() })),
73        }
74    }
75}
76
77impl<S: SemaphoreState + ?Sized, C: IsCloseable> SemaphoreQueue<S, C> {
78    /// Access the state with mutable access.
79    ///
80    /// This gives direct access to the state, be careful not to
81    /// break any of your own state invariants. You can use this
82    /// to peek at the current state, or to modify it, eg to add or
83    /// remove permits from the semaphore.
84    pub fn with_state<T>(&mut self, f: impl FnOnce(&mut S) -> T) -> T {
85        let res = f(&mut self.state);
86        self.check();
87        res
88    }
89
90    #[inline]
91    fn check(&mut self) {
92        let Ok(queue) = &mut self.queue else { return };
93        let mut leader = queue.cursor_front_mut();
94        while let Some(p) = leader.protected_mut() {
95            let params = p.0.take().expect(
96                "params should be in place. possibly the SemaphoreState::acquire method panicked",
97            );
98            match self.state.acquire(params) {
99                Ok(permit) => match leader.remove_current(Ok(permit)) {
100                    Ok((_, waker)) => waker.wake(),
101                    // Safety: with protected_mut, we have just made sure it is in the list
102                    Err(_) => unsafe { unreachable_unchecked() },
103                },
104                Err(params) => {
105                    p.0 = Some(params);
106                    break;
107                }
108            }
109        }
110    }
111
112    /// Check if the queue is closed
113    pub fn is_closed(&self) -> bool {
114        self.queue.is_err()
115    }
116}
117
118impl<S: SemaphoreState + ?Sized> SemaphoreQueue<S, Closeable> {
119    /// Close the semaphore queue.
120    ///
121    /// All tasks currently waiting to acquire a token will immediately stop.
122    /// No new acquire attempts will succeed.
123    pub fn close(&mut self) {
124        let Ok(queue) = &mut self.queue else {
125            return;
126        };
127
128        let mut cursor = queue.cursor_front_mut();
129        while cursor.remove_current_with_or(
130            |(params, waker)| {
131                waker.wake();
132
133                Err(params)
134            },
135            || Err(None),
136        ) {}
137
138        debug_assert!(queue.is_empty());
139
140        // It's important that we only mark the queue as closed when we have ensured that
141        // all linked nodes are removed.
142        // If we did this early, we could panic and not dequeue every node.
143        self.queue = Err(());
144    }
145}
146
147#[cfg(all(test, loom))]
148mod loom_tests {
149    use crate::{SemaphoreQueue, closeable::Closeable};
150
151    #[derive(Debug)]
152    struct NeverSucceeds;
153
154    impl crate::SemaphoreState for NeverSucceeds {
155        type Params = ();
156        type Permit = ();
157
158        fn acquire(&mut self, _params: Self::Params) -> Result<Self::Permit, Self::Params> {
159            Err(())
160        }
161
162        fn release(&mut self, _permit: Self::Permit) {}
163    }
164
165    #[test]
166    fn concurrent_closed() {
167        loom::model(|| {
168            use std::sync::Arc;
169
170            let s = Arc::new(crate::loom::Mutex::<parking_lot::RawMutex, _>::new(
171                SemaphoreQueue::<NeverSucceeds, Closeable>::new(NeverSucceeds),
172            ));
173
174            let s2 = s.clone();
175            let handle = loom::thread::spawn(move || {
176                loom::future::block_on(async move {
177                    SemaphoreQueue::acquire(&s2, (), crate::acquire::FairOrder::Fifo)
178                        .await
179                        .unwrap_err()
180                })
181            });
182
183            s.lock().close();
184
185            handle.join().unwrap();
186        });
187    }
188}