use ringbuffer::{AllocRingBuffer, RingBuffer};
use std::sync::atomic::{AtomicBool, Ordering::*};
use std::sync::{Arc, Mutex, MutexGuard};
use crate::alo::notify::Notify;
const MAX_READS: usize = 128;
const MAX_ACQUIRES: usize = 16;
const ACQUIRE_PADDING_BYTES: usize = 0;
#[derive(Debug)]
pub enum AcquireError {
ValueNone,
ReadUnavailable,
WriteUnavailable,
Closed,
}
#[derive(Debug)]
enum AcquireType {
Read,
Write,
}
#[derive(Debug)]
struct Acquire {
at: AcquireType,
ac: usize,
an: Arc<Notify>,
_p: [u8; ACQUIRE_PADDING_BYTES],
}
pub(super) type AcquireResult<T = ()> = Result<T, AcquireError>;
type AcquireList = Mutex<AllocRingBuffer<Acquire>>;
pub(super) struct RwSemaphore {
rc: Mutex<usize>,
al: AcquireList,
an: Arc<Notify>,
sc: AtomicBool,
or: Arc<Notify>,
}
impl Acquire {
fn new(at: AcquireType) -> Self {
Self {
at,
ac: 1,
an: Notify::new(),
_p: [0; ACQUIRE_PADDING_BYTES],
}
}
#[inline(always)]
fn is_read(&self) -> bool {
matches!(self.at, AcquireType::Read)
}
#[allow(dead_code)]
#[inline(always)]
fn is_write(&self) -> bool {
matches!(self.at, AcquireType::Write)
}
}
impl RwSemaphore {
pub(super) fn new() -> Self {
Self {
rc: Mutex::new(0),
al: Mutex::new(AllocRingBuffer::new(MAX_ACQUIRES)),
an: Notify::new(),
sc: AtomicBool::new(false),
or: Notify::new(),
}
}
fn poll_inner(&self, rc: &mut MutexGuard<usize>) -> Option<usize> {
if **rc > 0 && **rc < MAX_READS {
let mut lock = self.al.lock().unwrap();
let acquire = lock.front_mut().unwrap();
let permit_count = (MAX_READS - **rc).min(acquire.ac);
acquire.ac -= permit_count;
**rc += permit_count;
for _ in 0..permit_count {
acquire.an.notify_one();
}
} else if **rc == 0 {
let mut lock = self.al.lock().unwrap();
if let Some(acquire) = lock.front_mut() {
if acquire.ac > 0 {
acquire.ac -= 1;
if acquire.is_read() {
**rc += 1;
acquire.an.notify_one();
} else {
acquire.an.notify_one();
}
} else {
let _ = lock.dequeue().unwrap();
self.an.notify_waiters();
}
}
}
None
}
fn poll(&self) {
let mut rc = self.rc.lock().unwrap();
if self.poll_inner(&mut rc).is_some() {
let _ = self.poll_inner(&mut rc);
}
}
#[inline(always)]
fn is_closed(&self) -> bool {
self.sc.load(Acquire)
}
#[inline]
fn is_writing(&self) -> bool {
self.al.lock().is_ok_and(|mtx| {
mtx.front().is_some_and(|ac| ac.is_write())
})
}
#[inline(always)]
fn is_reading(&self) -> bool {
*self.rc.lock().unwrap() > 0
}
pub(super) fn acquire_read(&self) -> bool {
if self.is_closed() || self.is_writing() {
return false;
}
let not = {
let lock = self.al.lock();
if lock.is_err() {
return false;
}
let mut lock = lock.unwrap();
let is_full = lock.is_full();
if let Some(acquire) = lock.back_mut() {
if acquire.is_read() && !is_full {
acquire.ac += 1;
acquire.an.clone()
} else if acquire.is_write() {
lock.push(Acquire::new(AcquireType::Read));
lock.back().unwrap().an.clone()
} else {
return false;
}
} else {
lock.push(Acquire::new(AcquireType::Read));
lock.back().unwrap().an.clone()
}
};
self.poll();
if not.notified().is_err() {
for ac in self.al.lock().unwrap().iter() {
println!("=============");
println!("{ac:?}");
}
println!("=============");
panic!("notify timed out",)
}
!self.is_closed()
}
pub(super) fn acquire_write(&self) -> bool {
if self.is_closed() || self.is_reading() {
return false;
}
let not = {
let lock = self.al.lock();
if lock.is_err() {
return false;
}
let mut lock = lock.unwrap();
let is_full = lock.is_full();
if let Some(acquire) = lock.back_mut() {
if acquire.is_read() && !is_full {
lock.push(Acquire::new(AcquireType::Write));
lock.back().unwrap().an.clone()
} else if acquire.is_write() {
acquire.ac += 1;
acquire.an.clone()
} else {
return false;
}
} else {
lock.push(Acquire::new(AcquireType::Write));
lock.back().unwrap().an.clone()
}
};
self.poll();
not.notified().expect("Notify timed out");
!self.is_closed()
}
pub(super) fn acquire_read_wait(&self) -> bool {
while !self.acquire_read() {
if self.is_closed() {
return false;
}
self.an.notified().unwrap();
}
true
}
pub(super) fn acquire_write_wait(&self) -> bool {
while !self.acquire_write() {
if self.is_closed() {
return false;
}
self.an.notified().unwrap();
}
true
}
pub(super) fn release_read(&self) {
if !self.is_closed() {
*self.rc.lock().unwrap() -= 1;
self.poll();
}
}
pub(super) fn release_write(&self) {
if !self.is_closed() {
self.poll();
}
}
#[allow(dead_code)]
pub(super) fn evict(&self) {
self.close();
while *self.rc.lock().unwrap() != 0 {
self.or.notified().unwrap();
}
self.sc.fetch_not(Release);
}
pub(super) fn close(&self) {
self.sc.fetch_not(Release);
let lock = self.al.lock().unwrap();
self.an.notify_waiters();
for acquire in lock.iter() {
acquire.an.notify_waiters();
}
}
}