use core::fmt;
use core::mem;
use core::pin::Pin;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::Poll;
use alloc::sync::Arc;
use event_listener::{Event, EventListener};
use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
#[derive(Debug)]
pub struct Semaphore {
count: AtomicUsize,
event: Event,
}
impl Semaphore {
pub const fn new(n: usize) -> Semaphore {
Semaphore {
count: AtomicUsize::new(n),
event: Event::new(),
}
}
pub fn try_acquire(&self) -> Option<SemaphoreGuard<'_>> {
let mut count = self.count.load(Ordering::Acquire);
loop {
if count == 0 {
return None;
}
match self.count.compare_exchange_weak(
count,
count - 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Some(SemaphoreGuard(self)),
Err(c) => count = c,
}
}
}
pub fn acquire(&self) -> Acquire<'_> {
Acquire::_new(AcquireInner {
semaphore: self,
listener: EventListener::new(),
})
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
#[inline]
pub fn acquire_blocking(&self) -> SemaphoreGuard<'_> {
self.acquire().wait()
}
pub fn try_acquire_arc(self: &Arc<Self>) -> Option<SemaphoreGuardArc> {
let mut count = self.count.load(Ordering::Acquire);
loop {
if count == 0 {
return None;
}
match self.count.compare_exchange_weak(
count,
count - 1,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Some(SemaphoreGuardArc(Some(self.clone()))),
Err(c) => count = c,
}
}
}
pub fn acquire_arc(self: &Arc<Self>) -> AcquireArc {
AcquireArc::_new(AcquireArcInner {
semaphore: self.clone(),
listener: EventListener::new(),
})
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
#[inline]
pub fn acquire_arc_blocking(self: &Arc<Self>) -> SemaphoreGuardArc {
self.acquire_arc().wait()
}
pub fn add_permits(&self, n: usize) {
self.count.fetch_add(n, Ordering::AcqRel);
self.event.notify(n);
}
}
easy_wrapper! {
pub struct Acquire<'a>(AcquireInner<'a> => SemaphoreGuard<'a>);
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub(crate) wait();
}
pin_project_lite::pin_project! {
struct AcquireInner<'a> {
semaphore: &'a Semaphore,
#[pin]
listener: EventListener,
}
}
impl fmt::Debug for Acquire<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Acquire { .. }")
}
}
impl<'a> EventListenerFuture for AcquireInner<'a> {
type Output = SemaphoreGuard<'a>;
fn poll_with_strategy<'x, S: Strategy<'x>>(
self: Pin<&mut Self>,
strategy: &mut S,
cx: &mut S::Context,
) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.semaphore.try_acquire() {
Some(guard) => return Poll::Ready(guard),
None => {
if !this.listener.is_listening() {
this.listener.as_mut().listen(&this.semaphore.event);
} else {
ready!(strategy.poll(this.listener.as_mut(), cx));
}
}
}
}
}
}
easy_wrapper! {
pub struct AcquireArc(AcquireArcInner => SemaphoreGuardArc);
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub(crate) wait();
}
pin_project_lite::pin_project! {
struct AcquireArcInner {
semaphore: Arc<Semaphore>,
#[pin]
listener: EventListener,
}
}
impl fmt::Debug for AcquireArc {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("AcquireArc { .. }")
}
}
impl EventListenerFuture for AcquireArcInner {
type Output = SemaphoreGuardArc;
fn poll_with_strategy<'x, S: Strategy<'x>>(
self: Pin<&mut Self>,
strategy: &mut S,
cx: &mut S::Context,
) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.semaphore.try_acquire_arc() {
Some(guard) => return Poll::Ready(guard),
None => {
if !this.listener.is_listening() {
this.listener.as_mut().listen(&this.semaphore.event);
} else {
ready!(strategy.poll(this.listener.as_mut(), cx));
}
}
}
}
}
}
#[clippy::has_significant_drop]
#[derive(Debug)]
pub struct SemaphoreGuard<'a>(&'a Semaphore);
impl SemaphoreGuard<'_> {
#[inline]
pub fn forget(self) {
mem::forget(self);
}
}
impl Drop for SemaphoreGuard<'_> {
fn drop(&mut self) {
self.0.count.fetch_add(1, Ordering::AcqRel);
self.0.event.notify(1);
}
}
#[clippy::has_significant_drop]
#[derive(Debug)]
pub struct SemaphoreGuardArc(Option<Arc<Semaphore>>);
impl SemaphoreGuardArc {
#[inline]
pub fn forget(mut self) {
drop(self.0.take());
mem::forget(self);
}
}
impl Drop for SemaphoreGuardArc {
fn drop(&mut self) {
let opt = self.0.take().unwrap();
opt.count.fetch_add(1, Ordering::AcqRel);
opt.event.notify(1);
}
}