no_std_async/
semaphore.rs

1use core::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll, Waker},
5};
6
7use pin_list::PinList;
8use pin_project::{pin_project, pinned_drop};
9use spin::Mutex;
10
11type PinListTypes = dyn pin_list::Types<
12    Id = pin_list::id::Unchecked,
13    Protected = Waker,
14    Removed = (),
15    Unprotected = usize,
16>;
17
18/// A type that asynchronously distributes "permits."
19///
20/// A permit is a token that allows the holder to perform some action.
21/// The semaphore itself does not dictate this action, but instead
22/// handles only the distribution of permits.
23/// This is useful for cases where you want to limit the number of
24/// concurrent operations, such as in a mutex.
25///
26/// `n` permits can be acquired through [`acquire`](Self::acquire).
27/// They can later be released through [`release`](Self::release).
28///
29/// # Examples
30/// For examples, look at the implementations of [`Mutex`](crate::Mutex) and [`RwLock`](crate::RwLock).
31/// [`Mutex`](crate::Mutex) uses a semaphore with a maximum of 1 permit to allow a single lock at a time.
32/// [`RwLock`](crate::RwLock) uses a semaphore with a maximum of `max_readers` permits to allow any number of readers.
33/// When a `write` call is encountered, it acquires all of the permits, blocking any new readers from locking.
34pub struct Semaphore {
35    inner: Mutex<SemaphoreInner>,
36}
37impl Semaphore {
38    /// Creates a new [`Semaphore`] with the given initial number of permits.
39    pub const fn new(initial_count: usize) -> Self {
40        Self {
41            inner: Mutex::new(SemaphoreInner {
42                count: initial_count,
43                waiters: PinList::new(unsafe { pin_list::id::Unchecked::new() }),
44            }),
45        }
46    }
47
48    /// Acquires `n` permits from the semaphore.
49    /// These permits should be [`release`](Self::release)d later,
50    /// or they will be permanently removed.
51    pub fn acquire(&self, n: usize) -> Acquire<'_> {
52        #[cfg(test)]
53        println!("acquire({})", n);
54        Acquire {
55            semaphore: self,
56            n,
57            node: pin_list::Node::new(),
58        }
59    }
60
61    /// Releases `n` permits back to the semaphore.
62    pub fn release(&self, n: usize) {
63        let mut lock = self.inner.lock();
64        lock.count += n;
65        match lock.waiters.cursor_front_mut().unprotected().copied() {
66            Some(count) if lock.count >= count => {
67                let waker = lock.waiters.cursor_front_mut().remove_current(()).unwrap();
68                drop(lock);
69                waker.wake();
70            }
71            _ => {}
72        }
73    }
74
75    /// Returns the number of remaining permits.
76    pub fn remaining(&self) -> usize {
77        self.inner.lock().count
78    }
79}
80
81struct SemaphoreInner {
82    count: usize,
83    waiters: PinList<PinListTypes>,
84}
85
86/// A future that acquires a permit from a [`Semaphore`].
87/// This future should not be dropped before completion,
88/// otherwise the permit will not be acquired.
89#[must_use]
90#[pin_project(PinnedDrop)]
91pub struct Acquire<'a> {
92    semaphore: &'a Semaphore,
93    n: usize,
94    #[pin]
95    node: pin_list::Node<PinListTypes>,
96}
97impl Future for Acquire<'_> {
98    type Output = ();
99    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
100        let mut projected = self.project();
101
102        let mut lock = projected.semaphore.inner.lock();
103
104        if let Some(node) = projected.node.as_mut().initialized_mut() {
105            if let Err(e) = node.take_removed(&lock.waiters) {
106                // Someone has polled us again, but we haven't been woken yet.
107                // We update the waker, then go back to sleep.
108                *e.protected_mut(&mut lock.waiters)
109                    .unwrap() = cx.waker().clone();
110                return Poll::Pending;
111            }
112        }
113
114        if lock.count >= *projected.n {
115            lock.count -= *projected.n;
116            if lock.count > 0 {
117                // There's still more for others to take, give it to the next task in line.
118                if let Ok(waker) = lock.waiters.cursor_front_mut().remove_current(()) {
119                    drop(lock);
120                    waker.wake();
121                }
122            }
123            return Poll::Ready(());
124        }
125
126        lock.waiters.cursor_back_mut().insert_after(
127            projected.node,
128            cx.waker().clone(),
129            *projected.n,
130        );
131
132        Poll::Pending
133    }
134}
135#[pinned_drop]
136impl PinnedDrop for Acquire<'_> {
137    fn drop(self: Pin<&mut Self>) {
138        let projected = self.project();
139        let node = match projected.node.initialized_mut() {
140            Some(node) => node,
141            None => return, // We're either already done or never started. In either case, we can just return.
142        };
143
144        let mut lock = projected.semaphore.inner.lock();
145
146        match node.reset(&mut lock.waiters) {
147            (pin_list::NodeData::Linked(_waker), _) => {} // We've been cancelled before ever being woken.
148            (pin_list::NodeData::Removed(()), _) => {
149                // Oops, we were already woken! We need to wake the next task in line.
150                if let Ok(waker) = lock.waiters.cursor_front_mut().remove_current(()) {
151                    drop(lock);
152                    waker.wake();
153                }
154            }
155        }
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::thread;
162
163    use super::*;
164
165    #[test]
166    fn semaphore() {
167        static SEMAPHORE: Semaphore = Semaphore::new(10);
168
169        let take_10 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(10))); // should complete instantly
170        thread::sleep(std::time::Duration::from_millis(10));
171        assert!(take_10.is_finished());
172
173        let take_1 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(1)));
174        thread::sleep(std::time::Duration::from_millis(10));
175        let take_30 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(30)));
176        thread::sleep(std::time::Duration::from_millis(10));
177        let take_5 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(5)));
178        thread::sleep(std::time::Duration::from_millis(10));
179
180        SEMAPHORE.release(30);
181        thread::sleep(std::time::Duration::from_millis(10));
182        assert!(take_1.is_finished());
183        assert!(!take_30.is_finished()); // we only have 29 now
184        assert!(!take_5.is_finished()); // take_30 waits at the start of the line and doesn't notify 5
185
186        SEMAPHORE.release(6);
187        thread::sleep(std::time::Duration::from_millis(10));
188        assert!(take_30.is_finished());
189        assert!(take_5.is_finished());
190    }
191}