no_std_async/
semaphore.rs1use core::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll, Waker},
5};
6
7use pin_list::PinList;
8use pin_project::{pin_project, pinned_drop};
9use spin::Mutex;
10
11type PinListTypes = dyn pin_list::Types<
12 Id = pin_list::id::Unchecked,
13 Protected = Waker,
14 Removed = (),
15 Unprotected = usize,
16>;
17
18pub struct Semaphore {
35 inner: Mutex<SemaphoreInner>,
36}
37impl Semaphore {
38 pub const fn new(initial_count: usize) -> Self {
40 Self {
41 inner: Mutex::new(SemaphoreInner {
42 count: initial_count,
43 waiters: PinList::new(unsafe { pin_list::id::Unchecked::new() }),
44 }),
45 }
46 }
47
48 pub fn acquire(&self, n: usize) -> Acquire<'_> {
52 #[cfg(test)]
53 println!("acquire({})", n);
54 Acquire {
55 semaphore: self,
56 n,
57 node: pin_list::Node::new(),
58 }
59 }
60
61 pub fn release(&self, n: usize) {
63 let mut lock = self.inner.lock();
64 lock.count += n;
65 match lock.waiters.cursor_front_mut().unprotected().copied() {
66 Some(count) if lock.count >= count => {
67 let waker = lock.waiters.cursor_front_mut().remove_current(()).unwrap();
68 drop(lock);
69 waker.wake();
70 }
71 _ => {}
72 }
73 }
74
75 pub fn remaining(&self) -> usize {
77 self.inner.lock().count
78 }
79}
80
81struct SemaphoreInner {
82 count: usize,
83 waiters: PinList<PinListTypes>,
84}
85
86#[must_use]
90#[pin_project(PinnedDrop)]
91pub struct Acquire<'a> {
92 semaphore: &'a Semaphore,
93 n: usize,
94 #[pin]
95 node: pin_list::Node<PinListTypes>,
96}
97impl Future for Acquire<'_> {
98 type Output = ();
99 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
100 let mut projected = self.project();
101
102 let mut lock = projected.semaphore.inner.lock();
103
104 if let Some(node) = projected.node.as_mut().initialized_mut() {
105 if let Err(e) = node.take_removed(&lock.waiters) {
106 *e.protected_mut(&mut lock.waiters)
109 .unwrap() = cx.waker().clone();
110 return Poll::Pending;
111 }
112 }
113
114 if lock.count >= *projected.n {
115 lock.count -= *projected.n;
116 if lock.count > 0 {
117 if let Ok(waker) = lock.waiters.cursor_front_mut().remove_current(()) {
119 drop(lock);
120 waker.wake();
121 }
122 }
123 return Poll::Ready(());
124 }
125
126 lock.waiters.cursor_back_mut().insert_after(
127 projected.node,
128 cx.waker().clone(),
129 *projected.n,
130 );
131
132 Poll::Pending
133 }
134}
135#[pinned_drop]
136impl PinnedDrop for Acquire<'_> {
137 fn drop(self: Pin<&mut Self>) {
138 let projected = self.project();
139 let node = match projected.node.initialized_mut() {
140 Some(node) => node,
141 None => return, };
143
144 let mut lock = projected.semaphore.inner.lock();
145
146 match node.reset(&mut lock.waiters) {
147 (pin_list::NodeData::Linked(_waker), _) => {} (pin_list::NodeData::Removed(()), _) => {
149 if let Ok(waker) = lock.waiters.cursor_front_mut().remove_current(()) {
151 drop(lock);
152 waker.wake();
153 }
154 }
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use std::thread;
162
163 use super::*;
164
165 #[test]
166 fn semaphore() {
167 static SEMAPHORE: Semaphore = Semaphore::new(10);
168
169 let take_10 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(10))); thread::sleep(std::time::Duration::from_millis(10));
171 assert!(take_10.is_finished());
172
173 let take_1 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(1)));
174 thread::sleep(std::time::Duration::from_millis(10));
175 let take_30 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(30)));
176 thread::sleep(std::time::Duration::from_millis(10));
177 let take_5 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(5)));
178 thread::sleep(std::time::Duration::from_millis(10));
179
180 SEMAPHORE.release(30);
181 thread::sleep(std::time::Duration::from_millis(10));
182 assert!(take_1.is_finished());
183 assert!(!take_30.is_finished()); assert!(!take_5.is_finished()); SEMAPHORE.release(6);
187 thread::sleep(std::time::Duration::from_millis(10));
188 assert!(take_30.is_finished());
189 assert!(take_5.is_finished());
190 }
191}