use std::cell::Cell;
use std::mem::ManuallyDrop;
use crate::wait_list::WaitList;
#[derive(Debug)]
pub struct Semaphore {
waiters: WaitList<usize, WakeUp>,
permits: Cell<usize>,
total_permits: Cell<usize>,
}
#[derive(Debug, Clone, Copy)]
enum WakeUp {
Fair,
Unfair,
}
impl Semaphore {
#[must_use]
pub const fn new(permits: usize) -> Self {
Self {
waiters: WaitList::new(),
permits: Cell::new(permits),
total_permits: Cell::new(permits),
}
}
#[must_use]
pub fn available_permits(&self) -> usize {
self.permits.get()
}
#[must_use]
pub fn total_permits(&self) -> usize {
self.total_permits.get()
}
pub fn add_permits(&self, new_permits: usize) {
self.total_permits.set(
self.total_permits
.get()
.checked_add(new_permits)
.expect("number of permits overflowed"),
);
self.release_permits(new_permits, WakeUp::Unfair);
}
pub fn add_permits_fair(&self, new_permits: usize) {
self.total_permits
.set(self.total_permits.get().checked_add(new_permits).unwrap());
self.release_permits(new_permits, WakeUp::Fair);
}
pub fn try_acquire(&self, to_acquire: usize) -> Option<Permit<'_>> {
if !self.waiters.borrow().is_empty() {
return None;
}
self.try_acquire_unfair(to_acquire)
}
pub fn try_acquire_unfair(&self, to_acquire: usize) -> Option<Permit<'_>> {
let new_permits = self.permits.get().checked_sub(to_acquire)?;
self.permits.set(new_permits);
Some(Permit {
semaphore: self,
permits: to_acquire,
})
}
pub async fn acquire(&self, to_acquire: usize) -> Permit<'_> {
loop {
if let Some(guard) = self.try_acquire(to_acquire) {
break guard;
}
match self.waiters.wait(to_acquire).await {
WakeUp::Unfair => continue,
WakeUp::Fair => {
return Permit {
semaphore: self,
permits: to_acquire,
};
}
}
}
}
pub async fn acquire_unfair(&self, to_acquire: usize) -> Permit<'_> {
loop {
if let Some(guard) = self.try_acquire_unfair(to_acquire) {
break guard;
}
match self.waiters.wait(to_acquire).await {
WakeUp::Unfair => continue,
WakeUp::Fair => {
return Permit {
semaphore: self,
permits: to_acquire,
};
}
}
}
}
fn release_permits(&self, permits: usize, fairness: WakeUp) {
let mut permits = self.permits.get() + permits;
self.permits.set(permits);
let mut waiters = self.waiters.borrow();
while let Some(&wanted_permits) = waiters.head_input() {
permits = match permits.checked_sub(wanted_permits) {
Some(new_permits) => new_permits,
None => break,
};
if let WakeUp::Fair = fairness {
self.permits.set(permits);
}
if waiters.wake_one(fairness).is_err() {
unreachable!();
}
}
}
}
#[derive(Debug)]
pub struct Permit<'semaphore> {
semaphore: &'semaphore Semaphore,
permits: usize,
}
impl<'semaphore> Permit<'semaphore> {
#[must_use]
pub fn semaphore(&self) -> &'semaphore Semaphore {
self.semaphore
}
#[must_use]
pub fn permits(&self) -> usize {
self.permits
}
pub fn leak(self) {
let this = ManuallyDrop::new(self);
let reduced_permits = this.semaphore.total_permits.get() - this.permits;
this.semaphore.total_permits.set(reduced_permits);
}
pub fn release_fair(self) {
let this = ManuallyDrop::new(self);
this.semaphore.release_permits(this.permits, WakeUp::Fair);
}
}
impl Drop for Permit<'_> {
fn drop(&mut self) {
self.semaphore()
.release_permits(self.permits(), WakeUp::Unfair);
}
}