use crossbeam_queue::SegQueue;
use crossbeam_utils::Backoff;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
pub struct Semaphore {
permits: AtomicUsize,
waiters: SegQueue<Waker>,
}
impl Semaphore {
pub const fn new(permits: usize) -> Self {
debug_assert!(permits >= 1);
Self {
permits: AtomicUsize::new(permits),
waiters: SegQueue::new(),
}
}
pub async fn acquire(&self) -> SemaphorePermit<'_> {
Acquire::new(self).await
}
pub fn available_permits(&self) -> usize {
self.permits.load(Ordering::Acquire)
}
pub fn try_acquire(&self) -> Option<SemaphorePermit> {
let backoff = Backoff::new();
let mut permits = self.permits.load(Ordering::Relaxed);
loop {
if permits == 0 {
return None;
}
match self.permits.compare_exchange_weak(
permits,
permits - 1,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => return Some(SemaphorePermit::new(self)),
Err(changed) => permits = changed,
}
backoff.spin();
}
}
}
pub struct SemaphorePermit<'a> {
semaphore: &'a Semaphore,
}
impl Drop for SemaphorePermit<'_> {
fn drop(&mut self) {
self.semaphore.permits.fetch_add(1, Ordering::Release);
if let Some(waker) = self.semaphore.waiters.pop() {
waker.wake();
}
}
}
impl<'a> SemaphorePermit<'a> {
const fn new(semaphore: &'a Semaphore) -> Self {
Self { semaphore }
}
}
struct Acquire<'a> {
semaphore: &'a Semaphore,
waiting: AtomicBool,
}
impl<'a> Future for Acquire<'a> {
type Output = SemaphorePermit<'a>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.semaphore.try_acquire() {
Some(permit) => Poll::Ready(permit),
None => {
if self
.waiting
.compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
.is_ok()
{
self.semaphore.waiters.push(cx.waker().clone());
}
Poll::Pending
}
}
}
}
impl<'a> Acquire<'a> {
const fn new(semaphore: &'a Semaphore) -> Self {
Self {
semaphore,
waiting: AtomicBool::new(false),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::time::Duration;
#[tokio::test]
async fn test_abort_acquire() {
let sem = Arc::new(Semaphore::new(1));
let permit = sem.try_acquire().unwrap();
let a = {
let sem = sem.clone();
tokio::spawn(tokio::time::timeout(Duration::from_millis(1), async move {
let _ = sem.acquire().await;
}))
};
tokio::time::sleep(Duration::from_millis(1)).await;
let b = {
let sem = sem.clone();
tokio::spawn(tokio::time::timeout(Duration::from_millis(2), async move {
let _ = sem.acquire().await;
}))
};
tokio::time::sleep(Duration::from_millis(1)).await;
drop(permit);
assert!(a.await.unwrap().is_err());
assert!(b.await.unwrap().is_ok());
assert!(sem.waiters.is_empty());
}
}