use ringbuffer::{AllocRingBuffer, RingBuffer};
use tokio::sync::{Mutex, Notify};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering::*};
use crate::alo::notify::Notify as NotifyBlocking;
const MAX_READS: usize = 128;
const MAX_ACQUIRES: usize = 16;
const ACQUIRE_PADDING_BYTES: usize = 0;
#[derive(Debug)]
pub enum AcquireError {
ValueNone,
Unavailable,
Closed,
}
enum AcquireType {
Read,
Write,
}
struct Acquire {
at: AcquireType,
ac: usize,
an: Arc<Notify>,
_p: [u8; ACQUIRE_PADDING_BYTES],
}
pub(super) type AcquireResult<T = ()> = Result<T, AcquireError>;
type AcquireList = Mutex<AllocRingBuffer<Acquire>>;
pub(super) struct RwSemaphore {
rc: AtomicUsize,
al: AcquireList,
an: Arc<Notify>,
sc: AtomicBool,
or: Arc<NotifyBlocking>,
}
impl Acquire {
fn new(at: AcquireType) -> Self {
Self {
at,
ac: 1,
an: Notify::new().into(),
_p: [0; ACQUIRE_PADDING_BYTES],
}
}
#[inline(always)]
fn is_read(&self) -> bool {
matches!(self.at, AcquireType::Read)
}
#[allow(dead_code)]
#[inline(always)]
fn is_write(&self) -> bool {
matches!(self.at, AcquireType::Write)
}
}
impl RwSemaphore {
pub(super) fn new() -> Self {
Self {
rc: AtomicUsize::new(0),
al: Mutex::new(AllocRingBuffer::new(MAX_ACQUIRES)),
an: Notify::new().into(),
sc: AtomicBool::new(false),
or: NotifyBlocking::new(),
}
}
fn poll(&self, rc: usize) {
if rc > 0 && rc < MAX_READS {
#[cfg(feature = "tokio")]
let mut lock = self.al.blocking_lock();
#[cfg(not(feature = "tokio"))]
let mut lock = self.al.lock().unwrap();
let acquire = lock.front_mut().unwrap();
let permit_count = (MAX_READS - rc).min(acquire.ac);
acquire.ac -= permit_count;
self.rc.fetch_add(permit_count, Release);
for _ in 0..permit_count {
acquire.an.notify_one();
}
} else if rc == 0 {
#[cfg(feature = "tokio")]
let mut lock = self.al.blocking_lock();
#[cfg(not(feature = "tokio"))]
let mut lock = self.al.lock().unwrap();
if let Some(acquire) = lock.front_mut() {
if acquire.ac > 0 {
acquire.ac -= 1;
acquire.an.notify_one();
if acquire.is_read() {
self.rc.fetch_add(1, Release);
self.poll(1);
return;
}
} else {
let _ = lock.dequeue().unwrap();
self.an.notify_waiters();
}
self.poll(0);
}
}
}
pub(super) async fn acquire_read(&self) -> bool {
if self.sc.load(Acquire) {
return false;
}
let not = {
let mut lock = self.al.lock().await;
if lock.is_full() {
return false;
}
self.rc.fetch_add(1, Release);
if let Some(acquire) = lock.back_mut() {
if acquire.is_read() {
acquire.ac += 1;
acquire.an.clone()
} else {
lock.push(Acquire::new(AcquireType::Read));
lock.back().unwrap().an.clone()
}
} else {
lock.push(Acquire::new(AcquireType::Read));
lock.back().unwrap().an.clone()
}
};
self.poll(self.rc.load(Acquire));
not.notified().await;
!self.sc.load(Acquire)
}
pub(super) async fn acquire_write(&self) -> bool {
if self.sc.load(Acquire) {
return false;
}
let not = {
let mut lock = self.al.lock().await;
if lock.is_full() {
return false;
}
if let Some(acquire) = lock.back_mut() {
if acquire.is_read() {
lock.push(Acquire::new(AcquireType::Write));
lock.back().unwrap().an.clone()
} else {
acquire.ac += 1;
acquire.an.clone()
}
} else {
lock.push(Acquire::new(AcquireType::Write));
lock.back().unwrap().an.clone()
}
};
self.poll(self.rc.load(Acquire));
not.notified().await;
!self.sc.load(Acquire)
}
pub(super) async fn acquire_read_wait(&self) -> bool {
while !self.acquire_read().await {
if self.sc.load(Acquire) {
return false;
}
self.an.notified().await;
}
true
}
pub(super) async fn acquire_write_wait(&self) -> bool {
while !self.acquire_write().await {
if self.sc.load(Acquire) {
return false;
}
self.an.notified().await;
}
true
}
pub(super) fn release_read(&self) {
if !self.sc.load(Acquire) {
self.poll(self.rc.fetch_sub(1, Release) - 1);
}
}
pub(super) fn release_write(&self) {
if !self.sc.load(Acquire) {
self.poll(0);
}
}
#[allow(dead_code)]
pub(super) fn evict(&self) {
self.close();
while self.rc.load(Acquire) != 0 {
self.or.notified();
}
self.sc.fetch_not(Release);
}
pub(super) fn close(&self) {
self.sc.fetch_not(Release);
let lock = self.al.blocking_lock();
self.an.notify_waiters();
for acquire in lock.iter() {
acquire.an.notify_waiters();
}
}
}