use core::{
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
};
use pin_list::PinList;
use pin_project::{pin_project, pinned_drop};
use spin::Mutex;
type PinListTypes = dyn pin_list::Types<
Id = pin_list::id::Unchecked,
Protected = Waker,
Removed = (),
Unprotected = usize,
>;
pub struct Semaphore {
inner: Mutex<SemaphoreInner>,
}
impl Semaphore {
pub const fn new(initial_count: usize) -> Self {
Self {
inner: Mutex::new(SemaphoreInner {
count: initial_count,
waiters: PinList::new(unsafe { pin_list::id::Unchecked::new() }),
}),
}
}
pub fn acquire(&self, n: usize) -> Acquire<'_> {
#[cfg(test)]
println!("acquire({})", n);
Acquire {
semaphore: self,
n,
node: pin_list::Node::new(),
}
}
pub fn release(&self, n: usize) {
let mut lock = self.inner.lock();
lock.count += n;
match lock.waiters.cursor_front_mut().unprotected().copied() {
Some(count) if lock.count >= count => {
let waker = lock.waiters.cursor_front_mut().remove_current(()).unwrap();
drop(lock);
waker.wake();
}
_ => {}
}
}
pub fn remaining(&self) -> usize {
self.inner.lock().count
}
}
struct SemaphoreInner {
count: usize,
waiters: PinList<PinListTypes>,
}
#[must_use]
#[pin_project(PinnedDrop)]
pub struct Acquire<'a> {
semaphore: &'a Semaphore,
n: usize,
#[pin]
node: pin_list::Node<PinListTypes>,
}
impl Future for Acquire<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let mut projected = self.project();
let mut lock = projected.semaphore.inner.lock();
if let Some(node) = projected.node.as_mut().initialized_mut() {
if let Err(e) = node.take_removed(&lock.waiters) {
*e.protected_mut(&mut lock.waiters)
.unwrap() = cx.waker().clone();
return Poll::Pending;
}
}
if lock.count >= *projected.n {
lock.count -= *projected.n;
if lock.count > 0 {
if let Ok(waker) = lock.waiters.cursor_front_mut().remove_current(()) {
drop(lock);
waker.wake();
}
}
return Poll::Ready(());
}
lock.waiters.cursor_back_mut().insert_after(
projected.node,
cx.waker().clone(),
*projected.n,
);
Poll::Pending
}
}
#[pinned_drop]
impl PinnedDrop for Acquire<'_> {
fn drop(self: Pin<&mut Self>) {
let projected = self.project();
let node = match projected.node.initialized_mut() {
Some(node) => node,
None => return, };
let mut lock = projected.semaphore.inner.lock();
match node.reset(&mut lock.waiters) {
(pin_list::NodeData::Linked(_waker), _) => {} (pin_list::NodeData::Removed(()), _) => {
if let Ok(waker) = lock.waiters.cursor_front_mut().remove_current(()) {
drop(lock);
waker.wake();
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::thread;
use super::*;
#[test]
fn semaphore() {
static SEMAPHORE: Semaphore = Semaphore::new(10);
let take_10 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(10))); thread::sleep(std::time::Duration::from_millis(10));
assert!(take_10.is_finished());
let take_1 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(1)));
thread::sleep(std::time::Duration::from_millis(10));
let take_30 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(30)));
thread::sleep(std::time::Duration::from_millis(10));
let take_5 = thread::spawn(|| pollster::block_on(SEMAPHORE.acquire(5)));
thread::sleep(std::time::Duration::from_millis(10));
SEMAPHORE.release(30);
thread::sleep(std::time::Duration::from_millis(10));
assert!(take_1.is_finished());
assert!(!take_30.is_finished()); assert!(!take_5.is_finished());
SEMAPHORE.release(6);
thread::sleep(std::time::Duration::from_millis(10));
assert!(take_30.is_finished());
assert!(take_5.is_finished());
}
}